In [2]:
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('Using GPU (Windows)')
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print('Using GPU (Mac)')
else:
    device = torch.device("cpu")
    print('Using CPU')

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


from transformers import GPT2LMHeadModel, GPT2Config
from transformers import Trainer, TrainingArguments



import json
from pathlib import Path
import os
from typing import List

# Load miditok tokenizer
from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile, Instrument, Note
from miditok.pytorch_data import DatasetJSON

from symusic.midi import MidiFile





Using GPU (Windows)


ModuleNotFoundError: No module named 'symusic.midi'

In [35]:
import torch
print(torch.version.cuda)

11.8


In [36]:
tokenizer = REMI.from_pretrained("tokenizer.json")



In [37]:
class MIDITokenDataset(Dataset):
    def __init__(self, files_paths, bos_token_id=None, eos_token_id=None, max_seq_len=1024):
        self.paths = files_paths
        self.bos = bos_token_id
        self.eos = eos_token_id
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # Load list of ids
        with open(self.paths[idx], "r") as f:
            ids = json.load(f)

        # Wrap in TokSequence (optional — just to stay consistent)
        seq = TokSequence(ids=ids)

        # Add BOS and EOS if specified
        tokens = []
        if self.bos is not None:
            tokens.append(self.bos)
        tokens += seq.ids
        if self.eos is not None:
            tokens.append(self.eos)

        # Truncate or pad as needed
        tokens = tokens[:self.max_seq_len]

        return torch.tensor(tokens)

    
# ----- Collate function -----
def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    labels = [item["labels"] for item in batch]

    pad_token_id = tokenizer["PAD_None"]
    input_ids_padded = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
    labels_padded = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

    return {
        "input_ids": input_ids_padded,
        "labels": labels_padded,
    }
class MIDIDataCollator:
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id

    def __call__(self, batch):
        # batch: list of 1D tensors
        input_ids_padded = pad_sequence(batch, batch_first=True, padding_value=self.pad_token_id)
        labels_padded = input_ids_padded.clone()  # copy for labels

        return {
            "input_ids": input_ids_padded,
            "labels": labels_padded,
            "attention_mask": (input_ids_padded != self.pad_token_id).long(),
        }



data_collator = MIDIDataCollator(tokenizer['PAD_None'])


In [38]:
right_hand_jsons = list(Path("tokenized_json/right_hand").glob("*.json"))

dataset = MIDITokenDataset(
    files_paths=right_hand_jsons,
    bos_token_id = tokenizer["BOS"] if "BOS" in tokenizer.special_tokens else tokenizer.vocab["BOS_None"],
    eos_token_id = tokenizer["EOS"] if "EOS" in tokenizer.special_tokens else tokenizer.vocab["EOS_None"],
    max_seq_len=1024
)


In [39]:
config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=1024,
    n_layer=6,
    n_head=8,
    n_embd=512,
)
model = GPT2LMHeadModel(config)


In [None]:
training_args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=4,
    num_train_epochs=1000,
    save_steps=500,
    logging_steps=100,
    warmup_steps=100,
    logging_dir="logs",
    report_to="none",  # or "tensorboard" if you use it
    save_total_limit=2,
    
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
    
)
#trainer.train()
#model.save_pretrained("model_weights")


Step,Training Loss
100,9.2419
200,8.4305
300,8.3117
400,8.1301
500,7.9597
600,7.8266
700,7.6475
800,7.5015
900,7.3993
1000,7.2009


TrainOutput(global_step=107000, training_loss=0.44771192796542264, metrics={'train_runtime': 40135.8571, 'train_samples_per_second': 10.589, 'train_steps_per_second': 2.666, 'total_flos': 4.938660232868659e+16, 'train_loss': 0.44771192796542264, 'epoch': 1000.0})

In [32]:
import torch
from miditoolkit import MidiFile
from miditok import TokSequence

# Prepare the mask outside the function (keep this as is)
valid_token_ids = list(tokenizer.vocab.values())  # get all integer token IDs
valid_token_ids_tensor = torch.tensor(valid_token_ids, device=device)

vocab_size = model.config.vocab_size  # 15000
mask = torch.full((vocab_size,), float('-inf'), device=device)
mask[valid_token_ids_tensor] = 0

def generate_unconditional_midi(
    model,
    tokenizer,
    output_path="generated.mid",
    max_len=512,
    device="cpu"
):
    model.eval()

    # Get BOS and EOS token ids correctly from tokenizer.vocab dict
    bos_token_id = tokenizer.vocab.get("BOS_None", tokenizer.vocab.get("BOS"))
    eos_token_id = tokenizer.vocab.get("EOS_None", tokenizer.vocab.get("EOS"))

    generated = [bos_token_id]
    input_ids = torch.tensor([generated], dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len):
            outputs = model(input_ids=input_ids)
            logits = outputs.logits[0, -1, :]  # (vocab_size,)

            # Apply mask to block invalid tokens
            logits = logits + mask

            #next_token_id = torch.argmax(logits).item()
            top_k = 50
            logits = logits + mask  # apply vocab mask
            top_logits, top_indices = torch.topk(logits, top_k)
            probs = torch.nn.functional.softmax(top_logits, dim=-1)
            next_token_id = top_indices[torch.multinomial(probs, num_samples=1)].item()

            if next_token_id == eos_token_id:
                break

            generated.append(next_token_id)
            # Update input_ids with newly generated tokens (avoid recreating tensor every time by slicing)
            input_ids = torch.tensor([generated], dtype=torch.long, device=device)

    print(f"Generated {len(generated)} tokens.")
    print("Tokens:", generated)
    score = tokenizer.decode(generated)  # symusic ScoreTick object
    print("Score:",score)
    score.dump_midi(output_path)         # save midi directly via symusic method
    print(f"Saved generated MIDI to {output_path}")



In [33]:
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model vocab size: {model.config.vocab_size}")
print("Unique vocab size:", len(set(tokenizer.vocab)))

Tokenizer vocab size: 15000
Model vocab size: 15000
Unique vocab size: 866


In [34]:
# Load trained model if not in memory already
# model = GPT2LMHeadModel.from_pretrained("out")  # if needed
model = GPT2LMHeadModel.from_pretrained("model_weights").to(device)
tokenizer = REMI.from_pretrained("tokenizer.json")

generate_unconditional_midi(
    model=model,
    tokenizer=tokenizer,
    output_path="unconditional_generation.mid",
    max_len=512,
    device=device
)



Generated 355 tokens.
Tokens: [1, 1, 1, 1, 1, 1, 533, 533, 533, 1, 800, 800, 48, 48, 80, 1, 48, 48, 48, 48, 1, 1, 577, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 1, 560, 560, 560, 560, 582, 582, 582, 564, 557, 557, 557, 557, 557, 557, 557, 793, 793, 793, 793, 793, 793, 793, 577, 560, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 527, 527, 529, 529, 529, 529, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 560, 793, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 563, 797, 567, 799, 570, 570, 570, 570, 570, 570, 570, 552, 552, 552, 552, 531, 531, 531, 531, 534, 534, 534, 534, 534, 534, 797, 797, 567, 805, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 535, 550, 550, 550, 536, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 565, 797, 538, 805, 566, 577, 577, 577, 577, 577, 577, 577, 577, 577

In [35]:
ids = [
    tokenizer["BOS_None"],
    tokenizer["Bar_None"],
    tokenizer["Position_0"],
    tokenizer["Program_0"],
    tokenizer["Pitch_60"],
    tokenizer["Velocity_80"],
    tokenizer["Duration_4"],
    tokenizer["EOS_None"],
]
score = tokenizer.decode(ids)
print(score)
score.dump_midi("test_valid.mid")

KeyError: 'Velocity_80'