In [1]:
import datetime
import os
import torch
import torch.nn as nn
import torchvision
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, roc_auc_score
import numpy as np
from torch.utils.data import DataLoader as DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from melSpecDataset import MelSpecDataset
from matplotlib import pyplot as plt
import basic_model as net0
import ModMusicRedNet as net1
import RBFMusicNet as net2


In [2]:
# device
device = torch.device('cpu')

# set output directory
out_dir = './output/'
os.makedirs(out_dir, exist_ok=True)

test_dir = './splitdata/testing'
batch_size = 64


In [None]:
def calcMeanStd ():
    # Assuming MelSpecDataset is your dataset class and train_dir is your training directory
    resize_size = (258, 128)
    dataset = MelSpecDataset(test_dir, transform=Compose([Resize(resize_size), ToTensor()]))

    # Create a DataLoader with the desired batch size
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # Variables to accumulate the sum and sum of squares
    mean_sum = 0.0
    sum_of_squares = 0.0
    nb_samples = 0

    # Loop through all the batches in the DataLoader
    for images, _ in loader:
        # Flatten the images to (batch_size, pixels)
        images = images.view(images.size(0), -1)
        # Sum up the mean and mean of squares
        mean_sum += images.mean(1).sum(0)
        sum_of_squares += (images ** 2).mean(1).sum(0)
        # Count the total number of samples (images) processed
        nb_samples += images.size(0)

    # Calculate the mean and standard deviation
    mean = mean_sum / nb_samples
    # For std, we need to take the square root of the variance (average of the squared differences from the mean)
    std = (sum_of_squares / nb_samples - mean ** 2) ** 0.5

    # Convert to scalar for single-channel image
    mean = mean.item()
    std = std.item()

    #print(f'Calculated mean: {mean}')
    #print(f'Calculated std: {std}')
    return mean, std

In [3]:
model = net1.MusicClassNet()
params = torch.load('./model_mod_2023-12-05_08-13-43.pth')
#model = net2.MusicGenreClassifierRBF((4*128*128), 10, 10)
#params = torch.load('./model_RBF_2023-12-04_19-32-56.pth')
model.load_state_dict(params)
model.to(device)
print('model loaded OK!')

mean, std = calcMeanStd()
resize_size = (258, 128)
transform = Compose([
    Resize(resize_size),
    ToTensor(),
    Normalize(mean=[mean], std=[std]) 
])

#training
testset = MelSpecDataset(test_dir, transform)
data_loader = DataLoader(dataset=testset, batch_size=batch_size, shuffle=False)


model loaded OK!


In [4]:
def test():
    model.eval()
    model.to(device)

    # initalize arrays for storing info needed for evaluation metrics
    all_labels = []
    all_outputs = []

    with torch.no_grad():
        for melspecs, labels in data_loader:
            audios = melspecs.to(device)
            labels = labels.to(device)
            #print(audios[20][2][20])
            # calculate losses and call call model
            output = model(audios)
            #print(output)
            # store outputs and labels
            all_outputs.append(output.cpu())
            all_labels.append(labels.cpu())

    # concatenate all outputs and labels into proper array format
    all_outputs = torch.cat(all_outputs).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # calculate accuracy
    predictions = np.argmax(all_outputs, axis=1)
    accuracy = accuracy_score(all_labels, predictions)
    print(f'Accuracy: {accuracy}')

    # generate confusion matrix
    cm = confusion_matrix(all_labels, predictions)
    print(f'Confusion Matrix:\n{cm}')

    # print classification report (Precision, Recall, F1-Score)
    report = classification_report(all_labels, predictions, 
                                   target_names=["blues", 
                                                "country", 
                                                "classical", 
                                                "disco", 
                                                "jazz", 
                                                "hiphop",  
                                                "reggae",
                                                "pop",
                                                "metal",
                                                "rock"])
    print(f'Classification Report:\n{report}')

    # ROC and AUC (One-vs-Rest approach for multi-class)
    roc_auc = roc_auc_score(all_labels, all_outputs, multi_class='ovr')
    print(f'ROC AUC Score: {roc_auc}')

In [5]:
if __name__ == "__main__":
    # call test function
    test()

Accuracy: 0.17317317317317318
Confusion Matrix:
[[61  0  0  0  1  0  9 25  0  4]
 [ 5  0  0  0  0  0 16 61  0 18]
 [57  0  0  1  0  0 19 15  0  8]
 [ 1  0  0  0  0  0  8 49  0 42]
 [47  0  0  0  0  0 16 26  0 10]
 [20  0  0  0  0  0  2 36  0 42]
 [38  0  0  0  1  0 15 40  0  6]
 [ 1  0  0  0  0  0 20 63  0 16]
 [23  0  0  1  1  0  5  7  0 63]
 [28  0  0  0  0  0 14 24  0 34]]
Classification Report:
              precision    recall  f1-score   support

       blues       0.22      0.61      0.32       100
     country       0.00      0.00      0.00       100
   classical       0.00      0.00      0.00       100
       disco       0.00      0.00      0.00       100
        jazz       0.00      0.00      0.00        99
      hiphop       0.00      0.00      0.00       100
      reggae       0.12      0.15      0.13       100
         pop       0.18      0.63      0.28       100
       metal       0.00      0.00      0.00       100
        rock       0.14      0.34      0.20       100

  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
