In [1]:
#import required packages
import random
random.seed(42)
from glob import glob
from collections import defaultdict

import numpy as np
from numpy.random import choice

from symusic import Score
from miditok import REMI, TokenizerConfig
from midiutil import MIDIFile

from datasets import load_dataset

In [2]:
import pretty_midi
from IPython.display import Audio

In [3]:
# get our pokemon diamond midi files
midi_files = glob('./classical_midi_dataset/*.mid')
len(midi_files)

4796

In [5]:
midi_data = pretty_midi.PrettyMIDI(midi_files[0])
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [6]:
from music21 import converter, stream, meter
def split_score(score, bars_per_split=16):
    # Get initial time signature
    ts = score.recurse().getElementsByClass(meter.TimeSignature)
    beats_per_bar, beat_type = (4, 4)
    if ts:
        beats_per_bar = ts[0].numerator
        beat_type = ts[0].denominator

    bar_duration = beats_per_bar * (4 / beat_type)
    total_duration = score.highestTime
    split_duration = bars_per_split * bar_duration
    num_splits = int(total_duration // split_duration) + 1

    flattened_parts = [part.flat.notesAndRests for part in score.parts]

    # Pre-filtered metadata
    key_sigs = [el for el in score.flat.getElementsByClass('KeySignature')]
    time_sigs = [el for el in score.flat.getElementsByClass(meter.TimeSignature)]

    chunks = []

    for i in range(num_splits):
        start_time = i * split_duration
        end_time = start_time + split_duration
        chunk_score = stream.Score()

        # Add relevant global metadata
        for ks in key_sigs:
            if ks.offset <= start_time:
                chunk_score.insert(0, ks)
        for ts in time_sigs:
            if ts.offset <= start_time:
                chunk_score.insert(0, ts)

        # Process each part
        for part_stream in flattened_parts:
            new_part = stream.Part()
            # Insert elements directly without calling .clone()
            for el in part_stream.getElementsByOffset(start_time, end_time):
                new_part.insert(el.offset - start_time, el) # Changed line
            chunk_score.append(new_part)

        has_notes = False
        for el in chunk_score.flat.notesAndRests:
            if el.offset <= end_time:
                has_notes = True
                break

        # filter out anything less than 2 bars
        if has_notes and (end_time - start_time) >= 2 * bar_duration:
            chunks.append(chunk_score)

    return chunks

In [7]:
# already generated midi_chunks -- don't worry about this cell
from tqdm import tqdm

for file_path in tqdm(midi_files, desc="Processing MIDI files"):
    try:
        score = converter.parse(file_path)
        chunks = split_score(score, bars_per_split=16)
        for i, chunk in enumerate(chunks):
            filename = file_path.split("/")[-1].split(".")[0]
            output_path = f"./classical_midi_chunks/{filename}_part_{i + 1}.mid"
            chunk.write("midi", fp=output_path)
    except Exception as e:
        print(f"Skipping {file_path}: {e}")

  return self.iter().getElementsByClass(classFilterList)
Processing MIDI files:   4%|▎         | 169/4796 [10:27<3:30:14,  2.73s/it] 

Skipping ./classical_midi_dataset/unknown_artist-p_z-tango.mid: badly formatted midi bytes, got: b'\x00\x05Tango\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'


Processing MIDI files:   5%|▍         | 225/4796 [17:40<14:44:51, 11.61s/it]

Skipping ./classical_midi_dataset/unknown_artist-a_h-canon_2.mid: badly formatted midi bytes, got: b'\x00/\x00\x00\x80=\x00\x00@$\x05\x00 \x00\x00\x00\x8e\x1bg\x11'


Processing MIDI files:   5%|▌         | 260/4796 [25:18<6:52:56,  5.46s/it] 

Skipping ./classical_midi_dataset/unknown_artist-a_h-beet_51_s.mid: badly formed midi string: missing leading MTrk


Processing MIDI files:   6%|▌         | 269/4796 [28:08<27:45:03, 22.07s/it]

Skipping ./classical_midi_dataset/mendelsonn-organ_sonata_n1.mid: list index out of range


Processing MIDI files:   7%|▋         | 331/4796 [36:32<20:24:06, 16.45s/it]

Skipping ./classical_midi_dataset/unknown_artist-a_h-f_20_a_min.mid: badly formed midi string: missing leading MTrk


Processing MIDI files:   7%|▋         | 358/4796 [41:42<6:49:28,  5.54s/it] 

Skipping ./classical_midi_dataset/unknown_artist-p_z-stravinsky_2.mid: badly formatted midi bytes, got: b'\x00\x06Strav2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'


Processing MIDI files:   9%|▊         | 410/4796 [47:46<7:05:29,  5.82s/it] 

Skipping ./classical_midi_dataset/unknown_artist-a_h-clarinet_1.mid: badly formatted midi bytes, got: b's\n\xb8\x00\x00\x00\x80\xeb\x07\x90\x90\x90\x0f\xac\xd0\x10\x89G\x04['


Processing MIDI files:   9%|▉         | 420/4796 [47:58<1:36:44,  1.33s/it]

Skipping ./classical_midi_dataset/unknown_artist-i_o-monteverdi.mid: badly formatted midi bytes, got: b'\x00\x0cMntevrdi.mid\n1788\x00'


Processing MIDI files:  10%|▉         | 476/4796 [58:33<13:32:55, 11.29s/it]

Skipping ./classical_midi_dataset/unknown_artist-a_h-beet_51_a.mid: badly formed midi string: missing leading MTrk


Processing MIDI files:  10%|█         | 482/4796 [1:00:37<16:20:03, 13.63s/it]

Skipping ./classical_midi_dataset/maier-atalanta_fugiens_no15.mid: badly formatted midi bytes, got: b'RIFF\xc0\x03\x00\x00RMIDdata\xb4\x03\x00\x00'


Processing MIDI files:  10%|█         | 484/4796 [1:01:09<17:33:21, 14.66s/it]

Skipping ./classical_midi_dataset/unknown_artist-a_h-gershwin_3.mid: badly formatted midi bytes, got: b'\x00\x06GERSH3\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'


Processing MIDI files:  10%|█         | 496/4796 [1:02:42<9:03:37,  7.59s/it] 


KeyboardInterrupt: 

In [8]:
# just going to stop at ~10% of the dataset being chunked into 16 bars for performance and time constraint
midi_chunks = glob('./classical_midi_chunks/*.mid')
len(midi_chunks)

4487

In [9]:
from miditok import REMI, TokenizerConfig
from pathlib import Path
import json


# Use REMIPlus tokenizer with appropriate config
config = TokenizerConfig()
config.use_programs = True
config.one_token_stream_for_programs = True
config.use_time_signatures = True

tokenizer = REMI(config)

midi_dir = Path("./classical_midi_chunks")
token_dir = Path("./classical_remi_tokens")
token_dir.mkdir(exist_ok=True)

# Tokenize all MIDI files
for midi_path in midi_dir.glob("*.mid"):
    tokens = tokenizer(midi_path)  # REMIPlus will return a single TokenSequence

    print(f"{midi_path.name}: {len(tokens.ids)} tokens")

    # Save tokens to JSON
    with open(token_dir / (midi_path.stem + ".json"), "w") as f:
        json.dump({"ids": tokens.ids}, f)

  super().__init__(tokenizer_config, params)


moszkowski-etude_de_virtuosite_op72_no13_in_ab_minor_part_15.mid: 691 tokens
haendel-concertos_grossos_hwv319-330_op06-concerto_grosso_op6_n04_4mov_part_6.mid: 2181 tokens
unknown_artist-p_z-p_01_c_maj_part_1.mid: 1319 tokens
unknown_artist-p_z-rach_33_n6_part_2.mid: 1853 tokens
unknown_artist-a_h-bach_g_part_2.mid: 729 tokens
brahms-sonata_2_pianos_n34b_4mov_part_10.mid: 289 tokens
unknown_artist-p_z-prelude_13_part_1.mid: 1251 tokens
maier-atalanta_fugiens_no20_part_1.mid: 501 tokens
mozart-piano_sonatas-nueva_carpeta-k331_piano_sonata_n11__part_122.mid: 332 tokens
chopin-scherzo_op31_part_29.mid: 665 tokens
mozart-piano_sonatas-nueva_carpeta-k331_piano_sonata_n11__part_37.mid: 557 tokens
maier-atalanta_fugiens_no4_part_1.mid: 571 tokens
beethoven-piano_concerto_no1_op15_1mov_part_7.mid: 2653 tokens
beethoven-piano_sonatas-piano_sonata_n16_part_62.mid: 816 tokens
bach-bwv0811_english_suite_n6_3mov_part_35.mid: 152 tokens
chopin-19_polish_songs_for_solo_voice_and_piano_accomplements_n

In [10]:
import torch

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
import torch.nn as nn
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_len=2048):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, d_model))  # learnable pos encoding
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding[:, :x.size(1), :]
        x = self.transformer(x)
        return self.fc_out(x)

In [12]:
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

class MIDITokenDataset(Dataset):
    def __init__(self, token_dir, max_seq_len=2048):
        self.token_paths = list(Path(token_dir).glob("*.json"))
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.token_paths)

    def __getitem__(self, idx):
        with open(self.token_paths[idx], "r") as f:
            tokens = json.load(f)["ids"]
        tokens = tokens[:self.max_seq_len]
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        return input_ids, target_ids

In [13]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids, target_ids = zip(*batch)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
    target_ids = pad_sequence(target_ids, batch_first=True, padding_value=0)
    return input_ids, target_ids


In [14]:
from tqdm import tqdm

In [19]:
vocab_size = len(tokenizer.vocab)

model = MusicTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Data
dataset = MIDITokenDataset("./classical_remi_tokens", max_seq_len=2048)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Train
for epoch in range(10):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        input_ids, target_ids = batch
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)

        optimizer.zero_grad()
        output = model(input_ids)

        # output: [B, T, vocab], target_ids: [B, T] → flatten for CrossEntropyLoss
        loss = criterion(output.view(-1, vocab_size), target_ids.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader):.4f}")

Training: 100%|██████████| 561/561 [04:48<00:00,  1.94it/s]


Epoch 1, Loss: 1.6041


Training: 100%|██████████| 561/561 [03:48<00:00,  2.45it/s]


Epoch 2, Loss: 1.4456


Training: 100%|██████████| 561/561 [03:48<00:00,  2.46it/s]


Epoch 3, Loss: 1.4383


Training: 100%|██████████| 561/561 [04:26<00:00,  2.10it/s]


Epoch 4, Loss: 1.4366


Training: 100%|██████████| 561/561 [03:52<00:00,  2.41it/s]


Epoch 5, Loss: 1.4285


Training: 100%|██████████| 561/561 [03:46<00:00,  2.47it/s]


Epoch 6, Loss: 1.4397


Training: 100%|██████████| 561/561 [03:49<00:00,  2.45it/s]


Epoch 7, Loss: 1.4262


Training: 100%|██████████| 561/561 [03:49<00:00,  2.44it/s]


Epoch 8, Loss: 1.4223


Training: 100%|██████████| 561/561 [03:46<00:00,  2.47it/s]


Epoch 9, Loss: 1.4388


Training: 100%|██████████| 561/561 [03:47<00:00,  2.46it/s]

Epoch 10, Loss: 1.4293





In [38]:
from miditok import TokSequence
from collections import Counter
import torch.nn.functional as F

model.eval()
vocab = tokenizer.vocab
inverse_vocab = {v: k for k, v in vocab.items()}
with torch.no_grad():
    seed_tokens = [tokenizer.vocab['Bar_None']] #+ \
              #[tokenizer.vocab[f'Position_{i}'] for i in range(32)]
    seed = torch.tensor([seed_tokens], dtype=torch.long).to(device)
    generated = seed

    for _ in range(1024):
        output = model(generated)
        temperature = 1.0  # tweak: 0.7 - 1.0 usually works well
        top_k = 50  # restrict to top 50 candidates

        logits = output[:, -1, :] / temperature
        topk_logits, topk_indices = torch.topk(logits, top_k)

        probs = F.softmax(topk_logits, dim=-1)
        next_token = topk_indices.gather(-1, torch.multinomial(probs, 1))
        generated = torch.cat((generated, next_token), dim=1)
        
    sequence = TokSequence(ids=generated.squeeze().tolist())
    print(len(sequence))
    print("Sample tokens as strings:", [inverse_vocab[id] for id in sequence[:50]])
    token_strings = [inverse_vocab[id] for id in sequence]
    print(Counter(token_strings).most_common(20))
    midi = tokenizer.decode(sequence)
    midi.dump_midi("test_generated.mid")

1025
Sample tokens as strings: ['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_0', 'Pitch_69', 'Velocity_99', 'Duration_0.2.8', 'Program_0', 'Pitch_45', 'Velocity_91', 'Duration_0.6.8', 'Program_0', 'Pitch_69', 'Velocity_63', 'Duration_0.6.8', 'Program_0', 'Pitch_69', 'Velocity_127', 'Duration_0.2.8', 'Position_6', 'Program_0', 'Pitch_78', 'Velocity_111', 'Duration_0.4.8', 'Program_0', 'Pitch_56', 'Velocity_71', 'Duration_1.0.8', 'Program_0', 'Pitch_81', 'Duration_0.3.8', 'Program_0', 'Pitch_43', 'Velocity_111', 'Duration_1.0.8', 'Program_0', 'Pitch_61', 'Velocity_91', 'Duration_0.2.8', 'Program_0', 'Pitch_72', 'Velocity_63', 'Duration_0.6.8', 'Bar_None', 'TimeSig_3/4', 'Position_0', 'Program_0', 'Pitch_68', 'Velocity_31', 'Duration_0.2.8']
[('Program_0', 230), ('Duration_0.2.8', 80), ('Duration_0.4.8', 61), ('Velocity_127', 31), ('Duration_1.0.8', 26), ('Velocity_111', 24), ('Duration_0.3.8', 20), ('Velocity_63', 18), ('Duration_0.6.8', 17), ('Velocity_95', 17), ('Velocity_99', 15)

In [39]:
midi_data = pretty_midi.PrettyMIDI("test_generated.mid")
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [24]:
test_data = pretty_midi.PrettyMIDI("classical_midi_chunks/beethoven-136_part_2.mid")
audio = test_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [35]:
with open("classical_remi_tokens/beethoven-136_part_2.json", "r") as f:
    tokens = json.load(f)["ids"]
    
    token_strings = [inverse_vocab[id] for id in tokens]
    print(token_strings[:50])
    token_counter = Counter(token_strings)
    for t in token_counter:
        if t.startswith("Bar") or t.startswith("Position"):
            print(f"{t}: {token_counter[t]}")
midi = tokenizer.decode(TokSequence(ids=tokens))
midi.dump_midi("test_reconstructed.mid")
midi_data = pretty_midi.PrettyMIDI("test_reconstructed.mid")
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

['Bar_None', 'TimeSig_3/4', 'Position_0', 'Program_0', 'Pitch_53', 'Velocity_63', 'Duration_0.2.8', 'Program_0', 'Pitch_58', 'Velocity_63', 'Duration_0.2.8', 'Program_0', 'Pitch_62', 'Velocity_67', 'Duration_0.2.8', 'Position_4', 'Program_0', 'Pitch_56', 'Velocity_47', 'Duration_0.2.8', 'Program_0', 'Pitch_58', 'Velocity_47', 'Duration_0.2.8', 'Program_0', 'Pitch_65', 'Velocity_51', 'Duration_0.2.8', 'Position_6', 'Program_0', 'Pitch_74', 'Velocity_103', 'Duration_1.0.8', 'Program_0', 'Pitch_77', 'Velocity_107', 'Duration_1.0.8', 'Program_0', 'Pitch_86', 'Velocity_107', 'Duration_1.0.8', 'Position_8', 'Program_0', 'Pitch_54', 'Velocity_39', 'Duration_0.2.8', 'Program_0', 'Pitch_58', 'Velocity_43', 'Duration_0.2.8']
Bar_None: 17
Position_0: 16
Position_4: 10
Position_6: 11
Position_8: 16
Position_12: 10
Position_14: 12
Position_16: 15
Position_18: 15
Position_20: 10
Position_22: 12
Position_2: 15
Position_10: 14


In [40]:
from miditok import TokSequence
from collections import Counter
import torch.nn.functional as F

model.eval()
vocab = tokenizer.vocab
inverse_vocab = {v: k for k, v in vocab.items()}
with torch.no_grad():
    # get last 1024 tokens from song to generated from
    seed_tokens = tokens[-1024:] if len(tokens) > 1024 else tokens
    seed = torch.tensor([seed_tokens], dtype=torch.long).to(device)
    generated = seed

    for _ in range(1024):
        output = model(generated)
        temperature = 1.0  # tweak: 0.7 - 1.0 usually works well
        top_k = 50  # restrict to top 50 candidates

        logits = output[:, -1, :] / temperature
        topk_logits, topk_indices = torch.topk(logits, top_k)

        probs = F.softmax(topk_logits, dim=-1)
        next_token = topk_indices.gather(-1, torch.multinomial(probs, 1))
        generated = torch.cat((generated, next_token), dim=1)
        
    sequence = TokSequence(ids=generated.squeeze().tolist())
    print(len(sequence))
    print("Sample tokens as strings:", [inverse_vocab[id] for id in sequence[:50]])
    token_strings = [inverse_vocab[id] for id in sequence]
    print(Counter(token_strings).most_common(20))
    midi = tokenizer.decode(sequence)
    midi.dump_midi("test_extended.mid")

2048
Sample tokens as strings: ['Program_0', 'Pitch_77', 'Velocity_99', 'Duration_0.2.8', 'Position_12', 'Program_0', 'Pitch_36', 'Velocity_87', 'Duration_0.2.8', 'Program_0', 'Pitch_48', 'Velocity_91', 'Duration_0.2.8', 'Position_14', 'Program_0', 'Pitch_63', 'Velocity_111', 'Duration_0.2.8', 'Program_0', 'Pitch_75', 'Velocity_115', 'Duration_0.2.8', 'Position_16', 'Program_0', 'Pitch_34', 'Velocity_83', 'Duration_0.2.8', 'Program_0', 'Pitch_46', 'Velocity_87', 'Duration_0.2.8', 'Position_18', 'Program_0', 'Pitch_62', 'Velocity_91', 'Duration_0.2.8', 'Program_0', 'Pitch_74', 'Velocity_99', 'Duration_0.2.8', 'Position_20', 'Program_0', 'Pitch_31', 'Velocity_87', 'Duration_0.2.8', 'Program_0', 'Pitch_43', 'Velocity_91', 'Duration_0.2.8', 'Position_22']
[('Program_0', 457), ('Duration_0.4.8', 172), ('Duration_0.2.8', 170), ('Velocity_99', 45), ('Duration_0.6.8', 41), ('Velocity_111', 34), ('Velocity_91', 32), ('Pitch_74', 32), ('Velocity_127', 32), ('Velocity_103', 30), ('Pitch_50', 30),

In [41]:
midi_data = pretty_midi.PrettyMIDI("test_extended.mid")
audio = midi_data.synthesize(fs=22050)
Audio(audio, rate=22050)

In [22]:
testtoks = tokenizer("./midi_dataset/A New Adventure!.mid")
print(testtoks)
token_counter = Counter(testtoks.tokens)
for t in token_counter:
    if t.startswith("Bar") or t.startswith("Position"):
        print(f"{t}: {token_counter[t]}")

TokSequence(tokens=['Bar_None', 'TimeSig_4/4', 'Position_0', 'Program_48', 'Pitch_45', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_3', 'Program_48', 'Pitch_52', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_5', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_71', 'Velocity_111', 'Duration_0.3.8', 'Position_8', 'Program_48', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_11', 'Program_48', 'Pitch_57', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_69', 'Velocity_111', 'Duration_0.3.8', 'Position_13', 'Program_48', 'Pitch_59', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_71', 'Velocity_111', 'Duration_0.3.8', 'Position_16', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Program_48', 'Pitch_64', 'Velocity_111', 'Duration_0.3.8', 'Position_19