In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import autorootcwd
import torch
import os
import numpy as np
from tqdm import tqdm
from typing import Dict, List

import matplotlib.pyplot as plt
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import  accuracy_score

from src.utils import get_torch_device, collate_fn, NUM_CHORDS, get_split_filenames
from src.models.ismir2017 import ISMIR2017ACR
from src.models.base_model import BaseACR
from src.data.dataset import FullChordDataset
from src.eval import EvalMetric

In [3]:
from functools import lru_cache
from src.utils import chord_to_id, id_to_chord

@lru_cache(maxsize=None)
def large_to_small_vocab_id(id: int) -> int:
    """
    Converts a large vocabulary chord id to a small vocabulary chord id.

    Args:
        id (int): The large vocabulary chord id.

    Returns:
        int: The small vocabulary chord id.
    """

    chord = id_to_chord(id, use_small_vocab=False)
    return chord_to_id(chord, use_small_vocab=True)

## Load Models

In [None]:
def evaluate_model_large_vs_small(
    model: BaseACR,
    is_small: bool,
    dataset: FullChordDataset,
    evals: List[EvalMetric] = [
        EvalMetric.ROOT,
        EvalMetric.MAJMIN,
        EvalMetric.MIREX,
        EvalMetric.THIRD,
        EvalMetric.SEVENTH,
        EvalMetric.SONG_WISE_ACC,
    ],
    frame_wise_acc: bool = True,
    class_wise_acc: bool = True,
    batch_size: int = 8,
    device: torch.device = None,
) -> Dict[str, float]:
    """
    Evaluate a model on a dataset split using a list of evaluation metrics.

    Args:
        model (BaseACRModel): The model to evaluate.
        dataset (FullChordDataset): The dataset to evaluate on.
        evals (list[EvalMetrics]): The evaluation metrics to use. Defaults to [EvalMetrics.ROOT, EvalMetrics.MAJMIN, EvalMetrics.MIREX, EvalMetrics.THIRD, EvalMetrics.SEVENTH].
        batch_size (int): The batch size to use for evaluation. Defaults to 32.
        device (torch.device): The device to use for evaluation. Defaults to None.

    Returns:
        metrics (dict[str, float]): A dictionary of evaluation metrics and their values.
    """
    if not device:
        device = get_torch_device()

    model.to(device)
    model.eval()

    # Initialize metrics storage
    metrics = {"mean": {}, "median": {}}

    for eval in evals:
        metrics["mean"][eval.value] = 0.0
        metrics["median"][eval.value] = []

    data_loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn
    )

    song_metrics = {eval.value: {} for eval in evals}

    all_hypotheses = []
    all_references = []

    for j, (batch_features, batch_labels) in enumerate(tqdm(data_loader)):
        batch_features, batch_labels = batch_features.to(device), batch_labels.to(
            device
        )
        predictions = model.predict(batch_features).to(device)

        for i in range(batch_labels.shape[0]):  # Iterate over songs in the batch
            valid_mask = batch_labels[i] != -1

            if not valid_mask.any():
                continue
            filtered_references = batch_labels[i][valid_mask].cpu().numpy()
            filtered_hypotheses = predictions[i][valid_mask].cpu().numpy()

            # Map predictions to small vocabulary if necessary
            if not is_small:
                filtered_hypotheses = np.array([large_to_small_vocab_id(id) for id in filtered_hypotheses])

            all_hypotheses.extend(filtered_hypotheses)
            all_references.extend(filtered_references)

            ref_labels = np.vectorize(lambda id: id_to_chord(id, use_small_vocab=True))(filtered_references)
            hyp_labels = np.vectorize(lambda id: id_to_chord(id, use_small_vocab=True))(filtered_hypotheses)

            for eval in evals:
                if eval.get_eval_input_type() == "int":
                    song_eval_scores = eval.evaluate(
                        filtered_references, filtered_hypotheses
                    )
                else:
                    song_eval_scores = eval.evaluate(hyp_labels, ref_labels)

                # Filter out invalid scores that are -1 (produced by mir_eval with 'X' labels for example)
                song_eval_scores = song_eval_scores[song_eval_scores != -1]

                if i not in song_metrics[eval.value]:
                    song_metrics[eval.value][i] = []

                song_metrics[eval.value][i].append(np.mean(song_eval_scores))

    for eval in evals:
        song_scores = [np.mean(scores) for scores in song_metrics[eval.value].values()]

        metrics["mean"][eval.value] = np.mean(song_scores)
        metrics["median"][eval.value] = np.median(song_scores)

    # Flatten along song dimension
    all_hypotheses = np.array(all_hypotheses)
    all_references = np.array(all_references)

    if frame_wise_acc:
        metrics["frame_wise_acc"] = accuracy_score(all_references, all_hypotheses)

    if class_wise_acc:
        class_accs = np.full(NUM_CHORDS, np.nan)
        for i in range(NUM_CHORDS):
            # Find all references of class i
            mask = all_references == i
            class_references = all_references[mask]
            class_hypotheses = all_hypotheses[mask]

            if class_references.size > 0:
                class_accs[i] = accuracy_score(class_references, class_hypotheses)

        metrics["class_wise_acc_mean"] = np.nanmean(class_accs)  # Ignore NaNs
        metrics["class_wise_acc_median"] = np.nanmedian(class_accs)

    return metrics

In [5]:
DIR = f'./results_archive/small_vs_large_vocab'

device = get_torch_device()

small_exp = 'small_vocab_defaults'
small_model = ISMIR2017ACR(num_classes=25)
small_model.load_state_dict(torch.load(f'{DIR}/{small_exp}/best_model.pth', map_location=device, weights_only=True))
small_model.eval()

big_exp = 'large_vocab_defaults'
big_model = ISMIR2017ACR()
big_model.load_state_dict(torch.load(f'{DIR}/{big_exp}/best_model.pth', map_location=device, weights_only=True))

print('Models loaded')

_, val_filenames, _ = get_split_filenames()
val_dataset = FullChordDataset(val_filenames, small_vocab=True)

print('Loaded dataset')

Models loaded
Loaded dataset


In [19]:
small_metrics = evaluate_model_large_vs_small(small_model, is_small=True, dataset=val_dataset)

100%|██████████| 31/31 [03:59<00:00,  7.73s/it]


In [20]:
big_metrics = evaluate_model_large_vs_small(big_model, is_small=False, dataset=val_dataset)

100%|██████████| 31/31 [04:06<00:00,  7.94s/it]


In [21]:
small_metrics

{'mean': {'root': 0.7521009331113848,
  'majmin': 0.7266224277446486,
  'mirex': 0.7266224277446486,
  'third': 0.7266224277446486,
  'seventh': 0.7266224277446486,
  'song_wise_acc': 0.7266224277446486},
 'median': {'root': 0.752433221002297,
  'majmin': 0.7266083548858546,
  'mirex': 0.7266083548858546,
  'third': 0.7266083548858546,
  'seventh': 0.7266083548858546,
  'song_wise_acc': 0.7266083548858546},
 'frame_wise_acc': 0.721178526092597,
 'class_wise_acc_mean': 0.7100064014497233,
 'class_wise_acc_median': 0.7055723581706534}

In [22]:
big_metrics

{'mean': {'root': 0.748228778171669,
  'majmin': 0.7228633715382331,
  'mirex': 0.7228633715382331,
  'third': 0.7228633715382331,
  'seventh': 0.7228633715382331,
  'song_wise_acc': 0.7228633715382331},
 'median': {'root': 0.747905636814356,
  'majmin': 0.7190395911639591,
  'mirex': 0.7190395911639591,
  'third': 0.7190395911639591,
  'seventh': 0.7190395911639591,
  'song_wise_acc': 0.7190395911639591},
 'frame_wise_acc': 0.7174313334356299,
 'class_wise_acc_mean': 0.7123980491226508,
 'class_wise_acc_median': 0.730242825607064}

In [18]:
bad_filename = val_filenames[104]
bad_filename

'weezer_hashpipe'

In [19]:
from src.utils import get_raw_chord_annotation

bad_annotation = get_raw_chord_annotation(bad_filename)
bad_annotation

SortedKeyList([Observation(time=0.0, duration=0.8390000000000001, value='N', confidence=1.0), Observation(time=0.8390000000000001, duration=22.926000000000002, value='A:maj(*3)', confidence=1.0), Observation(time=23.765, duration=3.7870000000000004, value='D:maj(*3)', confidence=1.0), Observation(time=27.552000000000003, duration=3.874, value='A:maj(*3)', confidence=1.0), Observation(time=31.426000000000002, duration=3.8120000000000003, value='D:maj(*3)', confidence=1.0), Observation(time=35.238, duration=3.813, value='F:maj(*3)', confidence=1.0), Observation(time=39.051, duration=3.8470000000000004, value='E:maj(*3)', confidence=1.0), Observation(time=42.898, duration=3.85, value='D:maj(*3)', confidence=1.0), Observation(time=46.748000000000005, duration=3.814, value='C:maj(*3)', confidence=1.0), Observation(time=50.562000000000005, duration=3.785, value='B:maj(*3)', confidence=1.0), Observation(time=54.347, duration=22.917, value='A:maj(*3)', confidence=1.0), Observation(time=77.2640

In [22]:
chord_to_id('D:maj(*3)', use_small_vocab=True)

0

In [None]:
small_dataset = FullChordDataset(override_small_vocab=True)

small_all_preds = []
small_all_labels = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(small_dataset))):
        cqt, label = small_dataset[i]
        pred = small_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        small_all_preds.append(preds[0])  # Keep as tensors
        small_all_labels.append(label)    # Keep as tensors

# Concatenate all predictions and labels at the end
small_all_preds = torch.cat(small_all_preds)
small_all_labels = torch.cat(small_all_labels)

100%|██████████| 1213/1213 [04:22<00:00,  4.62it/s]


In [22]:
big_dataset = FullChordDataset()

big_all_preds = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(big_dataset))):
        cqt, label = big_dataset[i]
        pred = big_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        big_all_preds.append(preds[0])  # Keep as tensors

# Concatenate all predictions and labels at the end
big_all_preds = torch.cat(big_all_preds)

100%|██████████| 1213/1213 [04:34<00:00,  4.41it/s]


In [24]:
# Map large vocabulary predictions to small vocabulary
big_all_preds_small_vocab = torch.tensor([large_to_small_vocab_id(id.item()) for id in big_all_preds])

In [None]:
N_mask = small_all_labels != 0 # Mask out N chords

# Accuracy of small model on small dataset ignoring N (index 0)
small_all_preds_masked = small_all_preds[N_mask]
small_all_labels_masked = small_all_labels[N_mask]
small_correct = (small_all_preds_masked == small_all_labels_masked).sum().item()
small_total = small_all_labels_masked.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset ignoring N (index 0)
big_all_preds_masked = big_all_preds_small_vocab[N_mask]
big_correct = (big_all_preds_masked == small_all_labels_masked).sum().item()
big_total = small_all_labels_masked.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.81
Big model accuracy on small dataset: 0.78


In [29]:
# Accuracy of small model on small dataset without ignoring N
small_correct = (small_all_preds == small_all_labels).sum().item()
small_total = small_all_labels.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset without ignoring N
big_correct = (big_all_preds_small_vocab == small_all_labels).sum().item()
big_total = small_all_labels.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.77
Big model accuracy on small dataset: 0.73
