In [1]:
#import required packages
import random
random.seed(42)
from glob import glob
from collections import defaultdict

import numpy as np
from numpy.random import choice

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile

from datasets import load_dataset

In [2]:
import pretty_midi
from IPython.display import Audio

In [3]:
# get our pokemon diamond midi files
midi_files = glob('./midi_dataset/*.mid')
#midi_files = glob('./drive/MyDrive/midi-classical-music/data/*.mid')
len(midi_files)

200

In [5]:
midi_data = pretty_midi.PrettyMIDI(midi_files[0])
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [6]:
from music21 import converter, stream, meter
def split_score(score, bars_per_split=16):
    # Get initial time signature
    ts = score.recurse().getElementsByClass(meter.TimeSignature)
    beats_per_bar, beat_type = (4, 4)
    if ts:
        beats_per_bar = ts[0].numerator
        beat_type = ts[0].denominator

    bar_duration = beats_per_bar * (4 / beat_type)
    total_duration = score.highestTime
    split_duration = bars_per_split * bar_duration
    num_splits = int(total_duration // split_duration) + 1

    flattened_parts = [part.flat.notesAndRests for part in score.parts]

    # Pre-filtered metadata
    key_sigs = [el for el in score.flat.getElementsByClass('KeySignature')]
    time_sigs = [el for el in score.flat.getElementsByClass(meter.TimeSignature)]

    chunks = []

    for i in range(num_splits):
        start_time = i * split_duration
        end_time = start_time + split_duration
        chunk_score = stream.Score()

        # Add relevant global metadata
        for ks in key_sigs:
            if ks.offset <= start_time:
                chunk_score.insert(0, ks)
        for ts in time_sigs:
            if ts.offset <= start_time:
                chunk_score.insert(0, ts)

        # Process each part
        for part_stream in flattened_parts:
            new_part = stream.Part()
            # Insert elements directly without calling .clone()
            for el in part_stream.getElementsByOffset(start_time, end_time):
                new_part.insert(el.offset - start_time, el) # Changed line
            chunk_score.append(new_part)

        has_notes = False
        for el in chunk_score.flat.notesAndRests:
            if el.offset <= end_time:
                has_notes = True
                break

        # filter out anything less than 2 bars
        if has_notes and (end_time - start_time) >= 2 * bar_duration:
            chunks.append(chunk_score)

    return chunks

In [86]:
from tqdm import tqdm
for file_path in tqdm(midi_files, desc="Processing MIDI files"):
  score = converter.parse(file_path)
  chunks = split_score(score, bars_per_split=8) # for pokemon music, splitting to 8 bars because of its loopiness
  for i, chunk in enumerate(chunks):
    filename = file_path.split("/")[-1].split(".")[0].split("\\")[1]
    output_path = "./midi_chunks/" + filename + "_part_" + str(i + 1) + ".mid"
    chunk.write("midi", fp=output_path)

Processing MIDI files:   0%|          | 0/100 [00:01<?, ?it/s]


StreamException: cannot process repeats on Stream that does not contain measures

In [12]:
from miditok import REMI, TokenizerConfig
from pathlib import Path
import json


# Use REMIPlus tokenizer with appropriate config
config = TokenizerConfig()
config.use_programs = True
config.one_token_stream_for_programs = True
config.use_time_signatures = True

tokenizer = REMI(config)

  super().__init__(tokenizer_config, params)


In [38]:
midi_dir = Path("./midi_chunks")
token_dir = Path("./remi_tokens")
token_dir.mkdir(exist_ok=True)

# Tokenize all MIDI files
for midi_path in midi_dir.glob("*.mid"):
    tokens = tokenizer(midi_path)  # REMIPlus will return a single TokenSequence

    print(f"{midi_path.name}: {len(tokens.ids)} tokens")

    # Save tokens to JSON
    with open(token_dir / (midi_path.stem + ".json"), "w") as f:
        json.dump({"ids": tokens.ids}, f)

  super().__init__(tokenizer_config, params)


A New Adventure!_part_1.mid: 1721 tokens
A New Adventure!_part_2.mid: 2737 tokens
A New Adventure!_part_3.mid: 1312 tokens
A New Adventure!_part_4.mid: 598 tokens
Amity Square_part_1.mid: 1641 tokens
Amity Square_part_2.mid: 2050 tokens
Amity Square_part_3.mid: 2938 tokens
Amity Square_part_4.mid: 1491 tokens
Battle Arcade_part_1.mid: 2491 tokens
Battle Arcade_part_2.mid: 2883 tokens
Battle Arcade_part_3.mid: 3123 tokens
Battle Arcade_part_4.mid: 3400 tokens
Battle Castle_part_1.mid: 1312 tokens
Battle Castle_part_2.mid: 2492 tokens
Battle Castle_part_3.mid: 2333 tokens
Battle Castle_part_4.mid: 2123 tokens
Battle Castle_part_5.mid: 70 tokens
Battle Factory_part_1.mid: 1741 tokens
Battle Factory_part_2.mid: 2344 tokens
Battle Factory_part_3.mid: 2049 tokens
Battle Factory_part_4.mid: 1925 tokens
Battle Factory_part_5.mid: 46 tokens
Battle Hall_part_1.mid: 1577 tokens
Battle Hall_part_2.mid: 2161 tokens
Battle Hall_part_3.mid: 1960 tokens
Battle Hall_part_4.mid: 2480 tokens
Battle Hall_

In [7]:
import torch

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
import torch.nn as nn
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_len=2048):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, d_model))  # learnable pos encoding
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
        x = self.transformer(x)
        return self.fc_out(x)

In [9]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

class MIDITokenDataset(Dataset):
    def __init__(self, token_dir, max_seq_len=2048):
        self.token_paths = list(Path(token_dir).glob("*.json"))
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.token_paths)

    def __getitem__(self, idx):
        with open(self.token_paths[idx], "r") as f:
            tokens = json.load(f)["ids"]
        tokens = tokens[:self.max_seq_len]
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        return input_ids, target_ids

In [10]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids, target_ids = zip(*batch)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    target_ids = pad_sequence(target_ids, batch_first=True, padding_value=0)
    return input_ids, target_ids


In [25]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import math

# Model, optimizer, criterion
vocab_size = len(tokenizer.vocab)
model = MusicTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Dataset
dataset = MIDITokenDataset("./remi_tokens", max_seq_len=2048)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for input_ids, target_ids in dataloader:
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)

            output = model(input_ids)
            loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
            total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

# Training loop
for epoch in range(10):
    model.train()
    total_loss = 0

    for input_ids, target_ids in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)

        optimizer.zero_grad()
        output = model(input_ids)
        loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    train_perplexity = math.exp(avg_train_loss)

    val_loss, val_perplexity = evaluate(model, val_loader, criterion)

    print(f"Epoch {epoch+1}")
    print(f"  Train Loss: {avg_train_loss:.4f}, Perplexity: {train_perplexity:.2f}")
    print(f"  Val   Loss: {val_loss:.4f}, Perplexity: {val_perplexity:.2f}")


Epoch 1: 100%|██████████| 58/58 [00:23<00:00,  2.45it/s]


Epoch 1
  Train Loss: 2.4554, Perplexity: 11.65
  Val   Loss: 1.6340, Perplexity: 5.12


Epoch 2: 100%|██████████| 58/58 [00:23<00:00,  2.45it/s]


Epoch 2
  Train Loss: 1.6345, Perplexity: 5.13
  Val   Loss: 1.4798, Perplexity: 4.39


Epoch 3: 100%|██████████| 58/58 [00:23<00:00,  2.44it/s]


Epoch 3
  Train Loss: 1.5538, Perplexity: 4.73
  Val   Loss: 1.4550, Perplexity: 4.28


Epoch 4: 100%|██████████| 58/58 [00:23<00:00,  2.42it/s]


Epoch 4
  Train Loss: 1.5321, Perplexity: 4.63
  Val   Loss: 1.4547, Perplexity: 4.28


Epoch 5: 100%|██████████| 58/58 [00:23<00:00,  2.43it/s]


Epoch 5
  Train Loss: 1.5283, Perplexity: 4.61
  Val   Loss: 1.4514, Perplexity: 4.27


Epoch 6: 100%|██████████| 58/58 [00:23<00:00,  2.43it/s]


Epoch 6
  Train Loss: 1.5215, Perplexity: 4.58
  Val   Loss: 1.4492, Perplexity: 4.26


Epoch 7: 100%|██████████| 58/58 [00:23<00:00,  2.42it/s]


Epoch 7
  Train Loss: 1.5045, Perplexity: 4.50
  Val   Loss: 1.4310, Perplexity: 4.18


Epoch 8: 100%|██████████| 58/58 [00:24<00:00,  2.41it/s]


Epoch 8
  Train Loss: 1.5050, Perplexity: 4.50
  Val   Loss: 1.4369, Perplexity: 4.21


Epoch 9: 100%|██████████| 58/58 [00:29<00:00,  1.96it/s]


Epoch 9
  Train Loss: 1.4940, Perplexity: 4.46
  Val   Loss: 1.4278, Perplexity: 4.17


Epoch 10: 100%|██████████| 58/58 [00:25<00:00,  2.31it/s]


Epoch 10
  Train Loss: 1.4949, Perplexity: 4.46
  Val   Loss: 1.4210, Perplexity: 4.14


In [28]:
from miditok import TokSequence
from collections import Counter
import torch.nn.functional as F

model.eval()
vocab = tokenizer.vocab
inverse_vocab = {v: k for k, v in vocab.items()}

max_context = 2048  # context window of the model
generate_length = 10000  # total tokens to generate

# Seed tokens
seed_tokens = [vocab['Bar_None']]
generated = seed_tokens.copy()  # this stores the full sequence

with torch.no_grad():
    for _ in range(generate_length - len(seed_tokens)):
        # Feed only the last `max_context` tokens
        input_seq = torch.tensor(
            [generated[-max_context:]], dtype=torch.long
        ).to(device)

        output = model(input_seq)
        logits = output[:, -1, :]  # take the logits for the last token

        # Sample next token using top-k sampling
        temperature = 1.0
        top_k = 50
        logits = logits / temperature
        topk_logits, topk_indices = torch.topk(logits, top_k)
        probs = F.softmax(topk_logits, dim=-1)
        next_token = topk_indices.gather(-1, torch.multinomial(probs, 1)).item()

        # Append the sampled token to the full sequence
        generated.append(next_token)

# Convert to TokSequence and MIDI
sequence = TokSequence(ids=generated)
token_strings = [inverse_vocab[id] for id in sequence]

print(f"Generated {len(sequence)} tokens")
print("Sample tokens as strings:", token_strings[:50])
print("Top 20 most common tokens:", Counter(token_strings).most_common(20))

midi = tokenizer.decode(sequence)
midi.dump_midi("task1.mid")

Generated 10000 tokens
Sample tokens as strings: ['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_0', 'Pitch_37', 'Velocity_111', 'Duration_0.4.8', 'Position_8', 'Program_0', 'Pitch_68', 'Velocity_127', 'Duration_0.4.8', 'Program_0', 'Pitch_42', 'Velocity_75', 'Duration_0.3.8', 'Program_0', 'Pitch_73', 'Velocity_111', 'Duration_1.3.8', 'Program_0', 'Pitch_39', 'Velocity_55', 'Duration_0.3.8', 'Program_0', 'Pitch_41', 'Velocity_111', 'Duration_0.4.8', 'Program_0', 'Pitch_42', 'Velocity_71', 'Duration_0.4.8', 'Program_0', 'Pitch_83', 'Velocity_39', 'Duration_0.2.8', 'Program_0', 'Pitch_70', 'Velocity_111', 'Duration_0.5.8', 'Program_0', 'Pitch_48', 'Velocity_111', 'Duration_0.2.8', 'Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_0', 'Pitch_72', 'Velocity_111']
Top 20 most common tokens: [('PAD_None', 3785), ('Program_0', 1476), ('Velocity_111', 633), ('Duration_0.4.8', 533), ('Duration_0.2.8', 429), ('Velocity_127', 306), ('Duration_0.3.8', 239), ('Velocity_99', 113), ('Duration_1.0.

In [29]:
midi_data = pretty_midi.PrettyMIDI("task1.mid")
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [21]:
test_data = pretty_midi.PrettyMIDI("midi_chunks/A New Adventure!_part_1.mid")
audio = test_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [32]:
from miditok import TokSequence
from collections import Counter
import torch.nn.functional as F

# Load and tokenize the existing MIDI file
midi_path = "midi_chunks/A New Adventure!_part_1.mid"
input_sequence = tokenizer(midi_path)
seed_ids = input_sequence.ids

# Use last 2048 tokens or fewer if shorter
max_context = 2048
seed_tokens = seed_ids[-max_context:]

# Store full generated sequence
generated = seed_tokens.copy()
generate_length = 10000

model.eval()
vocab = tokenizer.vocab
inverse_vocab = {v: k for k, v in vocab.items()}

with torch.no_grad():
    for _ in range(generate_length - len(seed_tokens)):
        input_seq = torch.tensor(
            [generated[-max_context:]], dtype=torch.long
        ).to(device)

        output = model(input_seq)
        logits = output[:, -1, :]

        # Sample next token
        temperature = 1.0
        top_k = 50
        logits = logits / temperature
        topk_logits, topk_indices = torch.topk(logits, top_k)
        probs = F.softmax(topk_logits, dim=-1)
        next_token = topk_indices.gather(-1, torch.multinomial(probs, 1)).item()

        generated.append(next_token)

# Convert to TokSequence and MIDI
sequence = TokSequence(ids=generated)
token_strings = [inverse_vocab[id] for id in sequence]

print(f"Generated {len(sequence)} tokens")
print("Sample tokens as strings:", token_strings[:50])
print("Top 20 most common tokens:", Counter(token_strings).most_common(20))

midi = tokenizer.decode(sequence)
midi.dump_midi("task2.mid")

Generated 10000 tokens
Sample tokens as strings: ['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_48', 'Pitch_45', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_3', 'Program_48', 'Pitch_52', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_5', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_71', 'Velocity_111', 'Duration_0.3.8', 'Position_8', 'Program_48', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_11', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_13', 'Program_48', 'Pitch_59']
Top 20 most common tokens: [('PAD_None', 5496), ('Program_0', 670), ('Velocity_111', 648), ('Duration_0.3.8', 382), ('Program_48', 271), ('Duration_0.4.8', 259), ('Duration_0.2.8', 249), ('Velocity_127', 143), (

In [34]:
test_data = pretty_midi.PrettyMIDI("task2.mid")
audio = test_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [22]:
with open("remi_tokens/A New Adventure!_part_2.json", "r") as f:
    tokens = json.load(f)["ids"]

    token_strings = [inverse_vocab[id] for id in tokens]
    print(token_strings[:50])
    token_counter = Counter(token_strings)
    for t in token_counter:
        if t.startswith("Bar") or t.startswith("Position"):
            print(f"{t}: {token_counter[t]}")
midi = tokenizer.decode(TokSequence(ids=tokens))
midi.dump_midi("task2.mid")
midi_data = pretty_midi.PrettyMIDI("task2.mid")
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_0', 'Pitch_54', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_62', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_62', 'Velocity_111', 'Duration_0.7.8', 'Program_0', 'Pitch_45', 'Velocity_127', 'Duration_1.0.8', 'Program_0', 'Pitch_74', 'Velocity_111', 'Duration_1.3.8', 'Program_0', 'Pitch_74', 'Velocity_111', 'Duration_1.4.8', 'Position_3', 'Program_0', 'Pitch_54', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_0', 'Pitch_66']
Bar_None: 9
Position_0: 9
Position_3: 8
Position_5: 8
Position_6: 8
Position_8: 8
Position_10: 3
Position_11: 8
Position_12: 3
Position_13: 8
Position_16: 8
Position_19: 8
Position_21: 8
Position_22: 7
Position_24: 8
Position_26: 8
Position_27: 8
Position_28: 8
Position

In [104]:
testtoks = tokenizer("./midi_dataset/A New Adventure!.mid")
print(testtoks)
token_counter = Counter(testtoks.tokens)
for t in token_counter:
    if t.startswith("Bar") or t.startswith("Position"):
        print(f"{t}: {token_counter[t]}")

TokSequence(tokens=['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_48', 'Pitch_45', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_3', 'Program_48', 'Pitch_52', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_5', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_71', 'Velocity_111', 'Duration_0.3.8', 'Position_8', 'Program_48', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_11', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_13', 'Program_48', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_71', 'Velocity_111', 'Duration_0.3.8', 'Position_16', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_19