In [None]:
# L 4-23-25
# notebooks/3.2_CNN_Spectrogram_Classifier.ipynb

In [None]:
# notebooks/3.2_CNN_Spectrogram_Classifier.ipynb

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import numpy as np

# Parameters
BATCH_SIZE = 32
IMG_SIZE = 128
EPOCHS = 5
MODEL_SAVE_PATH = "../models/cnn_model.pth"
INFERENCE_BUNDLE_PATH = "../models/cnn_inference_bundle.pth"

# Determine source of data
RUN_MODE = globals().get("RUN_MODE", "train")
if RUN_MODE == "songs":
    DATA_DIR = "../reports/4_Classify_New_Song/spectrograms/"
    REPORT_DIR = "../reports/4_Classify_New_Song/CNN/"
else:
    DATA_DIR = "../spectrograms"
    REPORT_DIR = "../reports/3_CNN_Spectrogram_Classifier"

os.makedirs(REPORT_DIR, exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Loading spectrograms from: {DATA_DIR}")

# Data transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataset and loader
dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
class_names = dataset.classes
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define model


class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 16 * 16, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# Initialize model
model = SimpleCNN(num_classes=len(class_names)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train
train_losses = []
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")

# Save model
torch.save(model.state_dict(), MODEL_SAVE_PATH)
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'transform': transform,
    'img_size': IMG_SIZE,
}, INFERENCE_BUNDLE_PATH)

# Evaluation
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Report
report = classification_report(y_true, y_pred, target_names=class_names)
with open(os.path.join(REPORT_DIR, "cnn_classification_report.txt"), "w") as f:
    f.write(report)
print(report)

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
plt.imshow(cm, cmap="Blues", interpolation="nearest")
plt.title("Confusion Matrix - CNN")
plt.colorbar()
plt.xticks(np.arange(len(class_names)), class_names, rotation=90)
plt.yticks(np.arange(len(class_names)), class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.savefig(os.path.join(REPORT_DIR, "cnn_confusion_matrix.png"))
plt.close()

# Per-genre metrics
precision, recall, f1, support = precision_recall_fscore_support(
    y_true, y_pred, labels=range(len(class_names)), zero_division=0)

metrics = {"Precision": precision, "Recall": recall, "F1-Score": f1}
for metric_name, values in metrics.items():
    plt.figure(figsize=(10, 5))
    plt.bar(class_names, values)
    plt.title(f"{metric_name} per Genre - CNN")
    plt.ylabel(metric_name)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    fname = f"cnn_{metric_name.lower().replace('-', '_')}_bar.png"
    plt.savefig(os.path.join(REPORT_DIR, fname))
    plt.close()

Using device: cpu
Training mode
Epoch 1/5, Loss: 7.2071
Epoch 2/5, Loss: 6.3650
Epoch 3/5, Loss: 6.1387
Epoch 4/5, Loss: 5.8828
Epoch 5/5, Loss: 5.5996
Model and inference bundle saved.
               precision    recall  f1-score   support

   Electronic       0.54      0.45      0.49       192
 Experimental       0.39      0.26      0.31       178
         Folk       0.48      0.49      0.48       216
      Hip-Hop       0.42      0.77      0.54       188
 Instrumental       0.43      0.55      0.48       205
International       0.47      0.48      0.47       201
          Pop       0.40      0.15      0.21       225
         Rock       0.59      0.60      0.60       194

     accuracy                           0.46      1599
    macro avg       0.46      0.47      0.45      1599
 weighted avg       0.46      0.46      0.45      1599

