# Model Training

## Imports

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from src.__00__paths import train_dir, test_dir, validation_dir, model_dir

from PIL import Image
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt

## Device Setup

In [12]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


## Dataset Definition

In [13]:
class GenreSpectrogramDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.samples = []
        self.class_to_idx = {genre.name: idx for idx, genre in enumerate(sorted(self.root_dir.iterdir()))}

        for genre in self.class_to_idx:
            for file in (self.root_dir / genre).glob("*.png"):
                self.samples.append((file, self.class_to_idx[genre]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        image = Image.open(image_path).convert('L')

        if self.transform:
            image = self.transform(image)

        return image, label

## Transfroms & DataLoaders

In [14]:
# Declare Transformation to be done in Data setup
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL.Image -> PyTorch tensor
    transforms.Normalize(mean=[0.5], std=[0.5]),  # X_norm = (x - 0.5) / (0.5) = 2x - 1
])

train_data = GenreSpectrogramDataset(train_dir, transform=transform)
validation_data = GenreSpectrogramDataset(validation_dir, transform=transform)

# Load Data
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=32, shuffle=True)

## Channel Attention

In [15]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = y.squeeze(-1).transpose(-1, -2)
        y = self.conv(y)
        y = self.sigmoid(y).transpose(-1, -2).unsqueeze(-1)
        return x * y.expand_as(x)


In [16]:
class Genre_CNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        def block(in_c, out_c, drop):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(),
                ChannelAttention(out_c),
                nn.MaxPool2d(2),
                nn.Dropout(drop)
            )

        self.encoder = nn.Sequential(
            block(1, 32, 0.25),
            block(32, 64, 0.25),
            block(64, 128, 0.3),
            block(128, 256, 0.3),
            block(256, 512, 0.4)
        )

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.encoder(x)
        return self.classifier(x)

## Init Model, Optimizer, Loss

In [17]:
model = Genre_CNN(num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

## Training Function

In [18]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct = 0, 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X.size(0)
        correct += (out.argmax(1) == y).sum().item()

    return total_loss / len(loader.dataset), correct / len(loader.dataset)

## Validation Function

In [19]:
@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct = 0, 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        out = model(X)
        loss = criterion(out, y)

        total_loss += loss.item() * X.size(0)
        correct += (out.argmax(1) == y).sum().item()

    return total_loss / len(loader.dataset), correct / len(loader.dataset)

## Train Loop

In [20]:
epochs = 20

for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_acc = evaluate(model, validation_loader, criterion)

    print(f"Epoch {epoch + 1:02d}: "
          f"Train Loss={train_loss:.4f} Acc={train_acc:.4f} | "
          f"Val Loss={val_loss:.4f} Acc={val_acc:.4f}")

Epoch 01: Train Loss=2.2272 Acc=0.2082 | Val Loss=2.0991 Acc=0.0781
Epoch 02: Train Loss=2.0256 Acc=0.2803 | Val Loss=2.2431 Acc=0.0372
Epoch 03: Train Loss=1.8877 Acc=0.3363 | Val Loss=2.3723 Acc=0.0706
Epoch 04: Train Loss=1.7887 Acc=0.3864 | Val Loss=2.5178 Acc=0.0725
Epoch 05: Train Loss=1.6795 Acc=0.4394 | Val Loss=2.7851 Acc=0.0781
Epoch 06: Train Loss=1.6237 Acc=0.4274 | Val Loss=2.8841 Acc=0.0762
Epoch 07: Train Loss=1.5730 Acc=0.4494 | Val Loss=2.8494 Acc=0.1041
Epoch 08: Train Loss=1.5096 Acc=0.4755 | Val Loss=3.2117 Acc=0.0948
Epoch 09: Train Loss=1.4859 Acc=0.4775 | Val Loss=2.9637 Acc=0.1078
Epoch 10: Train Loss=1.4546 Acc=0.4855 | Val Loss=2.9826 Acc=0.1413
Epoch 11: Train Loss=1.4218 Acc=0.5145 | Val Loss=3.0738 Acc=0.1190
Epoch 12: Train Loss=1.3905 Acc=0.5085 | Val Loss=2.8982 Acc=0.1431
Epoch 13: Train Loss=1.3919 Acc=0.5035 | Val Loss=3.2243 Acc=0.1301
Epoch 14: Train Loss=1.3436 Acc=0.5385 | Val Loss=3.2017 Acc=0.1431
Epoch 15: Train Loss=1.2993 Acc=0.5666 | Val Los

## Save Model

In [25]:
torch.save(model.state_dict(), model_dir / "genre_cnn_model.pth")
print(f"Model saved at {'/'.join((model_dir / "genre_cnn_model.pth"))}.")

Model saved at outputs/models/genre_cnn_model.pth.
