In [None]:
!7z x Dataset.7z


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.20GHz (406F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan         1 file, 1345976 bytes (1315 KiB)

Extracting archive: Dataset.7z
--
Path = Dataset.7z
Type = 7z
Physical Size = 1345976
Headers Size = 41547
Method = LZMA2:12m
Solid = +
Blocks = 1

  0%      0% 485 - Dataset/3-vidi_track_4.mid                                       0% 858 - Dataset/brandenburg4-presto_track_4.mid                                                    0% 1145 - Dataset/canzonottava_track_3.mid                                              0% 1366 -

In [None]:
!pip install miditoolkit mido pandas seaborn matplotlib

Collecting miditoolkit
  Downloading miditoolkit-1.0.1-py3-none-any.whl.metadata (4.9 kB)
Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading miditoolkit-1.0.1-py3-none-any.whl (24 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mido, miditoolkit
Successfully installed miditoolkit-1.0.1 mido-1.3.3


In [None]:
import miditoolkit
from mido import MidiFile, MidiTrack, MetaMessage, Message
import os
import pandas as pd
import json
import numpy as np
from pathlib import Path
import torch

In [None]:
def extract_note_attributes(note):
    duration = max(60, round((note.end - note.start) / 60) * 60)
    return (note.pitch, note.start, duration)


def insert_silence_tokens(note_events):
    result = []
    prev_end = None

    for pitch, start, duration in note_events:
        if prev_end is not None and start > prev_end:
            silence_duration = max(60, round((start - prev_end) / 60) * 60)
            result.append((0, silence_duration))

        result.append((pitch, duration))
        prev_end = start + duration

    return result


def extract_metadata(filename):
    midi = MidiFile(filename)
    metadata = {
        'ticks_per_beat': midi.ticks_per_beat,
        'time_signature': MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8),
        'key_signature': MetaMessage('key_signature', key='C'),
        'set_tempo': MetaMessage('set_tempo', tempo=500000)  # Default 120 BPM
    }

    for track in midi.tracks:
        for msg in track:
            if msg.type in metadata:
                metadata[msg.type] = msg

    return metadata


def encode_midi(filename):
    midi = miditoolkit.MidiFile(filename)
    note_events = [
        extract_note_attributes(note)
        for instrument in midi.instruments
        for note in instrument.notes
    ]

    tokens = insert_silence_tokens(note_events)
    return tokens


def tokens_to_ids(tokens, token_dict):
    return [token_dict[token] for token in tokens]


def ids_to_tokens(ids, id_to_token):
    return [id_to_token[_id] for _id in ids]


def decode_midi(notes, output_filename, metadata):
    mid = MidiFile(ticks_per_beat=metadata['ticks_per_beat'])
    track = MidiTrack()
    mid.tracks.append(track)

    # Add metadata and instrument info
    track.extend([
        metadata['time_signature'],
        metadata['key_signature'],
        metadata['set_tempo'],
        Message('program_change', channel=0, program=1, time=0),  # Instrument
        Message('control_change', channel=0, control=7, value=80, time=0)  # Volume
    ])

    current_time = 0
    for pitch, duration in notes:

        if pitch == 0:
            current_time += duration
        else:
            track.append(Message('note_on', note=pitch, velocity=100, time=current_time))
            track.append(Message('note_on', note=pitch, velocity=0, time=duration))
            current_time = 0

    mid.save(output_filename)


def decode_ids_to_midi(id_sequence, output_filename, metadata, id_to_token):
    tokens = ids_to_tokens(id_sequence, id_to_token)
    decode_midi(tokens, output_filename, metadata)



In [None]:
DATASET_DIR = "/content/Dataset"
all_tokens = []
token_set = set()

for filename in os.listdir(DATASET_DIR):
    if filename.lower().endswith('.mid'):
        filepath = os.path.join(DATASET_DIR, filename)
        tokens = encode_midi(filepath)

        all_tokens.extend(tokens)
        token_set.update(tokens)

        print(f"Encoded {filename}")


Encoded 12-do_what_you_can_track_2.mid
Encoded 7-triumphavi_track_3.mid
Encoded micor-fantasia19-a4_track_1.mid
Encoded brewert-03-almaine_track_3.mid
Encoded jenkinsj-fantasia8-a5_track_2.mid
Encoded tomkinst-pavan4-a5_track_3.mid
Encoded twv41-d4_track_1.mid
Encoded ferraboscoa-fantasia08-a6_track_5.mid
Encoded qv3.26_track_2.mid
Encoded micor-fantasia11-a4_track_4.mid
Encoded carwardenj-suite5-31-a_daunce_track_1.mid
Encoded twv42-c2_track_4.mid
Encoded twv54-b2_1_track_4.mid
Encoded gibbonso-fantasia1-a6_track_1.mid
Encoded 53_last_will_and_test_track_1.mid
Encoded 10_galliard_track_5.mid
Encoded twv42c1_track_1.mid
Encoded 27_the_image_of_melan_track_4.mid
Encoded tomkinst-fantasia06-a3_track_2.mid
Encoded 6-follow_me_track_2.mid
Encoded hingestonj-almande5-a3_track_3.mid
Encoded triosonataop2nr6_1_track_2.mid
Encoded twv52-a2_track_5.mid
Encoded twv52-f1_track_1.mid
Encoded canzon199_track_7.mid
Encoded 13-terpsicore_track_3.mid
Encoded naudot-suite6_track_1.mid
Encoded micor-pav

In [None]:
print(len(all_tokens)) # 1641743 tokens in vocab

1641743


In [None]:
print(list(token_set)) # vocab in tuples
print(len(token_set)) # vocab size = 2788

[(36, 480), (69, 1560), (60, 1740), (51, 1920), (0, 2220), (0, 13740), (44, 420), (81, 6000), (71, 2280), (79, 8520), (0, 59640), (0, 36600), (86, 3720), (67, 180), (58, 360), (70, 2400), (75, 300), (83, 6540), (66, 480), (52, 2760), (57, 660), (0, 9660), (0, 32700), (0, 21180), (48, 840), (81, 1920), (39, 1020), (72, 2100), (84, 2340), (56, 780), (67, 7620), (59, 1380), (45, 3660), (47, 960), (50, 1560), (36, 3840), (83, 2640), (74, 2820), (0, 5760), (76, 120), (88, 360), (79, 540), (82, 2760), (40, 1860), (67, 3720), (78, 840), (69, 1020), (60, 1200), (51, 1380), (0, 1680), (76, 7560), (66, 3840), (0, 13200), (0, 24720), (35, 60), (38, 660), (71, 1740), (29, 840), (62, 1920), (53, 2100), (70, 1860), (91, 900), (28, 960), (83, 6000), (74, 6180), (57, 120), (0, 9120), (43, 2400), (48, 300), (81, 1380), (39, 480), (72, 1560), (63, 2760), (84, 1800), (56, 240), (59, 840), (50, 1020), (47, 420), (41, 1200), (74, 2280), (0, 5220), (0, 16560), (0, 28080), (87, 120), (78, 300), (69, 480), (6

In [None]:
token_dict = {token: idx for idx, token in enumerate(sorted(token_set))}
print(token_dict)
print(len(token_dict))

{(0, 60): 0, (0, 120): 1, (0, 180): 2, (0, 240): 3, (0, 300): 4, (0, 360): 5, (0, 420): 6, (0, 480): 7, (0, 540): 8, (0, 600): 9, (0, 660): 10, (0, 720): 11, (0, 780): 12, (0, 840): 13, (0, 900): 14, (0, 960): 15, (0, 1020): 16, (0, 1080): 17, (0, 1140): 18, (0, 1200): 19, (0, 1260): 20, (0, 1320): 21, (0, 1380): 22, (0, 1440): 23, (0, 1500): 24, (0, 1560): 25, (0, 1620): 26, (0, 1680): 27, (0, 1740): 28, (0, 1800): 29, (0, 1860): 30, (0, 1920): 31, (0, 1980): 32, (0, 2040): 33, (0, 2100): 34, (0, 2160): 35, (0, 2220): 36, (0, 2280): 37, (0, 2340): 38, (0, 2400): 39, (0, 2460): 40, (0, 2520): 41, (0, 2580): 42, (0, 2640): 43, (0, 2700): 44, (0, 2760): 45, (0, 2820): 46, (0, 2880): 47, (0, 2940): 48, (0, 3000): 49, (0, 3060): 50, (0, 3120): 51, (0, 3180): 52, (0, 3240): 53, (0, 3300): 54, (0, 3360): 55, (0, 3420): 56, (0, 3480): 57, (0, 3540): 58, (0, 3600): 59, (0, 3660): 60, (0, 3720): 61, (0, 3780): 62, (0, 3840): 63, (0, 3900): 64, (0, 3960): 65, (0, 4020): 66, (0, 4080): 67, (0, 41

In [None]:
# Create token-to-ID mapping
token2id = {token: idx for idx, token in enumerate(sorted(token_set))}
id2token = {idx: token for token, idx in token2id.items()}  # Optional: for decoding

In [None]:
all_token_ids = tokens_to_ids(all_tokens, token2id)

In [None]:
print(len(all_token_ids)) # All tokens from their IDs
print(all_token_ids[:100])

1641743
[1856, 1959, 2010, 2111, 2013, 1962, 1850, 1735, 1856, 1735, 1741, 1692, 5, 2111, 2221, 2277, 2380, 2333, 2374, 2111, 2224, 2227, 2108, 2013, 1968, 1850, 2469, 2371, 2280, 2224, 2224, 2280, 2380, 2280, 2230, 2111, 2111, 2013, 1962, 1856, 5, 2111, 2221, 2277, 2380, 2333, 2374, 2111, 2224, 2227, 2108, 2013, 1968, 1850, 2469, 2371, 2280, 2224, 2224, 2280, 2380, 2280, 2230, 2111, 1847, 1959, 2016, 2108, 2224, 2013, 2019, 5, 1735, 1853, 1959, 2010, 2108, 2224, 2117, 2068, 1856, 1735, 11, 2013, 1919, 1850, 1741, 1850, 2280, 2224, 2111, 2068, 2466, 2371, 2221, 2277, 2108, 2224, 2280, 2224]


In [None]:
import torch

data = torch.tensor(all_token_ids, dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1641743]) torch.int64
tensor([1856, 1959, 2010, 2111, 2013, 1962, 1850, 1735, 1856, 1735, 1741, 1692,
           5, 2111, 2221, 2277, 2380, 2333, 2374, 2111, 2224, 2227, 2108, 2013,
        1968, 1850, 2469, 2371, 2280, 2224, 2224, 2280, 2380, 2280, 2230, 2111,
        2111, 2013, 1962, 1856,    5, 2111, 2221, 2277, 2380, 2333, 2374, 2111,
        2224, 2227, 2108, 2013, 1968, 1850, 2469, 2371, 2280, 2224, 2224, 2280,
        2380, 2280, 2230, 2111, 1847, 1959, 2016, 2108, 2224, 2013, 2019,    5,
        1735, 1853, 1959, 2010, 2108, 2224, 2117, 2068, 1856, 1735,   11, 2013,
        1919, 1850, 1741, 1850, 2280, 2224, 2111, 2068, 2466, 2371, 2221, 2277,
        2108, 2224, 2280, 2224])


Splitting MIDI files into training and validation sets with a 0.1:0.9 ratio

In [None]:
# split the data into training and testing data with a ration of 0.9:0.1
n = int(0.9*len(data))
train_data = data[:n]
test_data = data[n:]

print(len(train_data)) # 0.9 * 1641743 = 1477568
print(len(test_data)) # 0.1 * 1641743 = 164175

1477568
164175


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [None]:
from torch.utils.data import Dataset

class MIDITokenDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + 1 + self.block_size]
        return {"input_ids": x, "labels": y}


In [None]:
from torch.utils.data import DataLoader

block_size = 128
batch_size = 32
n_embd = 128
n_head = 4
n_layer = 4
dropout = 0.1

eval_iters = 200

train_dataset = MIDITokenDataset(train_data, block_size=block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, dropout=0.1):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-2, -1) / (k.size(-1) ** 0.5)

        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        return self.dropout(wei @ v)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, n_embd, block_size, dropout=0.1):
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.proj(torch.cat([h(x) for h in self.heads], dim=-1)))

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embd, block_size, dropout)
        self.ffwd = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TransformerMIDILanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd=128, n_head=4, n_layer=6, block_size=128, dropout=0.1):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size
        self.vocab_size = vocab_size

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))

        x = self.blocks(tok_emb + pos_emb)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1), ignore_index=-100)
        return logits, loss

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, eos_token_id=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)

            if eos_token_id is not None and idx_next.item() == eos_token_id:
                break

        return idx

In [None]:
vocab_size = len(token2id)

model = TransformerMIDILanguageModel(
    vocab_size=vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    block_size=block_size,
    dropout=dropout
).to(device)

torch.manual_seed(1337)

print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")


1.524708 M parameters


In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
learning_rate = 1e-4
max_iters = 20000
eval_interval = 100

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    _, loss = model(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


step 0: train loss 3.3903, val loss 3.4575
step 100: train loss 3.3839, val loss 3.4364
step 200: train loss 3.3597, val loss 3.4281
step 300: train loss 3.3590, val loss 3.4236
step 400: train loss 3.3357, val loss 3.4017
step 500: train loss 3.3273, val loss 3.3985
step 600: train loss 3.3070, val loss 3.3865
step 700: train loss 3.2937, val loss 3.3872
step 800: train loss 3.2933, val loss 3.3660
step 900: train loss 3.2923, val loss 3.3597
step 1000: train loss 3.2810, val loss 3.3353
step 1100: train loss 3.2685, val loss 3.3419
step 1200: train loss 3.2644, val loss 3.3359
step 1300: train loss 3.2531, val loss 3.3153
step 1400: train loss 3.2480, val loss 3.3177
step 1500: train loss 3.2180, val loss 3.3188
step 1600: train loss 3.2186, val loss 3.3151
step 1700: train loss 3.2235, val loss 3.2924
step 1800: train loss 3.2229, val loss 3.2807
step 1900: train loss 3.1973, val loss 3.2842
step 2000: train loss 3.1988, val loss 3.2885
step 2100: train loss 3.2044, val loss 3.2668


In [None]:
start_token = token2id[(0, 480)]  # or any common starting token
context = torch.tensor([[start_token]], dtype=torch.long, device=device)

generated_ids = model.generate(context, max_new_tokens=500)[0].tolist()

In [None]:
print(generated_ids)

[7, 2226, 2070, 2121, 2070, 2113, 2070, 2121, 2113, 2226, 2384, 2376, 2226, 2113, 2234, 2121, 2023, 1876, 2113, 2226, 2282, 2226, 2113, 2250, 23, 2230, 2278, 2222, 2109, 1964, 1852, 11, 2476, 2376, 2282, 1964, 11, 2464, 2372, 2276, 2226, 2113, 7, 2468, 2370, 2220, 2464, 2066, 2109, 2070, 2224, 2107, 1848, 1960, 2222, 2109, 2222, 2226, 2109, 1733, 2278, 2222, 2113, 2019, 2011, 2222, 2109, 1733, 1848, 2011, 1733, 1848, 1848, 2011, 2011, 1962, 1845, 1733, 7, 2222, 2276, 2370, 2220, 2109, 1852, 1964, 1733, 1848, 1960, 1735, 1731, 1850, 1846, 1960, 2011, 1958, 2064, 2107, 2220, 2109, 2109, 2226, 2113, 2226, 2431, 2278, 2226, 11, 1850, 1958, 2064, 2109, 2220, 2329, 2222, 1846, 2107, 2220, 2278, 2222, 2111, 2219, 2328, 2372, 2335, 3, 2113, 2066, 1846, 1958, 2066, 2109, 2374, 2220, 2278, 2224, 2107, 2015, 7, 1850, 1958, 2011, 1848, 2464, 2372, 2462, 2370, 1958, 2276, 2372, 2464, 2464, 2468, 3, 2372, 2224, 2276, 2109, 1850, 2009, 3, 2466, 2370, 2329, 2374, 2220, 2107, 2226, 1960, 1585, 2370, 24

In [None]:
metadata={
    "ticks_per_beat": 480,
    "time_signature": MetaMessage("time_signature", numerator=4, denominator=4),
    "key_signature": MetaMessage("key_signature", key="C"),
    "set_tempo": MetaMessage("set_tempo", tempo=500000)
}

decode_ids_to_midi(generated_ids, "generated_output.mid", metadata, id2token)