In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
import numpy as np
import mido
from pathlib import Path
import os
from miditok import MIDILike, TokenizerConfig, REMI
from miditok.pytorch_data import DataCollator
import math
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from miditok.pytorch_data import DatasetMIDI, DataCollator

In [2]:
@dataclass
class TransformerConfig:
    block_size: int = 2048
    vocab_size: int = 478
    n_layers: int = 6
    n_heads: int = 8
    n_embed: int = 512

In [None]:
TOKENIZER_CONFIG = TokenizerConfig(
    num_velocities=32,
    use_programs=False,
    use_time_signatures=False,
    use_chords=False,
    use_sustain_pedals=True,
)

remi_config = TokenizerConfig(
    use_programs=False,
    num_velocities=32,
    use_sustain_pedals=True,
    use_time_signatures=False,
    use_chords=False,
    use_tempos=True,
    use_rests=True,
    beat_res={(0, 4): 8, (4, 12): 16},
    use_bars=True
)

In [4]:
tokenizer = REMI(tokenizer_config=remi_config)

In [5]:
MIDI_DIR = Path("maestro-v3.0.0")
TOKEN_DIR = Path("tokens")

In [6]:
midi_paths = list(MIDI_DIR.glob("**/*.mid")) + list(MIDI_DIR.glob("**/*.midi"))
if midi_paths and not any(TOKEN_DIR.iterdir()):
    tokenizer.tokenize_dataset(midi_paths, TOKEN_DIR)

In [7]:
VOCAB_SIZE = len(tokenizer)
VOCAB_SIZE

478

In [None]:
class MidiDataset(Dataset):
    def __init__(self, tokens_dir: Path, block_size: int, tokenizer, stride=None, augment=True):
        self.block_size = block_size
        self.stride = stride if stride is not None else block_size
        self.tokenizer = tokenizer
        self.augment = augment

        # special tokens
        self.pad_id = tokenizer.special_tokens_ids[0]  # PAD_None
        self.bos_id = tokenizer.special_tokens_ids[1]  # BOS_None
        self.eos_id = tokenizer.special_tokens_ids[2]  # EOS_None
        self.mask_id = tokenizer.special_tokens_ids[3]

        self.samples = []

        for path in tokens_dir.glob("**/*.json"):
            list_of_tracks = tokenizer.load_tokens(path)
            for ids in list_of_tracks:
                ids = [self.bos_id] + list(ids) + [self.eos_id]

                for i in range(0, len(ids), self.stride):
                    chunk = ids[i : i + block_size + 1]
                    if len(chunk) < block_size + 1:
                        pad_chunk = torch.full((block_size + 1,), self.pad_id, dtype=torch.long)
                        pad_chunk[: len(chunk)] = torch.tensor(chunk, dtype=torch.long)
                        chunk = pad_chunk
                    else:
                        chunk = torch.tensor(chunk, dtype=torch.long)
                    self.samples.append(chunk)

        print(f"Loaded {len(self.samples)} chunks (stride={self.stride}) from {tokens_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        chunk = self.samples[idx]

        if self.augment:
            chunk = self.apply_augmentation(chunk.clone())

        return chunk[:-1], chunk[1:]

    def apply_augmentation(self, chunk: torch.Tensor) -> torch.Tensor:
        tokens = [self.tokenizer[t] for t in chunk.tolist()]

        if np.random.rand() < 0.5:
            transpose = np.random.randint(-6, 7)
            tokens = self.transpose(tokens, transpose)

        if np.random.rand() < 0.3:
            jitter = np.random.randint(-2, 3)
            tokens = self.jitter_velocity(tokens, jitter)

        if np.random.rand() < 0.3:
            tokens = self.jitter_timing(tokens)

        return torch.tensor([self.tokenizer.vocab_model.get(t, self.tokenizer.special_tokens_ids[0]) for t in tokens], dtype=torch.long)

    def transpose(self, tokens, semitones):
        new_tokens = []
        for t in tokens:
            if t.startswith("Note-On") or t.startswith("Note-Off"):
                note_type, pitch = t.split("_")
                pitch = int(pitch)
                pitch = max(0, min(127, pitch + semitones))
                new_tokens.append(f"{note_type}_{pitch}")
            else:
                new_tokens.append(t)
        return new_tokens

    def jitter_velocity(self, tokens, jitter):
        new_tokens = []
        nb_velocities = self.tokenizer.config.num_velocities
        for t in tokens:
            if t.startswith("Velocity"):
                _, vel = t.split("_")
                vel = int(vel)
                vel = max(0, min(nb_velocities-1, vel + jitter))
                new_tokens.append(f"Velocity_{vel}")
            else:
                new_tokens.append(t)
        return new_tokens

    def jitter_timing(self, tokens):
        new_tokens = []
        for t in tokens:
            if t.startswith("Time-Shift"):
                _, shift = t.split("_")
                shift = int(shift)
                shift = max(1, shift + np.random.choice([-1, 0, 1]))
                new_tokens.append(f"Time-Shift_{shift}")
            else:
                new_tokens.append(t)
        return new_tokens


In [None]:
class RoPE(nn.Module):
    def __init__(self, head_dim: int, max_len: int):
        super().__init__()
        self.head_dim = head_dim

        theta = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        m = torch.arange(max_len)
        freqs = torch.outer(m, theta)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        
        self.register_buffer("freqs_cis", freqs_cis, persistent=False)

    def forward(self, q, k):
        q_ = q.float().reshape(*q.shape[:-1], -1, 2)
        k_ = k.float().reshape(*k.shape[:-1], -1, 2)
        
        q_complex = torch.view_as_complex(q_)
        k_complex = torch.view_as_complex(k_)
        
        freqs_cis = self.freqs_cis[None, :q.shape[1], None, :]
        
        q_rotated = torch.view_as_real(q_complex * freqs_cis).flatten(3)
        k_rotated = torch.view_as_real(k_complex * freqs_cis).flatten(3)
        
        return q_rotated.type_as(q), k_rotated.type_as(k)

In [10]:
class SelfAttention(nn.Module):
    def __init__(self, config: TransformerConfig, rope: RoPE):
        super().__init__()

        self.qkv_proj = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.o_proj = nn.Linear(config.n_embed, config.n_embed)
        self.o_proj.SCALE = 1

        self.d_model = config.n_embed
        self.n_heads = config.n_heads
        self.d_head = self.d_model // self.n_heads

        self.rope = rope

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.d_model, dim=2)
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1,2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1,2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1,2)

        q, k = self.rope(q, k)

        attn_score = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        attn_score = attn_score.transpose(1,2).contiguous().view(B, T, C)

        return self.o_proj(attn_score)

In [11]:
class MLP(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()

        self.ff1 = nn.Linear(config.n_embed, 4*config.n_embed)
        self.gelu = nn.GELU(approximate='tanh')
        self.ff2 = nn.Linear(config.n_embed*4, config.n_embed)
        self.ff2.SCALE = 1

    def forward(self, x):
        x = self.ff1(x)
        x = self.gelu(x)
        x = self.ff2(x)
        return x

In [12]:
class Layer(nn.Module):
    def __init__(self, config: TransformerConfig, rope: RoPE):
        super().__init__()

        self.ln1 = nn.LayerNorm(config.n_embed)
        self.attn = SelfAttention(config, rope)
        self.ln2 = nn.LayerNorm(config.n_embed)
        self.ff = MLP(config)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = x + self.dropout(self.attn(self.ln1(x)))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x

In [None]:
class MusicTransformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()

        self.config = config
        self.pad_id = tokenizer.special_tokens_ids[0]

        head_dim = config.n_embed // config.n_heads
        rope = RoPE(head_dim, config.block_size)

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embed),
            h = nn.ModuleList([Layer(config, rope) for _ in range(config.n_layers)]),
            ln_final = nn.LayerNorm(config.n_embed)
        ))

        self.lm_head = nn.Linear(config.n_embed, config.vocab_size)

        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self.init_weights)

    def init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, 'SCALE'):
                std *= (2 * self.config.n_layers) ** -0.5 
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
        

    def forward(self, x, targets = None):
        B, T = x.size()
        assert T <= self.config.block_size, "cannot exceed block size"

        x = self.transformer.wte(x)

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_final(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=self.pad_id, label_smoothing=0.1)

        return logits, loss

In [14]:
BLOCK_SIZE = TransformerConfig.block_size
BLOCK_SIZE

2048

In [15]:
BATCH_SIZE = 8
EFFECTIVE_BATCH_SIZE = 32
accum_steps = EFFECTIVE_BATCH_SIZE // BATCH_SIZE

In [16]:
from torch.utils.data import random_split

dataset = MidiDataset(TOKEN_DIR, BLOCK_SIZE, tokenizer, stride=BLOCK_SIZE//2, augment=False)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, test_size])

dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

Loaded 28549 chunks (stride=1024) from tokens


In [17]:
torch.set_float32_matmul_precision("high")
device = torch.device("cuda")
model = MusicTransformer(config=TransformerConfig).to(device)
model = torch.compile(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.999), fused=True, weight_decay=0.01)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=15, T_mult=2, eta_min=1e-5)

In [18]:
@torch.no_grad()
def generate_sequence(model, tokenizer, device, prompt_tokens, max_new_tokens, temperature=0.8, top_k=50):
    model.eval()
    idx = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
    
    for _ in tqdm(range(max_new_tokens), desc="Generating"):
        idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
            
        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        if idx_next.item() == tokenizer['EOS_None']:
            break
            
        idx = torch.cat((idx, idx_next), dim=1)
        
    return idx[0].tolist()

In [19]:
def tokens_to_midi(tokens, tokenizer, output_midi_path="output(1).mid"):
    score_object = tokenizer.decode([tokens])
    output_path = Path(output_midi_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    score_object.dump_midi(str(output_path))

In [20]:
epochs = 200

In [21]:
scaler = torch.amp.GradScaler('cuda')

In [None]:
for epoch in range(epochs):
        total_loss = 0
        model.train()

        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        
        for i, (input_ids, labels) in enumerate(progress_bar):
            input_ids = input_ids.to(device)
            labels = labels.to(device)

            with torch.amp.autocast('cuda'):
                _, loss = model(input_ids, targets=labels)
                loss = loss / accum_steps

            scaler.scale(loss).backward()

            if not torch.isfinite(loss):
                print(f"[WARNING] Loss is nan/inf, skipping step")
                optimizer.zero_grad(set_to_none=True)
                continue

            if (i+1) % accum_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            
            total_loss += loss.item() * accum_steps
            progress_bar.set_postfix(loss=loss.item() * accum_steps)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for (inputs, labels) in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)
                with torch.amp.autocast('cuda'):
                    _, loss = model(inputs, targets=labels)
                val_loss += loss.item()
    
        avg_train_loss = total_loss / len(dataloader)
        avg_val_loss = val_loss / len(test_dataloader)

        scheduler.step()

        if (epoch) % 15 == 0:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
            }, f"checkpoint_epoch_{epoch+1}.pt")

        if (epoch+1) % 5 == 0 or epoch == 0:
            prompt = [tokenizer['BOS_None']]
            generated_tokens = generate_sequence(
                model,
                tokenizer,
                device,
                prompt_tokens=prompt,
                max_new_tokens=2048
            )
            tokens_to_midi(generated_tokens[1:], tokenizer, output_midi_path=f"output_{epoch+1}.mid")
        
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

In [None]:
@torch.no_grad()
def generate_sequence(model, tokenizer, device, prompt_tokens, max_new_tokens, temperature=0.8, top_k=50):
    model.eval()
    idx = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
    
    for _ in tqdm(range(max_new_tokens), desc="Generating"):
        idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
            
        probs = torch.nn.functional.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        if idx_next.item() == tokenizer['EOS_None']:
            break
            
        idx = torch.cat((idx, idx_next), dim=1)
        
    return idx[0].tolist()

In [23]:
def tokens_to_midi(tokens, tokenizer, output_midi_path="output_test(1).mid"):
    score_object = tokenizer.decode([tokens])
    output_path = Path(output_midi_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    score_object.dump_midi(str(output_path))