# Clustering beats and fills on Groove Midi Dataset
*https://magenta.tensorflow.org/datasets/e-gmd*

In [None]:
import sys
import os
parent_dir = os.path.abspath(os.path.join('..'))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
    
%load_ext autoreload
%autoreload 2

import mirdata
import librosa
import itertools
import umap
import umap.plot
import numpy as np
import pandas as pd
from scripts.data_loaders import load_malian_jembe_dataset, load_candombe_dataset, load_cretan_dances_dataset, load_ballroom_dataset
from sklearn.preprocessing import LabelEncoder
from scripts.scale_transform_magnitude import compute_stm, compute_stm_multi_channel
from scripts.clusterers import select_best_num_clusters
from pathlib import Path
from tqdm import tqdm

## Smaller version of groove-midi, used for trying out new configurations

In [None]:
# groove_dataset._metadatafill
metadata = pd.read_csv("~/mir_datasets/groove_midi/info.csv")
metadata_beats = metadata[metadata["beat_type"] == "beat"]  # select only beats

metadata_beats = metadata_beats[metadata_beats["duration"] >= 30]
metadata_beats = metadata_beats[metadata_beats["time_signature"] >= "4-4"]

print("duration summary in seconds: \n", metadata_beats["duration"].describe())

metadata_beats["genre"] = metadata_beats["style"].apply(
    lambda x: x.split("/")[0]
)  # create genre column based on style

metadata_beats = metadata_beats[["genre", "style", "audio_filename", "duration"]] # get rid of unnecessary columns
metadata_beats = metadata_beats.reset_index(drop=True) # resetting index is neede for interactive plot

print(metadata_beats.genre.unique())

# Counting occurrences of each genre
genre_counts = metadata_beats['genre'].value_counts()

# Identifying the top n most frequent genres
top_n_genres = genre_counts.head(6).index

metadata_beats = metadata_beats[metadata_beats['genre'].isin(top_n_genres)]

metadata_beats.genre.value_counts()
metadata_beats = metadata_beats.dropna()
metadata_beats = metadata_beats.reset_index(drop=True)


In [None]:
# preparing the data and computing stm
groove_midi_path = Path("~/mir_datasets/groove_midi/").expanduser()
features = []
for row in tqdm(metadata_beats.itertuples(index=False), total=metadata_beats.shape[0]): # TODO: find a more efficient way to loop
    try:
        if row.audio_filename == None: continue
        y, sr = librosa.load(groove_midi_path / row.audio_filename, sr=None, duration=30)
        features.append(compute_stm(y=y, sr=sr))
        # features.append(np.mean(compute_stm_multi_channel(y=y, sr=sr, channels = [0,5,40], num_stm_coefs=200), axis=0)) 
    except Exception as e:
        print(f"Error: {e}")

### Clustering Analysis

In [None]:
num_of_clusters = [i for i in range(3, 6)]
results, optimal_k = select_best_num_clusters(
    n_clusters=num_of_clusters, X=np.array(features), dim_reduction="tsne", cluster_method="kmedoids"
)

print(f"Best number of clusters: {optimal_k}; silhouette score: {results.get(optimal_k)}")

In [None]:
labels = pd.factorize(metadata_beats["genre"])[0]  # integer labels needed for the interactive plot
reducer = umap.UMAP(metric="cosine").fit(features)  # reduce dimensionality

p = umap.plot.interactive(
    reducer, labels=labels, hover_data=metadata_beats, point_size=3,
)  # interactive plot, hover_data can be customized

umap.plot.output_file("groove_midi_beats.html") # save the plot locally
umap.plot.output_notebook() # display inline in notebook
umap.plot.show(p)

<hr>

## Extended Groove Midi

In [27]:
# clean metadata
metadata = pd.read_csv(
    "../datasets/e-gmd-v1.0.0/e-gmd-v1.0.0.csv"
)  # read metadata of extended groove midi

metadata_beats = metadata[metadata["beat_type"] == "fill"]  # select only beats

metadata_beats = metadata_beats[(metadata_beats["duration"] >= 3) & (metadata_beats["duration"] <= 6)]
print("duration summary in seconds: \n", metadata_beats["duration"].describe())

metadata_beats["genre"] = metadata_beats["style"].apply(
    lambda x: x.split("/")[0]
)  # create genre column based on style

metadata_beats = metadata_beats[["genre", "style", "audio_filename", "duration"]] # get rid of unnecessary columns
metadata_beats = metadata_beats.reset_index(drop=True) # resetting index is neede for interactive plot

metadata_beats.genre.unique()
metadata_beats.genre.value_counts()

duration summary in seconds: 
 count    6621.000000
mean        4.097333
std         0.828416
min         3.000045
25%         3.428571
50%         3.772993
75%         4.999546
max         6.000000
Name: duration, dtype: float64


genre
rock          2109
hiphop        1462
neworleans     774
funk           731
reggae         550
jazz           473
soul           341
pop            129
afrocuban       43
country          9
Name: count, dtype: int64

In [28]:
# Counting occurrences of each genre
genre_counts = metadata_beats['genre'].value_counts()

# Identifying the top n most frequent genres
top_n_genres = genre_counts.head(6).index

metadata_beats = metadata_beats[metadata_beats['genre'].isin(top_n_genres)]

metadata_beats.genre.value_counts()

genre
rock          2109
hiphop        1462
neworleans     774
funk           731
reggae         550
jazz           473
Name: count, dtype: int64

In [30]:
# preparing the data and computing stm
groove_midi_path = Path("../datasets/e-gmd-v1.0.0")
features = []
for row in tqdm(metadata_beats.itertuples(index=False), total=metadata_beats.shape[0]): # TODO: find a more efficient way to loop
    try:
        # TODO: segment audio file?
        y, sr = librosa.load(groove_midi_path / row.audio_filename, sr=None, duration=30)
        features.append(
            np.mean(compute_stm_multi_channel(y=y, sr=sr, num_stm_coefs=100), axis=0))
    except Exception as e:
        print(f"Error: {e}")

  "auto_cor_window_seconds is bigger than duration of audio file, setting it to duration"
  4%|▍         | 244/6099 [00:10<04:51, 20.07it/s]

### K-Means clustering and Silhouette analysis
*https://scikit-learn.org/stable/auto_examples/cluster/plot_kmeans_silhouette_analysis.html#sphx-glr-auto-examples-cluster-plot-kmeans-silhouette-analysis-py*

The plot on the left-hand side displays the silhouette score. The silhouette score measures how similar an object is to its own cluster compared to other clusters. 

A high silhouette score indicates that clusters are well-separated, while a low score suggests overlapping clusters or misclassification.

From the thickness of the silhouette plot the cluster size can be visualized.


In [None]:
num_of_clusters = [i for i in range(3, 8)]
results, optimal_k = select_best_num_clusters(
    n_clusters=num_of_clusters, X=np.array(features), dim_reduction="tsne", cluster_method="kmedoids"
)

print(f"Best number of clusters: {optimal_k}; silhouette score: {results.get(optimal_k)}")

### Visualize data with interactive UMAP plot

*https://umap-learn.readthedocs.io/en/latest/plotting.html#interactive-plotting-and-hover-tools*

In [None]:
labels = pd.factorize(metadata_beats["genre"])[0]  # integer labels needed for the interactive plot
reducer = umap.UMAP(metric="cosine").fit(features)  # reduce dimensionality

p = umap.plot.interactive(
    reducer, labels=labels, hover_data=metadata_beats, point_size=3,
)  # interactive plot, hover_data can be customized

umap.plot.output_file("groove_midi_beats.html") # save the plot locally
umap.plot.output_notebook() # display inline in notebook
umap.plot.show(p)

<hr>

# Clustering on Candombe, Malian Jembè, GreekDances, Ballroom and Cuban Salsa

In [None]:
# define parameters to compute scale transform magnitude
stm_params = {"mel_flag" : True, "with_padding" : True, "n_mels" : 50, "autocor_window_type" : "hamming", "num_stm_coefs" : 150}

In [None]:
features_mj, labels_mj, hover_data_mj = load_malian_jembe_dataset(stm_params=stm_params)
hover_data_mj

In [None]:
features_candombe, labels_candombe, hover_data_candombe = load_candombe_dataset(stm_params=stm_params)
hover_data_candombe

In [None]:
features_cretan, labels_cretan, hover_data_cretan = load_cretan_dances_dataset(stm_params=stm_params)
hover_data_cretan

In [None]:
features_ballroom, labels_ballroom, hover_data_ballroom = load_ballroom_dataset(stm_params=stm_params)
hover_data_ballroom

In [None]:
combined_features = list(itertools.chain(features_mj, features_candombe, features_cretan, features_ballroom))
combined_labels = list(itertools.chain(hover_data_mj["label"], hover_data_candombe["label"], hover_data_cretan["label"], hover_data_ballroom["label"]))
combined_hover_data = pd.concat([hover_data_mj, hover_data_candombe, hover_data_cretan, hover_data_ballroom]).reset_index(drop=True)

### K-Means clustering and Silhouette analysis


In [None]:
num_of_clusters = [i for i in range(3, 5)]
results, optimal_k = select_best_num_clusters(
    n_clusters=num_of_clusters, X=np.array(combined_features), dim_reduction="tsne", cluster_method="kmedoids"
)

print(f"Best number of clusters: {optimal_k}; silhouette score: {results.get(optimal_k)}")

### Visualize data with interactive UMAP plot

In [None]:
encoded_labels = LabelEncoder().fit_transform(combined_labels)  # integer labels needed for the interactive plot
reducer = umap.UMAP(metric="cosine").fit(combined_features) # https://umap-learn.readthedocs.io/en/latest/parameters.html#basic-umap-parameters

p = umap.plot.interactive(
    reducer, labels=encoded_labels, hover_data=combined_hover_data, point_size=5,
)  # interactive plot, hover_data can be customized

# umap.plot.output_file("mj.html") # save the plot locally
umap.plot.output_notebook() # display inline in notebook
umap.plot.show(p)