This notebook uses `musedetec` to train hierachical models, with one model for each instrument group.
It is more advanced: if you're just getting started, check out the simpler [example_medleydb](./example_medleydb.ipynb) notebook first.


In [None]:
import logging
import sys
from datetime import timedelta
from pathlib import Path

import matplotlib.pyplot as plt
import medleydb_instruments as mdb
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, ExactMatch, F1Score, Precision, Recall, Specificity, StatScores

from musedetect.data import MedleyDBDataset, MedleyDBPreprocessor, get_all_instruments, train_test_split
from musedetect.data.preprocess_transforms import MFCCTransform
from musedetect.eval import compute_metrics
from musedetect.models import CnnAudioNet
from musedetect.training import autodetect_device, train

In [None]:
MDB_WAV_PATH = "/media/data/linkaband_data/mdb_split/train"
MDB_WAV_PATH_TEST = "/media/data/linkaband_data/mdb_split/test"

In [None]:
instruments = get_all_instruments()

print(instruments)

In [None]:
logging.basicConfig(format="%(levelname)s : %(message)s", level=logging.INFO, stream=sys.stdout)

## Create MFCC Dataset


To begin with, we transform the dataset of `.wav` audio files into a dataset of MFCC features. The preprocessing can be slow, so we write the MFCC features to disk instead of doing them on the fly.


In [None]:
transform = MFCCTransform(
    origin_sample_rate=44100,  # The sample rate of the .wav files data
    new_sample_rate=22050,  # Resample to this rate before generating the MFCC features
    window_size=timedelta(seconds=1),  # How to split that .wav file in data points
    stride=timedelta(seconds=1),  # How to split that .wav file in data points
    n_mfcc=80,  # Number of MFCC bins
    melkwargs={
        "n_mels": 224,
        "n_fft": 2048,
        "f_max": 11025,
    },  # Arguments for the STFT and the Melspectrogram generation
)

In [None]:
preprocessor = MedleyDBPreprocessor(transform=transform)

In [None]:
MDB_PATH = "/media/data/linkaband_data/mdb_split/train_features"
MDB_PATH_TEST = "/media/data/linkaband_data/mdb_split/test_features"

In [None]:
try:
    preprocessor.apply(MDB_WAV_PATH, MDB_PATH, overwrite=False)
except FileExistsError:
    print("Dataset already exists, not regenerating")

try:
    preprocessor.apply(MDB_WAV_PATH_TEST, MDB_PATH_TEST, overwrite=False)
except FileExistsError:
    print("Test dataset already exists, not regenerating")

## Create pytorch Dataset


Load the data into the pytorch dataset:


In [None]:
data = MedleyDBDataset(MDB_PATH, hierarchy=True, class_names=instruments)
test_data = MedleyDBDataset(MDB_PATH_TEST, hierarchy=True, class_names=instruments)

We want one dataset per model, with custom labels.


In [None]:
class CustomDataset:
    def __init__(self, root, transform, files, labels):
        self.root = root
        self.transform = transform
        self.files = files
        self.labels = labels

    def __getitem__(self, index):
        x = torch.load(self.root / self.files[index]).unsqueeze(0)
        if self.transform is not None:
            x = self.transform(x)
        return x, self.labels[index]

    def __len__(self):
        return len(self.files)

In [None]:
# Convert from instrument group/category to idx of instruments in that category
# idx indexes into instruments
cat2instridx = [[] for _ in range(7)]
for i, idx in enumerate(data.class_name_to_aggregated_idx[1:]):
    cat2instridx[idx - 1].append(i)

In [None]:
train_data, val_data = train_test_split(data, [0.8, 0.2], seed=42)

In [None]:
def generate_datasets(data):
    """
    Generate individual datasets for each instrumental group.
    """
    datasets = {
        category: CustomDataset(data.dataset.root, data.dataset.transform, [], [])
        for category in data.dataset.aggregated_class_names[1:]
    }

    # Only get labels from the subset
    labels = torch.vstack(data.dataset.labels)[data.indices]
    instr_labels = labels[:, 1 : len(data.dataset.class_names)]
    class_labels = labels[:, len(data.dataset.class_names) :]

    # Only get files from the subset
    files = np.array(data.dataset.files)[data.indices]

    for i, cat_name in enumerate(datasets.keys()):
        dataset_labels = instr_labels[:, cat2instridx[i]]
        datasets[cat_name].labels = dataset_labels
        datasets[cat_name].files = files.tolist()

    # Add dataset for the first tree leve, aka instrument groups
    datasets["global"] = CustomDataset(
        data.dataset.root,
        data.dataset.transform,
        files=files,
        labels=class_labels,
    )

    return datasets


train_datasets = generate_datasets(train_data)
val_datasets = generate_datasets(val_data)

## Model


Generate the model, and move it to the GPU if available


In [None]:
device = autodetect_device()
batch_size = 32

In [None]:
# Seeding doesn't guarantee deterministic results, possibly because of nondeterministic GPU operations
# torch.random.manual_seed(42)

# Configure number of training epochs per model
epochs_per_model = {
    "11 Struck idiophones": 10,
    "21 Struck membranophones": 15,
    "31 Simple chordophones": 15,
    "32 Composite chordophones": 15,
    "41 Free aerophones": 15,
    "42 Non-free aerophones": 15,
    "53 Radioelectric instruments": 15,
    "global": 15,
}

for cat_name, epoch_num in epochs_per_model.items():
    print(f"Training on {cat_name}")
    local_train_data = train_datasets[cat_name]
    local_val_data = val_datasets[cat_name] if val_datasets else None

    model = CnnAudioNet(len(local_train_data.labels[0]))
    model.to(device)
    train_loader = DataLoader(local_train_data, batch_size=batch_size, prefetch_factor=2, num_workers=3)
    val_loader = (
        DataLoader(local_val_data, batch_size=batch_size, shuffle=False, num_workers=3) if local_val_data else None
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    num_labels = len(local_train_data.labels[0])

    train(
        epochs=epoch_num,
        model=model,
        loss_fn=nn.BCEWithLogitsLoss(),  # FocalLossWithLogits(),
        optimizer=optimizer,
        device=device,
        train_loader=train_loader,
        val_loader=val_loader,
        val_metrics_freq=1,
        metrics=[
            F1Score(task="multilabel", average="micro", num_labels=num_labels),
            Precision(task="multilabel", num_labels=num_labels),
            Recall(task="multilabel", num_labels=num_labels),
        ]
        if num_labels > 1
        else [],
        log_dir="./logs/Medley",
        quiet=False,
    )
    torch.save(model, f"multi_{cat_name}.pt")

## Inference


In [None]:
models = {cat: torch.load(f"multi_{cat}.pt") for cat in train_datasets}

Build the giant hierarchical model by combining smaller models


In [None]:
class SuperModel(nn.Module):
    def __init__(self, models, cat2idx):
        super().__init__()
        self.models = models.copy()
        self.global_model = self.models.pop("global")
        self.cat2idx = cat2idx
        self.instr_len = max(max(x) for x in self.cat2idx)
        self.total_length = self.instr_len + 1 + len(self.cat2idx) + 2

    def forward(self, x):
        res = torch.zeros(x.shape[0], self.total_length, device=x.device)
        global_res = self.global_model(
            x
        )  # +1 for initial silence instrument, +1 to start the step after the last instrument
        res[:, self.instr_len + 2 :] = torch.sigmoid(global_res)

        for i, model in enumerate(self.models.values()):
            y = model(x)
            res[:, torch.tensor(self.cat2idx[i]) + 1] = torch.sigmoid(y)  # * global_res[:, i + 1, None]

        res[:, 0] = res[:, self.instr_len + 1]
        return res

In [None]:
model = SuperModel(models, cat2instridx)

## Analyze result


In [None]:
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=8)

In [None]:
group_idx = list(range(len(data.class_names), len(data.class_names) + len(data.aggregated_class_names)))
instrument_idx = list(range(len(data.class_names)))

metric_values = {
    k: [
        F1Score(task="multilabel", average="micro", num_labels=v),
        Precision(task="multilabel", average="micro", num_labels=v),
        Recall(task="multilabel", average="micro", num_labels=v),
        Accuracy(task="multilabel", average="micro", num_labels=v),
        ExactMatch(task="multilabel", average="micro", num_labels=v),
    ]
    for k, v in (
        {
            "flat": len(data.class_names + data.aggregated_class_names),
            "groups": len(group_idx),
            "instruments": len(instrument_idx),
        }
    ).items()
}

results = compute_metrics(
    model,
    test_loader,
    device,
    metrics=metric_values,
    groups_idx=group_idx,
    instruments_idx=instrument_idx,
    show_progress=True,
)

for level in results:
    print(level)
    for metric, res in zip(metric_values[level], results[level]):
        print(f"{metric.__class__.__name__}: {res.cpu().item()}")
    print()

## Plotting


In [None]:
group_idx = list(range(len(data.class_names), len(data.class_names) + len(data.aggregated_class_names)))
instrument_idx = list(range(len(data.class_names)))


metric_values = {
    k: [
        Accuracy(task="multilabel", average=None, num_labels=v),
        Precision(task="multilabel", average=None, num_labels=v),
        Recall(task="multilabel", average=None, num_labels=v),
        Specificity(task="multilabel", average=None, num_labels=v),
        F1Score(task="multilabel", average=None, num_labels=v),
        StatScores(task="multilabel", average=None, num_labels=v),
        Precision(task="multilabel", average="micro", num_labels=v),
        Recall(task="multilabel", average="micro", num_labels=v),
        F1Score(task="multilabel", average="micro", num_labels=v),
    ]
    for k, v in (
        {
            "flat": len(data.class_names) + len(data.aggregated_class_names),
            "groups": len(group_idx),
            "instruments": len(instrument_idx),
        }
    ).items()
}

results = compute_metrics(
    model,
    test_loader,
    device,
    metrics=metric_values,
    groups_idx=group_idx,
    instruments_idx=instrument_idx,
)

for level in results:
    print(level)
    for metric, res in zip(metric_values[level][-3:], results[level][-3:]):
        print(f"{metric.__class__.__name__}: {res.cpu().item()}")
    print()

In [None]:
conf_mat = results["instruments"][-4].cpu()  # StatsScore
df_conf = pd.DataFrame(
    data=conf_mat,
    index=np.array(instruments),
    columns=["TP", "FP", "TN", "FN", "Support"],
)
# df_conf = df_conf.drop(columns="Support")
# df_conf["Precision"] = results["instruments"][1].cpu()[in_train]
# df_conf["Recall"] = results["instruments"][2].cpu()[in_train]
df_conf["FP rate"] = df_conf["FP"] / (df_conf["FP"] + df_conf["TN"])

In [None]:
from musedetect.data.medleydb import hornbostel_sachs

list_tracks = [file.stem for file in list(Path("/media/data/linkaband_data/mdb_split/train").glob("[!._]*"))]
dataset = list(mdb.MultiTrack(track_name) for track_name in list_tracks)
new_dataset = [x for x in dataset if x.has_bleed is False]
instrument_music_counts = {k: 0 for k in data.class_names}
for track in new_dataset:
    for instrument in track.instruments:
        instrument_music_counts[instrument] += 1

In [None]:
instrument_frame_counts = torch.zeros((len(data.class_names),))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=8)
for _, y in train_loader:
    instrument_frame_counts += y[:, : len(data.class_names)].sum(0)

test_instrument_frame_counts = torch.zeros((len(data.class_names),))
for _, y in test_loader:
    test_instrument_frame_counts += y[:, : len(data.class_names)].sum(0)

In [None]:
df = pd.DataFrame(
    {
        name: metric
        for name, metric in zip(
            [
                "accuracy",
                "precision",
                "recall",
                "specificity",
                "f1",
            ],
            np.array([x.cpu().numpy() for x in results["instruments"][:5]]),
        )
    },
    index=np.array(["silence"] + data.class_names[1:]),
)
df["music_count"] = instrument_music_counts.values()
df["frame_count"] = instrument_frame_counts
df["test_frame_count"] = test_instrument_frame_counts
df["instrument"] = df.index
df["group"] = df.instrument.apply(lambda x: hornbostel_sachs(x) if x != "silence" else "Silence")
df = df.sort_values("frame_count", ascending=False)

In [None]:
import matplotlib as mpl
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

legend_elements = [
    Patch(facecolor=sns.color_palette()[0], label="Train", alpha=0.5),
    Patch(facecolor=sns.color_palette()[1], label="Test", alpha=0.5),
    Line2D([0], [0], marker="o", color="w", label="Precision", linewidth=0, markerfacecolor=sns.color_palette()[2]),
    Line2D([0], [0], marker="o", color="w", label="Recall", linewidth=0, markerfacecolor=sns.color_palette()[3]),
]


sns.set_theme("paper")
sns.set_context("paper")
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["axes.facecolor"] = "white"
mpl.rcParams["grid.color"] = "black"

fit, ax = plt.subplots(figsize=(6, 15))
plt.xlim((-0.01, 1.0))
plt.ylabel("Instruments")
plt.xlabel("Score")
ax.grid(axis="y", visible=False)

ax_bar = ax.twiny()
plt.xlabel("Number of samples")

ax_scatter = ax.twiny()


ax.xaxis.set_label_position("top")
ax.xaxis.tick_top()

ax_bar.xaxis.set_label_position("bottom")
ax_bar.xaxis.tick_bottom()
ax_bar.grid(False)

ax_scatter.tick_params(axis="x", top=False, bottom=False, labeltop=False, labelbottom=False)
ax_scatter.grid(False)

# Plot frame count for training and test data
sns.barplot(
    data=pd.DataFrame(
        {
            "instrument": df.instrument.tolist() * 2,
            "frame_count": df.frame_count.tolist() + df.test_frame_count.tolist(),
            "Data split": ["train"] * len(df.frame_count) + ["test"] * len(df.test_frame_count),
        }
    ),
    y="instrument",
    x="frame_count",
    hue="Data split",
    alpha=0.5,
    dodge=False,
    legend=False,
    ax=ax_bar,
)


ylim = ax_bar.get_ylim()
plt.ylim(ylim)

xlim = ax.get_xlim()
df_plot = df[["instrument", "precision", "recall"]].melt("instrument", var_name="Metric", value_name="vals")
for i in range(len(df.precision)):
    colors = {"precision": sns.color_palette()[2], "recall": sns.color_palette()[3]}
    if df.precision.iloc[i] > df.recall.iloc[i]:
        col1, col2 = ("precision", "recall")
    else:
        col1, col2 = ("recall", "precision")
    plt.plot([-1, df[col2].iloc[i]], [i, i], color=colors[col2], linestyle="-", linewidth=0.5, zorder=0)
    plt.plot(
        [df[col2].iloc[i], df[col1].iloc[i]],
        [i, i],
        color=colors[col1],
        linestyle="-",
        linewidth=0.5,
        zorder=0,
    )
sns.scatterplot(
    df_plot,
    y="instrument",
    x="vals",
    hue="Metric",
    ax=ax_scatter,
    linewidth=1,
    marker="o",
    palette=sns.color_palette()[2:],
    legend=False,
)
plt.xlim(xlim)


plt.legend(title="Category", handles=legend_elements, loc=(0.7, 0.01))
plt.show();

In [None]:
df_group = (
    pd.DataFrame(
        {
            name: metric
            for name, metric in zip(
                [
                    "accuracy",
                    "precision",
                    "recall",
                    "specificity",
                    "f1",
                ],
                np.array([x.cpu().numpy() for x in results["groups"][:5]]),
            )
        },
        index=np.array(["Silence"] + data.aggregated_class_names[1:]),
    )
    .reset_index(names="group")
    .sort_values("group")
    .reset_index(drop=True)
)
df_group[["frame_count", "test_frame_count"]] = (
    df.groupby("group")[["frame_count", "test_frame_count"]].sum().sort_values("group").reset_index(drop=True)
)
df_group = df_group.sort_values("frame_count", ascending=False)

In [None]:
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

legend_elements = [
    Patch(facecolor=sns.color_palette()[0], label="Train", alpha=0.5),
    Patch(facecolor=sns.color_palette()[1], label="Test", alpha=0.5),
    Line2D([0], [0], marker="o", color="w", label="Precision", linewidth=0, markerfacecolor=sns.color_palette()[2]),
    Line2D([0], [0], marker="o", color="w", label="Recall", linewidth=0, markerfacecolor=sns.color_palette()[3]),
]


sns.set_theme("paper")
sns.set_context("paper")
mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["axes.facecolor"] = "white"
mpl.rcParams["grid.color"] = "f1f1f5"


fit, ax = plt.subplots(figsize=(8, 5))
plt.ylabel("Groups")
plt.xlabel("Score")
ax.grid(axis="y", visible=False)

ax_bar = ax.twiny()
plt.xlabel("Number of samples")

ax_scatter = ax.twiny()


ax.xaxis.set_label_position("top")
ax.xaxis.tick_top()

ax_bar.xaxis.set_label_position("bottom")
ax_bar.xaxis.tick_bottom()
ax_bar.grid(False)

ax_scatter.tick_params(axis="x", top=False, bottom=False, labeltop=False, labelbottom=False)
ax_scatter.grid(False)

sns.barplot(
    data=pd.DataFrame(
        {
            "group": df_group.group.tolist() * 2,
            "frame_count": df_group.frame_count.tolist() + df_group.test_frame_count.tolist(),
            "Data split": ["train"] * len(df_group.frame_count) + ["test"] * len(df_group.test_frame_count),
        }
    ),
    y="group",
    x="frame_count",
    hue="Data split",
    alpha=0.5,
    dodge=False,
    legend=False,
    ax=ax_bar,
)


ylim = ax_bar.get_ylim()
plt.ylim(ylim)

xlim = ax.get_xlim()
df_group_plot = df_group[["group", "precision", "recall"]].melt("group", var_name="Metric", value_name="vals")
for i in range(len(df_group.precision)):
    colors = {"precision": sns.color_palette()[2], "recall": sns.color_palette()[3]}
    if df_group.precision.iloc[i] > df_group.recall.iloc[i]:
        col1, col2 = ("precision", "recall")
    else:
        col1, col2 = ("recall", "precision")
    plt.plot([-1, df_group[col2].iloc[i]], [i, i], color=colors[col2], linestyle="-", linewidth=0.5, zorder=0)
    plt.plot(
        [df_group[col2].iloc[i], df_group[col1].iloc[i]],
        [i, i],
        color=colors[col1],
        linestyle="-",
        linewidth=0.5,
        zorder=0,
    )
sns.scatterplot(
    df_group_plot,
    y="group",
    x="vals",
    hue="Metric",
    ax=ax_scatter,
    linewidth=1,
    marker="o",
    palette=sns.color_palette()[2:],
    legend=False,
)
plt.xlim(xlim)


plt.legend(title="Category", handles=legend_elements, loc=(0.8, 0.27))
plt.show();