In [9]:
%reload_ext autoreload
%autoreload 2

In [10]:
import autorootcwd
import torch
import os
import numpy as np
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import confusion_matrix

from src.utils import chord_to_id_map, id_to_chord_map, get_torch_device, collate_fn, get_annotation_metadata
from src.models.ismir2017 import ISMIR2017ACR
from src.data.dataset import FullChordDataset

In [20]:

print(os.listdir('./data/experiments/ismir2017-lr-0.001'))

device = get_torch_device()

small_exp = 'ismir2017-lr-0.001'
small_model = ISMIR2017ACR(cr2=False, num_classes=25)
small_model.load_state_dict(torch.load(f'./data/experiments/{small_exp}/best_model.pth', map_location=device))
small_model.eval()

big_exp = 'large-vocab-fewer-X'
big_model = ISMIR2017ACR(cr2=False)
big_model.load_state_dict(torch.load(f'./data/experiments/{big_exp}/best_model.pth', map_location=device))
big_model.eval()

print('\nModels loaded')

['metrics.json', 'metadata.json', 'best_model.pth', 'training_args.json', 'training_history.json']

Models loaded


  small_model.load_state_dict(torch.load(f'./data/experiments/{small_exp}/best_model.pth', map_location=device))
  big_model.load_state_dict(torch.load(f'./data/experiments/{big_exp}/best_model.pth', map_location=device))


In [21]:
small_dataset = FullChordDataset(override_small_vocab=True)

small_all_preds = []
small_all_labels = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(small_dataset))):
        cqt, label = small_dataset[i]
        pred = small_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        small_all_preds.append(preds[0])  # Keep as tensors
        small_all_labels.append(label)    # Keep as tensors

# Concatenate all predictions and labels at the end
small_all_preds = torch.cat(small_all_preds)
small_all_labels = torch.cat(small_all_labels)

100%|██████████| 1213/1213 [04:22<00:00,  4.62it/s]


In [22]:
big_dataset = FullChordDataset()

big_all_preds = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(big_dataset))):
        cqt, label = big_dataset[i]
        pred = big_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        big_all_preds.append(preds[0])  # Keep as tensors

# Concatenate all predictions and labels at the end
big_all_preds = torch.cat(big_all_preds)

100%|██████████| 1213/1213 [04:34<00:00,  4.41it/s]


In [23]:
from functools import lru_cache
from src.utils import chord_to_id, id_to_chord

@lru_cache(maxsize=None)
def large_to_small_vocab_id(id: int) -> int:
    """
    Converts a large vocabulary chord id to a small vocabulary chord id.

    Args:
        id (int): The large vocabulary chord id.

    Returns:
        int: The small vocabulary chord id.
    """

    chord = id_to_chord(id, use_small_vocab=False)
    return chord_to_id(chord, use_small_vocab=True)

In [24]:
# Map large vocabulary predictions to small vocabulary
big_all_preds_small_vocab = torch.tensor([large_to_small_vocab_id(id.item()) for id in big_all_preds])

In [None]:
N_mask = small_all_labels != 0 # Mask out N chords

# Accuracy of small model on small dataset ignoring N (index 0)
small_all_preds_masked = small_all_preds[N_mask]
small_all_labels_masked = small_all_labels[N_mask]
small_correct = (small_all_preds_masked == small_all_labels_masked).sum().item()
small_total = small_all_labels_masked.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset ignoring N (index 0)
big_all_preds_masked = big_all_preds_small_vocab[N_mask]
big_correct = (big_all_preds_masked == small_all_labels_masked).sum().item()
big_total = small_all_labels_masked.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.81
Big model accuracy on small dataset: 0.78


In [29]:
# Accuracy of small model on small dataset without ignoring N
small_correct = (small_all_preds == small_all_labels).sum().item()
small_total = small_all_labels.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset without ignoring N
big_correct = (big_all_preds_small_vocab == small_all_labels).sum().item()
big_total = small_all_labels.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.77
Big model accuracy on small dataset: 0.73
