In [None]:
import fcwt
import librosa
import matplotlib.pyplot as plt
import numpy as np

from debug import index

In [None]:
# Load the audio file

audio_path = 'data/snare.wav'
y, sr = librosa.load(audio_path, sr=None)
print(y.shape, sr)
y = y[:int(sr*1)]


In [None]:
# plot the audio signal
plt.figure()
plt.plot(y)
plt.xlabel('Time (samples)')
plt.ylabel('Amplitude')
plt.title('Audio Signal')
plt.show()

In [None]:
import IPython
IPython.display.Audio(y, rate=sr)

In [None]:
# Compute the CWT
fn = 500
freqs, cwt = fcwt.cwt(y, sr, 20, 20000, fn, scaling='log', fast=True, norm=True)
print(cwt.shape)
cwt = np.abs(cwt)**2

In [None]:
# Visualize the CWT
use_db_scale = True
cwt_dB = librosa.power_to_db(cwt, ref=np.max)
signal = cwt if not use_db_scale else cwt_dB
fig, ax = plt.subplots()
img = ax.imshow(signal, aspect='auto', cmap='inferno')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Frequency (Hz)')
ax.set_title('CWT')
# ax.set_xticks(np.linspace(0,cwt_dB.size,10),np.arange(0,cwt_dB.size/sr,10))
ax.set_yticks(np.arange(0,fn,fn/10),np.round(freqs[::int(fn/10)]))
fig.colorbar(img, ax=ax, format='%+2.0f dB')
plt.show()

In [None]:
# show stft of the same signal
hop_length = 256
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=4096, hop_length=hop_length, n_mels=512)
fig, ax = plt.subplots()
S_dB = librosa.power_to_db(S, ref=np.max)
signal = S_dB if use_db_scale else S
img = librosa.display.specshow(signal, x_axis='time',
                                y_axis='mel', sr=sr,
                                fmin=20,
                                fmax=20000, ax=ax)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel-frequency spectrogram')
plt.show()

# Dataset Stats

In [None]:
import pretty_midi
import numpy as np
from dataset.A2MD import get_tracks
from dataset import get_drums
from dataset.mapping import DrumMapping
import os
import polars as pl
import json

In [None]:
def get_drum_pitch_velocity(path) -> np.ndarray:
    midi = pretty_midi.PrettyMIDI(
        midi_file=path
    )
    drum_instruments: list[pretty_midi.Instrument] = [
        instrument for instrument in midi.instruments if instrument.is_drum
    ]
    notes = np.array(
        [
            (note.pitch, note.velocity)
            for instrument in drum_instruments
            for note in instrument.notes
        ]
    , dtype=np.uint8)
    return notes

def get_mapped_drums(path, mapping: DrumMapping):
    midi = pretty_midi.PrettyMIDI(
        midi_file=path
    )
    return get_drums(midi, mapping)

In [None]:
tracks_per_alignment = get_tracks("./data/a2md_public")
tracks_per_alignment = {folder: [os.path.join("./data/a2md_public", "align_mid", folder, f"align_mid_{iden}.mid") for iden in lst] for folder, lst in tracks_per_alignment.items()}

In [None]:
note_df = pl.DataFrame(schema={"alignment": pl.String, "pitch": pl.UInt8, "velocity": pl.UInt8})
for folder, paths in tracks_per_alignment.items():
    for file in paths:
        notes = get_drum_pitch_velocity(file)
        if len(notes) == 0:
            print(file)
            continue
        note_df = note_df.extend(pl.from_numpy(notes, schema={"pitch": pl.UInt8, "velocity": pl.UInt8}).with_columns(pl.lit(folder).alias("alignment")).select("alignment", "pitch", "velocity"))


In [None]:
(
    note_df.lazy()
        .filter((pl.col("pitch") >= 35) & (pl.col("pitch") <= 81))
        .group_by("alignment", "pitch")
        .agg(
            pl.col("velocity").min().alias("min"),
            pl.col("velocity").mean().alias("mean"),
            pl.col("velocity").max().alias("max"),
            pl.col("velocity").std().alias("std"),
            pl.col("pitch").count().alias("total")
        )
        .sort("alignment", "pitch")
        .with_columns(
            pl.col("pitch")
            .map_elements(pretty_midi.note_number_to_drum_name, return_dtype=pl.String)
            .alias("name")
        )
        .fill_null(0)
        # combine velocity stats into one column
        .with_columns(pl.col("mean", "std").round(2))
        .with_columns(pl.col("min", "max").cast(pl.UInt8))
        .with_columns(pl.col("min", "mean", "max", "std").cast(pl.String))
        .with_columns(velocity="(" + pl.concat_list("min", "mean", "max", "std").list.join(", ") + ")")
        .select("alignment", "name", "total", "velocity")
        .collect()
        .write_csv("processed/A2MD_per_p_combined.csv", float_precision=2)
 )

In [None]:
(
    note_df.lazy()
        .filter((pl.col("pitch") >= 35) & (pl.col("pitch") <= 81))
        .group_by("pitch")
        .agg(
            pl.col("velocity").min().alias("min"),
            pl.col("velocity").mean().alias("mean"),
            pl.col("velocity").max().alias("max"),
            pl.col("velocity").std().alias("std"),
            pl.col("pitch").count().alias("total")
        )
        .sort("pitch")
        .fill_null(0)
        .with_columns(
            pl.col("pitch")
            .map_elements(pretty_midi.note_number_to_drum_name, return_dtype=pl.String)
            .alias("name")
        )
        .with_columns(pl.col("mean", "std").round(2))
        .with_columns(pl.col("min", "max").cast(pl.UInt8))
        .with_columns(pl.col("min", "mean", "max", "std").cast(pl.String))
        .with_columns(velocity="(" + pl.concat_list("min", "mean", "max", "std").list.join(", ") + ")")
        # .select("name", "total", "min", "mean", "max", "std")
        .select("name", "total", "velocity")
        .collect()
        # .write_csv("processed/A2MD_total_combined.csv", float_precision=2)
)

In [None]:
drum_mappings = [DrumMapping.THREE_CLASS_STANDARD, DrumMapping.EIGHT_CLASS, DrumMapping.EIGHTEEN_CLASS]
names_3_map = ['KD', 'SD', 'HH']
names_m_map = ['BD', 'SD', 'TT', 'HH', 'CY', 'RD', 'CB', 'CL']
names_l_map = ['BD', 'SD', 'SS', 'CLP', 'LT', 'MT', 'HT', 'CHH', 'PHH', 'OHH', 'TB', 'RD', 'RB', 'CRC', 'SPC', 'CHC', 'CB', 'CL']
class_names = [names_3_map, names_m_map, names_l_map]

In [None]:
mapped_counter = [np.zeros(len(mapping)) for mapping in drum_mappings]
for folder, paths in tracks_per_alignment.items():
    for file in paths:
        midi = pretty_midi.PrettyMIDI(
            midi_file=file
        )
        for i, mapping in enumerate(drum_mappings):
            drums = get_mapped_drums(file, mapping)
            if drums is None:
                continue
            mapped_counter[i] += np.array([len(drum) for drum in drums])

In [None]:
print(mapped_counter)
relative_count = [count / sum(count) for count in mapped_counter]
out = {
    "three": list(zip(names_3_map, relative_count[0])),
    "eight": list(zip(names_m_map, relative_count[1])),
    "eighteen": list(zip(names_l_map, relative_count[2]))
}
with open("processed/A2MD_relative_mapped.json", mode="wt") as f:
    f.write(json.dumps(out))
    f.flush()

# Experiment plots

In [None]:
import polars as pl
import numpy as np
import ipywidgets as widgets
import polars.selectors as cs
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
hparams = pl.scan_parquet("processed/BA/optuna.parquet").select(pl.col("dir_name", "config", "datetime_start", "seed")).sort("datetime_start")
tensors = pl.scan_parquet("processed/BA/tensors.parquet")
scalars = pl.scan_parquet("processed/BA/scores.parquet")

In [None]:
master_df = tensors.join(scalars, on=["dir_name", "step"], validate="1:1").join(hparams, on="dir_name").sort("dir_name", "step").fill_null(strategy="forward", limit=5)# .filter(pl.col("config").str.contains("tention"))

In [None]:
master_df.sort("dir_name", "step").fill_null(strategy="forward").collect()

In [None]:
master_df.unpivot(on=[cs.starts_with("F-Score"), cs.starts_with("Loss")], index=["dir_name", "step", "config"]).drop_nulls(pl.col("value")).collect()

In [None]:
# plot the loss and accuracy for each experiment

run_progressions = master_df.unpivot(on=[cs.starts_with("F-Score"), cs.starts_with("Loss")], index=["dir_name", "step", "config"]).drop_nulls(pl.col("value")).collect()


loss_progressions = run_progressions.filter(pl.col("variable").str.starts_with("Loss"))
style_order=["Loss/Train", "Loss/Validation", "Loss/Test/MDB_full", "Loss/Test/RBMA_full"]

fig, ax = plt.subplots()
sns.lineplot(loss_progressions.filter(pl.col("config").str.starts_with("Mamba")), x="step", y="value", hue="config", style="variable", style_order=style_order, ax=ax)
ax: plt.Axes = ax
ax.set_yscale("log")
plt.show()


In [None]:
score_progressions = run_progressions.filter(pl.col("variable").str.starts_with("F-Score/Sum"))
style_order=["Loss/Train", "Loss/Validation", "Loss/Test/MDB_full", "Loss/Test/RBMA_full"]

fig, ax = plt.subplots()
sns.lineplot(score_progressions.filter(pl.col("config").str.starts_with("Mamba")), x="step", y="value", hue="config", style="variable", ax=ax)
ax: plt.Axes = ax
plt.savefig("processed/Mamba_scores.svg")
plt.show()



