In [None]:
# Imports
import numpy as np
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchvision.models as models
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, roc_curve, classification_report
import seaborn as sns
from utils.audio_utils import MelSpectogramDataset


In [None]:
def calculate_metrics(y_true, y_pred_probs, num_classes):
    """
    Calculate AUC, precision, recall, and F1 score for multiclass classification.

    Args:
        y_true (array-like): True labels.
        y_pred_probs (array-like): Predicted probabilities or logits.
        num_classes (int): Number of classes in the classification task.

    Returns:
        dict: A dictionary containing AUC, precision, recall, and F1 scores.
    """
    # Convert predicted probabilities to class predictions
    y_pred = y_pred_probs.argmax(axis=1)

    # Calculate metrics
    metrics = {
        "AUC": roc_auc_score(y_true, y_pred_probs, multi_class="ovr", average="macro"),
        "Precision": precision_score(y_true, y_pred, average="macro", zero_division=1),
        "Recall": recall_score(y_true, y_pred, average="macro"),
        "F1 Score": f1_score(y_true, y_pred, average="macro")
    }
    
    print(classification_report(y_true, y_pred, target_names=[f"Class {i}" for i in range(num_classes)]))
    return metrics


from sklearn.metrics import classification_report


def per_class_metrics(y_true, y_pred, num_classes):
    """
    Prints classification metrics for each class.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Predicted labels.
        num_classes (int): Number of classes in the classification task.
    """
    report = classification_report(
        y_true,
        y_pred,
        zero_division=1,
        target_names=[f"Class {i}" for i in range(num_classes)]
    )
    print(report)

In [None]:
# Evaluate the model on the test set
model = models.resnet18()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=True)
model.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
model.fc = nn.Linear(512, 8)
#summary(model, (1, 96, 1024))

model = model.to("mps")

model.load_state_dict(torch.load("models/resnet_model_v1_weighted_alphabetical_normalize_fma.pt", weights_only=True))

In [None]:
model.eval()

In [None]:
mel_dataset = MelSpectogramDataset(data_path='mel_spectrogram')
dataset_loader = DataLoader(mel_dataset, batch_size=64)

In [None]:

with torch.no_grad():
    correct, total = 0, 0
    all_labels, all_probs = [], []
    test_progress = tqdm(dataset_loader, desc="Testing", leave=False)

    for mel_spectrogram, label in test_progress:
        mel_spectrogram, label = mel_spectrogram.to("mps").float(), label.to("mps")
        output = model(mel_spectrogram.unsqueeze(1))
        probabilities = nn.Softmax(dim=1)(output).cpu().numpy()
        all_probs.append(probabilities)
        all_labels.append(label.cpu().numpy())
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()


    # Concatenate all predictions and true labels
    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)

    # Calculate metrics
    metrics = calculate_metrics(all_labels, all_probs, num_classes=8)
    print("OVERALL METRICS")

    print(f"Accuracy: {100 * correct / total:.2f}% | Metrics: {metrics}")

In [None]:
# Get the true positive rate and false positive rate
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(8):
    fpr[i], tpr[i], _ = roc_curve((all_labels == i).astype(int), all_probs[:, i])
    roc_auc[i] = roc_auc_score((all_labels == i).astype(int), all_probs[:, i])

# Save data to plot the ROC curve

np.save("fpr_resnet.npy", fpr)
np.save("tpr_resnet.npy", tpr)
np.save("roc_auc_resnet.npy", roc_auc)

# Save label
np.save("labels_resnet.npy", mel_dataset.genres)
# Plot the ROC curve
plt.figure(figsize=(10, 8))
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
for i in range(8):
    plt.plot(fpr[i], tpr[i], label=f"{mel_dataset.genres[i]} (AUC = {roc_auc[i]:.2f})")
plt.plot([0, 1], [0, 1], "k--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
# Make background plain white
plt.style.use('fast')
# Add legend
plt.legend()
plt.savefig("auc_roc_resnet_3.eps", dpi=300)
plt.show()


In [None]:
all_probs[291]

In [None]:
all_labels[0]

In [None]:
y_pred[121]

In [None]:
# Plot the confusion matrix with annotations on test set
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# Calculate the confusion matrix
conf_matrix = confusion_matrix(all_labels, all_probs.argmax(axis=1))

# Plot the confusion matrix with annotations
plt.figure(figsize=(10, 8))
plt.imshow(conf_matrix, cmap="Blues")
plt.colorbar()
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.xticks(range(8), mel_dataset.genres, rotation=45)
plt.yticks(range(8), mel_dataset.genres)
for i in range(8):
    for j in range(8):
        plt.text(j, i, conf_matrix[i, j], ha="center", va="center", color="black")
plt.show()

In [None]:
from matplotlib.colors import LinearSegmentedColormap

colors = ["white", "#455681"]  # White to #455681 gradient
custom_cmap = LinearSegmentedColormap.from_list("custom_white_to_blue", colors)

In [None]:
# Normalize the confusion matrix# 
conf_matrix_norm = conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]

plt.rc('text', usetex=True)
plt.rc('font', family='serif')
# Plot the normalized confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix_norm, cmap=custom_cmap, annot=True, fmt=".2f", xticklabels=mel_dataset.genres,
            yticklabels=mel_dataset.genres, vmax=1.0)
plt.xlabel("Predicted", fontdict={"fontsize": 12})
plt.ylabel("True", fontdict={"fontsize": 12})
plt.tight_layout()
plt.savefig("confusion_matrix_normalized_resnet_fma.eps", dpi=300)
plt.show()

In [None]:
# Draw AUC-ROC curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt


In [None]:
# Load the data from csv for audio

from utils.audio_utils import MelSpectogramDataset

mel_dataset = MelSpectogramDataset(data_path='mel_spectrogram')
dataset_loader = DataLoader(mel_dataset, batch_size=32)
# Evaluate the model on the test set
audio_model = models.resnet18()
audio_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=(1, 2), bias=True)
audio_model.maxpool = nn.MaxPool2d((2, 3), stride=(1, 2))
audio_model.fc = nn.Linear(512, 8)

audio_model = audio_model.to("mps")

audio_model.load_state_dict(torch.load("models/resnet_model_v1_weighted_alphabetical.pt", weights_only=True))
# Get number of samples in test_loader

num_samples = 0
for mel_spectrogram, label in dataset_loader:
    num_samples += mel_spectrogram.size(0)
print(num_samples)

In [None]:
# Get predictions from the audio model
# Get predictions from the audio model
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()

with torch.no_grad():
    correct, total = 0, 0
    audio_preds, audio_labels = [], []
    test_progress = tqdm(dataset_loader, desc="Testing", leave=False)

    for mel_spectrogram, label in test_progress:
        mel_spectrogram, label = mel_spectrogram.to("mps").float(), label.to("mps")
        output = audio_model(mel_spectrogram.unsqueeze(1))
        probabilities = nn.Softmax(dim=1)(output).cpu().numpy()
        audio_preds.append(probabilities)
        audio_labels.append(label.cpu().numpy())
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

        # Calculate the loss
        loss = criterion(output, label)
        test_progress.set_postfix({"Loss": f"{loss.item():.4f}"})

    # Concatenate all predictions and true labels
    audio_labels = np.concatenate(audio_labels)
    audio_preds = np.concatenate(audio_preds)

    # Calculate metrics
    metrics = calculate_metrics(audio_labels, audio_preds, num_classes=8)
    print("OVERALL METRICS")

    print(f"Accuracy: {100 * correct / total:.2f}% | Metrics: {metrics}")

# Extract features from the audio model on the multi-modal dataset

In [None]:
from utils.audio_utils import MelSpectogramDataset

mel_dataset = MelSpectogramDataset(data_path='mel_spectrogram')
multimodal_dataset_loader = DataLoader(mel_dataset, batch_size=64)

In [None]:
# Check how many samples are in the dataset
num_samples = 0
for mel_spectrogram, label in multimodal_dataset_loader:
    num_samples += mel_spectrogram.size(0)
print(num_samples)

In [None]:
# Extract features from the audio model
model = models.resnet18()
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=True)
model.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
model.fc = nn.Linear(512, 8)

model.load_state_dict(torch.load("models/resnet_model_v1_weighted_alphabetical_normalize_fma.pt", weights_only=True))
model = model.to("mps")

# remove the last layer
model = nn.Sequential(*list(model.children())[:-1])
model.eval()

# extract features
features = []
labels = []

with torch.no_grad():
    for mel_spectrogram, label in tqdm(multimodal_dataset_loader, desc="Extracting Features", leave=False):
        mel_spectrogram, label = mel_spectrogram.to("mps").float(), label.to("mps")
        output = model(mel_spectrogram.unsqueeze(1))
        features.append(output.cpu().numpy())
        labels.append(label.cpu().numpy())

In [None]:
features[0].shape

In [None]:
features_concat = np.concatenate(features)
labels_concat = np.concatenate(labels)

In [None]:
features_concat.shape

In [None]:
labels_concat.shape

In [None]:
# Resize features to 2D
features_2d = features_concat.reshape(features_concat.shape[0], -1)
features_2d.shape

In [None]:
# Show example feature with label
example_idx = 0
print(f"Label: {labels_concat[example_idx]}")
print(features_2d[example_idx].shape)

In [None]:
features_2d

In [None]:
print(labels_concat)

In [None]:
# Save the features and labels
np.save("features/features_audio.npy", features_2d)
np.save("features/labels_audio.npy", labels_concat)