In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("kritanjalijain/maestropianomidi")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/kritanjalijain/maestropianomidi?dataset_version_number=1...


100%|██████████| 55.8M/55.8M [00:00<00:00, 152MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/kritanjalijain/maestropianomidi/versions/1


In [None]:
pip install mido pretty_midi midi-neural-processor

Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Collecting pretty_midi
  Downloading pretty_midi-0.2.11.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting midi-neural-processor
  Downloading midi_neural_processor-1.0.3-py3-none-any.whl.metadata (3.0 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading midi_neural_processor-1.0.3-py3-none-any.whl (5.6 kB)
Building wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.11-py3-none-any.whl size=5595886 sha256=bc8a551086f8d237aad9d1a56d0217bc350f5de72eaabf909b1d663d3b7e14b5
  Stored in directory: /root/.cache/pip/wheels

In [None]:
import requests
import torch
import torch.nn as nn
from torch.nn import functional as F
import random
import numpy as np
import midi_neural_processor.processor as midi_tokenizer
import os

In [None]:
#hyperparameters
vocab_size = 512
batch_size = 64 # N of independent sequneces processed in parallel
block_size = 512 # the maximum context length for prediction
max_iters = 10000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 768
n_layer = 12
n_head = 8
dropout = 0.1

# Sparse attention parameters for MIDI
window_size = 32  # Local sliding window size
stride_size = 4   # Strided attention every 4 positions (beat level)
num_global_tokens = 4  # Number of global attention positions

print(device)
#----------------------------------------------------------

torch.manual_seed(555)

#----------------------------------------------------------

cuda


<torch._C.Generator at 0x7a37f3ef2cd0>

In [None]:
folder_path = '/kaggle/input/maestropianomidi'

data = []
# Recursively walk through all subdirectories
for root, dirs, files in os.walk(folder_path):
    for file in files:
        full_path = os.path.join(root, file)
        try:
            tokens_cur = midi_tokenizer.encode_midi(full_path)
            if len(tokens_cur) > 1000:
                tokens_cur = torch.tensor(tokens_cur)
                data.append(tokens_cur)
        except:
            pass

# Split to train and val
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    #generate a small batch of data of inputs
    data = train_data if split == 'train' else val_data
    index = random.randint(0, len(data)-1)
    data = data[index]
    ix = torch.randint(len(data)-block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
len(train_data)

1145

In [None]:
def create_sparse_attention_mask(seq_len, window_size=32, stride_size=4, num_global=4):
    """
    Create sparse attention mask for MIDI generation:
    - Local sliding window for note-to-note relationships
    - Strided pattern for beat/measure relationships
    - Global tokens for long-range dependencies
    """
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)

    for i in range(seq_len):
        # 1. Local sliding window (±window_size positions)
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = True

        # 2. Strided attention (every stride_size positions)
        strided_positions = torch.arange(0, seq_len, stride_size)
        mask[i, strided_positions] = True

        # 3. Global tokens (first few positions attend to/from everything)
        if i < num_global:
            mask[i, :] = True  # Global tokens attend to everything
        mask[:, i] = mask[:, i] | (i < num_global)  # Everything attends to global tokens

    # Ensure causal masking (no future attention)
    causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    mask = mask & causal_mask

    return mask

In [None]:
# Rotary Position Embedding (RoPE) Implementation
def create_rope_cache(seq_len, dim, theta=10000.0, device='cpu'):
    """
    Create RoPE (Rotary Position Embedding) cache for efficient computation
    """
    # Create position indices
    pos = torch.arange(seq_len, device=device, dtype=torch.float32)

    # Create frequency tensor
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))

    # Create frequency matrix: pos x freqs
    freqs = torch.outer(pos, freqs)  # (seq_len, dim//2)

    # Create cos and sin components
    cos_cached = torch.cos(freqs)  # (seq_len, dim//2)
    sin_cached = torch.sin(freqs)  # (seq_len, dim//2)

    return cos_cached, sin_cached

def apply_rope(x, cos_cached, sin_cached):
    """
    Apply rotary position embedding to input tensor x
    x: (batch_size, seq_len, n_heads, head_dim) or (batch_size, seq_len, head_dim)
    """
    *batch_dims, seq_len, d = x.shape

    # Ensure we don't exceed cache length
    seq_len = min(seq_len, cos_cached.shape[0])

    # Get the cos/sin values for this sequence length
    cos = cos_cached[:seq_len]  # (seq_len, d//2)
    sin = sin_cached[:seq_len]  # (seq_len, d//2)

    # Reshape x to separate even/odd dimensions
    x1 = x[..., ::2]   # Even indices: (batch_dims, seq_len, d//2)
    x2 = x[..., 1::2]  # Odd indices: (batch_dims, seq_len, d//2)

    # Apply rotation
    # Expand cos/sin to match x dimensions
    cos_expanded = cos.view(*([1] * len(batch_dims)), seq_len, -1)
    sin_expanded = sin.view(*([1] * len(batch_dims)), seq_len, -1)

    # Rotary transformation
    rotated_x1 = x1 * cos_expanded - x2 * sin_expanded
    rotated_x2 = x1 * sin_expanded + x2 * cos_expanded

    # Interleave back to original format
    rotated = torch.stack([rotated_x1, rotated_x2], dim=-1)
    rotated = rotated.flatten(start_dim=-2)  # Merge last two dims

    return rotated

In [None]:
class SparseRoPEHead(nn.Module):
    """Sparse attention head with RoPE (Rotary Position Embedding) for MIDI generation"""

    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Initialize mask cache as regular attribute
        self._mask_cache = {}

        # Initialize RoPE cache
        self._rope_cos_cache = None
        self._rope_sin_cache = None

        self.dropout = nn.Dropout(dropout)

    def get_sparse_mask(self, seq_len):
        """Get or create sparse attention mask for given sequence length"""
        if seq_len not in self._mask_cache:
            mask = create_sparse_attention_mask(
                seq_len,
                window_size=window_size,
                stride_size=stride_size,
                num_global=num_global_tokens
            )
            # Store on correct device
            mask = mask.to(self.key.weight.device)
            # Cache the mask
            self._mask_cache[seq_len] = mask

        return self._mask_cache[seq_len]

    def get_rope_cache(self, seq_len):
        """Get or create RoPE cache for given sequence length"""
        if (self._rope_cos_cache is None or
            self._rope_sin_cache is None or
            self._rope_cos_cache.shape[0] < seq_len):

            cos_cached, sin_cached = create_rope_cache(
                max(seq_len, 512),  # Cache a bit more for efficiency
                self.head_size,
                device=self.key.weight.device
            )
            self._rope_cos_cache = cos_cached
            self._rope_sin_cache = sin_cached

        return self._rope_cos_cache, self._rope_sin_cache

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)

        # Apply RoPE to queries and keys
        cos_cached, sin_cached = self.get_rope_cache(T)
        q = apply_rope(q, cos_cached, sin_cached)
        k = apply_rope(k, cos_cached, sin_cached)

        # Compute attention scores
        wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)  # (B, T, T)

        # Apply sparse attention mask
        sparse_mask = self.get_sparse_mask(T)
        wei = wei.masked_fill(~sparse_mask, float('-inf'))

        # Softmax and dropout
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Weighted aggregation (no RoPE needed for values)
        out = wei @ v
        return out

class SparseHead(nn.Module):
    """Original Sparse attention head (without RoPE) for comparison"""

    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Initialize mask cache as regular attribute
        self._mask_cache = {}

        self.dropout = nn.Dropout(dropout)

    def get_sparse_mask(self, seq_len):
        """Get or create sparse attention mask for given sequence length"""
        if seq_len not in self._mask_cache:
            mask = create_sparse_attention_mask(
                seq_len,
                window_size=window_size,
                stride_size=stride_size,
                num_global=num_global_tokens
            )
            # Store on correct device
            mask = mask.to(self.key.weight.device)
            # Cache the mask
            self._mask_cache[seq_len] = mask

        return self._mask_cache[seq_len]

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)

        # Compute attention scores
        wei = q @ k.transpose(-2, -1) * (self.head_size ** -0.5)  # (B, T, T)

        # Apply sparse attention mask
        sparse_mask = self.get_sparse_mask(T)
        wei = wei.masked_fill(~sparse_mask, float('-inf'))

        # Softmax and dropout
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Weighted aggregation
        out = wei @ v
        return out

class Head(nn.Module):
    """Original dense attention head for comparison"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # (B, T, 16)
        q = self.query(x) # (B, T, 16)

        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
        # Remove future token to not communicate with them
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        # Softmax to get values that sum up to 1 - normalization
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # perfrom weighted aggregation of values
        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size, use_sparse=True, use_rope=True):
        super().__init__()

        # Choose attention head type based on configuration
        if use_sparse and use_rope:
            HeadClass = SparseRoPEHead
        elif use_sparse:
            HeadClass = SparseHead
        else:
            HeadClass = Head

        self.heads = nn.ModuleList([HeadClass(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    """ Simple layer followed by non-linearity"""

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation"""

    def __init__(self, n_embd, n_head, use_sparse=True, use_rope=True):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size, use_sparse=use_sparse, use_rope=use_rope)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Transformer Model with RoPE
class TransformerModel(nn.Module):

    def __init__(self, use_sparse=True, use_rope=True):
        super().__init__()
        self.use_rope = use_rope

        # Token embeddings
        self.token_enbedding_table = nn.Embedding(vocab_size, n_embd)

        # Positional embeddings (only used if not using RoPE)
        if not use_rope:
            self.positiion_embedding_table = nn.Embedding(block_size, n_embd)
        else:
            self.positiion_embedding_table = None

        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head, use_sparse=use_sparse, use_rope=use_rope) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # Final norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Token embeddings
        token_emb = self.token_enbedding_table(idx) # (B,T,C)

        # Add positional embeddings only if not using RoPE
        if self.use_rope:
            x = token_emb  # RoPE handles position encoding in attention
        else:
            pos_emb = self.positiion_embedding_table(torch.arange(T, device=device)) # (T,C)
            x = token_emb + pos_emb # (B,T,C)

        x = self.blocks(x) # apply self attention
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx=None, max_new_tokens=256, temperature=1.0, top_k=None, top_p=0.9):
          """
          Generate MIDI tokens with improved sampling strategies

          Args:
              idx: Initial sequence (optional). If None, starts with common MIDI start tokens
              max_new_tokens: Number of tokens to generate
              temperature: Sampling temperature (higher = more random)
              top_k: Keep only top k tokens for sampling
              top_p: Nucleus sampling threshold
          """
          # Smart initialization if no initial sequence provided
          if idx is None:
              # Start with common MIDI beginning tokens
              # These typically include tempo, time signature, and initial note events
              start_tokens = [
                  1,   # Start of sequence token (if your tokenizer uses this)
                  64,  # Common tempo token
                  32,  # Time signature token
                  60,  # Middle C note
              ]
              # Pad with a few more reasonable starting tokens
              start_tokens.extend([65, 67, 69])  # C major chord notes

              idx = torch.tensor(start_tokens, device=device).unsqueeze(0)  # (1, start_length)

          # Ensure idx is on the correct device and has batch dimension
          if idx.dim() == 1:
              idx = idx.unsqueeze(0)
          idx = idx.to(device)

          self.eval()  # Set to evaluation mode

          with torch.no_grad():
              for i in range(max_new_tokens):
                  # Crop context if it exceeds block_size
                  idx_cond = idx[:, -block_size:]

                  # Get predictions
                  logits, _ = self(idx_cond)

                  # Focus only on the last time step
                  logits = logits[:, -1, :] / temperature  # (B, C)

                  # Apply top-k filtering
                  if top_k is not None:
                      v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                      logits[logits < v[:, [-1]]] = -float('inf')

                  # Apply top-p (nucleus) filtering
                  if top_p < 1.0:
                      sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                      cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                      # Remove tokens with cumulative probability above the threshold
                      sorted_indices_to_remove = cumulative_probs > top_p
                      # Shift the indices to the right to keep also the first token above the threshold
                      sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                      sorted_indices_to_remove[..., 0] = 0

                      # Scatter sorted tensors to original indexing
                      indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                      logits[indices_to_remove] = -float('inf')

                  # Convert to probabilities
                  probs = F.softmax(logits, dim=-1)

                  # Sample from the distribution
                  idx_next = torch.multinomial(probs, num_samples=1)

                  # Append to the sequence
                  idx = torch.cat((idx, idx_next), dim=1)

                  # Optional: Add some musical structure by encouraging certain patterns
                  # You could add logic here to bias towards musically coherent sequences

          self.train()  # Return to training mode
          return idx

In [None]:
# Create model with sparse attention
model = TransformerModel(use_sparse=True, use_rope=True)
m = model.to(device)
best_metric = float('inf')  # Initialize with a large value for loss, or -inf for accuracy
best_model_state = None

# Add optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, weight_decay=0.001)
scaler = torch.GradScaler()

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=1000,
    T_mult=2,
    eta_min=1e-6
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")

Total trainable parameters: 85,815,296


In [None]:
for iter in range(max_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        if losses['val'] < best_metric:  # Update condition based on your metric and task
            best_metric = losses['val']
            best_model_state = model.state_dict().copy()  # Save the model state
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    optimizer.zero_grad()

    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits, loss = m(xb, yb)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    scheduler.step()
#-----------------------------------------------------------------------

step 0: train loss 6.3803, val loss 6.3864
step 500: train loss 3.6963, val loss 3.7382
step 1000: train loss 3.2139, val loss 3.2182
step 1500: train loss 2.8565, val loss 2.8594
step 2000: train loss 2.5657, val loss 2.6215
step 2500: train loss 2.4466, val loss 2.5107
step 3000: train loss 2.3973, val loss 2.4850
step 3500: train loss 2.4105, val loss 2.4634
step 4000: train loss 2.3303, val loss 2.4135
step 4500: train loss 2.2319, val loss 2.3577
step 5000: train loss 2.1753, val loss 2.3039
step 5500: train loss 2.1036, val loss 2.2530
step 6000: train loss 2.0514, val loss 2.2427
step 6500: train loss 2.0422, val loss 2.2109
step 7000: train loss 2.0420, val loss 2.2050
step 7500: train loss 2.1124, val loss 2.2887
step 8000: train loss 2.1193, val loss 2.2679
step 8500: train loss 2.0359, val loss 2.2173
step 9000: train loss 2.0522, val loss 2.2558
step 9500: train loss 1.9937, val loss 2.1891


In [None]:
torch.save(best_model_state, 'best-midi-classical-sound.pth')

In [None]:
model.load_state_dict(torch.load("/content/best-midi-classical-sound.pth"))

<All keys matched successfully>

In [None]:
start = val_data[0][:256].unsqueeze(0).to(device)
tokens = model.generate(start, 256)
tokens = torch.clamp(tokens, max=383)

tokens = tokens.squeeze(0).tolist()
midi_tokenizer.decode_midi(tokens, 'output.mid')

<pretty_midi.pretty_midi.PrettyMIDI at 0x7ab271ae8d10>

In [None]:
tokens = model.generate(max_new_tokens=512, temperature=1.2, top_k=80, top_p=0.95)

In [None]:
tokens = torch.clamp(tokens, max=383)
tokens = tokens.squeeze(0).tolist()
midi_tokenizer.decode_midi(tokens, 'output.mid')

info removed pitch: 84
info removed pitch: 41
info removed pitch: 59
info removed pitch: 52
info removed pitch: 81
info removed pitch: 76
info removed pitch: 77
info removed pitch: 77
info removed pitch: 73


<pretty_midi.pretty_midi.PrettyMIDI at 0x7a37f2c64590>