In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import autorootcwd
import torch
import os
import numpy as np
from tqdm import tqdm
from typing import Dict, List

import matplotlib.pyplot as plt
from torch.utils.data.dataloader import DataLoader
from sklearn.metrics import  accuracy_score

from src.utils import get_torch_device, collate_fn, NUM_CHORDS, get_split_filenames
from src.models.crnn import CRNN
from src.models.base_model import BaseACR
from src.data.dataset import FullChordDataset
from src.eval import EvalMetric

In [3]:
from functools import lru_cache
from src.utils import chord_to_id, id_to_chord

@lru_cache(maxsize=None)
def large_to_small_vocab_id(id: int) -> int:
    """
    Converts a large vocabulary chord id to a small vocabulary chord id.

    Args:
        id (int): The large vocabulary chord id.

    Returns:
        int: The small vocabulary chord id.
    """

    chord = id_to_chord(id, use_small_vocab=False)
    return chord_to_id(chord, use_small_vocab=True)

## Load Models

In [5]:
from src.data.dataset import IndexedDataset
from src.utils import collate_fn_indexed, chord_to_id, get_chord_seq, get_synthetic_chord_seq 
import mir_eval
from src.eval import bootstrap_mean_ci, compute_aggregated_class_metric

def evaluate_model_large_vs_small(
    model: BaseACR,
    dataset: FullChordDataset,
    evals: List[EvalMetric] = [
        EvalMetric.ROOT,
        EvalMetric.MIREX,
        EvalMetric.THIRD,
        EvalMetric.SEVENTH,
        EvalMetric.MAJMIN,
        EvalMetric.ACC,
    ],
    is_small: bool = False,
    batch_size: int = 32,
    device: torch.device = None,
    log_calibration: torch.Tensor = None,
) -> dict:
    """
    Evaluate a model using continuous, song-based metrics computed with mir_eval.
    """
    torch.set_grad_enabled(False)
    if device is None:
        device = get_torch_device()
    model.to(device)
    model.eval()

    filenamed_dataset = IndexedDataset(dataset)

    data_loader = DataLoader(
        filenamed_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn_indexed,
    )

    print("Evaluating model...")
    song_predictions = []

    for (batch_cqts, batch_gens, batch_labels), indices in tqdm(
        data_loader, desc="Predicting"
    ):
        batch_cqts = batch_cqts.to(device)
        if batch_gens is not None and batch_gens.nelement() > 0:
            batch_gens = batch_gens.to(device)
        batch_labels = batch_labels.to(device)

        valid_mask = torch.logical_and(
                batch_labels != -1, batch_labels != chord_to_id("X", use_small_vocab=True)
        )  # Mask out -1 and X labels

        ignore_mask = batch_labels != -1

        if hasattr(model, "use_generative_features") and model.use_generative_features:
            predictions = model.predict(
                batch_cqts, batch_gens, mask=valid_mask, device=device, log_calibration=log_calibration
            )
        else:
            crf_mask = valid_mask.clone()
            crf_mask[:, 0] = True  # Ensure the first frame is always valid
            predictions = model.predict(batch_cqts, mask=crf_mask, device=device, log_calibration=log_calibration)

        predictions = predictions.cpu().numpy()

        if not is_small:
            # Convert large vocabulary predictions to small vocabulary
            predictions = np.vectorize(large_to_small_vocab_id)(predictions)

        for i in range(predictions.shape[0]):
            song_predictions.append(
                {
                    "pred_ids": predictions[i][ignore_mask[i].cpu().numpy()].tolist(),
                    "idx": indices[i],
                }
            )

    song_metric_scores = {m: [] for m in evals}
    song_transition_counts = []
    song_agg_data = []

    for song in tqdm(song_predictions, desc="Evaluating"):
        filename = dataset.get_filename(song["idx"])
        is_synthetic = dataset.is_synthetic(song["idx"])
        pred_labels = [id_to_chord(x, use_small_vocab=True) for x in song["pred_ids"]]

        # Get estimated beat boundaries (from the features) and reference beat boundaries.
        est_beats = dataset.get_beats(song["idx"])

        # Get ground-truth chord sequence (one label per reference beat interval).
        if is_synthetic:
            ref_labels, ref_beats = get_synthetic_chord_seq(
                filename, override_dir=f"{dataset.synthetic_input_dir}/chords", use_small_vocab=True
            )
        else:
            ref_labels, ref_beats = get_chord_seq(
                filename, override_dir=f"{dataset.input_dir}/chords", use_small_vocab=True
            )

        # Convert beat boundaries into intervals.
        est_intervals = np.column_stack((est_beats[:-1], est_beats[1:]))
        ref_intervals = np.column_stack((ref_beats[:-1], ref_beats[1:]))

        # Adjust the estimated intervals so that they span the same range as the reference intervals.
        adjusted_est_intervals, est_labels = mir_eval.util.adjust_intervals(
            est_intervals,
            pred_labels,
            ref_intervals.min(),
            ref_intervals.max(),
            mir_eval.chord.NO_CHORD,
            mir_eval.chord.NO_CHORD,
        )

        merged_intervals, merged_ref, merged_est = (
            mir_eval.util.merge_labeled_intervals(
                ref_intervals, ref_labels, adjusted_est_intervals, est_labels
            )
        )
        durations = mir_eval.util.intervals_to_durations(merged_intervals)

        merged_ref = np.array(merged_ref)
        merged_est = np.array(merged_est)
        durations = np.array(durations)

        # Mask out X chords
        mask_no_X = merged_ref != "X"
        merged_ref = merged_ref[mask_no_X]
        merged_est = merged_est[mask_no_X]
        durations = durations[mask_no_X]

        # Save aggregated data for class-wise metrics.
        song_agg_data.append(
            {"merged_ref": merged_ref, "merged_est": merged_est, "durations": durations}
        )

        for e in evals:
            comp = e.evaluate(hypotheses=merged_est, references=merged_ref)
            score = mir_eval.chord.weighted_accuracy(comp, durations)
            song_metric_scores[e].append(score)

        # Compute number of transitions in the predicted sequence.
        pred_transitions = sum(
            1
            for j in range(len(pred_labels) - 1)
            if pred_labels[j] != pred_labels[j + 1]
        )
        song_transition_counts.append(pred_transitions)

    results = {}
    results["mean"] = {}
    results["median"] = {}
    results["std"] = {}
    results["boostrap-stde"] = {}
    results["bootstrap-95ci"] = {}
    for m in evals:
        results["mean"][m.value] = np.mean(song_metric_scores[m])
        results["median"][m.value] = np.median(song_metric_scores[m])
        results["std"][m.value] = np.std(song_metric_scores[m])
        _, se, ci = bootstrap_mean_ci(song_metric_scores[m], num_bootstrap=10000, ci=95)
        results["boostrap-stde"][m.value] = se
        results["bootstrap-95ci"][m.value] = ci

    results["avg_transitions_per_song"] = np.mean(song_transition_counts)

    class_agg_results = {}
    for e in tqdm(evals, desc="Class-wise metrics"):
        # Compute the overall (aggregated) metric using mean and median over chords.
        aggregated_class_mean = compute_aggregated_class_metric(
            song_agg_data, e, np.mean
        )
        aggregated_class_median = compute_aggregated_class_metric(
            song_agg_data, e, np.median
        )
        class_agg_results[e.value] = {
            "mean": aggregated_class_mean,
            "median": aggregated_class_median,
        }
    results["class_wise"] = class_agg_results

    return results

In [6]:
DIR = f'./results/small_vs_large_vocab'

device = get_torch_device()

small_exp = 'small'
small_model = CRNN(num_classes=26)
small_model.load_state_dict(torch.load(f'{DIR}/{small_exp}/best_model.pth', map_location=device, weights_only=True))
small_model.eval()

big_exp = 'large'
big_model = CRNN()
big_model.load_state_dict(torch.load(f'{DIR}/{big_exp}/best_model.pth', map_location=device, weights_only=True))

print('Models loaded')

_, val_filenames, _ = get_split_filenames()
val_dataset = FullChordDataset(val_filenames, small_vocab=True, dev_mode=True)

print('Loaded dataset')

Models loaded
Loaded dataset


  return self.fget.__get__(instance, owner)()


In [7]:
small_metrics = evaluate_model_large_vs_small(small_model, is_small=True, dataset=val_dataset)

Evaluating model...


Predicting: 100%|██████████| 8/8 [00:39<00:00,  4.88s/it]
Evaluating: 100%|██████████| 241/241 [00:46<00:00,  5.20it/s]
Class-wise metrics: 100%|██████████| 6/6 [01:01<00:00, 10.18s/it]


In [9]:
big_metrics = evaluate_model_large_vs_small(big_model, is_small=False, dataset=val_dataset)

Evaluating model...


Predicting: 100%|██████████| 8/8 [00:52<00:00,  6.60s/it]
Evaluating: 100%|██████████| 241/241 [00:28<00:00,  8.33it/s]
Class-wise metrics: 100%|██████████| 6/6 [00:42<00:00,  7.10s/it]


In [8]:
small_metrics

{'mean': {'root': 80.16192820434149,
  'mirex': 79.00230595404003,
  'third': 76.69456422683054,
  'seventh': 76.69456422683054,
  'majmin': 76.69456422683054,
  'acc': 76.69456422683054},
 'median': {'root': 83.31648483998393,
  'mirex': 81.433310167812,
  'third': 79.37455087138952,
  'seventh': 79.37455087138952,
  'majmin': 79.37455087138952,
  'acc': 79.37455087138952},
 'std': {'root': 12.542917903212135,
  'mirex': 13.114700775773693,
  'third': 14.876969519385401,
  'seventh': 14.876969519385401,
  'majmin': 14.876969519385401,
  'acc': 14.876969519385401},
 'boostrap-stde': {'root': 0.8031619667156519,
  'mirex': 0.8533863511352406,
  'third': 0.9473519619129844,
  'seventh': 0.9572588406109518,
  'majmin': 0.9551846338964388,
  'acc': 0.9460939939578605},
 'bootstrap-95ci': {'root': (78.52032294438204, 81.67142437076724),
  'mirex': (77.3137583783622, 80.617871971463),
  'third': (74.78449657570361, 78.49904217330506),
  'seventh': (74.77221771452572, 78.52058318445792),
  'm

In [11]:
big_metrics

{'mean': {'root': 79.09213953275189,
  'mirex': 79.18775134956911,
  'third': 76.0492967531169,
  'seventh': 76.0492967531169,
  'majmin': 76.0492967531169,
  'acc': 76.0492967531169},
 'median': {'root': 81.60258834509975,
  'mirex': 81.5537335097427,
  'third': 78.69716067380483,
  'seventh': 78.69716067380483,
  'majmin': 78.69716067380483,
  'acc': 78.69716067380483},
 'std': {'root': 13.144994095227174,
  'mirex': 12.890755898531689,
  'third': 14.750633999945409,
  'seventh': 14.750633999945409,
  'majmin': 14.750633999945409,
  'acc': 14.750633999945409},
 'boostrap-stde': {'root': 0.844519794998979,
  'mirex': 0.8318806724249199,
  'third': 0.9462478501822312,
  'seventh': 0.95020127736229,
  'majmin': 0.9594151543115406,
  'acc': 0.9488888860724315},
 'bootstrap-95ci': {'root': (77.42131911805222, 80.73680403015216),
  'mirex': (77.53164221118799, 80.79368282629916),
  'third': (74.09498418290353, 77.8253978220342),
  'seventh': (74.12908848792135, 77.82635153842561),
  'majmi

In [None]:
small_dataset = FullChordDataset(override_small_vocab=True)

small_all_preds = []
small_all_labels = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(small_dataset))):
        cqt, label = small_dataset[i]
        pred = small_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        small_all_preds.append(preds[0])  # Keep as tensors
        small_all_labels.append(label)    # Keep as tensors

# Concatenate all predictions and labels at the end
small_all_preds = torch.cat(small_all_preds)
small_all_labels = torch.cat(small_all_labels)

100%|██████████| 1213/1213 [04:22<00:00,  4.62it/s]


In [22]:
big_dataset = FullChordDataset()

big_all_preds = []

with torch.no_grad():  # Use no_grad to speed up inference
    for i in tqdm(range(len(big_dataset))):
        cqt, label = big_dataset[i]
        pred = big_model(cqt.unsqueeze(0))
        preds = torch.argmax(pred, dim=2)
        big_all_preds.append(preds[0])  # Keep as tensors

# Concatenate all predictions and labels at the end
big_all_preds = torch.cat(big_all_preds)

100%|██████████| 1213/1213 [04:34<00:00,  4.41it/s]


In [24]:
# Map large vocabulary predictions to small vocabulary
big_all_preds_small_vocab = torch.tensor([large_to_small_vocab_id(id.item()) for id in big_all_preds])

In [None]:
N_mask = small_all_labels != 0 # Mask out N chords

# Accuracy of small model on small dataset ignoring N (index 0)
small_all_preds_masked = small_all_preds[N_mask]
small_all_labels_masked = small_all_labels[N_mask]
small_correct = (small_all_preds_masked == small_all_labels_masked).sum().item()
small_total = small_all_labels_masked.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset ignoring N (index 0)
big_all_preds_masked = big_all_preds_small_vocab[N_mask]
big_correct = (big_all_preds_masked == small_all_labels_masked).sum().item()
big_total = small_all_labels_masked.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.81
Big model accuracy on small dataset: 0.78


In [29]:
# Accuracy of small model on small dataset without ignoring N
small_correct = (small_all_preds == small_all_labels).sum().item()
small_total = small_all_labels.size(0)
small_acc = small_correct / small_total
print(f'Small model accuracy on small dataset: {small_acc:.2f}')

# Accuracy of big model on small dataset without ignoring N
big_correct = (big_all_preds_small_vocab == small_all_labels).sum().item()
big_total = small_all_labels.size(0)
big_acc = big_correct / big_total
print(f'Big model accuracy on small dataset: {big_acc:.2f}')

Small model accuracy on small dataset: 0.77
Big model accuracy on small dataset: 0.73
