# Piano MIDI Generation - Complete Pipeline (Kaggle)

This notebook contains the complete pipeline for training a Transformer model on the ARIA MIDI dataset:

1. **Preprocessing** - Metadata analysis, tokenization, dataset creation
2. **Model Architecture** - GPT-style decoder-only transformer
3. **Training** - Full training loop with validation and checkpointing

## Optimized for Kaggle P100 GPU
- Larger batch sizes for better GPU utilization
- Memory-efficient data loading
- Full preprocessing + training pipeline

## Dataset Location
- Input: `/kaggle/input/aria-midi-v1-deduped-ext`
- Output: `/kaggle/working/` (checkpoints, processed data)


## Step 1: Setup and Imports


In [1]:
# Install required packages (Kaggle has most packages, but just in case)
import sys
!{sys.executable} -m pip install mido pretty_midi tqdm -q


'c:\Users\Vikas' is not recognized as an internal or external command,
operable program or batch file.


In [None]:
# Import all libraries
import os
import json
import math
import random
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import mido
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime

print("‚úÖ All libraries imported")


In [None]:
# Kaggle paths
INPUT_DIR = Path("/kaggle/input/aria-midi-v1-deduped-ext")
WORKING_DIR = Path("/kaggle/working")

# Create working directories
WORKING_DIR.mkdir(exist_ok=True)
(WORKING_DIR / "processed_data").mkdir(exist_ok=True)
(WORKING_DIR / "checkpoints").mkdir(exist_ok=True)

print(f"‚úÖ Input directory: {INPUT_DIR}")
print(f"‚úÖ Working directory: {WORKING_DIR}")
print(f"‚úÖ Dataset exists: {INPUT_DIR.exists()}")


In [None]:
# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ Using GPU: {gpu_name}")
    print(f"   Total memory: {gpu_memory:.2f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è  Using CPU (CUDA not available)")


## Step 2: Model Architecture

Define the Transformer model classes inline (from model.py)


In [None]:
# Model Architecture (from model.py)

class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class TransformerBlock(nn.Module):
    """Single transformer decoder block"""
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, attn_mask=None, key_padding_mask=None):
        # Check for NaN/Inf
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        attn_out, _ = self.attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
        
        if torch.isnan(attn_out).any() or torch.isinf(attn_out).any():
            attn_out = torch.nan_to_num(attn_out, nan=0.0, posinf=1.0, neginf=-1.0)
        
        x = self.norm1(x + self.dropout(attn_out))
        
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        ff_out = self.ff(x)
        
        if torch.isnan(ff_out).any() or torch.isinf(ff_out).any():
            ff_out = torch.nan_to_num(ff_out, nan=0.0, posinf=1.0, neginf=-1.0)
        
        x = self.norm2(x + ff_out)
        
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        return x


class PianoMIDIGenerator(nn.Module):
    """GPT-style decoder-only transformer for conditional MIDI generation"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.d_model = config['d_model']
        
        self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
        self.pos_encoding = PositionalEncoding(config['d_model'], config['max_seq_length'])
        
        self.blocks = nn.ModuleList([
            TransformerBlock(config['d_model'], config['n_heads'], config['d_ff'], config['dropout'])
            for _ in range(config['n_layers'])
        ])
        
        self.ln_f = nn.LayerNorm(config['d_model'])
        self.head = nn.Linear(config['d_model'], config['vocab_size'], bias=False)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight, gain=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)
    
    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape
        
        x = self.embedding(input_ids)
        
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        x = self.pos_encoding(x)
        
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Causal mask: True = masked (don't attend)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        
        if attention_mask is not None:
            key_padding_mask = (attention_mask == 0)
            attn_mask = causal_mask
        else:
            attn_mask = causal_mask
            key_padding_mask = None
        
        for i, block in enumerate(self.blocks):
            x = block(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
            if torch.isnan(x).any() or torch.isinf(x).any():
                x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        x = self.ln_f(x)
        
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1.0, neginf=-1.0)
        
        logits = self.head(x)
        
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            logits = torch.clamp(logits, min=-50.0, max=50.0)
            logits = torch.nan_to_num(logits, nan=0.0, posinf=50.0, neginf=-50.0)
        
        return logits

print("‚úÖ Model architecture defined")


In [None]:
# Load metadata
metadata_path = INPUT_DIR / "metadata.json"

print(f"Loading metadata from: {metadata_path}")
with open(metadata_path, 'r', encoding='utf-8') as f:
    metadata = json.load(f)

print(f"‚úÖ Loaded {len(metadata):,} entries")
print(f"Sample entry keys: {list(metadata.keys())[:5]}")


In [None]:
# Metadata tokenizer
class MetadataTokenizer:
    def __init__(self, include_composer=True, top_n_composers=100):
        self.include_composer = include_composer
        self.valid_genres = {'classical', 'pop', 'soundtrack', 'jazz', 'rock', 'folk', 'ambient', 'ragtime', 'blues', 'atonal'}
        self.valid_periods = {'contemporary', 'modern', 'romantic', 'classical', 'baroque', 'impressionist'}
        self.top_composers = self._load_top_composers(top_n_composers)
    
    def _load_top_composers(self, n):
        top = {'hisaishi', 'satie', 'yiruma', 'einaudi', 'joplin', 'chopin', 'beethoven', 'bach', 'mozart', 'debussy',
               'schubert', 'schumann', 'liszt', 'rachmaninoff', 'tchaikovsky', 'ravel', 'poulenc', 'faure', 'bartok'}
        return {self._normalize_composer(c) for c in top}
    
    def _normalize_composer(self, composer):
        if not composer:
            return ""
        normalized = composer.lower().strip()
        normalized = normalized.replace('√©', 'e').replace('√®', 'e').replace('√°', 'a').replace('√†', 'a')
        normalized = normalized.replace('√≠', 'i').replace('√¨', 'i').replace('√≥', 'o').replace('√≤', 'o')
        normalized = normalized.replace('√∫', 'u').replace('√π', 'u').replace('√±', 'n')
        normalized = re.sub(r'[^a-z0-9\s-]', '', normalized)
        normalized = re.sub(r'\s+', ' ', normalized).strip()
        return normalized
    
    def metadata_to_tokens(self, metadata, include_start=True):
        tokens = []
        if include_start:
            tokens.append("START")
        
        if metadata.get('genre'):
            genre = metadata['genre'].lower().strip()
            if genre in self.valid_genres:
                tokens.append(f"GENRE:{genre}")
        
        if metadata.get('music_period'):
            period = metadata['music_period'].lower().strip()
            if period in self.valid_periods:
                tokens.append(f"PERIOD:{period}")
        
        if self.include_composer and metadata.get('composer'):
            composer = self._normalize_composer(metadata['composer'])
            if composer in self.top_composers:
                tokens.append(f"COMPOSER:{composer}")
        
        return tokens

meta_tokenizer = MetadataTokenizer(include_composer=True)
print("‚úÖ Metadata tokenizer created")


In [None]:
# MIDI tokenizer
class MIDITokenizer:
    def __init__(self, time_quantization=10):
        self.time_quantization = time_quantization
    
    def midi_to_tokens(self, midi_path: Path) -> List[str]:
        try:
            mid = mido.MidiFile(midi_path)
            tokens = []
            current_time = 0
            
            for track in mid.tracks:
                for msg in track:
                    current_time += int(msg.time)
                    quantized_time = (current_time // self.time_quantization) * self.time_quantization
                    
                    if msg.type == 'note_on' and msg.velocity > 0:
                        if quantized_time > 0:
                            tokens.append(f"TIME_SHIFT:{quantized_time}")
                        tokens.append(f"NOTE_ON:{msg.note}")
                        tokens.append(f"VELOCITY:{msg.velocity}")
                        current_time = 0
                    elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                        if quantized_time > 0:
                            tokens.append(f"TIME_SHIFT:{quantized_time}")
                        tokens.append(f"NOTE_OFF:{msg.note}")
                        current_time = 0
            
            return tokens
        except Exception as e:
            return []

midi_tokenizer = MIDITokenizer(time_quantization=10)
print("‚úÖ MIDI tokenizer created")


In [None]:
# Find MIDI file helper
def find_midi_file(file_id: str, audio_index: str, data_root: Path) -> Optional[Path]:
    padded_id = file_id.zfill(6)
    filename = f"{padded_id}_{audio_index}.mid"
    
    for subfolder in data_root.iterdir():
        if subfolder.is_dir() and len(subfolder.name) == 2:
            filepath = subfolder / filename
            if filepath.exists():
                return filepath
    return None

print("‚úÖ File finder helper defined")


## Step 4: Dataset Processing - Balanced Sampling and Tokenization


In [None]:
# Preprocessing configuration
PREPROCESS_CONFIG = {
    'min_quality_score': 0.97,
    'max_sequence_length': 2048,
    'time_quantization': 10,
    'data_root': INPUT_DIR / "data",
}

# Balanced sampling config
SAMPLING_CONFIG = {
    'max_per_composer': 500,
    'max_empty_metadata_ratio': 0.05,
    'composer_strategy': 'balanced',
}

def analyze_metadata_distribution(metadata_dict):
    stats = {
        'by_composer': defaultdict(int),
        'by_genre': defaultdict(int),
        'empty_metadata': [],
        'with_composer': [],
        'no_composer': [],
    }
    
    for entry_id, entry_data in metadata_dict.items():
        metadata = entry_data.get('metadata', {})
        audio_scores = entry_data.get('audio_scores', {})
        
        if not audio_scores:
            continue
        best_score = max(audio_scores.values())
        if best_score < PREPROCESS_CONFIG['min_quality_score']:
            continue
        
        genre = metadata.get('genre', '').lower() if metadata.get('genre') else None
        composer = metadata.get('composer', '').lower() if metadata.get('composer') else None
        
        if not metadata:
            stats['empty_metadata'].append(entry_id)
        else:
            if genre:
                stats['by_genre'][genre] += 1
            if composer:
                stats['by_composer'][composer] += 1
                stats['with_composer'].append(entry_id)
            else:
                stats['no_composer'].append(entry_id)
    
    return stats

def create_balanced_sample(metadata_dict, sampling_config, tokenizer):
    random.seed(42)
    distribution = analyze_metadata_distribution(metadata_dict)
    
    balanced_ids = []
    composer_samples = defaultdict(int)
    
    # Sample top composers
    top_composers = sorted(distribution['by_composer'].items(), key=lambda x: x[1], reverse=True)[:20]
    
    for composer, _ in top_composers:
        normalized = tokenizer._normalize_composer(composer)
        if normalized in tokenizer.top_composers:
            composer_ids = [eid for eid in distribution['with_composer'] 
                          if tokenizer._normalize_composer(metadata_dict[eid].get('metadata', {}).get('composer', '')) == normalized]
            random.shuffle(composer_ids)
            sample_count = min(sampling_config['max_per_composer'], len(composer_ids))
            balanced_ids.extend(composer_ids[:sample_count])
    
    # Add no-composer samples
    random.shuffle(distribution['no_composer'])
    no_composer_count = len(balanced_ids)  # Match composer count
    balanced_ids.extend(distribution['no_composer'][:no_composer_count])
    
    random.shuffle(balanced_ids)
    return balanced_ids

print("‚úÖ Preprocessing functions defined")


In [None]:
# Process full dataset
def process_dataset(metadata_dict, data_root, balanced_ids=None):
    sequences = []
    
    entries_to_process = balanced_ids if balanced_ids else list(metadata_dict.keys())
    
    for entry_id in tqdm(entries_to_process, desc="Processing entries"):
        entry_data = metadata_dict.get(entry_id)
        if not entry_data:
            continue
        
        audio_scores = entry_data.get('audio_scores', {})
        if not audio_scores:
            continue
        
        best_idx = max(audio_scores.items(), key=lambda x: x[1])[0]
        score = audio_scores[best_idx]
        
        if score < PREPROCESS_CONFIG['min_quality_score']:
            continue
        
        metadata_dict_entry = entry_data.get('metadata', {})
        metadata_tokens = meta_tokenizer.metadata_to_tokens(metadata_dict_entry, include_start=True)
        
        midi_path = find_midi_file(entry_id, best_idx, data_root)
        if not midi_path or not midi_path.exists():
            continue
        
        midi_tokens = midi_tokenizer.midi_to_tokens(midi_path)
        if not midi_tokens:
            continue
        
        full_sequence = metadata_tokens + midi_tokens + ["<END>"]
        
        if len(full_sequence) > PREPROCESS_CONFIG['max_sequence_length']:
            metadata_len = len(metadata_tokens)
            max_midi_len = PREPROCESS_CONFIG['max_sequence_length'] - metadata_len - 1
            full_sequence = metadata_tokens + midi_tokens[:max_midi_len] + ["<END>"]
        
        sequences.append({
            'entry_id': entry_id,
            'tokens': full_sequence
        })
    
    return sequences

# Check if processed data already exists
output_dir = WORKING_DIR / "processed_data"
sequences_file = output_dir / "sequences.json"
vocab_file = output_dir / "vocab.json"
id_to_token_file = output_dir / "id_to_token.json"

if sequences_file.exists() and vocab_file.exists() and id_to_token_file.exists():
    print("‚úÖ Found existing processed data - loading from files...")
    print(f"   Loading from: {output_dir}")
    
    # Load sequences
    with open(sequences_file, 'r') as f:
        sequences_data = json.load(f)
    
    # Convert loaded sequences back to format with 'tokens' (we'll rebuild from token_ids)
    all_sequences = sequences_data
    
    print(f"‚úÖ Loaded {len(all_sequences):,} sequences from existing files")
    print("‚ö†Ô∏è  Note: Will rebuild vocabulary and token IDs in next step")
else:
    print("üìù No existing processed data found - starting preprocessing...")
    
    # Create balanced sample
    print("Creating balanced dataset sample...")
    balanced_ids = create_balanced_sample(metadata, SAMPLING_CONFIG, meta_tokenizer)
    balanced_metadata = {eid: metadata[eid] for eid in balanced_ids if eid in metadata}
    
    print(f"‚úÖ Balanced sample: {len(balanced_metadata):,} entries")
    
    # Process dataset
    print("\nProcessing MIDI files...")
    all_sequences = process_dataset(balanced_metadata, PREPROCESS_CONFIG['data_root'], balanced_ids)
    
    print(f"\n‚úÖ Processed {len(all_sequences):,} sequences")


In [None]:
# Build vocabulary or load existing
output_dir = WORKING_DIR / "processed_data"
vocab_file = output_dir / "vocab.json"
id_to_token_file = output_dir / "id_to_token.json"
sequences_file = output_dir / "sequences.json"

if vocab_file.exists() and id_to_token_file.exists() and sequences_file.exists():
    print("‚úÖ Loading existing vocabulary and sequences...")
    
    # Load vocabulary
    with open(vocab_file, 'r') as f:
        vocab = json.load(f)
    
    # Load id_to_token
    with open(id_to_token_file, 'r') as f:
        id_to_token = json.load(f)
    
    # Load sequences (should already have token_ids)
    with open(sequences_file, 'r') as f:
        all_sequences = json.load(f)
    
    # Ensure vocab values are ints and id_to_token keys are ints
    vocab = {str(k): int(v) if isinstance(v, str) else int(v) for k, v in vocab.items()}
    id_to_token = {int(k): str(v) if not isinstance(v, str) else v for k, v in id_to_token.items()}
    
    # Ensure all_sequences have token_ids (in case they don't)
    for seq in all_sequences:
        if 'token_ids' not in seq or not seq['token_ids']:
            # This shouldn't happen if files were saved correctly, but just in case
            print(f"‚ö†Ô∏è  Warning: Sequence {seq.get('entry_id', 'unknown')} missing token_ids")
    
    vocab_size = len(vocab)
    pad_token_id = vocab.get('<PAD>', 0)
    
    print(f"‚úÖ Loaded vocabulary: {vocab_size:,} tokens")
    print(f"‚úÖ Loaded sequences: {len(all_sequences):,}")
    print(f"‚úÖ Data loaded from: {output_dir}")
else:
    print("üìù Building vocabulary from processed sequences...")
    
    # Build vocabulary from tokens
    all_tokens = set()
    for seq in all_sequences:
        all_tokens.update(seq['tokens'])
    
    vocab = {
        '<PAD>': 0,
        '<UNK>': 1,
        '<START>': 2,
        '<END>': 3,
    }
    
    for token in sorted(all_tokens):
        if token not in vocab:
            vocab[token] = len(vocab)
    
    # Convert sequences to token IDs
    for seq in all_sequences:
        seq['token_ids'] = [vocab.get(token, vocab['<UNK>']) for token in seq['tokens']]
    
    id_to_token = {v: k for k, v in vocab.items()}
    vocab_size = len(vocab)
    pad_token_id = vocab['<PAD>']
    
    print(f"‚úÖ Vocabulary size: {vocab_size:,}")
    print(f"‚úÖ Total sequences: {len(all_sequences):,}")
    
    # Save processed data
    output_dir.mkdir(exist_ok=True)
    
    with open(output_dir / "vocab.json", 'w') as f:
        json.dump(vocab, f)
    
    with open(output_dir / "id_to_token.json", 'w') as f:
        json.dump(id_to_token, f)
    
    sequences_to_save = [{'entry_id': s['entry_id'], 'token_ids': s['token_ids']} for s in all_sequences]
    with open(output_dir / "sequences.json", 'w') as f:
        json.dump(sequences_to_save, f)
    
    print(f"\n‚úÖ Saved processed data to: {output_dir}")


## Step 5: Model Configuration and Dataset Setup


In [None]:
# Model configuration (optimized for P100)
MODEL_CONFIG = {
    'vocab_size': vocab_size,
    'd_model': 512,
    'n_heads': 8,
    'n_layers': 6,
    'd_ff': 2048,
    'dropout': 0.1,
    'max_seq_length': 2048,
}

# Training configuration (optimized for P100 - memory efficient)
TRAIN_CONFIG = {
    'num_epochs': 50,
    'learning_rate': 6e-5,
    'weight_decay': 0.1,
    'warmup_steps': 500,
    'max_grad_norm': 1.0,
    'batch_size': 4,  # Reduced for 16GB GPU memory
    'accumulation_steps': 4,  # Effective batch size: 16
    'eval_steps': 100,
    'save_steps': 500,
    'patience': 5,
    'checkpoint_dir': WORKING_DIR / "checkpoints",
}

print("‚úÖ Model and training config set")
print(f"   Model params: {MODEL_CONFIG['d_model']}d_model, {MODEL_CONFIG['n_layers']} layers, {MODEL_CONFIG['n_heads']} heads")
print(f"   Training: batch_size={TRAIN_CONFIG['batch_size']}, effective={TRAIN_CONFIG['batch_size'] * TRAIN_CONFIG['accumulation_steps']}")


In [None]:
# Dataset class
class MIDIDataset(Dataset):
    def __init__(self, sequences, vocab, max_length=2048):
        self.sequences = sequences
        self.vocab = vocab
        self.pad_token_id = vocab.get('<PAD>', 0)
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        token_ids = self.sequences[idx]['token_ids']
        if len(token_ids) > self.max_length:
            token_ids = token_ids[:self.max_length]
        input_ids = token_ids[:-1]
        target_ids = token_ids[1:]
        pad_len = self.max_length - len(input_ids)
        if pad_len > 0:
            input_ids = input_ids + [self.pad_token_id] * pad_len
            target_ids = target_ids + [self.pad_token_id] * pad_len
        attention_mask = [1] * len(token_ids[:-1]) + [0] * pad_len
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'target_ids': torch.tensor(target_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        }

# Create dataset and dataloaders
dataset = MIDIDataset(all_sequences, vocab, max_length=MODEL_CONFIG['max_seq_length'])
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,  # Reduced to 0 to save memory
    pin_memory=True,
    persistent_workers=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=TRAIN_CONFIG['batch_size'],
    shuffle=False,
    num_workers=0,  # Reduced to 0 to save memory
    pin_memory=True,
    persistent_workers=False
)

print(f"‚úÖ Dataset: Train={len(train_dataset):,}, Val={len(val_dataset):,}")


In [None]:
# Clear cache before model creation
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Initialize model
model = PianoMIDIGenerator(MODEL_CONFIG)
model = model.to(device)

print(f"‚úÖ Model created: {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
print(f"‚úÖ Model on device: {next(model.parameters()).device}")

if torch.cuda.is_available():
    torch.cuda.empty_cache()  # Clear cache after model creation
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    print(f"‚úÖ GPU memory (model loaded):")
    print(f"   Allocated: {allocated:.2f} GB")
    print(f"   Reserved: {reserved:.2f} GB")


In [None]:
# Setup optimizer, scheduler, loss
optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAIN_CONFIG['learning_rate'],
    weight_decay=TRAIN_CONFIG['weight_decay'],
    betas=(0.9, 0.95)
)

def get_lr_scheduler(optimizer, num_training_steps, warmup_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, num_training_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

total_steps = len(train_loader) * TRAIN_CONFIG['num_epochs'] // TRAIN_CONFIG['accumulation_steps']
scheduler = get_lr_scheduler(optimizer, total_steps, TRAIN_CONFIG['warmup_steps'])

criterion = nn.CrossEntropyLoss(ignore_index=pad_token_id)

print(f"‚úÖ Training setup complete")
print(f"   Optimizer: AdamW")
print(f"   Scheduler: Cosine with {TRAIN_CONFIG['warmup_steps']} step warmup")
print(f"   Total training steps: {total_steps:,}")


## Step 6: Training Loop with Validation


In [None]:
# Training functions
train_losses = []
val_losses = []
learning_rates = []
best_val_loss = float('inf')
steps_without_improvement = 0

def train_step(model, batch, optimizer, criterion, accumulation_steps, global_step):
    model.train()
    
    input_ids = batch['input_ids'].to(device, non_blocking=True)
    attention_mask = batch['attention_mask'].to(device, non_blocking=True)
    target_ids = batch['target_ids'].to(device, non_blocking=True)
    
    if torch.isnan(input_ids).any() or torch.isnan(target_ids).any():
        return float('nan')
    
    logits = model(input_ids, attention_mask=attention_mask)
    
    if torch.isnan(logits).any() or torch.isinf(logits).any():
        logits = torch.nan_to_num(logits, nan=0.0, posinf=10.0, neginf=-10.0)
    
    logits_flat = logits.view(-1, vocab_size)
    targets_flat = target_ids.view(-1)
    loss = criterion(logits_flat, targets_flat)
    
    if torch.isnan(loss) or torch.isinf(loss):
        return float('nan')
    
    loss = loss / accumulation_steps
    loss.backward()
    
    for param in model.parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                param.grad = torch.nan_to_num(param.grad, nan=0.0, posinf=1.0, neginf=-1.0)
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), TRAIN_CONFIG['max_grad_norm'])
    
    if (global_step + 1) % accumulation_steps == 0:
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()  # Clear cache after optimizer step
            torch.cuda.synchronize()
    
    return loss.item() * accumulation_steps

def validate(model, val_loader, criterion):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device, non_blocking=True)
            attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            target_ids = batch['target_ids'].to(device, non_blocking=True)
            
            logits = model(input_ids, attention_mask=attention_mask)
            
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                logits = torch.nan_to_num(logits, nan=0.0, posinf=10.0, neginf=-10.0)
            
            logits_flat = logits.view(-1, vocab_size)
            targets_flat = target_ids.view(-1)
            mask = (targets_flat != pad_token_id)
            
            if mask.sum() > 0:
                loss = criterion(logits_flat, targets_flat)
                if not (torch.isnan(loss) or torch.isinf(loss)):
                    total_loss += loss.item() * mask.sum().item()
                    total_tokens += mask.sum().item()
    
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    return avg_loss if not (math.isnan(avg_loss) or math.isinf(avg_loss)) else float('inf')

def save_checkpoint(model, optimizer, scheduler, epoch, step, val_loss, is_best=False):
    checkpoint = {
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': val_loss,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': learning_rates,
    }
    
    checkpoint_dir = TRAIN_CONFIG['checkpoint_dir']
    torch.save(checkpoint, checkpoint_dir / 'checkpoint_latest.pt')
    
    if is_best:
        torch.save(checkpoint, checkpoint_dir / 'checkpoint_best.pt')
        print(f"  üíæ Saved best model (val_loss: {val_loss:.4f})")

print("‚úÖ Training functions defined")


In [None]:
# Main training loop
print("üöÄ Starting training...")
print("=" * 60)

# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

# GPU verification with smaller test batch
if torch.cuda.is_available():
    print(f"‚úÖ Training will use GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ Model parameters on: {next(model.parameters()).device}")
    
    # Test forward pass with minimal batch
    print("\nüîç Testing GPU computation...")
    model.eval()
    with torch.no_grad():
        # Create a minimal test input instead of using actual batch
        test_input = torch.randint(0, vocab_size, (1, 512), device=device)  # Smaller seq len for test
        test_mask = torch.ones(1, 512, dtype=torch.long, device=device)
        test_output = model(test_input, attention_mask=test_mask)
        print(f"‚úÖ GPU test passed! Output shape: {test_output.shape}")
        
        allocated = torch.cuda.memory_allocated(0) / 1e9
        reserved = torch.cuda.memory_reserved(0) / 1e9
        print(f"‚úÖ GPU memory (after test):")
        print(f"   Allocated: {allocated:.2f} GB")
        print(f"   Reserved: {reserved:.2f} GB")
    
    torch.cuda.empty_cache()  # Clear test cache
    model.train()
    print("")

global_step = 0

for epoch in range(TRAIN_CONFIG['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{TRAIN_CONFIG['num_epochs']}")
    print("-" * 60)
    
    epoch_losses = []
    pbar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
    
    for batch_idx, batch in enumerate(pbar):
        loss = train_step(model, batch, optimizer, criterion, TRAIN_CONFIG['accumulation_steps'], global_step)
        
        if math.isnan(loss) or math.isinf(loss):
            optimizer.zero_grad()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            continue
        
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            # Periodically clear cache to prevent fragmentation
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
        
        epoch_losses.append(loss)
        train_losses.append(loss)
        
        current_lr = scheduler.get_last_lr()[0]
        learning_rates.append(current_lr)
        
        avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('nan')
        loss_str = f'{loss:.4f}' if not (math.isnan(loss) or math.isinf(loss)) else 'nan'
        avg_loss_str = f'{avg_loss:.4f}' if not (math.isnan(avg_loss) or math.isinf(avg_loss)) else 'nan'
        
        pbar.set_postfix({
            'loss': loss_str,
            'lr': f'{current_lr:.2e}',
            'avg_loss': avg_loss_str
        })
        
        global_step += 1
        
        # Validation
        if global_step % TRAIN_CONFIG['eval_steps'] == 0:
            val_loss = validate(model, val_loader, criterion)
            
            if math.isnan(val_loss) or math.isinf(val_loss):
                val_loss = float('inf')
            
            val_losses.append(val_loss)
            print(f"\n  Step {global_step}: Val Loss = {val_loss:.4f}")
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                steps_without_improvement = 0
                save_checkpoint(model, optimizer, scheduler, epoch, global_step, val_loss, is_best=True)
            else:
                steps_without_improvement += 1
            
            if steps_without_improvement >= TRAIN_CONFIG['patience'] * (len(train_loader) // TRAIN_CONFIG['eval_steps']):
                print(f"\n‚ö†Ô∏è  Early stopping triggered!")
                break
        
        # Save checkpoint periodically
        if global_step % TRAIN_CONFIG['save_steps'] == 0:
            current_val_loss = val_losses[-1] if val_losses else best_val_loss
            save_checkpoint(model, optimizer, scheduler, epoch, global_step, current_val_loss, is_best=False)
    
    # Epoch summary
    avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf')
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Avg Train Loss: {avg_epoch_loss:.4f}")
    if val_losses:
        print(f"  Best Val Loss: {best_val_loss:.4f}")
    print(f"  Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    
    if steps_without_improvement >= TRAIN_CONFIG['patience'] * (len(train_loader) // TRAIN_CONFIG['eval_steps']):
        break

print("\n‚úÖ Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")


## Step 7: Visualization


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(train_losses, label='Train Loss', alpha=0.7)
if val_losses:
    val_steps = [i * TRAIN_CONFIG['eval_steps'] for i in range(len(val_losses))]
    axes[0].plot(val_steps, val_losses, label='Val Loss', marker='o', markersize=3)
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Learning rate schedule
axes[1].plot(learning_rates, label='Learning Rate', color='green', alpha=0.7)
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule')
axes[1].set_yscale('log')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(WORKING_DIR / 'training_curves.png', dpi=150)
print(f"‚úÖ Training curves saved to: {WORKING_DIR / 'training_curves.png'}")
