In [11]:
import os
from pathlib import Path
from copy import deepcopy

import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('Using GPU via CUDA:', torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print('Using GPU via MPS (Apple Silicon)')
else:
    device = torch.device("cpu")
    print('Using CPU')

# Use device like this:
# model.to(device)


from mido import MidiFile
from symusic import Score

from miditok import REMI, TokenizerConfig, TokSequence
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

import miditoolkit
from miditoolkit import MidiFile
import json

import pretty_midi

from torch.utils.data import DataLoader

import torch
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
from torch import nn
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
from typing import List

# Load miditok tokenizer
from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile, Instrument, Note






Using GPU via MPS (Apple Silicon)


In [12]:
tokenizer = REMI.from_pretrained("tokenizer.json")

# ----- Dataset -----
class PairedMIDIDataset(Dataset):
    def __init__(self, right_dir: Path, left_dir: Path, max_len=1024):
        self.right_files = sorted(right_dir.glob("*.json"))
        self.left_files = sorted(left_dir.glob("*.json"))
        self.max_len = max_len

    def __len__(self):
        return len(self.right_files)

    def __getitem__(self, idx):
        with open(self.right_files[idx]) as f:
            right = json.load(f)[:self.max_len]
        with open(self.left_files[idx]) as f:
            left = json.load(f)[:self.max_len]

        return torch.tensor(right), torch.tensor(left)

# ----- Collate function -----
def collate_fn(batch):
    right_batch, left_batch = zip(*batch)
    right_batch = [torch.tensor(seq, dtype=torch.long) for seq in right_batch]
    left_batch = [torch.tensor(seq, dtype=torch.long) for seq in left_batch]

    pad_token_id = tokenizer["PAD_None"]  # Use string-based access here
    right_padded = nn.utils.rnn.pad_sequence(right_batch, batch_first=True, padding_value=pad_token_id)
    left_padded = nn.utils.rnn.pad_sequence(left_batch, batch_first=True, padding_value=pad_token_id)

    return right_padded, left_padded



# ----- Dataloader -----
right_json_dir = Path("tokenized_json/right_hand")
left_json_dir = Path("tokenized_json/left_hand")

dataset = PairedMIDIDataset(right_json_dir, left_json_dir)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# ----- Model: Simple Transformer -----
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=n_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.fc_out = nn.Linear(emb_dim, vocab_size)

    def forward(self, src, tgt):
        src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        src_emb = self.embedding(src)
        tgt_emb = self.embedding(tgt)
        memory = self.encoder(src_emb.transpose(0, 1), src_mask)
        out = self.decoder(tgt_emb.transpose(0, 1), memory, tgt_mask)
        logits = self.fc_out(out.transpose(0, 1))
        return logits

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

In [None]:
from miditoolkit import MidiFile, Instrument, Note

def score_to_midi(score_tick):
    midi = MidiFile()

    # Correctly access instruments/tracks from ScoreTick
    try:
        for track in score_tick.tracks:  # ← this is the fix
            midi_instr = Instrument(
                program=track.program,
                is_drum=track.is_drum,
                name=track.name
            )
            for note in track.notes:
                midi_instr.notes.append(Note(
                    pitch=note.pitch,
                    start=note.start,
                    end=note.end,
                    velocity=note.velocity
                ))
            midi.instruments.append(midi_instr)
    except AttributeError as e:
        raise ValueError("Provided object does not contain valid MIDI track info") from e

    return midi
from miditoolkit import MidiFile, Instrument
import torch
import numpy as np

def generate_left_hand_and_save_midi(
    right_hand_tokens,
    model,
    tokenizer,
    output_path,
    max_len=1024,
    device=device,
    valid_token_ids=None,  # ← NEW: pass a list or set of allowed IDs
):
    
    model.eval()

    # Ensure right_hand_tokens is a batched tensor
    if isinstance(right_hand_tokens, list):
        input_ids = torch.tensor([right_hand_tokens], dtype=torch.long, device=device)
    elif isinstance(right_hand_tokens, torch.Tensor):
        if right_hand_tokens.ndim == 1:
            input_ids = right_hand_tokens.unsqueeze(0).to(device)
        else:
            input_ids = right_hand_tokens.to(device)
    else:
        raise ValueError("right_hand_tokens must be a list of ints or a torch.Tensor")

    bos_token_id = tokenizer.vocab.get("BOS_None", tokenizer.vocab.get("BOS", 0))
    eos_token_id = tokenizer.vocab.get("EOS_None", tokenizer.vocab.get("EOS", -1))

    decoder_input = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)

    vocab_size = len(tokenizer)
    mask_tensor = torch.full((vocab_size,), float('-inf'), device=device)

    mask_tensor[valid_token_ids] = 0.0


    # Autoregressive generation
    with torch.no_grad():
        for _ in range(max_len):
            output = model(input_ids, decoder_input)  # (batch, seq, vocab)
            next_token_logits = output[:, -1, :]      # (batch, vocab)

            # Apply mask
            next_token_logits = next_token_logits + mask_tensor

            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            decoder_input = torch.cat([decoder_input, next_token], dim=1)

            if next_token.item() == eos_token_id:
                break

    left_hand_tokens = decoder_input.squeeze(0).tolist()
    left_hand_tokens[0:2] = input_ids.squeeze(0).tolist()[0:2]

    # Decode to Score objects
    print('right tokens', input_ids.squeeze(0).tolist())
    print('left tokens', left_hand_tokens)
    

    right_score = tokenizer.decode(input_ids.squeeze(0).tolist())
    left_score = tokenizer.decode(left_hand_tokens)
    
    print('right score:', right_score)
    print('left score:', left_score)

    # Convert to MIDI
    right_midi = score_to_midi(right_score)
    left_midi = score_to_midi(left_score)

    #print(right_midi.instruments[0])


    print('left midi', left_midi)
    # Create MIDI
    # Create new MIDI and combine tracks
    midi = MidiFile()

    # Append notes from decoded MIDI objects
    for track, program, name in zip([right_midi, left_midi], [0, 0], ["RH-1", "LH-1"]):
        inst = Instrument(program=program, is_drum=False, name=name)
        # Take notes from the first instrument in the decoded track
        decoded_inst = track.instruments[0]
        inst.notes.extend(decoded_inst.notes)
        midi.instruments.append(inst)

    midi.dump(str(output_path))

    print(f"Saved MIDI to {output_path}")
sample_batch = next(iter(dataloader))

right_hand_sample = sample_batch[0][0]  # First sample of the right-hand batch
left_hand_sample = sample_batch[1][0] 

vocab_size = len(tokenizer)
all_ids = []
for right, left in dataloader.dataset:
    all_ids.extend(right)
    all_ids.extend(left)

print(f"Max token ID in dataset: {max(all_ids)}")
print(f"Vocab size from tokenizer: {len(tokenizer)}")

valid_token_id_set = valid_ids = set(t.item() for t in all_ids)
valid_token_ids = torch.tensor(list(valid_token_id_set), dtype=torch.long, device=device)

print(f"Num valid token IDs: {len(valid_token_id_set)}")


model = MusicTransformer(vocab_size=vocab_size)
model.load_state_dict(torch.load(Path('best_model.pth'), weights_only=True))
model.eval()
model = model.to(device)
# Call the generation function
generate_left_hand_and_save_midi(
    right_hand_tokens=right_hand_sample,
    model=model,
    tokenizer=tokenizer,
    output_path="generated_ragtime.mid",
    device=device,
    valid_token_ids=valid_token_ids,
)



  right_batch = [torch.tensor(seq, dtype=torch.long) for seq in right_batch]
  left_batch = [torch.tensor(seq, dtype=torch.long) for seq in left_batch]


In [None]:
# Using miditoolkit
path = os.path.join('generated_ragtime.mid')
midi_obj = miditoolkit.midi.parser.MidiFile(path)
midi_obj

ticks per beat: 480
max tick: 4018
tempo changes: 1
time sig: 1
key sig: 0
markers: 0
lyrics: False
instruments: 2

In [None]:
midi_obj.instruments

[Instrument(program=0, is_drum=False, name=RH-1) - 477 notes,
 Instrument(program=0, is_drum=False, name=LH-1) - 649 notes]

In [None]:

tokenizer = REMI.from_pretrained("tokenizer.json")
# Invert the vocab dictionary: id → token
id_to_token = {v: k for k, v in tokenizer.vocab.items()}

# Load token IDs
import json
with open("tokenized_json/right_hand/acsrnade.json") as f:
    token_ids = json.load(f)

# Map first few token IDs to strings
token_strs = [id_to_token[token_id] for token_id in token_ids[:20]]

print(token_strs)



KeyError: 3173

In [None]:
id_to_token

{0: 'PAD_None',
 1: 'BOS_None',
 2: 'EOS_None',
 3: 'MASK_None',
 4: 'Bar_None',
 5: 'Pitch_21',
 6: 'Pitch_22',
 7: 'Pitch_23',
 8: 'Pitch_24',
 9: 'Pitch_25',
 10: 'Pitch_26',
 11: 'Pitch_27',
 12: 'Pitch_28',
 13: 'Pitch_29',
 14: 'Pitch_30',
 15: 'Pitch_31',
 16: 'Pitch_32',
 17: 'Pitch_33',
 18: 'Pitch_34',
 19: 'Pitch_35',
 20: 'Pitch_36',
 21: 'Pitch_37',
 22: 'Pitch_38',
 23: 'Pitch_39',
 24: 'Pitch_40',
 25: 'Pitch_41',
 26: 'Pitch_42',
 27: 'Pitch_43',
 28: 'Pitch_44',
 29: 'Pitch_45',
 30: 'Pitch_46',
 31: 'Pitch_47',
 32: 'Pitch_48',
 33: 'Pitch_49',
 34: 'Pitch_50',
 35: 'Pitch_51',
 36: 'Pitch_52',
 37: 'Pitch_53',
 38: 'Pitch_54',
 39: 'Pitch_55',
 40: 'Pitch_56',
 41: 'Pitch_57',
 42: 'Pitch_58',
 43: 'Pitch_59',
 44: 'Pitch_60',
 45: 'Pitch_61',
 46: 'Pitch_62',
 47: 'Pitch_63',
 48: 'Pitch_64',
 49: 'Pitch_65',
 50: 'Pitch_66',
 51: 'Pitch_67',
 52: 'Pitch_68',
 53: 'Pitch_69',
 54: 'Pitch_70',
 55: 'Pitch_71',
 56: 'Pitch_72',
 57: 'Pitch_73',
 58: 'Pitch_74',
 59: '

In [None]:
# For trained tokenizer with BPE/multi-vocab setup
id_to_token = {v: k for k, v in tokenizer._vocab_base.items()}  # try this first

# Fallback if that doesn't work
if not id_to_token:
    print('doesnt work')
    id_to_token = {v: k for k, v in tokenizer.vocab_bpe.items()}


In [None]:
len(id_to_token.keys())

594