In [1]:
import pandas as pd
%load_ext autoreload
%autoreload 2

from collections import Counter
import mirdata
import librosa
import stm
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
from rhythmic_features import feature as f
from tqdm import tqdm
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, LeaveOneOut
from sklearn.metrics import accuracy_score, ConfusionMatrixDisplay

## Groove Midi

_Jon Gillick, Adam Roberts, Jesse Engel, Douglas Eck, and David Bamman.
"Learning to Groove with Inverse Sequence Transformations."
International Conference on Machine Learning (ICML), 2019._

The Groove MIDI Dataset (GMD) is composed of 13.6 hours of aligned MIDI and (synthesized) audio of human-performed, tempo-aligned expressive drumming. The dataset contains 1,150 MIDI files and over 22,000 measures of drumming.

It could be used to classify fills or beats. Keep in mind that while fills tend to have a short duration (few seconds), beats tend to be longer. Therefore beats should be segmented in shorter chunks?


track.style -> a string style for the performance formatted as “primary/secondary” (e.g. rock/halftime, funk/purdieshuffle). The primary style comes from the Genre List below.

Genre List: afrobeat, afrocuban, blues, country, dance, funk, gospel, highlife, hiphop, jazz, latin, middleeastern, neworleans, pop, punk, reggae, rock, soul

For the following experiment the label will consist of the primary style only.

In [2]:
metadata = pd.read_csv("~/mir_datasets/groove_midi/info.csv")
metadata

Unnamed: 0,drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split
0,drummer1,drummer1/eval_session,drummer1/eval_session/1,funk/groove1,138,beat,4-4,drummer1/eval_session/1_funk-groove1_138_beat_...,drummer1/eval_session/1_funk-groove1_138_beat_...,27.872308,test
1,drummer1,drummer1/eval_session,drummer1/eval_session/10,soul/groove10,102,beat,4-4,drummer1/eval_session/10_soul-groove10_102_bea...,drummer1/eval_session/10_soul-groove10_102_bea...,37.691158,test
2,drummer1,drummer1/eval_session,drummer1/eval_session/2,funk/groove2,105,beat,4-4,drummer1/eval_session/2_funk-groove2_105_beat_...,drummer1/eval_session/2_funk-groove2_105_beat_...,36.351218,test
3,drummer1,drummer1/eval_session,drummer1/eval_session/3,soul/groove3,86,beat,4-4,drummer1/eval_session/3_soul-groove3_86_beat_4...,drummer1/eval_session/3_soul-groove3_86_beat_4...,44.716543,test
4,drummer1,drummer1/eval_session,drummer1/eval_session/4,soul/groove4,80,beat,4-4,drummer1/eval_session/4_soul-groove4_80_beat_4...,drummer1/eval_session/4_soul-groove4_80_beat_4...,47.987500,test
...,...,...,...,...,...,...,...,...,...,...,...
1145,drummer2,drummer2/session2,drummer2/session2/11,rock,130,beat,4-4,drummer2/session2/11_rock_130_beat_4-4.mid,,1.909613,train
1146,drummer2,drummer2/session2,drummer2/session2/12,rock,130,beat,4-4,drummer2/session2/12_rock_130_beat_4-4.mid,,1.808652,train
1147,drummer2,drummer2/session2,drummer2/session2/13,rock,130,beat,4-4,drummer2/session2/13_rock_130_beat_4-4.mid,,1.864421,train
1148,drummer2,drummer2/session2,drummer2/session2/14,rock,130,beat,4-4,drummer2/session2/14_rock_130_beat_4-4.mid,,1.875960,train


In [3]:
sum(metadata["audio_filename"].isna()) 

60

In [4]:
groove_dataset = mirdata.initialize("groove_midi")
# groove_dataset.download()
groove_dataset.validate()

100%|██████████| 1/1 [00:00<00:00, 862.67it/s]
100%|██████████| 1150/1150 [00:24<00:00, 46.80it/s] 
INFO: Success: the dataset is complete and all files are valid.
INFO: --------------------


({'metadata': {}, 'tracks': {}}, {'metadata': {}, 'tracks': {}})

In [7]:
features = []
labels = []
for _, track in tqdm(groove_dataset.load_tracks().items()):
    if track.beat_type == "fill":
        try:
            y, sr = librosa.load(track.audio_path, sr=None)
            stm_mean = stm.compute_stm(y=y, sr=sr, target_sr=8000, auto_cor_window_seconds=1, with_padding=True)
            features.append(stm_mean[:100])
            labels.append(track.style.split("/")[0])
        except Exception as e:
            # exception encoutered with invalid audio_path
            print("Error:", e)
            continue

c = Counter(labels)
print(c)
encoded_labels = LabelEncoder().fit_transform(labels)
print(len(features))

  0%|          | 0/1150 [00:00<?, ?it/s]

  segment_correlation = segment_correlation / segment_correlation[0]
  2%|▏         | 19/1150 [00:00<00:06, 169.62it/s]

Error: y must be finite everywhere
Error: y must be finite everywhere
Error: y must be finite everywhere
Error: y must be finite everywhere
Error: y must be finite everywhere
Error: y must be finite everywhere


  6%|▌         | 67/1150 [00:00<00:15, 69.54it/s] 

Error: y must be finite everywhere


 10%|█         | 118/1150 [00:01<00:13, 75.27it/s]

Error: y must be finite everywhere


 13%|█▎        | 155/1150 [00:01<00:12, 80.75it/s]

Error: y must be finite everywhere


 16%|█▌        | 182/1150 [00:02<00:11, 83.13it/s]

Error: y must be finite everywhere
Error: y must be finite everywhere
Error: y must be finite everywhere


 18%|█▊        | 212/1150 [00:02<00:10, 92.67it/s]

Error: y must be finite everywhere


 24%|██▍       | 278/1150 [00:03<00:09, 94.02it/s] 

Error: y must be finite everywhere
Error: y must be finite everywhere


 26%|██▌       | 300/1150 [00:03<00:09, 85.93it/s]


KeyboardInterrupt: 

In [6]:
x_train, x_test, y_train, y_test = train_test_split(
    np.array(features), encoded_labels, test_size=0.3, stratify=encoded_labels, random_state=42
)

knn = KNeighborsClassifier(n_neighbors=7, metric="cosine")
# rnn = RadiusNeighborsClassifier(n_neighbors=7, metric="cosine").fit(X=X_train, y=y_train)

k_fold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(knn, features, encoded_labels, cv=k_fold, scoring="accuracy")

print("Cross-Validation Scores:", cv_scores)
print("Mean CV Accuracy:", cv_scores.mean())

knn = KNeighborsClassifier(n_neighbors=7, metric="cosine").fit(X=x_train, y=y_train)
ConfusionMatrixDisplay.from_estimator(knn, x_test, y_test, display_labels=set(labels))
plt.xticks(rotation=90)
plt.show()

Traceback (most recent call last):
  File "/home/stef/uni/internship/rythmic-pattern-analysis/.venv-rythm/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 137, in __call__
    score = scorer._score(
            ^^^^^^^^^^^^^^
  File "/home/stef/uni/internship/rythmic-pattern-analysis/.venv-rythm/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 345, in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "/home/stef/uni/internship/rythmic-pattern-analysis/.venv-rythm/lib/python3.11/site-packages/sklearn/metrics/_scorer.py", line 87, in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/home/stef/uni/internship/rythmic-pattern-analysis/.venv-rythm/lib/python3.11/site-packages/sklearn/utils/_response.py", line 210, in _get_response_values
    y_pred = prediction_method(X)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/stef/uni/internship/rythmic-pattern-analysis/.venv-rythm/lib/python3.11/site-pac

Cross-Validation Scores: [nan nan nan nan nan]
Mean CV Accuracy: nan


ValueError: Input X contains NaN.
KNeighborsClassifier does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor which accept missing values encoded as NaNs natively. Alternatively, it is possible to preprocess the data, for instance by using an imputer transformer in a pipeline or drop samples with missing values. See https://scikit-learn.org/stable/modules/impute.html You can find a list of all estimators that handle NaN values at the following page: https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values