# Model Training
We first train the model 

In [1]:
# The entire training code:
import pytorch_lightning as pl
import torch
from midi_score import BeatPredictorPL

#trading-off precision for speed with tensor cores, enable if you have tensor cores
torch.set_float32_matmul_precision('medium')

epochs = 10

model = BeatPredictorPL("midi_score/dataset", epochs)
pl.Trainer(accelerator="gpu", devices = 1, max_epochs=epochs).fit(model)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  from .autonotebook import tqdm as notebook_tqdm

  | Name  | Type     | Params | In sizes       | Out sizes   
-------------------------------------------------------------------
0 | model | BeatCRNN | 1.3 M  | [1, 2000, 128] | [1, 2000, 3]
-------------------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.361     Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 23.62it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:  72%|███████▏  | 323/448 [00:46<00:17,  7.02it/s, v_num=154, train/loss=nan.0]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [3]:
#save the model
torch.save(model.state_dict(), './pretrained/beatModel.pth')

In [None]:
#Run this to see the training process
!tensorboard --logdir=lightning_logs/

And then we throw it into the wrapper class:

In [2]:
import torch

def encode_notes(midi_data, interval, cutoff):
    # Find the total duration required
    total_duration = max(
        note[1] + note[2] for note in midi_data[midi_data[:, 1] < cutoff]
    )  # considering the note's offset
    length = int(total_duration // interval)
    # Create an encoding matrix filled with zeros
    encoding = torch.zeros(length, 128)
    # Populate the encoding for the notes from midi_data
    for idx, note in enumerate(midi_data[midi_data[:, 1] < cutoff]):
        pitch, onset, duration, _ = note
        start_idx = int((onset.item() // interval))
        end_idx = int(((onset.item() + duration.item()) // interval))
        encoding[start_idx:end_idx, int(pitch.item())] = 1

    return encoding

In [3]:
from model_wrapper import MuscribeModelWrapper
import midi_score
import torch
from midi_score import BeatPredictorPL

model = BeatPredictorPL("midi_score/dataset", 10)
model.load_state_dict(torch.load('./pretrained/beatModel.pth'))
# midi = get_midi("example/sonatine.mp3", "example/sonatine.midi")
mmw = MuscribeModelWrapper(beat_model = model.forward)
midi = midi_score.midi_read.read_note_sequence("example/heartgrace.midi")
#print(midi)
encoded_notes = encode_notes(midi,0.02, 30).cuda()
#print(encoded_notes)
print(model.cuda().forward(encoded_notes.unsqueeze(0).cuda()))
##print(encoded_notes[143])
# beats = mmw.get_beats(use_midi = True, midi_notes = encoded_notes.unsqueeze(0).cuda())
# print(beats.shape)
# print(beats)
#key_change = mmw.get_keysig(midi)
#hand_parts = mmw.get_hand_parts(midi)

tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         ...,
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.], device='cuda:0')


  input = module(input)


In [None]:
builder = midi_score.MusicXMLBuilder(beats)
builder.add_notes(midi.numpy(), hand_parts.numpy())
builder.add_key_changes(key_change)
builder.infer_bpm_changes(diff_size=2, log_bin_size=0.03)
builder.render("example/sonatine.xml")