In [None]:
%matplotlib inline
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import os

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA version: {torch.version.cuda}") if device.type == 'cuda' else None

In [None]:
import sys
sys.path.append("..")  # Add parent directory to path
from python_helpers import get_project_root_dir
from datasets import SoundTracksDataset
import models


In [None]:
full_dataset = SoundTracksDataset()
print(f"Dataset size: {len(full_dataset)} samples")
print(f"Sample features shape: {full_dataset.melspecs[0].shape}")

In [None]:
train_val, test = full_dataset.train_test_split(split_size=0.2)
train, val = train_val.train_test_split(split_size=0.25)
print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

In [None]:
MODEL_TYPE = 'vgg'  # Change to 'vgg' or 'nilscnn'
FEATURE_TYPE = 'melspecs'

model = {
    'nilscnn': models.NilsHMeierCNN(FEATURE_TYPE),
    'vgg': models.VGGStyleCNN(FEATURE_TYPE),
    'resnet': models.ResNetStyleCNN(FEATURE_TYPE)
}[MODEL_TYPE].to(device)

print(f"Selected model: {MODEL_TYPE.upper()}")
print(model)


In [None]:
from train import ModelTrainer

trainer = ModelTrainer(
    task='multiclass',
    num_classes=4,
    device=device
)

In [None]:
trainer.train(
    model=model,
    train_dset=train.to(device),
    val_dset=val.to(device),
    batch_size=16,
    max_epochs=20,
    lr=0.0001,
    lambda_val=0.01,
    l1_ratio=0.0,
    take_best=True
)

In [None]:
test = test.to(device)
test_loss, test_acc, test_cm = trainer.evaluate_performance(model, test)
print(f"\nFinal Test Performance:")
print(f"Loss: {test_loss:.4f} | Accuracy: {test_acc:.2%}")

In [None]:
plt.figure(figsize=(10,8))
plt.imshow(test_cm.cpu().numpy(), cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(range(4), ['Happy', 'Sad', 'Anger', 'Neutral'])
plt.yticks(range(4), ['Happy', 'Sad', 'Anger', 'Neutral'])
plt.colorbar()
plt.show()

In [None]:
sample_idx = np.random.choice(len(test), 3)
for idx in sample_idx:
    features, true_label = test[idx]
    with torch.no_grad():
        pred = model({k:v.unsqueeze(0).to(device) for k,v in features.items()})
    
    print(f"\nSample {idx}:")
    print(f"True: {true_label.item()} | Predicted: {pred.argmax().item()}")
    display(Audio(features['waveforms'].cpu().numpy(), rate=44100))