# Piano Perception Transformer - MAESTRO Training Pipeline

This notebook implements the complete training pipeline:
1. **SSAST Pre-training** on MAESTRO dataset
2. **AST Fine-tuning** on PercePiano dataset

**Runtime Requirements:**
- Use **TPU v2-8**

**Memory-Efficient Approach:**
- Streaming MAESTRO processing (avoids 200GB storage limit)
- Only keeps processed spectrograms (~10GB vs 200GB raw audio)
- Automatic cleanup of raw audio files

In [None]:
# Cell 1: Initial Setup
print("🚀 Setting up Piano Perception Transformer...")

# Clone repo (skip if already exists)
import os
if not os.path.exists('piano-perception-transformer'):
    !git clone https://github.com/Jai-Dhiman/piano-perception-transformer.git
else:
    print("Repository already exists, skipping clone...")

%cd piano-perception-transformer

# Install uv
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Fix the path and install dependencies
print("📦 Installing dependencies with uv...")
!export PATH="/usr/local/bin:$PATH" && uv pip install --system jax[tpu] flax optax librosa pandas wandb requests zipfile36

print("✅ Setup completed!")

In [None]:
# Cell 2: Mount Google Drive & Setup Memory-Efficient Storage
from google.colab import drive
import os

print("💾 Mounting Google Drive for persistent storage...")
drive.mount('/content/drive')

# Create directories for saving processed data and checkpoints
!mkdir -p /content/drive/MyDrive/piano_transformer
!mkdir -p /content/drive/MyDrive/piano_transformer/processed_spectrograms
!mkdir -p /content/drive/MyDrive/piano_transformer/checkpoints
!mkdir -p /content/drive/MyDrive/piano_transformer/logs
!mkdir -p /content/drive/MyDrive/piano_transformer/temp

print("✅ Google Drive mounted and directories created!")
print("📁 Storage structure:")
!ls -la /content/drive/MyDrive/piano_transformer/

# Check available space
!df -h /content
print("✅ Storage setup completed!")

In [None]:
# Cell 3: Streaming MAESTRO Processing
import os
import requests
import json
import librosa
import numpy as np
import zipfile
import tempfile
from pathlib import Path
import sys
from io import BytesIO
sys.path.append('./src')

print("🌊 Starting streaming MAESTRO processing...")

def ensure_directories():
    """Create all necessary directories in Google Drive"""
    directories = [
        '/content/drive/MyDrive/piano_transformer',
        '/content/drive/MyDrive/piano_transformer/processed_spectrograms',
        '/content/drive/MyDrive/piano_transformer/checkpoints',
        '/content/drive/MyDrive/piano_transformer/logs',
        '/content/drive/MyDrive/piano_transformer/temp'
    ]
    
    print("📁 Ensuring directory structure...")
    for directory in directories:
        os.makedirs(directory, exist_ok=True)
        print(f"✅ Created/verified: {directory}")

def download_and_process_maestro_streaming(max_files=None):
    """Download MAESTRO ZIP as stream, extract and process audio→spectrograms, save to Drive"""
    
    # Ensure directories exist first
    ensure_directories()
    
    # Download metadata first to get real file paths
    print("📋 Downloading MAESTRO metadata...")
    metadata_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.json"
    
    try:
        metadata_response = requests.get(metadata_url, timeout=30)
        metadata_response.raise_for_status()
        maestro_metadata = metadata_response.json()
    except requests.exceptions.RequestException as e:
        print(f"❌ Failed to download metadata: {e}")
        raise Exception(f"Cannot download MAESTRO metadata: {e}")
    
    print(f"📊 Found metadata for MAESTRO dataset")
    
    # Save metadata to Drive
    try:
        with open('/content/drive/MyDrive/piano_transformer/maestro_metadata.json', 'w') as f:
            json.dump(maestro_metadata, f)
        print("✅ Metadata saved to Drive")
    except IOError as e:
        print(f"❌ Failed to save metadata: {e}")
        raise Exception(f"Cannot save metadata to Drive: {e}")
    
    # MAESTRO v3.0.0 uses pandas-style JSON structure
    if not isinstance(maestro_metadata, dict):
        raise Exception(f"Expected dict metadata, got {type(maestro_metadata)}")
    
    # Check for required fields
    required_fields = ['audio_filename', 'canonical_composer', 'canonical_title']
    for field in required_fields:
        if field not in maestro_metadata:
            raise Exception(f"Required field '{field}' not found in metadata. Available fields: {list(maestro_metadata.keys())}")
    
    # Get the audio filenames from the pandas-style structure
    audio_filenames = maestro_metadata['audio_filename']
    if not isinstance(audio_filenames, dict):
        raise Exception(f"Expected dict for audio_filename field, got {type(audio_filenames)}")
    
    total_files = len(audio_filenames)
    print(f"📝 Found {total_files} audio files in metadata")
    
    # Get list of audio files to process
    target_files = set()
    files_to_process = list(audio_filenames.items())
    if max_files:
        files_to_process = files_to_process[:max_files]
        print(f"🎯 Processing first {max_files} files for demo/testing")
    else:
        print(f"🎯 Processing all {total_files} files")
    
    for idx, filename in files_to_process:
        if filename and isinstance(filename, str) and filename.endswith('.wav'):
            target_files.add(filename)
    
    if not target_files:
        raise Exception("No valid .wav files found in metadata")
    
    print(f"🎵 Target: {len(target_files)} audio files from ZIP")
    
    # Download and stream process the MAESTRO ZIP
    zip_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip"
    print(f"📦 Downloading MAESTRO ZIP stream from: {zip_url}")
    
    processed_count = 0
    
    try:
        # Stream download the ZIP file
        with requests.get(zip_url, stream=True, timeout=300) as zip_response:
            zip_response.raise_for_status()
            
            print("✅ ZIP stream connected, processing...")
            
            # Create a temporary file to hold the ZIP stream
            with tempfile.NamedTemporaryFile(suffix='.zip') as temp_zip:
                # Download ZIP in chunks to avoid memory issues
                total_size = int(zip_response.headers.get('content-length', 0))
                downloaded = 0
                
                print(f"📊 ZIP size: {total_size / (1024**3):.1f}GB")
                
                for chunk in zip_response.iter_content(chunk_size=8192 * 1024):  # 8MB chunks
                    if chunk:
                        temp_zip.write(chunk)
                        downloaded += len(chunk)
                        
                        # Show progress every 1GB
                        if downloaded % (1024**3) < (8192 * 1024):
                            progress = (downloaded / total_size) * 100 if total_size > 0 else 0
                            print(f"📥 Downloaded: {downloaded / (1024**3):.1f}GB ({progress:.1f}%)")
                
                print("✅ ZIP download completed, extracting audio files...")
                temp_zip.seek(0)  # Reset file pointer
                
                # Process ZIP contents
                with zipfile.ZipFile(temp_zip, 'r') as zip_file:
                    # Get list of files in ZIP
                    zip_files = zip_file.namelist()
                    audio_files_in_zip = [f for f in zip_files if f.endswith('.wav')]
                    
                    print(f"📂 Found {len(audio_files_in_zip)} audio files in ZIP")
                    
                    # Process target files found in ZIP
                    for zip_audio_path in audio_files_in_zip:
                        # Check if this file is in our target list
                        audio_filename = Path(zip_audio_path).name
                        if not any(audio_filename in target_file for target_file in target_files):
                            continue
                            
                        try:
                            print(f"🎛️ Processing: {audio_filename}...")
                            
                            # Extract audio file to memory
                            with zip_file.open(zip_audio_path) as audio_file:
                                audio_data = audio_file.read()
                            
                            # Save to temp file for librosa
                            with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_audio:
                                temp_audio.write(audio_data)
                                temp_audio_path = temp_audio.name
                            
                            try:
                                # Load audio (limit duration to save memory)
                                y, sr = librosa.load(temp_audio_path, sr=22050, duration=60.0)  # 60 seconds
                                
                                # Generate mel-spectrogram
                                mel_spec = librosa.feature.melspectrogram(
                                    y=y, sr=sr, n_fft=2048, hop_length=512, n_mels=128
                                )
                                mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
                                
                                # Save spectrogram to Drive
                                spec_filename = Path(audio_filename).stem + '_mel.npy'
                                spec_path = f'/content/drive/MyDrive/piano_transformer/processed_spectrograms/{spec_filename}'
                                
                                np.save(spec_path, mel_spec_db)
                                print(f"✅ Saved: {spec_filename} (shape: {mel_spec_db.shape})")
                                processed_count += 1
                                
                                # Check if we've reached our target (if max_files is set)
                                if max_files and processed_count >= max_files:
                                    print(f"🎯 Reached target limit of {processed_count} files")
                                    break
                                    
                            except Exception as audio_error:
                                print(f"❌ Audio processing error: {audio_error}")
                                continue
                            finally:
                                # Cleanup temp audio file
                                if os.path.exists(temp_audio_path):
                                    os.remove(temp_audio_path)
                                    
                        except Exception as extract_error:
                            print(f"❌ Extraction error for {zip_audio_path}: {extract_error}")
                            continue
                        
                        # Storage check periodically
                        if processed_count % 10 == 0:
                            try:
                                storage_info = os.statvfs('/content')
                                free_gb = (storage_info.f_bavail * storage_info.f_frsize) / (1024**3)
                                print(f"💾 Storage: {free_gb:.1f}GB free, {processed_count} files processed")
                            except OSError:
                                pass
                        
                        # Break if we've reached our target
                        if max_files and processed_count >= max_files:
                            break
    
    except requests.exceptions.RequestException as download_error:
        raise Exception(f"Failed to download MAESTRO ZIP: {download_error}")
    except zipfile.BadZipFile as zip_error:
        raise Exception(f"Invalid ZIP file: {zip_error}")
    except Exception as general_error:
        raise Exception(f"Processing error: {general_error}")
    
    print(f"\n🎉 Streaming processing completed!")
    print(f"✅ Successfully processed: {processed_count} files")
    print(f"💾 Spectrograms saved to: /content/drive/MyDrive/piano_transformer/processed_spectrograms/")
    
    if processed_count == 0:
        raise Exception("No files were successfully processed")
    
    return processed_count


# Run streaming processing with proper error handling
try:
    # Set max_files=None to process all files, or set a number for testing
    # For testing: max_files=50
    # For full dataset: max_files=None
    num_processed = download_and_process_maestro_streaming(max_files=None)
    print(f"\n✅ SUCCESS: {num_processed} MAESTRO files processed!")
    print("🎯 Ready to proceed with pre-training on processed spectrograms")
        
except Exception as main_error:
    print(f"❌ Processing failed: {main_error}")
    raise Exception(f"MAESTRO processing failed: {main_error}")

In [None]:
# Cell 4: Quick TPU Verification
import jax

print("🧠 Quick TPU check...")
print(f"JAX backend: {jax.default_backend()}")
print(f"Device count: {jax.device_count()}")


In [None]:
# Cell 5: Dataset Setup from Processed Spectrograms
import os
import sys
import numpy as np
import json
from pathlib import Path

sys.path.append('./src')

print("🎵 Setting up dataset from processed spectrograms...")

# Check processed spectrograms
spec_dir = '/content/drive/MyDrive/piano_transformer/processed_spectrograms'
if not os.path.exists(spec_dir):
    raise FileNotFoundError("Run Cell 3 first to process MAESTRO")

spec_files = [f for f in os.listdir(spec_dir) if f.endswith('_mel.npy')]
if len(spec_files) == 0:
    raise FileNotFoundError("No spectrograms found. Re-run Cell 3")

print(f"📊 Found {len(spec_files)} processed spectrograms")

# Load metadata
metadata_path = '/content/drive/MyDrive/piano_transformer/maestro_metadata.json'
with open(metadata_path, 'r') as f:
    maestro_metadata = json.load(f)
print("✅ MAESTRO metadata loaded")

# Dataset class for pre-processed spectrograms
class ProcessedSpectrogramDataset:
    def __init__(self, spec_dir, metadata):
        self.spec_dir = spec_dir
        self.metadata = metadata
        self.spec_files = [f for f in os.listdir(spec_dir) if f.endswith('_mel.npy')]
        self.num_files = len(self.spec_files)
        
    def __len__(self):
        return self.num_files
    
    def load_spectrogram(self, idx):
        spec_file = self.spec_files[idx]
        spec_path = os.path.join(self.spec_dir, spec_file)
        return np.load(spec_path)
    
    def get_batch(self, batch_size=32):
        batch_indices = np.random.choice(self.num_files, batch_size, replace=True)
        batch_specs = []
        
        for idx in batch_indices:
            spec = self.load_spectrogram(idx)
            # Normalize to 128x128
            if spec.shape[1] >= 128:
                spec = spec[:, :128]
            else:
                pad_width = 128 - spec.shape[1]
                spec = np.pad(spec, ((0, 0), (0, pad_width)), mode='constant')
            
            batch_specs.append(spec)
        
        return np.array(batch_specs)

# Initialize dataset
dataset = ProcessedSpectrogramDataset(spec_dir, maestro_metadata)
print(f"✅ Dataset ready: {len(dataset)} spectrograms")

# Test batch loading
test_batch = dataset.get_batch(4)
print(f"✅ Test batch: shape {test_batch.shape}")
print("🎯 Ready for SSAST pre-training!")

In [None]:
# Cell 6: SSAST Pre-training
import sys
sys.path.append('./src')

if 'dataset' not in locals():
    raise RuntimeError("Run Cell 5 first to load dataset")

print("🧠 Starting SSAST pre-training...")
print(f"📊 Training on {len(dataset)} spectrograms")

# Import SSAST trainer
try:
    from models.ssast_pretraining import SSASTPretrainer
    
    # Initialize SSAST trainer
    trainer = SSASTPretrainer(
        model_dim=768,
        num_heads=12,
        num_layers=12,
        patch_size=(16, 16),
        num_patches=(8, 8)
    )
    
    print("✅ Using full SSAST implementation")
    
    # Run pre-training
    results = trainer.pretrain(
        dataset=dataset,
        num_epochs=100,
        batch_size=32,
        learning_rate=1e-4,
        checkpoint_dir='/content/drive/MyDrive/piano_transformer/checkpoints/ssast',
        save_every=10
    )
    
    print(f"🎉 SSAST pre-training completed!")
    print(f"📈 Final loss: {results['final_loss']:.4f}")
    
except ImportError as e:
    raise ImportError(f"SSAST model not found: {e}. Check src/models/ssast_pretraining.py exists")
except Exception as e:
    raise RuntimeError(f"Pre-training failed: {e}")
    
# Save results
import json
from datetime import datetime

training_summary = {
    'model': 'SSAST',
    'dataset_size': len(dataset),
    'final_loss': results['final_loss'],
    'epochs': results['epochs'],
    'timestamp': datetime.now().isoformat()
}

with open('/content/drive/MyDrive/piano_transformer/pretraining_results.json', 'w') as f:
    json.dump(training_summary, f, indent=2)

print("💾 Results saved to Google Drive!")

In [None]:
# Cell 7: AST Fine-tuning on PercePiano
import sys
import os
import json

sys.path.append('./src')

try:
    from train_ast import train_ast
    from datasets.percepiano_dataset import PercepianoDataset
    print("✅ Fine-tuning modules imported")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Check if src/train_ast.py and src/datasets/percepiano_dataset.py exist")
    exit()

print("🎹 Starting AST fine-tuning on PercePiano dataset...")
print("⏱️  Expected duration: ~3 hours on GPU")

# Check for pre-trained model
pretrained_path = '/content/drive/MyDrive/piano_transformer/checkpoints/ssast/best_model.pkl'

if os.path.exists(pretrained_path):
    print(f"✅ Pre-trained model found: {pretrained_path}")
else:
    print(f"❌ Pre-trained model not found at: {pretrained_path}")
    print("Available checkpoints:")
    !ls -la /content/drive/MyDrive/piano_transformer/checkpoints/ssast/
    # Use the latest checkpoint
    checkpoints = !ls /content/drive/MyDrive/piano_transformer/checkpoints/ssast/*.pkl
    if checkpoints:
        pretrained_path = checkpoints[-1]  # Use latest
        print(f"Using latest checkpoint: {pretrained_path}")
    else:
        print("No checkpoints found. Running without pre-training.")
        pretrained_path = None

# Check PercePiano dataset
percepiano_path = './PercePiano'
if os.path.exists(percepiano_path):
    print(f"✅ PercePiano dataset found: {percepiano_path}")
else:
    print(f"❌ PercePiano dataset not found at: {percepiano_path}")
    exit()

# Start fine-tuning
print("🚀 Starting fine-tuning...")
results = train_ast(
    pretrained_model_path=pretrained_path,
    percepiano_path=percepiano_path,
    target_correlation=0.7,
    epochs=50,
    batch_size=32,
    learning_rate=1e-5,  # Lower LR for fine-tuning
    checkpoint_dir='/content/drive/MyDrive/piano_transformer/checkpoints/ast_finetuned',
    save_every=5
)

print("🎉 Fine-tuning completed! ✅")
print(f"Best correlation achieved: {results.get('best_correlation', 'N/A'):.3f}")

# Save fine-tuning results
with open('/content/drive/MyDrive/piano_transformer/finetuning_results.json', 'w') as f:
    json.dump(results, f, indent=2)
    
print("💾 Fine-tuning results saved to Google Drive!")

In [None]:
# Cell 9: Final Results and Evaluation
import json
import os
import pickle
from datetime import datetime

print("📊 Generating final results summary...")

# Load training results
pretraining_results = {}
finetuning_results = {}

# Load pre-training results if available
pretraining_path = '/content/drive/MyDrive/piano_transformer/pretraining_results.json'
if os.path.exists(pretraining_path):
    with open(pretraining_path, 'r') as f:
        pretraining_results = json.load(f)
    print("✅ Pre-training results loaded")

# Load fine-tuning results if available
finetuning_path = '/content/drive/MyDrive/piano_transformer/finetuning_results.json'
if os.path.exists(finetuning_path):
    with open(finetuning_path, 'r') as f:
        finetuning_results = json.load(f)
    print("✅ Fine-tuning results loaded")

# Check final model
final_model_path = '/content/drive/MyDrive/piano_transformer/checkpoints/ast_finetuned/best_model.pkl'
model_exists = os.path.exists(final_model_path)

# Generate comprehensive summary
final_summary = {
    'experiment_date': datetime.now().isoformat(),
    'pipeline': {
        'step_1': 'SSAST Pre-training on MAESTRO-v3',
        'step_2': 'AST Fine-tuning on PercePiano'
    },
    'datasets': {
        'pretraining': 'MAESTRO-v3 (200+ hours piano audio)',
        'finetuning': 'PercePiano (1202 performances, 19 dimensions)'
    },
    'model_architecture': {
        'base': 'Audio Spectrogram Transformer (AST)',
        'pretraining': 'Self-Supervised AST (SSAST) with MSPM',
        'parameters': '~85M (encoder) + task heads'
    },
    'results': {
        'pretraining': pretraining_results,
        'finetuning': finetuning_results,
        'final_model_saved': model_exists
    },
    'target_achieved': finetuning_results.get('best_correlation', 0) >= 0.7
}

# Save comprehensive results
final_results_path = '/content/drive/MyDrive/piano_transformer/FINAL_RESULTS.json'
with open(final_results_path, 'w') as f:
    json.dump(final_summary, f, indent=2)

# Display results
print("\n" + "="*60)
print("🎉 PIANO PERCEPTION TRANSFORMER - FINAL RESULTS")
print("="*60)

print(f"\n📅 Experiment completed: {datetime.now().strftime('%Y-%m-%d %H:%M')}")

print("\n🏗️  Architecture:")
print("   • Audio Spectrogram Transformer (AST) - 85M parameters")
print("   • Self-supervised pre-training with MSPM")
print("   • Multi-task regression for 19 perceptual dimensions")

print("\n📊 Training Pipeline:")
print("   1. ✅ SSAST pre-training on MAESTRO-v3 (TPU)")
print("   2. ✅ AST fine-tuning on PercePiano (GPU)")

if finetuning_results:
    best_corr = finetuning_results.get('best_correlation', 0)
    target_met = "✅" if best_corr >= 0.7 else "❌"
    print(f"\n🎯 Performance:")
    print(f"   • Best correlation: {best_corr:.3f}")
    print(f"   • Target >0.7: {target_met}")

print(f"\n💾 All results saved to: {final_results_path}")
print(f"📁 Model checkpoints in: /content/drive/MyDrive/piano_transformer/checkpoints/")

print("\n🎊 Training pipeline completed successfully!")
print("="*60)

# Show file structure
print("\n📁 Final file structure in Google Drive:")
!ls -la /content/drive/MyDrive/piano_transformer/