### 데이터 로드 및 전처리 클래스

In [166]:
import os as os
import torch
import torchaudio
import pandas as pd
import torch.utils.data as data

class CustomDataset(data.Dataset):
    def __init__(self, csv_file, root_dir, n_mels=32, n_fft=512):
        self.csv_file = csv_file
        self.root_dir = root_dir
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.train_data = []
        self.labels = []
        self.max_length = 0
        self.load_data()

    def load_data(self):
        data = pd.read_csv(self.csv_file)

        for _, row in data.iterrows():
            file_path = os.path.join(self.root_dir, 'train_wav', f"{row['id']}.wav")
            if os.path.exists(file_path):
                waveform, sample_rate = torchaudio.load(file_path)
                transform = torchaudio.transforms.MelSpectrogram(
                    sample_rate=sample_rate,
                    n_mels=self.n_mels,
                    n_fft=self.n_fft
                )
                mels = transform(waveform)
                mels = torchaudio.transforms.AmplitudeToDB()(mels)
                # print(mels)
                self.max_length = max(self.max_length, mels.shape[2])
                self.train_data.append(mels)
                label = 1 if row['label'] == 'fake' else 0
                self.labels.append(label)
            else:
                print(f"File not found: {file_path}")
        
        print(f"Loaded {len(self.train_data)} samples")
        print(f"Maximum length found: {self.max_length}")
    
    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        def pad(mels, cut_length=64600):
             current_width = mels.shape[2]
             if current_width >= cut_length:
                 padded_mel = mels[:, :, :cut_length]
             else:
                 pad_amount = cut_length - current_width
                 padded_mel = torch.nn.functional.pad(mels, (0, pad_amount), 'constant', 0)
                 # num_repeats = int(cut_length / current_width) + 1
                 # padded_mel = torch.tile(mels, (1, num_repeats))[:, :cut_length]
             # padded_mel = padded_mel.unsqueeze(0)

             return padded_mel
        
        mels = self.train_data[idx]
        mels_padded = pad(mels, cut_length=4836)
        # mels_resized = resize_tensor(mels, self.max_length)

        sample = {'mel': mels_padded, 'label': self.labels[idx]}
        return sample


In [167]:
# 데이터셋 경로와 CSV 파일 경로 설정
DATASET_PATH = "/home/oem/sw/data"
train_csv = os.path.join(DATASET_PATH, 'train.csv')

### Data Load

In [168]:
import torch.optim as optim
from tqdm import tqdm

# 하이퍼파라미터 설정
batch_size = 16
hidden_dim = 128
lr = 5e-4
temperature = 0.07
weight_decay = 1e-4
max_epochs = 10

In [169]:
from torch.utils.data import DataLoader

# 데이터셋 생성
dataset = CustomDataset(csv_file=train_csv, root_dir=DATASET_PATH)

# 패딩된 데이터셋 생성
# padded_dataset = PadDataset(dataset, cut_length=64600)

# 데이터로더 생성
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


Loaded 55438 samples
Maximum length found: 4836


In [61]:
# # 데이터 확인
# data_iter = iter(dataloader)
# data_sample = next(data_iter)

# mels = data_sample['mel']
# labels = data_sample['label']

# print("Shape of mels:", mels.shape)
# print("Min value of mels:", mels.min().item())
# print("Max value of mels:", mels.max().item())
# print("Mean value of mels:", mels.mean().item())

### Mel-spectrogram visualization

In [170]:
import torch

device = torch.device("cuda:1")
print("Device: ", device)

Device:  cuda:1


### Modeling

In [230]:
import torchvision
import torchvision.models
from torchvision.models import ResNet18_Weights
import torch.nn as nn

# SimCLR 모델 정의

class SimCLR(nn.Module):
    def __init__(self, hidden_dim):
        super(SimCLR, self).__init__()

        self.convnet = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.convnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
        self.convnet.fc = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 2)
        )

            
    def forward(self, x):
        return self.convnet(x)

### Loss Function

In [231]:
import torch.nn.functional as F

def info_nce_loss(feats, temperature=0.07):
    cos_sim = F.cosine_similarity(feats[:, None, :], feats[None, :, :], dim=-1)
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    cos_sim.masked_fill_(self_mask, -1e5)  # 대각선 부분을 큰 음수로 채움
    pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)
    cos_sim = cos_sim / temperature
    # max_cos_sim = cos_sim.max(dim=-1, keepdim=True).values
    # cos_sim = cos_sim - max_cos_sim.detach()

    # 디버깅 코드 추가
    # print(f"cos_sim min: {cos_sim.min().item()}, max: {cos_sim.max().item()}")

    nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()

    # 디버깅 코드 추가
    # print(f"nll: {nll.item()}")

    return nll

### Training

In [232]:
# 모델 및 기타 구성 요소 설정
model = SimCLR(hidden_dim)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [233]:
# 모든 파라미터의 requires_grad 확인
for param in model.parameters():
    param.requires_grad = True

: 

In [220]:
num_epochs = 10  # 예시
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, data in enumerate(tqdm(train_loader)):
        inputs, labels = data['mel'], data['label']

        # print(f"Batch {batch_idx}: Inputs shape: {inputs.shape}, Labels shape: {labels.shape}")
        
        # 디버깅 코드 추가
        # print(f"Batch {batch_idx}:")
        # print(f"  Inputs shape: {inputs.shape}, min: {inputs.min()}, max: {inputs.max()}")
        # print(f"  Labels shape: {labels.shape}")
        
        # 입력 데이터에 requires_grad=True 설정
        inputs.requires_grad_()

        # 옵티마이저 초기화
        optimizer.zero_grad()

        # 모델 예측
        outputs = model(inputs)

        # 디버깅 코드 추가
        # print(f"  Outputs shape: {outputs.shape}, min: {outputs.min()}, max: {outputs.max()}")
        # print(f"  Outputs requires_grad: {outputs.requires_grad}")
        # print(f"  Outputs dtype: {outputs.dtype}")

        # 손실 계산
        loss = info_nce_loss(outputs, temperature)

        # 디버깅 코드 추가
        # print(f"  Loss: {loss.item()}")

        loss.requires_grad_(True)
        # 역전파
        loss.backward()
        
        # 옵티마이저 스텝
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader)}")


100%|██████████| 3465/3465 [17:36<00:00,  3.28it/s]


Epoch [1/10], Loss: 3.2281110312859562


100%|██████████| 3465/3465 [17:24<00:00,  3.32it/s]


Epoch [2/10], Loss: 3.228170268016116


  1%|          | 21/3465 [00:06<19:07,  3.00it/s]


KeyboardInterrupt: 

### Model save

In [221]:
model_save_path = "/home/oem/sw/save_model/SimCLR_pad.pth"

torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to /home/oem/sw/save_model/SimCLR_pad.pth


### Evaluation

In [222]:
model_load_path = "/home/oem/sw/save_model/SimCLR_pad.pth"
model = SimCLR(hidden_dim)
model.load_state_dict(torch.load(model_load_path))
model.eval()

print(f"Model loaded from {model_load_path}")

Model loaded from /home/oem/sw/save_model/SimCLR_pad.pth


In [223]:
class TestDataset(Dataset):
    def __init__(self, csv_file, root_dir, n_mels=32, n_fft=512, cache_dir='./cache', is_test=False):
        self.csv_file = csv_file
        self.root_dir = root_dir
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.cache_dir = cache_dir
        self.is_test = is_test
        self.data = []
        self.labels = []
        self.ids = []
        self.max_length = 0
        self.load_data()

    def load_data(self):
        data = pd.read_csv(self.csv_file)
        
        os.makedirs(self.cache_dir, exist_ok=True)

        for _, row in tqdm(data.iterrows(), total=len(data)):
            file_path = os.path.join(self.root_dir, 'test_wav', f"{row['id']}.wav")
            cache_path = os.path.join(self.cache_dir, f"{row['id']}.pt")
            
            if os.path.exists(cache_path):
                mels = torch.load(cache_path)
            else:
                if os.path.exists(file_path):
                    waveform, sample_rate = torchaudio.load(file_path)
                    transform = torchaudio.transforms.MelSpectrogram(
                        sample_rate=sample_rate,
                        n_mels=self.n_mels,
                        n_fft=self.n_fft
                    )
                    mels = transform(waveform)
                    mels = torchaudio.transforms.AmplitudeToDB()(mels)
                    torch.save(mels, cache_path)
                else:
                    print(f"File not found: {file_path}")
                    continue

            self.max_length = max(self.max_length, mels.shape[2])
            self.data.append(mels)
            self.ids.append(row['id'])
           
        print(f"Loaded {len(self.data)} samples")
        print(f"Maximum length found: {self.max_length}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if idx >= len(self.data):
            raise IndexError("Index out of range")

        mels = self.data[idx]
        mels_padded = self.pad(mels, cut_length=self.max_length)
        
        sample = {'mel': mels_padded, 'id': self.ids[idx]}
        if not self.is_test:
            sample['label'] = self.labels[idx]
        return sample

    def pad(self, mels, cut_length):
        current_width = mels.shape[2]
        if current_width >= cut_length:
            padded_mel = mels[:, :, :cut_length]
        else:
            pad_amount = cut_length - current_width
            padded_mel = torch.nn.functional.pad(mels, (0, pad_amount), 'constant', 0)
        return padded_mel


In [224]:
DATASET_PATH = "/home/oem/sw/data"
test_csv = os.path.join(DATASET_PATH, 'test.csv')

test_dataset = TestDataset(csv_file=test_csv, root_dir=DATASET_PATH, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

  0%|          | 0/50000 [00:00<?, ?it/s]

100%|██████████| 50000/50000 [00:13<00:00, 3791.42it/s]


Loaded 50000 samples
Maximum length found: 629


In [226]:
predictions = []

# 예측 수행
with torch.no_grad():
    for _, data in enumerate(tqdm(test_loader)):
        inputs = data['mel']
        ids = data['id']
        
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        
        # print(f"outputs shape: {outputs.shape}, probs shape: {probs.shape}")  # 출력 크기 확인
        
        for i in range(inputs.size(0)):
            pred = {'id': ids[i], 'fake': probs[i, 1].item(), 'real': probs[i, 0].item()}
            predictions.append(pred)


100%|██████████| 3125/3125 [02:28<00:00, 21.10it/s]


In [227]:
df_predictions = pd.DataFrame(predictions)

In [229]:
df_predictions.to_csv('SimCLR_submission.csv', index=False, columns = ['id', 'fake', 'real'])