<h1>Training from Scratch with Various Architectures</h1>

This notebook can be used as a template for training models with a variety of architectures

Just need to set model_local to be the desired architecture. 

In the example here, the architecture is imported from an external file in the line:
from resnet18flex import ResNet18Gray
and then instantiated with the line
model_local = ResNet18Gray(num_classes=num_classes, input_height=128)


* Features: 15-second mel spectrograms, augmented with pitch shift and pink noise
* Architecture: ResNet18 (modified for our inputs)
* Inputs: A variety of spectrograms of dimension 12x256, 32x256, and 128x256 (also graphs of dimension 300x300)
    * We'll adjust the ResNet18 architecture use stride=1 in conv1 and skip maxpool when the height is 12 or 32
* Model: Tuning from scratch
* Loss function: Soft Labeling Loss
* Target: ~250 musicmap genres

In [1]:
# Run this command if having trouble with import statements
# %pip install scikit-learn seaborn beautifulsoup4 pyvis

In [2]:
import sys
print(sys.executable)

/opt/homebrew/anaconda3/envs/deep-learning/bin/python


<h2>Imports and Parameters</h2>

In [1]:
# Imports
import os

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import one_hot
from torch import nn, optim
from torch.optim.swa_utils import AveragedModel
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

import numpy as np

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils.multiclass import unique_labels

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import networkx as nx
from bs4 import BeautifulSoup
from collections import defaultdict
from pyvis.network import Network
from tabulate import tabulate

from tqdm import tqdm

from musicmap_graph_creator import create_musicmap_graph, compute_shortest_paths_between_classes_nx
from resnet18flex import ResNet18Gray
from musicmap_helpers import compute_image_dataset_stats, check_set_gpu, batch_soft_labeling_loss, top_k_accuracy, evaluate, plot_metrics, plot_cm_super


In [2]:
train_data_directory = "./15_second_features_augmented/musicmap_processed_output_splits_train/mel_db_gray"
val_data_directory = "./15_second_features_augmented/musicmap_processed_output_splits_val/mel_db_gray"

# Set the device to the local GPU if available
device = check_set_gpu()

num_epochs = 100  # Maximum number of epochs to try
beta = 5        # Hyperparameter for soft loss function
patience = 10   # Number of epochs in early stopping condition
swa_start_epoch = 5  # start averaging after this epoch

# Edge weights for the graph
primary_edge_weight = 1
secondary_edge_weight = 2
backlash_edge_weight = 3
supergenre_edge_weight = 4
world_supergenre_weight = 1
world_cluster_weight = 1
util_genre_weight = 2
cluster_weight = 1

Using MPS: True


<h2>Shortest Path Matrix</h2>

In [3]:
musicmap_graph = create_musicmap_graph(
    primary_edge_weight, 
    secondary_edge_weight,
    backlash_edge_weight,
    supergenre_edge_weight,
    world_supergenre_weight,
    world_cluster_weight,
    util_genre_weight,
    cluster_weight
)

shortest_graph, class_names = compute_shortest_paths_between_classes_nx(
    class_dir=train_data_directory,  # path to folder with class folders
    graph=musicmap_graph,
    return_tensor=True
)

<h2>Transform</h2>

We'll use the means and standard deviations to normalize the images.

Using ImageFolder might default to loading images as RGB, even if they're just greyscale. So we'll force greyscale as part of the transform for greyscale images.

For use with images that require color, there is also a RGB version of the transform that only strips out the alpha channel.

In [4]:
image_dir = train_data_directory
means, sds = compute_image_dataset_stats(image_dir)

In [5]:
transform_greyscale = transforms.Compose([
    transforms.ToTensor(), # Should always be the last step before feeding into a model
    transforms.Grayscale(num_output_channels=1),  # force grayscale
    transforms.Normalize(mean=means, std=sds)    # Normalize to imagenet mean and standard deviation
])

transform_rgb = transforms.Compose([
    transforms.ToTensor(), # Should always be the last step before feeding into a model
    transforms.Lambda(lambda x: x[:3, :, :]),  # Remove the alpha channel and keep only the first 3 channels (RGB)
    transforms.Normalize(mean=means, std=sds)    # Normalize to imagenet mean and standard deviation
])

<h2>Loading the Data</h2>

Define the datasets and create dataloaders.
ImageFolder automatically creates labels from the folder structure.

In [6]:
train_dataset = torchvision.datasets.ImageFolder(
    root=train_data_directory,
    transform=transform_greyscale
)

val_dataset = torchvision.datasets.ImageFolder(
    root=val_data_directory,
    transform=transform_greyscale
)

In [7]:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=64,  # Adjust batch size as needed
    shuffle=True,
    num_workers=4,  # Adjust this to tweak multiprocessing
    pin_memory=True,
    prefetch_factor=2
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=10,  # Adjust batch size as needed
    shuffle=False, #Setting to false keeps evaluation stable across epochs
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2 
)

<h2>Architecture</h2>

Compute the number of classes and instantiate the model

In [8]:
num_classes = len(train_dataloader.dataset.classes) 

model_local = ResNet18Gray(num_classes=num_classes, input_height=128)
model_local = model_local.to(device)

<h2>Criterion, Optimizer, and Scheduler</h2>

In [9]:
criterion = lambda outputs, labels: batch_soft_labeling_loss(outputs, labels, shortest_graph, beta)

optimizer = optim.Adam(model_local.parameters(), lr=0.001, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=2)#, verbose=True)

<h2>Training</h2>

This calls an evaluate function, which automatically saves a confusion matrix.

Incorporates auto-stopping

Automatically saves a best model, a SWA model, a plot of evaluation metrics, and a plot of a confusion matrix.

In [None]:
# Create index-to-supergenre lookup once
idx_to_supergenre = np.array([cls[:3] for cls in train_dataset.classes])  # shape (num_classes,)

# The line below might speed up training on MPS
torch.set_float32_matmul_precision("high")  # or "medium"

# Metric lists
val_accuracy_list, top3_acc_list, top5_acc_list = [], [], []
precision_list, recall_list, f1_list = [], [], []
mean_dist_list, super_acc_list, super_top3_acc_list = [], [], []
train_loss_list, train_accuracy_list, val_loss_list = [], [], []

# Early stopping and SWA setup
# Going to stop based on mean distance
best_mean_distance = float('inf')
best_model_state = None
counter = 0
swa_model = AveragedModel(model_local)  # model to store average weights


# === Training Loop ===
for epoch in range(num_epochs):
    model_local.train()
    train_loss, correct, total = 0.0, 0, 0


    # Train
    for inputs, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1} - Training", leave=True, ncols=100):
    #for batch_idx, (inputs, labels) in enumerate(train_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        # Putting this inside autocast might speed up cuda computations
        if torch.cuda.is_available():
            with torch.amp.autocast('cuda'):
                outputs = model_local(inputs)

                # Take the log softmax of outputs
                outputs = torch.log_softmax(outputs, dim=1)
                loss = criterion(outputs, labels)
        else:
            outputs = model_local(inputs)
            outputs = torch.log_softmax(outputs, dim=1)
            loss = criterion(outputs, labels)

        
        # On Apple Silicon, just call backward and step
        loss.backward()
        optimizer.step()
        # On CUDA, substitute with the scaler
        #scaler.scale(loss).backward()
        #scaler.step(optimizer)
        #scaler.update()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    avg_train_loss = train_loss / len(train_dataloader)
    train_accuracy = correct / total

    # Add train loss and accuracy to their lists
    train_loss_list.append(avg_train_loss)
    train_accuracy_list.append(train_accuracy)

    # --- Validation ---
    metrics = evaluate(idx_to_supergenre, model_local, val_dataloader, criterion, device, shortest_graph, train_dataset)#, epoch=epoch)

    # Early stopping based on mean distance
    if metrics['mean_distance'] < best_mean_distance:
        best_mean_distance = metrics['mean_distance']
        best_model_state = model_local.state_dict()
        counter = 0
        print(f"[Epoch {epoch+1}] New best mean_distance: {best_mean_distance:.4f}")
    else:
        counter += 1
        print(f"No improvement in mean_distance for {counter} epoch(s).")


    # SWA Averaging
    if epoch >= swa_start_epoch:
        swa_model.update_parameters(model_local)

    # Append metrics to lists
    for lst, key in zip(
        [val_loss_list, val_accuracy_list, top3_acc_list, top5_acc_list, precision_list, recall_list, f1_list,
        mean_dist_list, super_acc_list, super_top3_acc_list],
        ["val_loss", "val_accuracy", "top3_accuracy", "top5_accuracy", "precision", "recall", "f1",
        "mean_distance", "supergenre_accuracy", "supergenre_top3_accuracy"]
    ):
        lst.append(metrics[key])

    # Print metrics table
    headers = ["Epoch", "Train Loss", "Train Acc", "Val Loss", "Val Acc",
               "Top-3 Acc", "Top-5 Acc", "Precision", "Recall", "F1",
               "Supergenre Acc", "Top-3 Super", "Mean Dist"]

    row = [epoch + 1,
        f"{avg_train_loss:.4f}", f"{train_accuracy:.2%}",
        f"{metrics['val_loss']:.4f}", f"{metrics['val_accuracy']:.2%}",
        f"{metrics['top3_accuracy']:.2%}", f"{metrics['top5_accuracy']:.2%}",
        f"{metrics['precision']:.2%}", f"{metrics['recall']:.2%}", f"{metrics['f1']:.2%}",
        f"{metrics['supergenre_accuracy']:.2%}", f"{metrics['supergenre_top3_accuracy']:.2%}",
        f"{metrics['mean_distance']:.4f}"]

    try:
        print(tabulate([row], headers=headers, tablefmt="grid"))
    except Exception as e:
        print(f"Error printing metrics: {e}", flush=True)
        print(f"metrics dict: {metrics}")
        raise

    if counter >= patience:
        print(f"Early stopping triggered after {patience} epochs of no improvement.")
        break
    
    scheduler.step(metrics['mean_distance'])

# Save the best model
if best_model_state is not None:
    torch.save(best_model_state, 'best_model.pth')
    print(f"Best model saved with mean_distance: {best_mean_distance:.4f}")

# Save the SWA model
torch.save(swa_model.module.state_dict(), 'swa_model.pth')
print("SWA model saved from averaged checkpoints.")

metrics_df = pd.DataFrame({
    "epoch": list(range(1, len(train_loss_list) + 1)),
    "train_loss": train_loss_list,
    "train_accuracy": train_accuracy_list,
    "val_loss": val_loss_list,
    "val_accuracy": val_accuracy_list,
    "top3_accuracy": top3_acc_list,
    "top5_accuracy": top5_acc_list,
    "precision": precision_list,
    "recall": recall_list,
    "f1": f1_list,
    "supergenre_accuracy": super_acc_list,
    "supergenre_top3_accuracy": super_top3_acc_list,
    "mean_distance": mean_dist_list
})
metrics_df.to_csv("training_metrics.csv", index=False)

# Save graphs of the various metrics
plot_metrics(val_accuracy_list, top3_acc_list, top5_acc_list, precision_list, recall_list, f1_list, mean_dist_list, super_acc_list, super_top3_acc_list, "metrics.png")

# Save the final confusion matrices
cm_genre = metrics['genre_confusion_matrix']
cm_genre.to_csv("confusion_matrix_genre.csv")

cm_supergenre = metrics['supergenre_confusion_matrix']
cm_supergenre.to_csv("confusion_matrix_supergenre.csv")

# Load confusion_matrix.csv and make a heatmap of the confusion matrix
plot_cm_super(cm_supergenre, "supergenre_confusion_matrix.png")

Epoch 1 - Training: 100%|█████████████████████████████████████████| 630/630 [15:57<00:00,  1.52s/it]


[Epoch 1] New best mean_distance: 3.3128
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       1 |       4.2548 | 9.24%       |     5.0472 | 5.68%     | 13.41%      | 19.27%      | 4.29%       | 6.02%    | 3.48% | 18.98%           | 40.60%        |      3.3128 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 2 - Training: 100%|█████████████████████████████████████████| 630/630 [16:29<00:00,  1.57s/it]


[Epoch 2] New best mean_distance: 3.2298
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       2 |       2.8437 | 32.33%      |     5.1704 | 8.47%     | 18.83%      | 25.55%      | 8.55%       | 8.80%    | 6.98% | 24.98%           | 45.11%        |      3.2298 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 3 - Training: 100%|█████████████████████████████████████████| 630/630 [17:17<00:00,  1.65s/it]


[Epoch 3] New best mean_distance: 3.2115
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       3 |       1.7389 | 60.18%      |     5.4897 | 8.38%     | 18.24%      | 26.03%      | 8.51%       | 8.51%    | 7.18% | 26.05%           | 48.20%        |      3.2115 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 4 - Training: 100%|█████████████████████████████████████████| 630/630 [17:19<00:00,  1.65s/it]


[Epoch 4] New best mean_distance: 3.1643
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       4 |       1.0515 | 78.77%      |     5.8329 | 8.65%     | 19.21%      | 26.52%      | 9.02%       | 8.91%    | 7.54% | 26.49%           | 47.96%        |      3.1643 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 5 - Training: 100%|█████████████████████████████████████████| 630/630 [15:47<00:00,  1.50s/it]


[Epoch 5] New best mean_distance: 3.1562
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       5 |       0.5076 | 94.31%      |     5.6501 | 8.60%     | 19.61%      | 27.73%      | 8.31%       | 8.89%    | 7.83% | 26.50%           | 49.42%        |      3.1562 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 6 - Training: 100%|█████████████████████████████████████████| 630/630 [16:41<00:00,  1.59s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       6 |       0.4014 | 96.83%      |     5.7863 | 8.72%     | 19.74%      | 27.09%      | 9.89%       | 9.06%    | 8.54% | 26.01%           | 48.19%        |      3.1779 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 7 - Training: 100%|█████████████████████████████████████████| 630/630 [15:15<00:00,  1.45s/it]


[Epoch 7] New best mean_distance: 3.1086
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       7 |        0.368 | 97.47%      |      5.681 | 9.14%     | 19.24%      | 26.37%      | 8.89%       | 9.10%    | 8.18% | 26.72%           | 48.11%        |      3.1086 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 8 - Training: 100%|█████████████████████████████████████████| 630/630 [15:01<00:00,  1.43s/it]


[Epoch 8] New best mean_distance: 3.0422
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       8 |        0.257 | 99.59%      |     5.2841 | 10.73%    | 20.21%      | 27.92%      | 9.72%       | 10.91%   | 9.58% | 28.52%           | 49.02%        |      3.0422 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 9 - Training: 100%|█████████████████████████████████████████| 630/630 [18:00<00:00,  1.71s/it]


[Epoch 9] New best mean_distance: 3.0149
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|       9 |       0.2166 | 99.89%      |     5.2222 | 10.44%    | 21.45%      | 28.82%      | 9.74%       | 10.73%   | 9.60% | 28.34%           | 49.15%        |      3.0149 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 10 - Training: 100%|████████████████████████████████████████| 630/630 [18:05<00:00,  1.72s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      10 |       0.2301 | 99.61%      |     5.5414 | 9.84%     | 19.71%      | 26.99%      | 9.50%       | 9.89%    | 8.75% | 26.93%           | 48.32%        |      3.1568 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 11 - Training: 100%|████████████████████████████████████████| 630/630 [18:05<00:00,  1.72s/it]


[Epoch 11] New best mean_distance: 2.9641
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      11 |       0.2146 | 99.79%      |     5.1259 | 10.66%    | 22.22%      | 29.18%      | 9.99%       | 10.92%   | 9.89% | 29.24%           | 49.86%        |      2.9641 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 12 - Training: 100%|████████████████████████████████████████| 630/630 [18:02<00:00,  1.72s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      12 |       0.1886 | 99.87%      |     5.0211 | 11.07%    | 21.61%      | 29.14%      | 10.31%      | 11.26%   | 10.11% | 29.17%           | 48.58%        |      2.9731 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 13 - Training: 100%|████████████████████████████████████████| 630/630 [17:19<00:00,  1.65s/it]


No improvement in mean_distance for 2 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1    | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      13 |       0.1835 | 99.84%      |     5.1692 | 10.23%    | 20.66%      | 27.75%      | 10.18%      | 10.53%   | 9.39% | 27.89%           | 47.64%        |      3.0234 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+-------+------------------+---------------+-------------+


Epoch 14 - Training: 100%|████████████████████████████████████████| 630/630 [17:56<00:00,  1.71s/it]


[Epoch 14] New best mean_distance: 2.9076
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      14 |       0.1767 | 99.86%      |     4.9152 | 11.42%    | 22.03%      | 28.84%      | 10.44%      | 11.61%   | 10.39% | 29.72%           | 48.88%        |      2.9076 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 15 - Training: 100%|████████████████████████████████████████| 630/630 [26:16<00:00,  2.50s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      15 |       0.1714 | 99.90%      |     4.8919 | 10.85%    | 21.19%      | 28.64%      | 10.39%      | 10.92%   | 10.04% | 29.82%           | 48.52%        |      2.9155 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 16 - Training: 100%|████████████████████████████████████████| 630/630 [25:25<00:00,  2.42s/it]


No improvement in mean_distance for 2 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      16 |       0.1702 | 99.89%      |     4.8723 | 11.24%    | 21.89%      | 29.45%      | 10.26%      | 11.36%   | 10.10% | 29.91%           | 48.51%        |      2.9252 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 17 - Training: 100%|████████████████████████████████████████| 630/630 [25:19<00:00,  2.41s/it]


No improvement in mean_distance for 3 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      17 |       0.1665 | 99.88%      |     4.8229 | 11.24%    | 22.09%      | 29.43%      | 10.39%      | 11.31%   | 10.21% | 29.89%           | 48.64%        |      2.9219 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 18 - Training: 100%|████████████████████████████████████████| 630/630 [25:23<00:00,  2.42s/it]


[Epoch 18] New best mean_distance: 2.8908
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      18 |       0.1647 | 99.90%      |     4.8216 | 11.54%    | 22.04%      | 29.55%      | 10.63%      | 11.68%   | 10.42% | 30.14%           | 48.47%        |      2.8908 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 19 - Training: 100%|████████████████████████████████████████| 630/630 [24:46<00:00,  2.36s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      19 |       0.1645 | 99.88%      |     4.7995 | 11.59%    | 22.13%      | 29.61%      | 10.71%      | 11.66%   | 10.47% | 30.59%           | 48.70%        |      2.8923 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 20 - Training: 100%|████████████████████████████████████████| 630/630 [17:21<00:00,  1.65s/it]


No improvement in mean_distance for 2 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      20 |       0.1624 | 99.91%      |     4.7563 | 11.48%    | 22.17%      | 29.56%      | 10.56%      | 11.55%   | 10.37% | 30.29%           | 48.90%        |      2.9025 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 21 - Training: 100%|████████████████████████████████████████| 630/630 [16:48<00:00,  1.60s/it]


No improvement in mean_distance for 3 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      21 |       0.1618 | 99.88%      |      4.727 | 11.57%    | 22.20%      | 29.65%      | 10.51%      | 11.58%   | 10.35% | 30.43%           | 48.83%        |      2.9008 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 22 - Training: 100%|████████████████████████████████████████| 630/630 [14:44<00:00,  1.40s/it]


No improvement in mean_distance for 4 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      22 |       0.1616 | 99.89%      |     4.7212 | 11.56%    | 22.40%      | 29.81%      | 10.80%      | 11.64%   | 10.50% | 30.61%           | 48.30%        |      2.9153 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 23 - Training: 100%|████████████████████████████████████████| 630/630 [14:44<00:00,  1.40s/it]


No improvement in mean_distance for 5 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      23 |       0.1606 | 99.90%      |      4.713 | 11.54%    | 22.13%      | 29.65%      | 10.67%      | 11.62%   | 10.44% | 30.68%           | 48.61%        |      2.8962 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 24 - Training: 100%|████████████████████████████████████████| 630/630 [16:48<00:00,  1.60s/it]


No improvement in mean_distance for 6 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      24 |       0.1604 | 99.88%      |     4.6954 | 11.49%    | 22.55%      | 29.99%      | 10.79%      | 11.59%   | 10.53% | 30.53%           | 48.87%        |      2.8941 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 25 - Training: 100%|████████████████████████████████████████| 630/630 [18:02<00:00,  1.72s/it]


No improvement in mean_distance for 7 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      25 |       0.1603 | 99.89%      |     4.6813 | 11.47%    | 22.46%      | 29.76%      | 10.74%      | 11.59%   | 10.51% | 30.53%           | 49.09%        |      2.9135 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 26 - Training: 100%|████████████████████████████████████████| 630/630 [14:27<00:00,  1.38s/it]


No improvement in mean_distance for 8 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      26 |       0.1599 | 99.88%      |     4.6749 | 11.59%    | 22.39%      | 29.57%      | 10.66%      | 11.65%   | 10.46% | 30.05%           | 48.65%        |      2.9116 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 27 - Training: 100%|████████████████████████████████████████| 630/630 [16:23<00:00,  1.56s/it]


[Epoch 27] New best mean_distance: 2.8899
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      27 |       0.1598 | 99.91%      |     4.6685 | 11.83%    | 22.53%      | 29.95%      | 10.66%      | 11.89%   | 10.61% | 30.59%           | 48.70%        |      2.8899 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 28 - Training: 100%|████████████████████████████████████████| 630/630 [18:28<00:00,  1.76s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      28 |       0.1596 | 99.90%      |     4.6629 | 11.66%    | 22.45%      | 29.78%      | 10.78%      | 11.74%   | 10.57% | 30.35%           | 48.78%        |      2.8939 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 29 - Training: 100%|████████████████████████████████████████| 630/630 [14:51<00:00,  1.41s/it]


No improvement in mean_distance for 2 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      29 |       0.1594 | 99.90%      |      4.657 | 11.61%    | 22.64%      | 30.03%      | 10.62%      | 11.77%   | 10.50% | 30.68%           | 48.94%        |      2.9013 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 30 - Training: 100%|████████████████████████████████████████| 630/630 [16:16<00:00,  1.55s/it]


No improvement in mean_distance for 3 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      30 |       0.1594 | 99.90%      |     4.6536 | 11.57%    | 22.60%      | 29.76%      | 10.62%      | 11.66%   | 10.44% | 30.48%           | 48.80%        |      2.9045 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 31 - Training: 100%|████████████████████████████████████████| 630/630 [18:35<00:00,  1.77s/it]


[Epoch 31] New best mean_distance: 2.8866
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      31 |       0.1592 | 99.90%      |     4.6564 | 11.60%    | 22.51%      | 29.92%      | 10.70%      | 11.65%   | 10.49% | 30.43%           | 48.58%        |      2.8866 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 32 - Training: 100%|████████████████████████████████████████| 630/630 [16:56<00:00,  1.61s/it]


No improvement in mean_distance for 1 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      32 |       0.1592 | 99.92%      |     4.6548 | 11.57%    | 22.40%      | 30.03%      | 10.64%      | 11.60%   | 10.46% | 30.41%           | 48.67%        |      2.9038 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 33 - Training: 100%|████████████████████████████████████████| 630/630 [16:46<00:00,  1.60s/it]


No improvement in mean_distance for 2 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      33 |       0.1592 | 99.90%      |     4.6525 | 11.56%    | 22.51%      | 29.97%      | 10.73%      | 11.66%   | 10.52% | 30.57%           | 48.62%        |      2.8884 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 34 - Training: 100%|████████████████████████████████████████| 630/630 [18:18<00:00,  1.74s/it]


No improvement in mean_distance for 3 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      34 |       0.1592 | 99.90%      |     4.6512 | 11.37%    | 22.27%      | 29.57%      | 10.61%      | 11.48%   | 10.37% | 30.31%           | 48.84%        |      2.9052 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 35 - Training: 100%|████████████████████████████████████████| 630/630 [16:35<00:00,  1.58s/it]


No improvement in mean_distance for 4 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      35 |       0.1591 | 99.91%      |     4.6435 | 11.47%    | 22.26%      | 29.88%      | 10.70%      | 11.52%   | 10.40% | 30.25%           | 48.48%        |      2.9099 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 36 - Training: 100%|████████████████████████████████████████| 630/630 [18:26<00:00,  1.76s/it]


No improvement in mean_distance for 5 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      36 |       0.1591 | 99.92%      |     4.6471 | 11.51%    | 22.45%      | 30.00%      | 10.80%      | 11.61%   | 10.51% | 30.45%           | 48.78%        |      2.8951 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 37 - Training: 100%|████████████████████████████████████████| 630/630 [17:09<00:00,  1.63s/it]


No improvement in mean_distance for 6 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      37 |       0.1591 | 99.91%      |     4.6481 | 11.68%    | 22.67%      | 30.44%      | 10.61%      | 11.79%   | 10.48% | 30.67%           | 48.82%        |       2.892 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 38 - Training: 100%|████████████████████████████████████████| 630/630 [18:10<00:00,  1.73s/it]


No improvement in mean_distance for 7 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      38 |        0.159 | 99.91%      |     4.6504 | 11.52%    | 22.42%      | 29.79%      | 10.52%      | 11.49%   | 10.33% | 30.28%           | 48.68%        |      2.8991 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 39 - Training: 100%|████████████████████████████████████████| 630/630 [17:23<00:00,  1.66s/it]


No improvement in mean_distance for 8 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      39 |       0.1589 | 99.91%      |     4.6491 | 11.61%    | 22.46%      | 29.82%      | 10.69%      | 11.73%   | 10.52% | 30.48%           | 48.78%        |      2.9024 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 40 - Training: 100%|████████████████████████████████████████| 630/630 [18:03<00:00,  1.72s/it]


No improvement in mean_distance for 9 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      40 |        0.159 | 99.93%      |     4.6559 | 11.65%    | 22.88%      | 30.23%      | 10.63%      | 11.72%   | 10.51% | 30.55%           | 48.94%        |      2.8872 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+


Epoch 41 - Training: 100%|████████████████████████████████████████| 630/630 [18:07<00:00,  1.73s/it]


No improvement in mean_distance for 10 epoch(s).
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
|   Epoch |   Train Loss | Train Acc   |   Val Loss | Val Acc   | Top-3 Acc   | Top-5 Acc   | Precision   | Recall   | F1     | Supergenre Acc   | Top-3 Super   |   Mean Dist |
|      41 |        0.159 | 99.91%      |     4.6412 | 11.54%    | 22.53%      | 29.98%      | 10.66%      | 11.53%   | 10.43% | 30.39%           | 48.45%        |      2.9085 |
+---------+--------------+-------------+------------+-----------+-------------+-------------+-------------+----------+--------+------------------+---------------+-------------+
Early stopping triggered after 10 epochs of no improvement.
Best model saved with mean_distance: 2.8866
SWA model saved from averaged checkpoints.


NameError: name 'output_dir' is not defined

In [11]:
metrics_df.to_csv("training_metrics.csv", index=False)

# Save graphs of the various metrics
plot_metrics(val_accuracy_list, top3_acc_list, top5_acc_list, precision_list, recall_list, f1_list, mean_dist_list, super_acc_list, super_top3_acc_list, "metrics.png")

# Save the final confusion matrices
cm_genre = metrics['genre_confusion_matrix']
cm_genre.to_csv("confusion_matrix_genre.csv")

cm_supergenre = metrics['supergenre_confusion_matrix']
cm_supergenre.to_csv("confusion_matrix_supergenre.csv")

# Load confusion_matrix.csv and make a heatmap of the confusion matrix
plot_cm_super(cm_supergenre, "supergenre_confusion_matrix.png")