In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print('Using GPU via CUDA:', torch.cuda.get_device_name(0))
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print('Using GPU via MPS (Apple Silicon)')
else:
    device = torch.device("cpu")
    print('Using CPU')

# Use device like this:
# model.to(device)

from torch import nn
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
from typing import List

# Load miditok tokenizer
from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile, Instrument, Note



Using GPU via MPS (Apple Silicon)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = REMI.from_pretrained("tokenizer.json")

# ----- Dataset -----
class PairedMIDIDataset(Dataset):
    def __init__(self, right_dir: Path, left_dir: Path, max_len=1024):
        self.right_files = sorted(right_dir.glob("*.json"))
        self.left_files = sorted(left_dir.glob("*.json"))
        self.max_len = max_len

    def __len__(self):
        return len(self.right_files)

    def __getitem__(self, idx):
        with open(self.right_files[idx]) as f:
            right = json.load(f)[:self.max_len]
        with open(self.left_files[idx]) as f:
            left = json.load(f)[:self.max_len]

        return torch.tensor(right), torch.tensor(left)

# ----- Collate function -----
def collate_fn(batch):
    right_batch, left_batch = zip(*batch)
    #right_batch = [torch.tensor(seq, dtype=torch.long) for seq in right_batch]
    #left_batch = [torch.tensor(seq, dtype=torch.long) for seq in left_batch]
    right_batch = [
        seq.clone().detach() if isinstance(seq, torch.Tensor) else torch.tensor(seq, dtype=torch.long)
        for seq in right_batch
    ]
    left_batch = [
        seq.clone().detach() if isinstance(seq, torch.Tensor) else torch.tensor(seq, dtype=torch.long)
        for seq in left_batch
    ]

    pad_token_id = tokenizer["PAD_None"]  # Use string-based access here
    right_padded = nn.utils.rnn.pad_sequence(right_batch, batch_first=True, padding_value=pad_token_id)
    left_padded = nn.utils.rnn.pad_sequence(left_batch, batch_first=True, padding_value=pad_token_id)

    return right_padded, left_padded



# ----- Dataloader -----
right_json_dir = Path("tokenized_json/right_hand")
left_json_dir = Path("tokenized_json/left_hand")

dataset = PairedMIDIDataset(right_json_dir, left_json_dir)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# ----- Model: Simple Transformer -----
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, emb_dim=256, n_heads=4, n_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=n_heads)
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=n_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.fc_out = nn.Linear(emb_dim, vocab_size)

    def forward(self, src, tgt):
        src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)

        src_emb = self.embedding(src)
        tgt_emb = self.embedding(tgt)
        memory = self.encoder(src_emb.transpose(0, 1), src_mask)
        out = self.decoder(tgt_emb.transpose(0, 1), memory, tgt_mask)
        logits = self.fc_out(out.transpose(0, 1))
        return logits

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

  super().__init__(tokenizer_config, params)


In [3]:
all_ids = []
for right, left in dataloader.dataset:
    all_ids.extend(right)
    all_ids.extend(left)

print(f"Max token ID in dataset: {max(all_ids)}")
print(f"Vocab size from tokenizer: {len(tokenizer)}")

valid_token_id_set = set(t.item() for t in all_ids)
valid_token_ids = torch.tensor(list(valid_token_id_set), dtype=torch.long, device=device)

print(f"Num valid token IDs: {len(valid_token_id_set)}")


Max token ID in dataset: 897
Vocab size from tokenizer: 898
Num valid token IDs: 249


In [4]:
import torch
import torch.nn as nn

#vocab_size = len(valid_token_id_set)
vocab_size = len(tokenizer)
model = MusicTransformer(vocab_size=vocab_size)
device = "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
pad_token = tokenizer.vocab.get("PAD_None", -100)
criterion = nn.CrossEntropyLoss(ignore_index=pad_token)

best_loss = float("inf")  # Initialize best loss to very large value

for epoch in range(1000):
    model.train()
    total_loss = 0

    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])  # teacher forcing
        loss = criterion(output.reshape(-1, output.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1} - Loss: {avg_loss:.4f}")

    # Save the model if there is significant improvement
    if avg_loss < best_loss:
        print(f"Loss improved from {best_loss:.4f} to {avg_loss:.4f}. Saving model...")
        best_loss = avg_loss
        torch.save(model.state_dict(), "best_model.pth")

torch.save(model.state_dict(), "music_transformer_weights.pth")





Epoch 1 - Loss: 4.1180
Loss improved from inf to 4.1180. Saving model...
Epoch 2 - Loss: 2.5016
Loss improved from 4.1180 to 2.5016. Saving model...
Epoch 3 - Loss: 2.0830
Loss improved from 2.5016 to 2.0830. Saving model...
Epoch 4 - Loss: 1.9177
Loss improved from 2.0830 to 1.9177. Saving model...
Epoch 5 - Loss: 1.8210
Loss improved from 1.9177 to 1.8210. Saving model...
Epoch 6 - Loss: 1.7517
Loss improved from 1.8210 to 1.7517. Saving model...
Epoch 7 - Loss: 1.6809
Loss improved from 1.7517 to 1.6809. Saving model...
Epoch 8 - Loss: 1.6677
Loss improved from 1.6809 to 1.6677. Saving model...
Epoch 9 - Loss: 1.6440
Loss improved from 1.6677 to 1.6440. Saving model...
Epoch 10 - Loss: 1.6152
Loss improved from 1.6440 to 1.6152. Saving model...
Epoch 11 - Loss: 1.5966
Loss improved from 1.6152 to 1.5966. Saving model...
Epoch 12 - Loss: 1.5770
Loss improved from 1.5966 to 1.5770. Saving model...
Epoch 13 - Loss: 1.5686
Loss improved from 1.5770 to 1.5686. Saving model...
Epoch 14 - 

In [5]:
torch.save(model.state_dict(), "music_transformer_weights.pth")


In [9]:
from miditoolkit import MidiFile, Instrument, Note

def score_to_midi(score_tick):
    midi = MidiFile()

    # Correctly access instruments/tracks from ScoreTick
    try:
        for track in score_tick.tracks:  # ← this is the fix
            midi_instr = Instrument(
                program=track.program,
                is_drum=track.is_drum,
                name=track.name
            )
            for note in track.notes:
                midi_instr.notes.append(Note(
                    pitch=note.pitch,
                    start=note.start,
                    end=note.end,
                    velocity=note.velocity
                ))
            midi.instruments.append(midi_instr)
    except AttributeError as e:
        raise ValueError("Provided object does not contain valid MIDI track info") from e

    return midi


In [10]:
from miditoolkit import MidiFile, Instrument
import torch
import numpy as np

def generate_left_hand_and_save_midi(
    right_hand_tokens,
    model,
    tokenizer,
    output_path,
    max_len=1024,
    device="cpu",
    valid_token_ids=None,  # ← NEW: pass a list or set of allowed IDs
):
    
    model.eval()

    # Ensure right_hand_tokens is a batched tensor
    if isinstance(right_hand_tokens, list):
        input_ids = torch.tensor([right_hand_tokens], dtype=torch.long, device=device)
    elif isinstance(right_hand_tokens, torch.Tensor):
        if right_hand_tokens.ndim == 1:
            input_ids = right_hand_tokens.unsqueeze(0).to(device)
        else:
            input_ids = right_hand_tokens.to(device)
    else:
        raise ValueError("right_hand_tokens must be a list of ints or a torch.Tensor")

    bos_token_id = tokenizer.vocab.get("BOS_None", tokenizer.vocab.get("BOS", 0))
    eos_token_id = tokenizer.vocab.get("EOS_None", tokenizer.vocab.get("EOS", -1))

    decoder_input = torch.tensor([[bos_token_id]], dtype=torch.long, device=device)

    vocab_size = len(tokenizer)
    #vocab_size = len(valid_token_id_set)

    mask_tensor = torch.full((vocab_size,), float('-inf'), device=device)

    mask_tensor[valid_token_ids] = 0.0


    # Autoregressive generation
    with torch.no_grad():
        for _ in range(max_len):
            output = model(input_ids, decoder_input)  # (batch, seq, vocab)
            next_token_logits = output[:, -1, :]      # (batch, vocab)

            # Apply mask
            next_token_logits = next_token_logits + mask_tensor

            # Top-k sampling (top 50)
            top_k = 50
            top_logits, top_indices = torch.topk(next_token_logits, k=top_k, dim=-1)  # (batch, top_k)
            probs = torch.softmax(top_logits, dim=-1)
            sampled = torch.multinomial(probs, num_samples=1)  # (batch, 1)
            next_token = top_indices.gather(1, sampled)        # (batch, 1)

            decoder_input = torch.cat([decoder_input, next_token], dim=1)

            if next_token.item() == eos_token_id:
                break

    left_hand_tokens = decoder_input.squeeze(0).tolist()
    left_hand_tokens[0:2] = input_ids.squeeze(0).tolist()[0:2]

    # Decode to Score objects
    print('right tokens', input_ids.squeeze(0).tolist())
    print('left tokens', left_hand_tokens)
    

    right_score = tokenizer.decode(input_ids.squeeze(0).tolist())
    left_score = tokenizer.decode(left_hand_tokens)
    
    print('right score:', right_score)
    print('left score:', left_score)

    # Convert to MIDI
    right_midi = score_to_midi(right_score)
    left_midi = score_to_midi(left_score)

    #print(right_midi.instruments[0])


    print('left midi', left_midi)
    # Create MIDI
    # Create new MIDI and combine tracks
    midi = MidiFile()

    # Append notes from decoded MIDI objects
    for track, program, name in zip([right_midi, left_midi], [0, 0], ["RH-1", "LH-1"]):
        inst = Instrument(program=program, is_drum=False, name=name)
        # Take notes from the first instrument in the decoded track
        decoded_inst = track.instruments[0]
        inst.notes.extend(decoded_inst.notes)
        midi.instruments.append(inst)
    midi.ticks_per_beat=32

    midi.dump(str(output_path))

    print(f"Saved MIDI to {output_path}")


In [11]:
sample_batch = next(iter(dataloader))


right_hand_sample = sample_batch[0][0]  # First sample of the right-hand batch
left_hand_sample = sample_batch[1][0] 

vocab_size = len(tokenizer)
#vocab_size = len(valid_token_id_set)

model = MusicTransformer(vocab_size=vocab_size)
model.load_state_dict(torch.load(Path('best_model.pth'), weights_only=True))
model.eval()
model = model.to(device)
# Call the generation function
generate_left_hand_and_save_midi(
    right_hand_tokens=right_hand_sample,
    model=model,
    tokenizer=tokenizer,
    output_path="generated_ragtime.mid",
    device=device,
    valid_token_ids=valid_token_ids,
)



right tokens [4, 897, 557, 868, 888, 36, 149, 157, 562, 888, 37, 149, 157, 824, 568, 888, 38, 149, 157, 573, 888, 39, 149, 158, 824, 584, 888, 38, 149, 157, 589, 888, 39, 149, 158, 824, 599, 888, 48, 149, 157, 824, 605, 888, 49, 149, 157, 610, 888, 50, 149, 157, 615, 888, 51, 149, 158, 824, 561, 888, 50, 149, 157, 824, 567, 888, 51, 149, 158, 824, 577, 888, 48, 149, 158, 888, 54, 149, 158, 888, 56, 149, 158, 824, 587, 888, 48, 149, 157, 888, 54, 149, 157, 888, 57, 149, 157, 592, 888, 48, 149, 158, 888, 54, 149, 158, 888, 58, 149, 158, 824, 602, 888, 48, 149, 157, 888, 54, 149, 157, 888, 57, 149, 157, 607, 888, 48, 149, 158, 888, 54, 149, 158, 888, 58, 149, 158, 824, 618, 888, 48, 149, 158, 888, 54, 149, 158, 888, 56, 149, 158, 829, 574, 888, 56, 156, 157, 888, 60, 156, 157, 888, 68, 156, 157, 579, 888, 56, 149, 157, 888, 60, 149, 157, 888, 68, 149, 157, 583, 888, 58, 149, 157, 888, 61, 149, 157, 888, 70, 149, 157, 588, 888, 60, 149, 157, 888, 63, 149, 157, 888, 72, 149, 157, 593, 888, 