In [1]:
import midi_score.dataset.script as ds
import torch.nn.functional as F
import torch

def midi_to_encoded(midi_data: torch.Tensor, iv_size=0.05):
    """
    Convert midi_data tensor to encoded format.
    
    midi_data: A tensor of shape (batch, notes, 4) where each row is a tuple of (pitch, onset, duration, velocity)
    interval: Time resolution of encoding
    
    Returns: A tensor of shape (batch, intvs, 128) where intvs is the number of 0.05-second intervals
    """
    # Find the maximum end time to determine the length of the encoded tensor
    starts = midi_data[:, :, 1]
    ends = starts + midi_data[:, :, 2]
    n_intervals = int(torch.ceil(ends.max() / 0.05))

    # Initialize a [batch x intervals x 128] tensor filled with zeros
    encoded = torch.zeros((midi_data.shape[0], n_intervals, 128), device=midi_data.device)
    istarts, iends = torch.floor(starts / iv_size).long(), torch.ceil(ends / iv_size).long()
    for i in range(istarts.shape[0]):
        for j in range(istarts.shape[1]):
            midi_value = int(midi_data[i, j, 0].item())
            encoded[i, istarts[i, j]:iends[i, j], midi_value] = 1

    # for row in midi_data:
    #     start_time = row[1]
    #     end_time = row[1] + row[2]
    #     midi_value = int(row[1])

    #     # Calculate start and end intervals for this note
    #     start_interval = int(torch.floor(start_time / interval))
    #     end_interval = int(torch.ceil(end_time / interval))

    #     # Set 1s in the appropriate positions in the encoded tensor
    #     encoded[start_interval:end_interval, midi_value] = 1

    return encoded


def midi_beat_encoder(beats, length, interval):
    # beats: [batch, notes]
    encoding = torch.zeros(beats.shape[0], length, device=beats.device, dtype=torch.long)
    for n in range(beats.shape[0]):
        for beat in beats[n]:
            idx = int(beat / interval)
            if idx >= length:
                continue
            encoding[n, idx] = 1
    return encoding


def generate_combined_dataset(dataset: ds.MIDIDataset, group_size: int, data_size = 10000):
    
    def dataset_splitter(notes, annots, group_size):
        note_chunks = torch.split(notes, group_size, dim = 0)
        dataset = []
        for val in note_chunks:
            # mask = annots >= val[0][1] & annots <= val[0][-1]
            # indices = torch.nonzero(mask)
            corresponding_annots = annots[annots.ge(val[0][1].item()) & annots.le(val[0][-1].item())]
            padded_annots = F.pad(corresponding_annots, (0, group_size + 20 - corresponding_annots.shape[0])) 
            #print(val.shape[0])
            if(val.shape[0] == group_size):
                dataset.append((val, list([padded_annots])))
            else:
                pass
        return dataset

    combined_dataset = []
    for i in dataset:
        notes, labels = i
        combined_dataset.extend(dataset_splitter(notes, labels[0], group_size))
        if (len(combined_dataset) > data_size):
            break
    
    return combined_dataset




In [2]:
import torch.nn as nn
import torch
import math
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch import Tensor


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class MusicTransformer(nn.Module):
    def __init__(self, d_model, nhead, d_hid, dropout, nlayers):
        super(MusicTransformer, self).__init__()

        self.embedding = nn.Linear(128, d_model)  # d_model is the model's input and output dimensionality
        self.positional_encoding = PositionalEncoding(d_model)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.beats_head = nn.Linear(d_model, 2)
  #      self.downbeats_head = nn.Linear(d_model, 1)
  #      self.time_signatures_head = nn.Linear(d_model, 2) # one for numerator, one for denominator
 #       self.key_signatures_head = nn.Linear(d_model, num_keys)
 #      self.onsets_musical_head = nn.Linear(d_model, 1)
 #       self.note_value_head = nn.Linear(d_model, 1)
 #       self.hands_head = nn.Linear(d_model, 1)

    def forward(self, x):
        x = self.embedding(x.permute(1, 0, 2)).permute(1, 0, 2)
        x += self.positional_encoding(x)
        x = self.transformer_encoder(x)
        beats = self.beats_head(x)
        #beats = torch.softmax(self.beats_head(x), dim=-1)
        #downbeats = torch.sigmoid(self.downbeats_head(x))
        #time_signatures = self.time_signatures_head(x)
        #key_signatures = self.key_signatures_head(x)
        #onsets_musical = self.onsets_musical_head(x)
        #note_value = self.note_value_head(x)
        #hands = torch.sigmoid(self.hands_head(x))

        return beats#, downbeats, time_signatures, key_signatures, onsets_musical, note_value, hands
    
ms_trans = MusicTransformer(128, 8, 256, 0.1, 4)
#ms_trans(encoded_data).shape

In [3]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torchinfo

class ScoreModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MusicTransformer(128, 8, 256, 0.1, 4)

    def forward(self, x):
        # x = x[:, :100]
        encoded = midi_to_encoded(x.float())
        return self.model(encoded).transpose(1, 2)

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), 0.01, 0.9)

    def training_step(self, batch, batch_idx):
        # Data
        x, (y, ) = batch
        y_pred = self(x)
        y = midi_beat_encoder(y.float(), y_pred.shape[2], 0.05)
        #print(y_pred.shape)
        # Loss
        loss = nn.NLLLoss()(y_pred, y)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, (y, ) = batch
        y_pred = self(x)
        y = midi_beat_encoder(y.float(), y_pred.shape[2], 0.05)
        loss = nn.NLLLoss()(y_pred, y)
        accuracy = torch.sum(y_pred.argmax(dim=1) == y) / y.numel()
        self.log("val/loss", loss)
        self.log("val/accuracy", accuracy)
        return accuracy
    
    def train_dataloader(self):
        dataset = ds.MIDIDataset("midi_score/dataset", "midi_score/dataset/features.pkl", "train", ["beats"])
        dataset = generate_combined_dataset(dataset, 50, 20)
        return DataLoader(dataset, 4, True)

    def val_dataloader(self):
        dataset = ds.MIDIDataset("midi_score/dataset", "midi_score/dataset/features.pkl", "validation", ["beats"])
        dataset = generate_combined_dataset(dataset, 50, 20)
        return DataLoader(dataset, 4, False)


# TODO add collate function for dataloader
model = ScoreModule()
torchinfo.summary(model, input_size=(50, 1, 128))
pl.Trainer().fit(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | MusicTransformer | 546 K 
-------------------------------------------
546 K     Trainable params
0         Non-trainable params
546 K     Total params
2.187     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  from .autonotebook import tqdm as notebook_tqdm


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/7 [00:00<?, ?it/s] 

OutOfMemoryError: CUDA out of memory. Tried to allocate 654.00 MiB (GPU 0; 6.00 GiB total capacity; 3.69 GiB already allocated; 0 bytes free; 4.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [16]:
dataset = ds.MIDIDataset("midi_score/dataset", "midi_score/dataset/features.pkl", "train")
# print(len(dataset))
# from midi_score.dataset.script.dataset import prepare_features
# features = prepare_features(Path("midi_score/dataset"), Path("midi_score/dataset/features.pkl"), 8)

import torch.nn.functional as fn
def generate_combined_dataset(dataset: ds.MIDIDataset):
    
    def dataset_splitter(notes, annots, group_size):
        note_chunks = torch.split(notes, group_size, dim = 0)
        dataset = []
        for val in note_chunks:
            # mask = annots >= val[0][1] & annots <= val[0][-1]
            # indices = torch.nonzero(mask)
            corresponding_annots = annots[annots.ge(val[0][1].item()) & annots.le(val[0][-1].item())]
            padded_annots = fn.pad(corresponding_annots, (0, group_size + 50 - corresponding_annots.shape[0])) 
            print(val.shape[0])
            if(val.shape[0] == group_size):
                dataset.append((val, list([padded_annots])))
            else:
                pass
        return dataset

    combined_dataset = []
    for i in dataset:
        notes, labels = i
        combined_dataset.extend(dataset_splitter(notes, labels[0], 200))
    
    return combined_dataset

dataset = ds.MIDIDataset("midi_score/dataset", "midi_score/dataset/features.pkl", "train", ["beats"])
split_dataset = generate_combined_dataset(dataset)



note_first, first_label = split_dataset[0]
note_first.shape

note_second, second_label = split_dataset[1]
print(second_label[0].shape[0]) 



#notes

200
200
200
79
200
200
200
79
200
200
200
79
200
200
200
79
200
200
200
200
200
184
200
200
200
200
129
200
200
200
200
200
200
200
89
200
200
200
200
200
200
200
200
200
200
200
200
200
200
99
200
200
200
200
200
200
200
200
200
200
200
200
200
200
99
200
200
200
200
200
200
200
200
200
198
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
175
200
200
200
200
200
200
200
200
200
200
200
6
200
200
200
200
200
200
200
200
200
200
188
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
177
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
177
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
200
22
200
200
200
200
200
200
82
200
200
200
200
200
200
82
200
200
200
200
200
200
200
200
200
45
200
200
200
200
200
200
200
200
200
45
200
200
200
200
20

In [5]:
import pretty_midi as pm

midi = pm.PrettyMIDI("midi_score/dataset/asap/Bach/Italian_concerto/KyykhynenT03.mid")
piano = midi.instruments[0]
for note in piano.notes:
    print(note.start, note.end, note.pitch, note.velocity, sep="\t")

0.5502131250000001	0.71447578125	62	41
0.51282	0.89342859375	65	52
1.5197634375	1.9417715625	50	43
2.5921448437500003	2.954056875	50	39
3.64315875	4.6661278125	65	56
3.68188734375	4.724888437500001	62	49
4.623392812500001	5.699780625000001	67	65
4.6447603125	5.7064579687500006	64	54
5.58626578125	6.6212540625	65	55
5.56222734375	6.63861515625	69	74
6.63460875	7.0646296875	67	47
6.5892028125	7.2676209375	70	69
7.6402167187500005	8.092940625	50	40
8.736636562500001	9.06916828125	50	34
9.711528750000001	10.81329046875	69	63
9.768953906250001	10.881399375	65	47
10.74651703125	11.8723171875	67	69
10.785245625	11.900362031250001	64	54
11.75613140625	12.91264734375	65	62
11.80554375	12.952711406250002	62	41
12.89395078125	13.385403281250001	61	38
12.8605640625	13.836791718750002	64	51
14.06916328125	14.471139375	50	35
15.1362028125	15.45270890625	50	39
16.203242343750002	16.94042109375	67	69
16.227280781250002	17.4171834375	64	56
17.366435625	18.46285546875	62	59
17.334384375000003	18.4815520

In [66]:
import torch

def midi_to_encoded_with_annots(midi_data, annots, interval=0.05):

    def get_previous_beat(onset, beats):
        """Helper function to get the previous beat for a given onset time"""
        return max(beat for beat in beats if beat <= onset)

    # Helper function to get beat duration around the given onset time
    def get_beat_duration(onset, beats):
        # Find the nearest beats before and after the onset
        previous_beat = max(beat for beat in beats if beat <= onset)
        next_beat = min(beat for beat in beats if beat > onset)

        return next_beat - previous_beat

    # Find the total duration required
    total_duration = max(note[1] + note[2] for note in midi_data)  # considering the note's offset
    length = int(total_duration / interval) + 1

    # Create an encoding matrix filled with zeros
    encoding = torch.zeros(128 + 1 + 1 + 16 + 9 + 12 + 1 + 1, length)

    # Populate the encoding for the notes from midi_data
    for idx, note in enumerate(midi_data):
        pitch, onset, _, _ = note
        onset -= annots[4][idx]
        beat_duration = get_beat_duration(onset, annots[0])  # get beat duration surrounding this note
        adjusted_duration = beat_duration * annots[5][idx]  # adjusting the duration with its note value
        start_idx = int(torch.round(onset / interval).item())
        end_idx = int(torch.round((onset + adjusted_duration) / interval).long().item())
        encoding[pitch, start_idx:end_idx] = 1

    # Populate the encoding for the annotations
    beats, downbeats, time_signatures, key_signatures, onsets_musical, _, hands = annots

    for beat in beats:
        idx = int(beat / interval)
        encoding[128, idx] = 1

    for downbeat in downbeats:
        idx = int(downbeat / interval)
        encoding[129, idx] = 1

    for ts in time_signatures:
        time, numerator, denominator = ts
        idx = int(time / interval)
        encoding[130 + numerator - 1, idx] = 1  # Numerator encoding
        
        # Denominator encoding
        denominator_indices = {1: 0, 2: 1, 4: 2, 8: 3, 16: 4, 32: 5, 64: 6, 128: 7, 256: 8}
        encoding[146 + denominator_indices[denominator], idx] = 1 

    for ks in key_signatures:
        time, key_number = ks
        idx = int(time / interval)
        encoding[155 + key_number, idx] = 1

    for idx, onset in enumerate(annots[4]):
        previous_beat = get_previous_beat(midi_data[idx][1], annots[0])
        relative_onset = previous_beat + onset
        idx = int(relative_onset / interval)
        encoding[167, idx] = 1

    for hand in hands:
        time, hand_type = hand
        idx = int(time / interval)
        encoding[168, idx] = hand_type

    return encoding

# Example usage
midi_data = [
    (60, 0.5, 0.2, 50),  # C4 note with onset at 0.5s, duration 0.2s, and velocity 50
    # ... (other notes)
]

annots = [
    [0.3, 0.5, 0.8, 1.5],  # Beats
    [0.5],  # Downbeats
    [(0, 4, 4), (5, 3, 4)],  # Time signatures
    [(0, 5)],  # Key signatures
    [0.5],  # Onsets musical
    [2],  # Note values (this note will have twice its original length in beats)
    [(0.5, 1)]  # Hands
]

encoded_tensor = midi_to_encoded_with_annots(notes, lables)


IndexError: tensors used as indices must be long, int, byte or bool tensors

In [58]:
def inspect_tensor_structure(tensor_list):
    for i, tensor in enumerate(tensor_list):
        print(f"Tensor {i + 1}:")
        print(f"  Shape: {tensor.shape}")
        print(f"  Data type: {tensor.dtype}")
        print(f"  first shape: {tensor[0]}")
        print('-' * 50)



inspect_tensor_structure(lables)

Tensor 1:
  Shape: torch.Size([125])
  Data type: torch.float64
  first shape: 0.0
--------------------------------------------------
Tensor 2:
  Shape: torch.Size([63])
  Data type: torch.float64
  first shape: 0.0
--------------------------------------------------
Tensor 3:
  Shape: torch.Size([2, 3])
  Data type: torch.float64
  first shape: tensor([0., 1., 4.], dtype=torch.float64)
--------------------------------------------------
Tensor 4:
  Shape: torch.Size([1, 2])
  Data type: torch.float64
  first shape: tensor([0., 2.], dtype=torch.float64)
--------------------------------------------------
Tensor 5:
  Shape: torch.Size([679])
  Data type: torch.int64
  first shape: 0
--------------------------------------------------
Tensor 6:
  Shape: torch.Size([679])
  Data type: torch.float64
  first shape: 1.5
--------------------------------------------------
Tensor 7:
  Shape: torch.Size([679])
  Data type: torch.int64
  first shape: 1
------------------------------------------------