In [None]:
# The entire training code:
import pytorch_lightning as pl

from midi_score import BeatPredictorPL

model = BeatPredictorPL("midi_score/dataset")
pl.Trainer(accelerator="gpu", devices=1, max_epochs=100).fit(model)

In [4]:
# Inspecting dataset: peek at the first batch
from midi_score import BeatPredictorPL

model = BeatPredictorPL("midi_score/dataset")
train_loader = model.train_dataloader()
batch = next(iter(train_loader))
notes, (beats, ) = batch
notes.shape, beats.shape

(torch.Size([154, 200, 128]), torch.Size([154, 200]))

In [6]:
import torchinfo

# Check model size
torchinfo.summary(model, input_data=batch[0])

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                             Output Shape              Param #
BeatPredictorPL                                    --                        --
├─BeatTransformer: 1                               --                        --
│    └─TransformerEncoder: 2                       --                        --
│    │    └─ModuleList: 3-1                        --                        529,920
├─BeatTransformer: 1-1                             [170, 200, 2]             --
│    └─Linear: 2-1                                 [200, 170, 128]           16,512
│    └─PositionalEncoding: 2-2                     [170, 200, 128]           --
│    │    └─Dropout: 3-2                           [170, 200, 128]           --
│    └─TransformerEncoder: 2-3                     [170, 200, 128]           --
│    └─Linear: 2-4                                 [170, 200, 2]             258
Total params: 16,770
Trainable params: 16,770
Non-trainable params: 0
Total mult-adds (M): 3.35
Input siz

In [None]:
import pretty_midi as pm

midi = pm.PrettyMIDI("midi_score/dataset/asap/Bach/Italian_concerto/KyykhynenT03.mid")
piano = midi.instruments[0]
for note in piano.notes:
    print(note.start, note.end, note.pitch, note.velocity, sep="\t")

In [None]:
import torch


def midi_to_encoded_with_annots(midi_data, annots, interval=0.05):
    def get_previous_beat(onset, beats):
        """Helper function to get the previous beat for a given onset time"""
        return max(beat for beat in beats if beat <= onset)

    # Helper function to get beat duration around the given onset time
    def get_beat_duration(onset, beats):
        # Find the nearest beats before and after the onset
        previous_beat = max(beat for beat in beats if beat <= onset)
        next_beat = min(beat for beat in beats if beat > onset)

        return next_beat - previous_beat

    # Find the total duration required
    total_duration = max(
        note[1] + note[2] for note in midi_data
    )  # considering the note's offset
    length = int(total_duration / interval) + 1

    # Create an encoding matrix filled with zeros
    encoding = torch.zeros(128 + 1 + 1 + 16 + 9 + 12 + 1 + 1, length)

    # Populate the encoding for the notes from midi_data
    for idx, note in enumerate(midi_data):
        pitch, onset, _, _ = note
        onset -= annots[4][idx]
        beat_duration = get_beat_duration(
            onset, annots[0]
        )  # get beat duration surrounding this note
        adjusted_duration = (
            beat_duration * annots[5][idx]
        )  # adjusting the duration with its note value
        start_idx = int(torch.round(onset / interval).item())
        end_idx = int(torch.round((onset + adjusted_duration) / interval).long().item())
        encoding[pitch, start_idx:end_idx] = 1

    # Populate the encoding for the annotations
    beats, downbeats, time_signatures, key_signatures, onsets_musical, _, hands = annots

    for beat in beats:
        idx = int(beat / interval)
        encoding[128, idx] = 1

    for downbeat in downbeats:
        idx = int(downbeat / interval)
        encoding[129, idx] = 1

    for ts in time_signatures:
        time, numerator, denominator = ts
        idx = int(time / interval)
        encoding[130 + numerator - 1, idx] = 1  # Numerator encoding

        # Denominator encoding
        denominator_indices = {
            1: 0,
            2: 1,
            4: 2,
            8: 3,
            16: 4,
            32: 5,
            64: 6,
            128: 7,
            256: 8,
        }
        encoding[146 + denominator_indices[denominator], idx] = 1

    for ks in key_signatures:
        time, key_number = ks
        idx = int(time / interval)
        encoding[155 + key_number, idx] = 1

    for idx, onset in enumerate(annots[4]):
        previous_beat = get_previous_beat(midi_data[idx][1], annots[0])
        relative_onset = previous_beat + onset
        idx = int(relative_onset / interval)
        encoding[167, idx] = 1

    for hand in hands:
        time, hand_type = hand
        idx = int(time / interval)
        encoding[168, idx] = hand_type

    return encoding


# Example usage
midi_data = [
    (60, 0.5, 0.2, 50),  # C4 note with onset at 0.5s, duration 0.2s, and velocity 50
    # ... (other notes)
]

annots = [
    [0.3, 0.5, 0.8, 1.5],  # Beats
    [0.5],  # Downbeats
    [(0, 4, 4), (5, 3, 4)],  # Time signatures
    [(0, 5)],  # Key signatures
    [0.5],  # Onsets musical
    [2],  # Note values (this note will have twice its original length in beats)
    [(0.5, 1)],  # Hands
]

encoded_tensor = midi_to_encoded_with_annots(notes, lables)

In [None]:
def inspect_tensor_structure(tensor_list):
    for i, tensor in enumerate(tensor_list):
        print(f"Tensor {i + 1}:")
        print(f"  Shape: {tensor.shape}")
        print(f"  Data type: {tensor.dtype}")
        print(f"  first shape: {tensor[0]}")
        print("-" * 50)


inspect_tensor_structure(lables)