In [1]:
# reload
%load_ext autoreload
%autoreload 2

In [None]:
import autorootcwd
from tqdm import tqdm
from enum import Enum
import torch
from torch.utils.data import DataLoader
import mir_eval

from src.models.base import BaseACRModel
from src.data.dataset import ChordDataset
from src.utils import id_to_chord

class EvalMetric(Enum):
    """
    Defines an ENUM of evaluation metrics used in this project from the mir_eval library.

    Attributes:
        ROOT: Chord root evaluation metrics.
        MAJMIN: Major/minor evaluation metrics.
        MIREX: MIREX evaluation metric - returns 1 if at least three pitches are in common.
        THIRD: Third evaluation metric - returns 1 if the third is in common.
        SEVENTH: Seventh evaluation metric - returns 1 if the seventh is in common.
    """

    def __str__(self):
        return self.name
    
    ROOT = "root"
    MAJMIN = "majmin"
    MIREX = "mirex"
    THIRD = "third"
    SEVENTH = "seventh"

    def eval_func(self):
        if self == EvalMetric.ROOT:
            return mir_eval.chord.root
        elif self == EvalMetric.MAJMIN:
            return mir_eval.chord.majmin
        elif self == EvalMetric.MIREX:
            return mir_eval.chord.mirex
        elif self == EvalMetric.THIRD:
            return mir_eval.chord.thirds
        elif self == EvalMetric.SEVENTH:
            return mir_eval.chord.sevenths
        else:
            raise ValueError(f"Invalid evaluation metric: {self}")

    def evaluate(self, hypotheses: torch.Tensor, references: torch.Tensor):
        """
        Evaluate a model on a dataset split using a list of evaluation metrics.

        Args:
            hypotheses (torch.Tensor): The model's chord predictions as ids. Shape (B, frames)
            references (torch.Tensor): The ground truth chord labels as ids. Shape (B, frames)

        Returns:
            metrics (torch.Tensor): A tensor of evaluation metrics and their values. Shape (B, frames)
        """
        
        # Initialize the evaluation metrics tensor. Shape (num_batches, num_frames)
        metrics = torch.zeros_like(hypotheses, dtype=torch.float32)

        # Iterate over the batch of chord predictions and ground truth labels
        ref_labels = []
        hyp_labels = []
        for i in range(hypotheses.shape[0]):
            # Convert the chord labels from indices to strings
            ref_labels.extend([id_to_chord(id) for id in references[i]])
            hyp_labels.extend([id_to_chord(id) for id in hypotheses[i]])

        # Evaluate the chord labels using the evaluation metric
        metrics = torch.from_numpy(self.eval_func()(ref_labels, hyp_labels))
        
        # Reshape the metrics tensor to include the batch dimension again
        metrics = metrics.reshape(hypotheses.shape)
        
        return metrics

def evaluate_model(
        model: BaseACRModel,
        dataset: ChordDataset,
        evals: list[EvalMetric] = [EvalMetric.ROOT, EvalMetric.MAJMIN, EvalMetric.MIREX, EvalMetric.THIRD, EvalMetric.SEVENTH]
    ) -> dict[str, float]:
    """
    Evaluate a model on a dataset split using a list of evaluation metrics.

    Args:
        model (BaseACRModel): The model to evaluate.
        dataset (ChordDataset): The dataset to evaluate on.
        evals (list[EvalMetrics]): The evaluation metrics to use. Defaults to [EvalMetrics.ROOT, EvalMetrics.MAJMIN, EvalMetrics.MIREX, EvalMetrics.CHORD_OVERLAP, EvalMetrics.CHORD_LABEL].

    Returns:
        metrics (dict[str, float]): A dictionary of evaluation metrics and their values.
    """

    # Initialize the evaluation metrics dictionary
    metrics = {}

    data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

    for eval in evals:
        # Initialize the evaluation metric
        metrics[eval.value] = 0.0

    # Evaluate the model on the data loader
    for batch_features, batch_labels in tqdm(data_loader):

        # Get the chord predictions from the model
        predictions = model.predict(batch_features)

        # Evaluate the model on the sample using the evaluation metrics
        for eval in evals:
            metrics[eval.value] += torch.mean(eval.evaluate(predictions, batch_labels))

    # Calculate the average evaluation metrics
    for eval in evals:
        metrics[eval.value] /= len(dataset)
    
    return metrics

In [95]:
from src.data.dataset import ChordDataset
from torch.utils.data import random_split

dataset = ChordDataset()

# Split the dataset into train and test
train_size = int(0.999 * len(dataset))
test_size = len(dataset) - train_size

torch.manual_seed(42)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [96]:
from src.models.random_baseline import RandomACR

# Evaluate the random model on the test dataset
model = RandomACR()

metrics = evaluate_model(model, test_dataset)
metrics

  0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])


 50%|█████     | 1/2 [00:00<00:00,  2.22it/s]

torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 1712])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])
torch.Size([1, 3880])


100%|██████████| 2/2 [00:01<00:00,  1.19it/s]

torch.Size([1, 3880])
torch.Size([1, 3880])





{'root': tensor(0.0763, dtype=torch.float64),
 'majmin': tensor(0.0395, dtype=torch.float64),
 'mirex': tensor(0.0395, dtype=torch.float64),
 'third': tensor(0.0395, dtype=torch.float64),
 'seventh': tensor(0.0395, dtype=torch.float64)}