# Demo
We first train the model 

In [None]:
# Set environment variables
%env OMP_NUM_THREADS=1
%env MKL_NUM_THREADS=1
%env CUDA_LAUNCH_BLOCKING = 1
%env TORCH_USE_CUDA_DSA = 1

In [None]:
# The entire training code:
import pytorch_lightning as pl
import torch
from midi_score import BeatPredictorPL


# #trading-off precision for speed with tensor cores, enable if you have tensor cores
torch.set_float32_matmul_precision('medium')

epochs = 50
model = BeatPredictorPL("midi_score/dataset", epochs)

In [None]:
#pl.Trainer(accelerator="gpu", devices = 1, max_epochs=epochs, log_every_n_steps=30, overfit_batches=0.5).fit(model)
pl.Trainer(accelerator="gpu", devices = 1,  gradient_clip_val=1.5, max_epochs=epochs, callbacks=[pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-5)],  log_every_n_steps=30, detect_anomaly=True, overfit_batches=0.1).fit(model)

In [None]:
#save the model
torch.save(model.state_dict(), './pretrained/beatModel.pth')

In [None]:
#Run this to see the training process
!tensorboard --logdir=lightning_logs/

And then we throw it into the wrapper class:

In [None]:
import torch

def encode_notes(midi_data, interval, cutoff):
    # Find the total duration required
    total_duration = max(
        note[1] + note[2] for note in midi_data[midi_data[:, 1] < cutoff]
    )  # considering the note's offset
    length = int(total_duration // interval)
    # Create an encoding matrix filled with zeros
    encoding = torch.zeros(length, 128)
    # Populate the encoding for the notes from midi_data
    for idx, note in enumerate(midi_data[midi_data[:, 1] < cutoff]):
        pitch, onset, duration, _ = note
        start_idx = int((onset.item() // interval))
        end_idx = int(((onset.item() + duration.item()) // interval))
        encoding[start_idx:end_idx, int(pitch.item())] = 1

    return encoding

In [None]:
from model_wrapper import MuscribeModelWrapper
import midi_score
import torch
from midi_score import BeatPredictorPL

model = BeatPredictorPL("midi_score/dataset", 10)
model.load_state_dict(torch.load('./pretrained/beatModel.pth'))
# midi = get_midi("example/sonatine.mp3", "example/sonatine.midi")
mmw = MuscribeModelWrapper(beat_model = model.forward)
midi = midi_score.midi_read.read_note_sequence("example/heartgrace.midi")
#print(midi)
encoded_notes = encode_notes(midi,0.02, 30).cuda()
#print(encoded_notes)
print(model.cuda().forward(encoded_notes.unsqueeze(0).cuda()))
##print(encoded_notes[143])
# beats = mmw.get_beats(use_midi = True, midi_notes = encoded_notes.unsqueeze(0).cuda())
# print(beats.shape)
# print(beats)
#key_change = mmw.get_keysig(midi)
#hand_parts = mmw.get_hand_parts(midi)

In [None]:
builder = midi_score.MusicXMLBuilder(beats)
builder.add_notes(midi.numpy(), hand_parts.numpy())
builder.add_key_changes(key_change)
builder.infer_bpm_changes(diff_size=2, log_bin_size=0.03)
builder.render("example/sonatine.xml")