This notebook demonstrates the usage of the `musedetect` package with the MedleyDB dataset.


In [None]:
import logging
import sys
from collections import defaultdict
from datetime import timedelta
from pathlib import Path

import matplotlib.pyplot as plt
import medleydb_instruments as mdb
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, ExactMatch, F1Score, Precision, Recall, Specificity

from musedetect.data import MedleyDBDataset, MedleyDBPreprocessor, get_all_instruments, train_test_split
from musedetect.data.preprocess_transforms import MFCCTransform
from musedetect.eval import compute_metrics
from musedetect.models import CnnAudioNet
from musedetect.training import FocalLossWithLogits, autodetect_device, train

## Dataset analysis for paper


In [None]:
MDB_WAV_PATH = "/media/data/linkaband_data/mdb_split/train"
MDB_WAV_PATH_TEST = "/media/data/linkaband_data/mdb_split/test"

In [None]:
list_tracks = [file.stem for file in list(Path(MDB_WAV_PATH).glob("[!._]*"))]
dataset = list(mdb.MultiTrack(track_name) for track_name in list_tracks)
new_dataset = [x for x in dataset if x.has_bleed is False]
instruments = defaultdict(lambda: 0)
for track in new_dataset:
    for instrument in track.instruments:
        instruments[instrument] += 1
instruments = {k: v for k, v in sorted(instruments.items(), key=lambda item: item[1], reverse=True)}
sns.set_theme("paper")
sns.set_context("paper")
plt.figure(figsize=(10, 5))
g = sns.barplot(x=list(instruments.keys()), y=np.fromiter(instruments.values(), dtype=int), color="b")
g.set_xticklabels(g.get_xticklabels(), rotation=90)
plt.ylabel("Number of tracks in which the instrument appears");

In [None]:
instruments = get_all_instruments()
print(instruments)

In [None]:
logging.basicConfig(format="%(levelname)s : %(message)s", level=logging.INFO, stream=sys.stdout)

## Create MFCC Dataset


To begin with, we transform the dataset of `.wav` audio files into a dataset of MFCC features. The preprocessing can be slow, so we write the MFCC features to disk instead of doing them on the fly.


In [None]:
transform = MFCCTransform(
    origin_sample_rate=44100,  # The sample rate of the .wav files data
    new_sample_rate=22050,  # Resample to this rate before generating the MFCC features
    window_size=timedelta(seconds=1),  # How to split that .wav file in data points
    stride=timedelta(seconds=0.3),  # How to split that .wav file in data points
    n_mfcc=80,  # Number of MFCC bins
    melkwargs={
        "n_mels": 224,
        "n_fft": 2048,
        "f_max": 11025,
    },  # Arguments for the STFT and the Melspectrogram generation
)

In [None]:
preprocessor = MedleyDBPreprocessor(transform=transform)

Below, we indicate where the generated features (MFCCs) should be saved:


In [None]:
MDB_PATH = "/media/data/linkaband_data/mdb_split/train_features"
MDB_PATH_TEST = "/media/data/linkaband_data/mdb_split/test_features"

In [None]:
try:
    preprocessor.apply(MDB_WAV_PATH, MDB_PATH, overwrite=False)
except FileExistsError:
    print("Dataset already exists, not regenerating")

try:
    preprocessor.apply(MDB_WAV_PATH_TEST, MDB_PATH_TEST, overwrite=False)
except FileExistsError:
    print("Test dataset already exists, not regenerating")

## Create pytorch Dataset


Load the data into the pytorch dataset:


In [None]:
data = MedleyDBDataset(MDB_PATH, hierarchy=True, class_names=instruments)
test_data = MedleyDBDataset(MDB_PATH_TEST, hierarchy=True, class_names=instruments)

In [None]:
train_data, val_data = train_test_split(data, [0.8, 0.2], seed=42)

## Model


Generate the model, and move it to the GPU if available


In [None]:
device = autodetect_device()

In [None]:
model = CnnAudioNet(class_num=len(data.class_names) + len(data.aggregated_class_names))
model.to(device)
print(f"The moodel has {sum(p.numel() for p in model.parameters()) / 1e6:.3f} million parameters")

## Training


In [None]:
batch_size = 64

train_loader = DataLoader(train_data, batch_size=batch_size, num_workers=8)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

Training loop:


In [None]:
# The weights can be used to re-weight a Cross-Entropy Loss, for instance
freq = torch.vstack(data.labels).float().sum(0) / torch.vstack(data.labels).float().sum()
weight = (1 / freq).to(device)
weight = torch.nan_to_num(weight, posinf=1.0)

In [None]:
train(
    epochs=12,
    model=model,
    loss_fn=FocalLossWithLogits(),  # nn.BCEWithLogitsLoss(weight=weight) - feel free to try out different loss functions
    optimizer=optimizer,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    val_metrics_freq=1,  # Compute metrics on the val set every 1 epochs
    metrics=[
        Accuracy(task="multilabel", num_labels=len(data.class_names) + len(data.aggregated_class_names)),
        ExactMatch(task="multilabel", num_labels=len(data.class_names) + len(data.aggregated_class_names)),
        F1Score(
            task="multilabel", average="micro", num_labels=len(data.class_names) + len(data.aggregated_class_names)
        ),
    ],
    log_dir="./logs/Medley",  # You can track the training progress using tensorboard --logdir ./logs/Medley
)

Export model


In [None]:
torch.save(model, "my_model_name.pt")

## Analyze result


In [None]:
model = torch.load("my_model_name.pt")
model.to(device);

Measure the model's performance on the test set:


In [None]:
group_idx = list(range(len(data.class_names) + 1, len(data.class_names) + len(data.aggregated_class_names)))
instrument_idx = list(range(1, len(data.class_names) + 1))


metrics = {
    k: [
        F1Score(task="multilabel", average="micro", num_labels=v),
        Precision(task="multilabel", average="micro", num_labels=v),
        Recall(task="multilabel", average="micro", num_labels=v),
        Accuracy(task="multilabel", average="micro", num_labels=v),
        ExactMatch(task="multilabel", average="micro", num_labels=v),
    ]
    for k, v in (
        {
            "flat": len(data.class_names) + len(data.aggregated_class_names),
            "groups": len(group_idx),
            "instruments": len(instrument_idx),
        }
    ).items()
}

results = compute_metrics(
    model,
    test_loader,
    device,
    metrics=metrics,
    groups_idx=group_idx,
    instruments_idx=instrument_idx,
    show_progress=True,
)

for level in results:
    print(level)
    for metric, res in zip(metrics[level], results[level]):
        print(f"{metric.__class__.__name__}: {res.cpu().item()}")
    print()

## Plotting


In [None]:
group_idx = list(range(len(data.class_names) + 1, len(data.class_names) + len(data.aggregated_class_names)))
instrument_idx = list(range(1, len(data.class_names) + 1))


metrics = {
    k: [
        Accuracy(task="multilabel", average=None, num_labels=v),
        Precision(task="multilabel", average=None, num_labels=v),
        Recall(task="multilabel", average=None, num_labels=v),
        Specificity(task="multilabel", average=None, num_labels=v),
        F1Score(task="multilabel", average=None, num_labels=v),
    ]
    for k, v in (
        {
            "flat": len(data.class_names) + len(data.aggregated_class_names),
            "groups": len(group_idx),
            "instruments": len(instrument_idx),
        }
    ).items()
}

metrics = compute_metrics(
    model,
    test_loader,
    device,
    metrics=metrics,
    groups_idx=group_idx,
    instruments_idx=instrument_idx,
)

In [None]:
list_tracks = [file.stem for file in list(Path("/media/data/linkaband_data/mdb_split/train").glob("[!._]*"))]
dataset = list(mdb.MultiTrack(track_name) for track_name in list_tracks)
new_dataset = [x for x in dataset if x.has_bleed is False]
instrument_music_counts = {k: 0 for k in data.class_names}
for track in new_dataset:
    for instrument in track.instruments:
        instrument_music_counts[instrument] += 1

In [None]:
instrument_frame_counts = torch.zeros((len(data.class_names),))
for _, y in train_loader:
    instrument_frame_counts += y[:, : len(data.class_names)].sum(0)

test_instrument_frame_counts = torch.zeros((len(data.class_names),))
for _, y in test_loader:
    test_instrument_frame_counts += y[:, : len(data.class_names)].sum(0)

In [None]:
df = pd.DataFrame(
    {
        name: metric
        for name, metric in zip(
            [
                "accuracy",
                "precision",
                "recall",
                "specificity",
                "f1",
            ],
            np.array([x.cpu().numpy() for x in metrics["instruments"]]),
        )
    },
    index=np.array(["silence"] + list(data.class_names.keys())[1:]),
)
df["music_count"] = instrument_music_counts.values()
df["frame_count"] = instrument_frame_counts
df["test_frame_count"] = test_instrument_frame_counts
df["instrument"] = df.index
df = df.sort_values("frame_count", ascending=False)

In [None]:
plt.figure(figsize=(15, 5))
sns.set_theme("paper")
sns.set_context("paper")
ax = sns.barplot(x=df["instrument"], y=df["frame_count"], color="b", alpha=0.5)
ax.set_xticklabels(g.get_xticklabels(), rotation=90)
ax.grid(False)
plt.ylabel("Number of training samples")
ax2 = plt.twinx()
df_plot = df[["instrument", "precision", "recall"]].melt("instrument", var_name="Metric", value_name="vals")
sns.lineplot(df_plot, x="instrument", y="vals", hue="Metric", ax=ax2, linewidth=1, marker="o")
plt.xlabel("Instrument")
plt.ylabel("Metric value")
plt.title("Cross-entropy loss")