In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torchaudio
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# MOUNT ON GOOGLE DRIVE
from google.colab import drive
drive.mount('/content/drive' , force_remount = True)

# ---- SEED ----
torch.manual_seed(42)

# ---- PATHS ----
base_path = '/content/drive/MyDrive/the-frequency-quest'
train_dir = os.path.join(base_path, 'train' , 'train')
test_dir = os.path.join(base_path, 'test' , 'test')

# ---- PARAMETERS ----
N_MELS = 64
TARGET_SR = 22050
MAX_FRAMES = 128
BATCH_SIZE = 32
EPOCHS = 15
# DRECTORY TO STORE MEL SPECTROGRAMS FOR EACH AUDIO FILE
CACHE_DIR = os.path.join(base_path, "cache_mel")
os.makedirs(CACHE_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- CLASSES ----
classes = sorted(d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d)))
class_to_index = {cls: idx for idx, cls in enumerate(classes)}

file_paths, labels = [], []
for cls in classes:
    cls_path = os.path.join(train_dir, cls)
    for f in os.listdir(cls_path):
        if f.endswith('.wav'):
            file_paths.append(os.path.join(cls_path, f))
            labels.append(class_to_index[cls])

# SPLIT THE FILES AND CORRESPONDING LABELS TO TRAIN AND TEST DATA
train_files, val_files, train_labels, val_labels = train_test_split(file_paths, labels, test_size=0.2, random_state=42)

# ---- DATASET ----
class AudioDataset(Dataset):
    def __init__(self, files, labels=None, train=True):
        self.files = files
        self.labels = labels
        self.train = train
        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=TARGET_SR, n_fft=1024, hop_length=512, n_mels=N_MELS
        )
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=15)
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=35)

    def _process(self, path):
        cache_path = os.path.join(CACHE_DIR, os.path.basename(path) + ".pt")
        if os.path.exists(cache_path):
            mel = torch.load(cache_path)
        else:
            audio, sr = torchaudio.load(path)
            # RESAMPLE
            if sr != TARGET_SR:
                audio = torchaudio.transforms.Resample(sr, TARGET_SR)(audio)
            # CONVERT STEREO TO MONO(>1 -> 1)
            if audio.shape[0] > 1:
                audio = audio.mean(dim=0, keepdim=True)
            # GET MEL SPECTROGRAM FROM WAVEFORM
            mel = self.mel_spec(audio)
            # TO NORMALIZE THE MEL SPECTROGRAM TENSOR FROM 0-1
            mel = (mel - mel.mean()) / (mel.std() + 1e-9)
            # PADDING
            if mel.shape[2] < MAX_FRAMES:
              pad = MAX_FRAMES - mel.shape[2]
              mel = F.pad(0 , pad)
            else:
              mel = mel[: , : , :MAX_FRAMES]
            torch.save(mel, cache_path)
        return mel

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        mel = self._process(self.files[idx])
        if self.train:
            mel = self.freq_mask(mel)
            mel = self.time_mask(mel)
        mel = mel.repeat(3, 1, 1)
        mel = F.interpolate(mel.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
        if self.labels is not None:
            return mel, self.labels[idx]
        return mel

# ---- DATALOADERS ----
train_ds = AudioDataset(train_files, train_labels, train=True)
val_ds = AudioDataset(val_files, val_labels, train=False)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# ---- MODEL ----
# FEATURE EXTRACTION BY PRETRAINED MODEL - DONE
model = models.resnet34(weights="IMAGENET1K_V1")

for param in model.parameters():
  param.requires_grad = False

# FEATURE CLASSIFICITATION
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, len(classes))
)

for param in model.parameters():
  param.requires_grad = True

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)

# ---- TRAINING LOOP ----
best_acc = 0
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    for xb, yb in tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_dl)

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            preds = model(xb)
            _, pred_labels = preds.max(1)
            correct += (pred_labels == yb).sum().item()
            total += yb.size(0)
    acc = correct / total
    scheduler.step(avg_loss)
    print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Val Acc: {acc:.4f}")

    # Save best model
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_model.pth")

print(f"Best Validation Accuracy: {best_acc:.4f}")

# ---- TEST ----
test_files = [os.path.join(test_dir, f) for f in os.listdir(test_dir) if f.endswith('.wav')]
print(len(test_files))
test_ds = AudioDataset(test_files, train=False)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# LOAD THE BEST MODEL FOR THE TESTING PROCESS
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

predictions = []
with torch.no_grad():
    for xb in tqdm(test_dl, desc="Predicting"):
        xb = xb.to(device)
        preds = model(xb)
        _, pred_labels = preds.max(1)
        predictions.extend(pred_labels.cpu().numpy())

idx_to_class = {v: k for k, v in class_to_index.items()}
predicted_classes = [idx_to_class[i] for i in predictions]

submission = pd.DataFrame({'ID': [os.path.basename(f) for f in test_files], 'Class': predicted_classes})
submission.to_csv('submission.csv', index=False)
print("Submission file created: submission.csv")
