In [86]:
import os
from tqdm import tqdm
import pretty_midi
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import pandas as pd
import music21

## DATA ANALYSIS

In [None]:
csv_path="maestro-v3.0.0.csv",
midi_root_dir="maestro-v3.0.0"

## DATA PROCESSING

In [41]:
from torch.nn.utils.rnn import pad_sequence

class MusicDataset(Dataset):
    def __init__(self, sequences, max_seq_len=512):
        self.data = []
        for seq in sequences:
            if max_seq_len is not None:
                # Break long sequences into chunks of max_seq_len
                for i in range(0, len(seq), max_seq_len):
                    chunk = seq[i:i + max_seq_len]
                    if len(chunk) > 1:
                        self.data.append(torch.tensor(chunk, dtype=torch.long))
            else:
                self.data.append(torch.tensor(seq, dtype=torch.long))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    batch = pad_sequence(batch, batch_first=False, padding_value=0)
    return batch


class NoteTokenizer:
    def __init__(self):
        self.token_to_id = {}
        self.id_to_token = {}

    def build_vocab(self, sequences):
        unique_tokens = sorted(set(token for seq in sequences for token in seq))
        self.token_to_id = {token: i+1 for i, token in enumerate(unique_tokens)}  # 0 = padding
        self.id_to_token = {i: token for token, i in self.token_to_id.items()}

    def encode(self, sequence):
        return [self.token_to_id[token] for token in sequence if token in self.token_to_id]

    def decode(self, ids):
        return [self.id_to_token[i] for i in ids if i in self.id_to_token]

    def vocab_size(self):
        return len(self.token_to_id) + 1  # add padding token

In [88]:
def detect_key_signature(midi_path):
        """Detect key signature using music21."""
        try:
            score = music21.converter.parse(midi_path)
            key_sig = score.analyze('key')
            return f"{key_sig.tonic.name}_{key_sig.mode}"
        except:
            return None

In [89]:
def notes_to_tokens(notes, key_sig, dur_step=0.05, time_step=0.05):
    """
    notes: List[(start, pitch, duration)]
    returns: List[str] of tokens
    """
    tokens = []
    prev_start = 0.0
    for start, pitch, dur in notes:
        dt = round((start - prev_start) / time_step) * time_step
        d = round(dur / dur_step) * dur_step
        prev_start = start

        tokens.append(f"TIME_SHIFT_{dt:.2f}")
        tokens.append(f"NOTE_ON_{pitch}")
        tokens.append(f"DURATION_{d:.2f}")
        tokens.append(f"KEY_{key_sig}")
    return tokens


def extract_notes_from_midi(midi_path):
    try:
        pm = pretty_midi.PrettyMIDI(midi_path)
        notes = []
        for instr in pm.instruments:
            if instr.is_drum: continue
            for n in instr.notes:
                notes.append((n.start, n.pitch, n.end - n.start))
        notes.sort(key=lambda x: x[0])
        return notes
    except Exception as e:
        print(f"Error loading {midi_path}: {e}")
        return []


def prepare_token_seqs(csv_path, midi_root_dir, tokenizer):
    df = pd.read_csv(csv_path)
    splits = {'train': [], 'validation': [], 'test': []}

    for split in splits:
        files = df[df['split'] == split]['midi_filename']
        for fname in files:
            path = os.path.join(midi_root_dir, fname)
            notes = extract_notes_from_midi(path)
            key_sig = detect_key_signature(path)
            if len(notes) > 20:
                splits[split].append((notes, key_sig))
    
    train_tok_seqs = [
        notes_to_tokens(notes, key_sig)
        for notes, key_sig in splits['train']
    ]
    tokenizer.build_vocab(train_tok_seqs)

    def encode_split(raw_seqs):
        out = []
        for notes in raw_seqs:
            toks = notes_to_tokens(notes)
            ids = tokenizer.encode(toks)
            if len(ids) > 1:
                out.append(ids)
        return out

    return (
        encode_split(splits['train']),
        encode_split(splits['validation']),
        encode_split(splits['test'])
    )


def make_loaders(train_seqs, val_seqs, test_seqs, batch_size=16, max_seq_len=512):
    train_ds = MusicDataset(train_seqs, max_seq_len)
    val_ds   = MusicDataset(val_seqs,   max_seq_len)
    test_ds  = MusicDataset(test_seqs,  max_seq_len)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn),
    )


## MODEL

In [43]:
class SimpleMusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, nhead=2, num_layers=2, dim_ff=128, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                   dim_feedforward=dim_ff, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src):
        # src: (seq_len, batch)
        x = self.token_embedding(src)  # (seq_len, batch, d_model)
        x = self.positional_encoding(x)  # (seq_len, batch, d_model)
        x = self.transformer(x)  # (seq_len, batch, d_model)
        return self.output_layer(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(1))  # (max_len, 1, d_model)

    def forward(self, x):
        # x: (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [6]:


class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dim_feedforward=512, dropout=0.1):
        super(MusicTransformer, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # x: (seq_len, batch_size)
        x = self.token_embedding(x)  # (seq_len, batch_size, d_model)
        x = self.positional_encoding(x)
        x = self.transformer(x)
        logits = self.output_layer(x)
        return logits


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2048):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div)
        pe[:, 1::2] = torch.cos(position * div)
        self.register_buffer('pe', pe.unsqueeze(1))  # (max_len, 1, d_model)

    def forward(self, x):  # x: (seq_len, batch, d_model)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class ConformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_ff, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        # Feed-forward module 1
        self.ffn1 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dim_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model),
            nn.Dropout(dropout)
        )
        # Multi-head self-attention
        self.mha = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.ln_attn = nn.LayerNorm(d_model)
        # Convolution module
        self.conv = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Conv1d(d_model, 2 * d_model, kernel_size=1),
            nn.GLU(dim=1),
            nn.Conv1d(d_model, d_model, kernel_size=conv_kernel_size, padding=conv_kernel_size//2, groups=1),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Conv1d(d_model, d_model, kernel_size=1),
            nn.Dropout(dropout)
        )
        # Feed-forward module 2
        self.ffn2 = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dim_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_ff, d_model),
            nn.Dropout(dropout)
        )
        # Final layer norm
        self.ln_final = nn.LayerNorm(d_model)

    def forward(self, x):  # x: (seq_len, batch, d_model)
        # FFN1 with residual
        x = x + 0.5 * self.ffn1(x)
        # Self-attention with residual
        residual = x
        x = self.ln_attn(x)
        x2, _ = self.mha(x, x, x)
        x = residual + x2
        # Convolution module with residual
        residual = x
        # reshape for conv: (batch, d_model, seq_len)
        x_conv = x.permute(1, 2, 0)
        x_conv = self.conv(x_conv)
        x = residual + x_conv.permute(2, 0, 1)
        # FFN2 with residual
        x = x + 0.5 * self.ffn2(x)
        # final norm
        return self.ln_final(x)

class MusicConformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=4, dim_ff=256,
                 conv_kernel_size=31, dropout=0.1, max_len=2048):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_enc = PositionalEncoding(d_model, dropout, max_len)
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, nhead, dim_ff, conv_kernel_size, dropout)
            for _ in range(num_layers)
        ])
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, src):  # src: (seq_len, batch)
        x = self.embedding(src)  # (seq_len, batch, d_model)
        x = self.pos_enc(x)
        for layer in self.layers:
            x = layer(x)
        return self.output(x)  # (seq_len, batch, vocab_size)


In [90]:
import copy

def train_model(model, train_loader, val_loader, device, epochs=10, lr=1e-4):
    model.to(device)
    loss_fn = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_model = copy.deepcopy(model)
    best_val_loss = 1e8
    for epoch in range(epochs):
        # train
        model.train()
        total_train_loss = 0
        for batch in tqdm(train_loader):
            batch = batch.to(device)
            src = batch[:-1] 
            tgt = batch[1:] 

            optimizer.zero_grad()
            output = model(src)
            loss = loss_fn(output.view(-1, output.size(-1)), tgt.view(-1))
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            #print(f"Epoch {epoch+1}/{epochs} | Batch {i}/{len(train_loader)} | Train Loss: {loss.item():.4f}")
        avg_train_loss = total_train_loss / len(train_loader)
        
        # validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                src = batch[:-1]
                tgt = batch[1:]

                output = model(src)
                loss = loss_fn(output.view(-1, output.size(-1)), tgt.view(-1))
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_loader)

        # save best model
        if avg_val_loss < best_val_loss:
            best_model = copy.deepcopy(model)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    return best_model


## TRAINING PIPELINE

In [None]:
# DATA PROCESSING
print("Preparing data...")
tokenizer = NoteTokenizer()
train_ids, val_ids, test_ids = prepare_token_seqs(
    csv_path="maestro-v3.0.0.csv",
    midi_root_dir="maestro-v3.0.0",
    tokenizer=tokenizer
)

Preparing data...


In [None]:
# DATALOADERS
BATCH_SIZE = 128
MAX_SEQ_LEN = 128

TRAIN_SIZE = 100 # for faster training

train_loader, val_loader, test_loader = make_loaders(train_ids[:TRAIN_SIZE], val_ids, test_ids, max_seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE)

print(tokenizer.vocab_size())

806


In [None]:
# MODEL TRAINING
EPOCHS = 10
LEARNING_RATE = 5e-4

print("Initializing model...")
model = SimpleMusicTransformer(vocab_size=tokenizer.vocab_size())
#model = MusicConformer(vocab_size=tokenizer.vocab_size())

print("Training...")
print(f"train size: {len(train_loader)}")
print(f"validation size: {len(val_loader)}")
print(f"test size: {len(test_loader)}")
model = train_model(model, train_loader, val_loader, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            epochs=EPOCHS, lr=LEARNING_RATE)

Initializing model...
Training...
train size: 234
validation size: 118
test size: 137


  0%|          | 0/234 [00:00<?, ?it/s]

  3%|▎         | 7/234 [00:03<02:00,  1.88it/s]


KeyboardInterrupt: 

## DECODING

In [None]:
def decode_tokens(ids, tokenizer):
    return tokenizer.decode(ids)


def sample_from_model(model, seed_ids, length=100, device="cpu", temperature=1.0):
    """
    Generate sample from model using a seed.
    
    Args:
        model: model
        seed_ids: list of ids referring to tokens
        length: number of notes played
        device: device used to generate sample
        temperature: temperature for logits 
    
    Returns:
        List of generated tokens
    """
    model.eval()
    generated = seed_ids[:]
    input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(1).to(device)  # (seq_len, 1)

    for _ in range(length):
        with torch.no_grad():
            output = model(input_seq)
            logits = output[-1, 0] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

        generated.append(next_token)
        input_seq = torch.tensor(generated, dtype=torch.long).unsqueeze(1).to(device)

    return generated


In [77]:

def tokens_to_notes(tokens):
    """
    Convert token sequence back to notes format.
    
    Args:
        tokens: List of token strings
    
    Returns:
        List of (start_time, pitch, duration) tuples
    """
    notes = []
    current_time = 0.0
    i = 0
    
    while i < len(tokens):
        token = tokens[i]
        if token.startswith("TIME_SHIFT_"):
            try:
                time_shift = float(token.replace("TIME_SHIFT_", ""))
                current_time += time_shift
            except ValueError:
                pass
            i += 1
        elif token.startswith("NOTE_ON_") and i + 1 < len(tokens):
            try:
                pitch = int(token.replace("NOTE_ON_", ""))
                duration_token = tokens[i + 1]
                
                if duration_token.startswith("DURATION_"):
                    duration = float(duration_token.replace("DURATION_", ""))
                    
                    # Ensure valid ranges
                    if 0 <= pitch <= 127 and duration > 0:
                        notes.append((current_time, pitch, duration))
                    
                    i += 2
                else:
                    i += 1
            except (ValueError, IndexError):
                i += 1
        else:
            i += 1
    return notes


def notes_to_midi(notes , 
                  output_path , 
                  program = 0,
                  tempo = 120.0):
    """
    Convert notes to MIDI file using pretty_midi.
    
    Args:
        notes: List of (start_time, pitch, duration) tuples
        output_path: Path to save MIDI file
        program: MIDI program number (instrument)
        tempo: Tempo in BPM
    """
    midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)
    instrument = pretty_midi.Instrument(program=program)
    
    for start_time, pitch, duration in notes:
        pitch = max(0, min(127, int(pitch)))
        
        note = pretty_midi.Note(
            velocity=80,
            pitch=pitch,
            start=float(start_time),
            end=float(start_time + duration)
        )
        instrument.notes.append(note)
    
    # Add instrument to MIDI
    midi.instruments.append(instrument)
    
    # Save MIDI file
    midi.write(output_path)
    print(f"MIDI file saved to: {output_path}")

In [None]:
# DECODING PARAMS
TEMPERATURE = 1.05

model.eval()
generated_ids = sample_from_model(model, seed_ids=[60,62,64], length=200, temperature=TEMPERATURE)
tokens = decode_tokens(generated_ids, tokenizer)
events = tokens_to_notes(tokens)
print(events[:10])
notes_to_midi(events, "generated.mid")

[(0.0, 66, 1.2), (0.0, 42, 1.2), (2.55, 73, 0.6), (3.3499999999999996, 61, 0.6), (3.8999999999999995, 61, 0.6), (4.449999999999999, 61, 0.6), (5.249999999999999, 66, 1.2), (5.799999999999999, 66, 1.2), (6.349999999999999, 66, 1.2), (6.899999999999999, 66, 1.2)]
MIDI file saved to: generated.mid
