In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import torch
from tqdm.auto import tqdm
from src.common import MaestroSplitType
from torch.utils.data import DataLoader
from src.maestro2 import MaestroDatasetSplit, FrameContextDataset, DynamicBatchIterableDataset2, custom_collate_fn
from torch.nn import MSELoss, BCEWithLogitsLoss, CrossEntropyLoss



In [None]:
dataset = MaestroDatasetSplit(MaestroSplitType.TRAIN)
print(len(dataset.split.entries))

In [None]:
entry = dataset.split.entries[0]
start_ms = 1000.0
end_ms = 15000.0

In [None]:
mel = entry.load_audio.compute_log_mel_spectrogram()
roll = entry.load_midi.get_piano_roll()
mel.shape, roll.shape

In [None]:
loader = FrameContextDataset(dataset, 64, 64)
iterable_loader = DynamicBatchIterableDataset2(loader, 64)
train_loader = DataLoader(
    iterable_loader,
    batch_size=1,  # Let the collate_fn handle the final batching
    collate_fn=custom_collate_fn,
    num_workers=2,
    prefetch_factor=2,
    multiprocessing_context='spawn',
    pin_memory=True,
    pin_memory_device='cuda',
)

In [None]:
from model import get_model2

model = get_model2(64, 'cuda')
# model.load_state_dict(torch.load('model2.pth', weights_only=True))
model.eval()

In [None]:
from src.maestro2 import custom_normalize_batch

x, y = next(iter(train_loader))
x, y = custom_normalize_batch(x, y)
print(x.shape, y.shape)
y_ = model.forward(x)
print(y_.shape)

In [None]:
y_[0, 0].cpu().detach()

In [None]:
from src.maestro2 import LOWEST_MIDI_NOTE
from src.common import MidiWrapper

for i in range(64):
    midiw = MidiWrapper.from_piano_roll(y_[i, 0].transpose(0,1).cpu().detach().numpy(), note_offset=LOWEST_MIDI_NOTE)
    midiw.plot_piano_roll()

In [None]:
for i in range(64):
    midiw2 = MidiWrapper.from_piano_roll(y[i, 0].transpose(0,1).cpu().detach().numpy(), note_offset=LOWEST_MIDI_NOTE)
    midiw2.plot_piano_roll()

In [None]:
from matplotlib import pyplot as plt

plt.imshow(x[0, 0].transpose(0,1).cpu().detach().numpy(), cmap='grey')

In [None]:
entry.load_audio.display_ipython()

In [None]:
import mir_eval
import numpy as np
def midi_to_hz(midi):
    return 440.0 * 2.0**((midi - 69) / 12.0)

pm1 = midiw.midi
pm2 = midiw2.midi

In [None]:
computed_mse = MSELoss()(y, y_)
computed_mse.item()

In [None]:
notes1 = [(note.start, note.end, midi_to_hz(note.pitch)) for note in pm1.instruments[0].notes]
notes2 = [(note.start, note.end, midi_to_hz(note.pitch)) for note in pm2.instruments[0].notes]

# Convert to the format required by mir_eval: intervals and pitches
ref_intervals = np.array([[note[0], note[1]] for note in notes1])
ref_pitches = np.array([note[2] for note in notes1])
est_intervals = np.array([[note[0], note[1]] for note in notes2])
est_pitches = np.array([note[2] for note in notes2])

# Compute F1 score using mir_eval
precision, recall, f1, _ = mir_eval.transcription.precision_recall_f1_overlap(
    ref_intervals, ref_pitches, est_intervals, est_pitches
)

print(f"Precision: {precision}, Recall: {recall}, F1 Score: {f1}")