In [None]:
# =============================================================================
# MODEL EXPORT FOR WEBSITE INTEGRATION  
# =============================================================================

def save_model_for_website(model, model_name="piano_transcription_weights.pth"):
    """
    Save the trained model in a format suitable for the Flask website
    
    Args:
        model: The trained CRNN_OnsetsAndFrames model
        model_name: Name for the saved model file
    """
    try:
        # Save model state dict (recommended approach)
        model_path = f"/content/drive/MyDrive/APS360_Team_2_Project/models/{model_name}"
        torch.save(model.state_dict(), model_path)
        print(f"✅ Model saved to: {model_path}")
        
        # Also save to current directory for easy download
        local_path = model_name
        torch.save(model.state_dict(), local_path)
        print(f"✅ Model also saved locally: {local_path}")
        
        # Print model info
        total_params = sum(p.numel() for p in model.parameters())
        model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
        
        print(f"📊 Model Info:")
        print(f"   Total parameters: {total_params:,}")
        print(f"   Model size: {model_size_mb:.1f} MB")
        print(f"   Architecture: {model.name}")
        
        # Save model configuration as well
        config = {
            "model_name": model.name,
            "num_pitches": model.num_pitches,
            "total_parameters": total_params,
            "model_size_mb": model_size_mb,
            "training_info": {
                "sample_rate": SAMPLE_RATE,
                "n_mels": N_MELS,
                "window_size_seconds": WINDOW_SIZE_SECONDS,
                "min_midi_note": MIN_MIDI_NOTE,
                "max_midi_note": MAX_MIDI_NOTE
            }
        }
        
        config_path = model_name.replace('.pth', '_config.json')
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        print(f"✅ Model config saved: {config_path}")
        
        return True
        
    except Exception as e:
        print(f"❌ Failed to save model: {e}")
        return False

def download_model_from_colab():
    """
    Instructions for downloading the model from Google Colab
    """
    print("📥 To download the model to your local machine:")
    print("1. Run the save_model_for_website() function above")
    print("2. In Colab, go to Files tab on the left")
    print("3. Find 'piano_transcription_weights.pth' and download it")
    print("4. Place it in your website's backend/weights/ directory")
    print("5. Restart your Flask server to load the new model")
    print("")
    print("💡 Alternative: Use Google Drive sync")
    print("   The model is also saved to your Drive at:")
    print("   /content/drive/MyDrive/APS360_Team_2_Project/models/")

# Test model saving (run this after training)
print("📝 Model export functions ready!")
print("After training completes, run:")
print("   save_model_for_website(model)")
print("   download_model_from_colab()")

## 7. Model Export for Website Integration

This section handles saving the trained model for use in the Flask website.

# APS360 Piano Transcription Project

## 1. Setup and Imports
Every Import needed for the project is below:

In [None]:
# We setup the environment
try:
    import google.colab
    IN_COLAB = True
    from google.colab import drive
    drive.mount('/content/drive')
    !pip install librosa soundfile pretty_midi -q
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

# Standard imports
import os, json, warnings, zipfile, urllib.request, random
from pathlib import Path
from collections import defaultdict
from typing import Tuple, Dict, List

# Core libraries
import numpy as np
import pandas as pd
import librosa
import soundfile as sf
import pretty_midi
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

warnings.filterwarnings('ignore')

# We set the device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    NUM_WORKERS = 2
    PIN_MEMORY = True
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
else:
    DEVICE = torch.device('cpu')
    NUM_WORKERS = min(4, os.cpu_count() // 2) if os.cpu_count() > 2 else 0
    PIN_MEMORY = False
    print(f"Using CPU with {os.cpu_count()} cores")

print(f"Workers: {NUM_WORKERS}")

Mounted at /content/drive
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
Running in Google Colab
Using CPU with 2 cores
Workers: 0


## 2. Configuration and Parameters
We can modify this if we need to change the data configurations

In [None]:
# Configuration
SAMPLE_RATE = 16000
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512

# Piano parameters
MIN_MIDI_NOTE = 21   # A0
MAX_MIDI_NOTE = 108  # C8
N_PIANO_KEYS = MAX_MIDI_NOTE - MIN_MIDI_NOTE + 1  # 88 keys
MIDDLE_C = 60

# Sliding window configuration
WINDOW_SIZE_SECONDS = 10.0
MIN_RECORDING_LENGTH = 5.0
WINDOW_OVERLAP = 0.25
WINDOW_SIZE_FRAMES = int(WINDOW_SIZE_SECONDS * SAMPLE_RATE / HOP_LENGTH)
STRIDE_FRAMES = int(WINDOW_SIZE_FRAMES * (1 - WINDOW_OVERLAP))

# Training parameters
TRAIN_RATIO = 0.7
VAL_RATIO = 0.15
TEST_RATIO = 0.15
RANDOM_SEED = 1000

# Batch sizes based on device
if DEVICE.type == 'cuda':
    BATCH_SIZE = 16
    PROCESSING_BATCH_SIZE = 4
else:
    BATCH_SIZE = 8
    PROCESSING_BATCH_SIZE = 2

# Dataset configuration
DATASETS_TO_USE = {
    'maestro': True,      # Clean piano performances
    'musdb_augmentation': True  # Songs for mixed audio training
}

# Data augmentation parameters
AUGMENTATION_RATIO = 0.5  # 50% will be augmented (clean + mixed)
MIX_VOLUME_RANGE = (0.3, 0.8)  # Volume range for mixing piano with accompaniment

# Caching
CACHE_PROCESSED_DATA = True # if we want to use the cached data
REGENERATE_CACHE = False # if we want to delete all the old cached data and re-cache
CACHE_VERSION = "v2.0"
CACHE_SUFFIX = f"maestro_musdb_win{WINDOW_SIZE_SECONDS}s_mel{N_MELS}_{CACHE_VERSION}"

# Print the configuration
print(f"Device: {DEVICE}")
print(f"Audio: {SAMPLE_RATE}Hz, {N_MELS} mel bins")
print(f"Piano: {N_PIANO_KEYS} keys (A0-C8)")
print(f"Window: {WINDOW_SIZE_SECONDS}s ({WINDOW_SIZE_FRAMES} frames)")
print(f"Batch size: {BATCH_SIZE}")
print(f"Augmentation ratio: {AUGMENTATION_RATIO:.0%} mixed audio")
print(f"Caching: {CACHE_PROCESSED_DATA}")
print(f"Cache regen: {REGENERATE_CACHE}")

Device: cpu
Audio: 16000Hz, 128 mel bins
Piano: 88 keys (A0-C8)
Window: 10.0s (312 frames)
Batch size: 8
Augmentation ratio: 50% mixed audio
Caching: True
Cache regen: False


## 3. Data Loading Functions
All the functions here are for setting up the directory, downloading the raw data, and reading metadata

In [None]:
def setup_project_directories():
    """Create organized directory structure"""
    if IN_COLAB:
        base_path = Path('/content/drive/MyDrive/APS360_Team_2_Project')
    else:
        base_path = Path('./APS360_Team_2_Project')

    directories = {
        'data_raw': base_path / 'data' / 'raw',
        'data_processed': base_path / 'data' / 'processed',
        'data_cached': base_path / 'data' / 'cached' / CACHE_SUFFIX,
        'models': base_path / 'models',
        'logs': base_path / 'logs',
        'results': base_path / 'results'
    }

    for name, path in directories.items():
        path.mkdir(parents=True, exist_ok=True)

    # Handle cache regeneration
    if CACHE_PROCESSED_DATA and REGENERATE_CACHE:
        cache_dir = directories['data_cached']
        if cache_dir.exists():
            import shutil
            shutil.rmtree(cache_dir)
            cache_dir.mkdir(parents=True, exist_ok=True)
            print("Cache directory cleaned")

    print(f"Project directories created at: {base_path}")
    return directories

def get_cache_paths(split_name: str, cache_dir: Path) -> Dict[str, Path]:
    """Get cache file paths for a specific split"""
    return {
        'metadata': cache_dir / f'{split_name}_metadata.json',
        'audio': cache_dir / f'{split_name}_audio_features.npy',
        'piano_roll': cache_dir / f'{split_name}_piano_roll.npy',
        'left_hand': cache_dir / f'{split_name}_left_hand.npy',
        'right_hand': cache_dir / f'{split_name}_right_hand.npy',
        'config': cache_dir / 'processing_config.json'
    }

def save_processing_config(cache_dir: Path):
    """Save processing configuration"""
    config = {
        'sample_rate': SAMPLE_RATE,
        'n_mels': N_MELS,
        'window_size_seconds': WINDOW_SIZE_SECONDS,
        'cache_version': CACHE_VERSION,
        'n_piano_keys': N_PIANO_KEYS,
        'min_midi_note': MIN_MIDI_NOTE,
        'datasets_used': DATASETS_TO_USE,
        'augmentation_ratio': AUGMENTATION_RATIO,
        'mix_volume_range': MIX_VOLUME_RANGE
    }

    with open(cache_dir / 'processing_config.json', 'w') as f:
        json.dump(config, f, indent=2)

def check_cache_compatibility(cache_dir: Path) -> bool:
    """Check if existing cache is compatible"""
    config_path = cache_dir / 'processing_config.json'
    if not config_path.exists():
        return False

    try:
        with open(config_path, 'r') as f:
            saved_config = json.load(f)

        # Check key parameters
        return (saved_config.get('cache_version') == CACHE_VERSION and
                saved_config.get('sample_rate') == SAMPLE_RATE and
                saved_config.get('n_mels') == N_MELS and
                saved_config.get('window_size_seconds') == WINDOW_SIZE_SECONDS and
                saved_config.get('datasets_used') == DATASETS_TO_USE and
                saved_config.get('augmentation_ratio') == AUGMENTATION_RATIO)
    except:
        return False

def download_maestro_dataset(data_dir: Path) -> Path:
    """Download MAESTRO dataset if needed"""
    maestro_dir = data_dir / 'maestro-v3.0.0'

    if maestro_dir.exists():
        print("MAESTRO dataset already exists")
        return maestro_dir

    print("Downloading MAESTRO dataset...")
    maestro_zip = data_dir / 'maestro-v3.0.0.zip'
    maestro_url = 'https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.zip'

    if not maestro_zip.exists():
        urllib.request.urlretrieve(maestro_url, maestro_zip)
        print("Download complete")

    print("Extracting MAESTRO dataset...")
    with zipfile.ZipFile(maestro_zip, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    print("Extraction complete")

    return maestro_dir

def find_musdb_dataset(data_dir: Path) -> Path:
    """Find MUSDB18-HQ dataset in the raw data directory"""
    musdb_dir = data_dir / 'MUSDB18-HQ'

    if musdb_dir.exists() and (musdb_dir / 'train').exists():
        print("MUSDB18-HQ dataset found")
        return musdb_dir

    print(f"MUSDB18-HQ not found, expected location: {musdb_dir}")
    return None

def load_maestro_metadata(maestro_path: Path) -> List[Dict]:
    """Load MAESTRO metadata from CSV"""
    csv_file = maestro_path / 'maestro-v3.0.0.csv'
    print(f"Loading MAESTRO metadata from: {csv_file.name}")

    df = pd.read_csv(csv_file)
    df = df[df['duration'] >= MIN_RECORDING_LENGTH]

    metadata_list = []
    for _, row in df.iterrows():
        audio_path = maestro_path / row['audio_filename']
        midi_path = maestro_path / row['midi_filename']

        if audio_path.exists() and midi_path.exists():
            metadata_list.append({
                'audio_path': audio_path,
                'midi_path': midi_path,
                'duration': float(row['duration']),
                'title': str(row['canonical_title']),
                'composer': str(row['canonical_composer']),
                'year': int(row['year']),
                'split': str(row['split']),
                'dataset': 'maestro'
            })

    print(f"Loaded {len(metadata_list)} MAESTRO recordings")
    return metadata_list

def load_musdb_stems(musdb_path: Path) -> List[Dict]:
    """Load MUSDB18-HQ stems from direct file structure for audio augmentation"""
    print(f"Loading MUSDB18-HQ stems from: {musdb_path.name}")

    train_dir = musdb_path / 'train'
    if not train_dir.exists():
        print(f"Training directory not found: {train_dir}")
        return []

    stems_list = []

    # Scan all track directories in train/
    track_folders = [d for d in train_dir.iterdir() if d.is_dir()]

    print(f"Found {len(track_folders)} track folders")

    for track_dir in track_folders:
        track_name = track_dir.name

        # Each song comes with these stems
        required_stems = ['mixture.wav', 'drums.wav', 'bass.wav', 'other.wav', 'vocals.wav']
        stem_paths = {}
        all_stems_exist = True

        for stem_name in required_stems:
            stem_path = track_dir / stem_name
            if stem_path.exists():
                stem_paths[stem_name] = stem_path
            else:
                print(f"Missing {stem_name} in {track_name}")
                all_stems_exist = False
                break

        if all_stems_exist:
            # Get duration from mixture file
            try:
                # Quick duration check without loading full audio
                import soundfile as sf
                with sf.SoundFile(stem_paths['mixture.wav']) as f:
                    duration = len(f) / f.samplerate
                    sample_rate = f.samplerate

                stems_info = {
                    'track_name': track_name,
                    'track_dir': track_dir,
                    'duration': duration,
                    'sample_rate': sample_rate,
                    'mixture_path': stem_paths['mixture.wav'],
                    'drums_path': stem_paths['drums.wav'],
                    'bass_path': stem_paths['bass.wav'],
                    'other_path': stem_paths['other.wav'],
                    'vocals_path': stem_paths['vocals.wav'],
                    'dataset': 'musdb_stems'
                }
                stems_list.append(stems_info)

            except Exception as e:
                print(f"Error reading {track_name}: {e}")

    print(f"Loaded {len(stems_list)} MUSDB tracks for augmentation")
    if stems_list:
        avg_duration = np.mean([s['duration'] for s in stems_list])
        total_hours = sum([s['duration'] for s in stems_list]) / 3600

    return stems_list

def load_combined_metadata(data_dir: Path) -> Tuple[List[Dict], List[Dict]]:
    """Load metadata from MAESTRO and MUSDB stems"""
    maestro_metadata = []
    musdb_stems = []

    # Load MAESTRO (pure piano)
    if DATASETS_TO_USE.get('maestro', False):
        print("\nLoading MAESTRO dataset...")
        maestro_dir = download_maestro_dataset(data_dir)
        maestro_metadata = load_maestro_metadata(maestro_dir)

    # Load MUSDB stems (for augmentation)
    if DATASETS_TO_USE.get('musdb_augmentation', False):
        print("\nLoading MUSDB18-HQ stems...")
        musdb_dir = find_musdb_dataset(data_dir)
        if musdb_dir:
            musdb_stems = load_musdb_stems(musdb_dir)
        else:
            print("MUSDB dataset not found, continuing without augmentation")

    print(f"\nCombined dataset summary:")
    print(f"   MAESTRO: {len(maestro_metadata):,} recordings")
    print(f"   MUSDB stems: {len(musdb_stems):,} tracks")

    return maestro_metadata, musdb_stems

## 4. Data Processing Functions
This whole section creates all the data filters and any functions we might need for when we start building the data

In [None]:
def standardize_audio(audio_path: Path) -> Tuple[np.ndarray, int]:
    """Standardize audio: resample to 16kHz mono and normalize volume"""
    try:
        # Load and convert to mono
        audio, sr = librosa.load(audio_path, sr=None, mono=True)

        # Resample if needed
        if sr != SAMPLE_RATE:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE)

        # Normalize volume
        rms = np.sqrt(np.mean(audio**2))
        if rms > 1e-6:
            audio = audio * (0.1 / rms)  # Normalize to target RMS
            audio = np.clip(audio, -0.95, 0.95)  # Prevent clipping

        return audio, SAMPLE_RATE
    except Exception as e:
        print(f"Audio error {audio_path}: {e}")
        return None, None

def extract_audio_features(audio: np.ndarray, sr: int) -> np.ndarray:
    """Extract mel spectrogram features"""
    try:
        mel_spec = librosa.feature.melspectrogram(
            y=audio, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        mel_spec_norm = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
        features = mel_spec_norm.T  # (time, features)
        
        # DEBUG: Print feature dimensions
        print(f"🔍 DEBUG - Audio features shape: {features.shape} (should be [time, 128])")
        
        return features
    except Exception as e:
        print(f"Feature extraction error: {e}")
        return None

def process_midi_to_piano_roll(midi_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Convert MIDI to piano roll with hand separation"""
    try:
        midi_data = pretty_midi.PrettyMIDI(str(midi_path))
        fps = SAMPLE_RATE // HOP_LENGTH

        # Get piano roll and extract piano range
        piano_roll = midi_data.get_piano_roll(fs=fps)
        piano_roll = piano_roll[MIN_MIDI_NOTE:MAX_MIDI_NOTE+1, :]
        piano_roll = (piano_roll > 0).astype(np.float32)
        piano_roll = piano_roll.T  # (time, keys)

        # DEBUG: Print piano roll dimensions
        print(f"🔍 DEBUG - Piano roll shape: {piano_roll.shape} (should be [time, 88])")

        # Hand separation
        middle_c_index = MIDDLE_C - MIN_MIDI_NOTE
        left_hand = np.zeros_like(piano_roll)
        right_hand = np.zeros_like(piano_roll)

        left_hand[:, :middle_c_index] = piano_roll[:, :middle_c_index]
        right_hand[:, middle_c_index:] = piano_roll[:, middle_c_index:]

        return piano_roll, left_hand, right_hand
    except Exception as e:
        print(f"MIDI error {midi_path}: {e}")
        return None, None, None

def extract_matching_audio_segment(audio_path: Path, start_time: float, duration: float, target_sr: int) -> np.ndarray:
    """Extract exact length segment from MUSDB stems with proper resampling"""
    try:
        # Load segment with exact timing
        audio, original_sr = librosa.load(
            audio_path,
            sr=target_sr,
            offset=start_time,
            duration=duration,
            mono=True
        )

        # Ensure exact sample count
        target_samples = int(duration * target_sr)  # 160,000 for 10s at 16kHz

        if len(audio) > target_samples:
            audio = audio[:target_samples]  # Trim if too long
        elif len(audio) < target_samples:
            # Pad with silence if too short
            padding = target_samples - len(audio)
            audio = np.pad(audio, (0, padding), mode='constant', constant_values=0.0)

        # Normalize volume
        rms = np.sqrt(np.mean(audio**2))
        if rms > 1e-6:
            audio = audio * (0.1 / rms)
            audio = np.clip(audio, -0.95, 0.95)

        return audio

    except Exception as e:
        print(f"Segment extraction error {audio_path}: {e}")
        return np.zeros(int(duration * target_sr))

def create_mixed_audio(maestro_audio: np.ndarray, musdb_stems_info: Dict) -> np.ndarray:
    """Create mixed audio: MAESTRO piano + MUSDB accompaniment of matching length"""
    try:
        target_sr = SAMPLE_RATE
        target_duration = len(maestro_audio) / target_sr

        # Choose a safe start time within the MUSDB track
        track_duration = musdb_stems_info['duration']
        safe_margin = 2.0
        max_start = max(0.0, track_duration - target_duration - safe_margin)
        start_time = random.uniform(0.0, max_start) if max_start > 0 else 0.0

        # Load matching-length stem segments
        # we only want drum, bass and vocals because others might have piano in it
        drums = extract_matching_audio_segment(
            musdb_stems_info['drums_path'], start_time, target_duration, target_sr
        )
        bass = extract_matching_audio_segment(
            musdb_stems_info['bass_path'], start_time, target_duration, target_sr
        )
        vocals = extract_matching_audio_segment(
            musdb_stems_info['vocals_path'], start_time, target_duration, target_sr
        )

        # Combine accompaniment
        accompaniment = drums + bass + vocals

        # Ensure exact length match
        target_samples = len(maestro_audio)
        if len(accompaniment) > target_samples:
            accompaniment = accompaniment[:target_samples]
        elif len(accompaniment) < target_samples:
            padding = target_samples - len(accompaniment)
            accompaniment = np.pad(accompaniment, (0, padding), mode='constant', constant_values=0.0)

        # Volume mixing, making sure piano is louder than accompaniment
        mix_ratio = random.uniform(*MIX_VOLUME_RANGE)
        mixed_audio = maestro_audio + (accompaniment * mix_ratio)
        mixed_audio = np.clip(mixed_audio, -0.95, 0.95)

        return mixed_audio

    except Exception as e:
        print(f"Audio mixing error: {e}")
        return maestro_audio  # Return original if mixing fails

def extract_sliding_windows(audio_features: np.ndarray, piano_roll: np.ndarray,
                          left_hand: np.ndarray, right_hand: np.ndarray,
                          metadata_item: Dict, musdb_stems: List[Dict] = None,
                          is_training: bool = True) -> List[Dict]:
    """Extract sliding windows with optional MUSDB augmentation"""
    min_length = min(len(audio_features), len(piano_roll), len(left_hand), len(right_hand))

    # DEBUG: Print all input dimensions before processing
    print(f"🔍 DEBUG - extract_sliding_windows input shapes:")
    print(f"   Audio features: {audio_features.shape}")
    print(f"   Piano roll: {piano_roll.shape}")
    print(f"   Left hand: {left_hand.shape}")
    print(f"   Right hand: {right_hand.shape}")
    print(f"   Min length: {min_length}, Window size needed: {WINDOW_SIZE_FRAMES}")

    if min_length < WINDOW_SIZE_FRAMES:
        print(f"❌ WARNING: Min length {min_length} < window size {WINDOW_SIZE_FRAMES}, returning empty list")
        return []

    # Truncate to same length
    audio_features = audio_features[:min_length]
    piano_roll = piano_roll[:min_length]
    left_hand = left_hand[:min_length]
    right_hand = right_hand[:min_length]

    # DEBUG: Print truncated shapes
    print(f"🔍 DEBUG - After truncation:")
    print(f"   Audio features: {audio_features.shape}")
    print(f"   Piano roll: {piano_roll.shape}")

    windows = []

    if is_training:
        # Random windows for training
        max_start = min_length - WINDOW_SIZE_FRAMES
        num_windows = min(max_start // STRIDE_FRAMES + 1, max(1, min_length // WINDOW_SIZE_FRAMES))
        start_frames = np.random.randint(0, max_start + 1, size=num_windows)
    else:
        # Sequential windows for validation/test
        start_frames = np.arange(0, min_length - WINDOW_SIZE_FRAMES + 1, STRIDE_FRAMES)

    for i, start_frame in enumerate(start_frames):
        end_frame = start_frame + WINDOW_SIZE_FRAMES

        # Extract window
        window_audio = audio_features[start_frame:end_frame].copy()
        window_piano = piano_roll[start_frame:end_frame].copy()
        window_left = left_hand[start_frame:end_frame].copy()
        window_right = right_hand[start_frame:end_frame].copy()

        # DEBUG: Print window shapes for first window
        if i == 0:
            print(f"🔍 DEBUG - First window shapes:")
            print(f"   Window audio: {window_audio.shape} (should be [{WINDOW_SIZE_FRAMES}, 128])")
            print(f"   Window piano: {window_piano.shape} (should be [{WINDOW_SIZE_FRAMES}, 88])")

        # Decide if this window should be augmented
        use_augmentation = (
            is_training and
            musdb_stems and
            len(musdb_stems) > 0 and
            random.random() < AUGMENTATION_RATIO
        )

        if use_augmentation:
            # Create mixed audio version
            window_type = 'mixed'
        else:
            # Use clean piano version
            window_type = 'clean'

        window = {
            'audio': window_audio,
            'piano_roll': window_piano,
            'left_hand': window_left,
            'right_hand': window_right,
            'metadata': metadata_item,
            'window_start_time': start_frame * HOP_LENGTH / SAMPLE_RATE,
            'window_id': f"{metadata_item['title'][:20]}_{start_frame}",
            'window_type': window_type
        }
        windows.append(window)

    print(f"🔍 DEBUG - Created {len(windows)} windows from {metadata_item['title'][:30]}")
    return windows

def process_audio_midi_pair(metadata_item: Dict, musdb_stems: List[Dict] = None, is_training: bool = True) -> List[Dict]:
    """Process a single audio-MIDI pair into windows with optional MUSDB augmentation"""
    try:
        print(f"🔍 DEBUG - Processing: {metadata_item['title'][:50]}")
        
        # Process audio
        audio, sr = standardize_audio(metadata_item['audio_path'])
        if audio is None:
            return []

        print(f"🔍 DEBUG - Raw audio length: {len(audio)} samples ({len(audio)/SAMPLE_RATE:.1f}s)")

        should_augment = (
            is_training and
            musdb_stems and
            len(musdb_stems) > 0 and
            random.random() < AUGMENTATION_RATIO
        )

        if should_augment:
            # Mix with random MUSDB track
            random_musdb = random.choice(musdb_stems)
            audio = create_mixed_audio(audio, random_musdb)

        # Extract audio features from final audio
        audio_features = extract_audio_features(audio, sr)
        if audio_features is None:
            return []

        # Process MIDI, we keep the same labels regardless of audio mixing
        piano_roll, left_hand, right_hand = process_midi_to_piano_roll(metadata_item['midi_path'])
        if piano_roll is None:
            return []

        # DEBUG: Check alignment before window extraction
        print(f"🔍 DEBUG - Pre-window alignment check:")
        print(f"   Audio features: {audio_features.shape[0]} frames")
        print(f"   Piano roll: {piano_roll.shape[0]} frames") 
        
        # Calculate expected frames for audio length
        expected_frames = len(audio) // HOP_LENGTH
        print(f"   Expected frames from audio length: {expected_frames}")

        # Extract windows
        windows = extract_sliding_windows(
            audio_features, piano_roll, left_hand, right_hand,
            metadata_item, musdb_stems, is_training
        )

        return windows

    except Exception as e:
        print(f"Error processing {metadata_item['title']}: {e}")
        return []

## 5. Dataset Creation
This is the dataset creation and laoding fucntions, it create splits, Dataset class, and DataLoader

In [None]:
def create_maestro_dataset_splits(metadata_list: List[Dict]) -> Dict[str, List[Dict]]:
    """Create train/val/test splits for MAESTRO dataset with MUSDB augmentation"""
    np.random.seed(RANDOM_SEED)

    # Filter to only MAESTRO items
    maestro_items = [item for item in metadata_list if item['dataset'] == 'maestro']

    if not maestro_items:
        raise ValueError("No MAESTRO recordings found in metadata")

    splits = {'train': [], 'val': [], 'test': []}

    # Handle MAESTRO with composer-based splits
    print(f"Splitting MAESTRO by composer...")
    composer_groups = defaultdict(list)
    for item in maestro_items:
        composer_groups[item['composer']].append(item)

    composers = list(composer_groups.keys())
    np.random.shuffle(composers)

    n_composers = len(composers)
    train_end = int(n_composers * TRAIN_RATIO)
    val_end = train_end + int(n_composers * VAL_RATIO)

    for i, composer in enumerate(composers):
        if i < train_end:
            split = 'train'
        elif i < val_end:
            split = 'val'
        else:
            split = 'test'
        splits[split].extend(composer_groups[composer])

    # Print split statistics
    print(f"MAESTRO dataset splits:")
    for split_name, split_data in splits.items():
        composers = set(item['composer'] for item in split_data)
        print(f"   {split_name.upper()}: {len(split_data)} recordings from {len(composers)} composers")

    return splits

def save_windows_to_cache(windows: List[Dict], split_name: str, cache_dir: Path):
    """Save processed windows to cache"""
    if not windows:
        return

    cache_paths = get_cache_paths(split_name, cache_dir)
    print(f"Saving {len(windows):,} {split_name} windows to cache...")

    # Stack arrays efficiently
    audio_features = np.stack([w['audio'] for w in windows])
    piano_rolls = np.stack([w['piano_roll'] for w in windows])
    left_hands = np.stack([w['left_hand'] for w in windows])
    right_hands = np.stack([w['right_hand'] for w in windows])

    # Save arrays
    np.save(cache_paths['audio'], audio_features)
    np.save(cache_paths['piano_roll'], piano_rolls)
    np.save(cache_paths['left_hand'], left_hands)
    np.save(cache_paths['right_hand'], right_hands)

    # Save metadata
    metadata_for_json = []
    for w in windows:
        metadata = w['metadata'].copy()
        metadata['audio_path'] = str(metadata['audio_path'])
        metadata['midi_path'] = str(metadata['midi_path'])

        window_info = {
            'metadata': metadata,
            'window_start_time': w['window_start_time'],
            'window_id': w['window_id']
        }
        metadata_for_json.append(window_info)

    with open(cache_paths['metadata'], 'w') as f:
        json.dump(metadata_for_json, f, indent=2)

    print(f"{split_name} cache saved")

def load_windows_from_cache(split_name: str, cache_dir: Path) -> List[Dict]:
    """Load windows from cache"""
    cache_paths = get_cache_paths(split_name, cache_dir)

    # Check files exist
    required_files = ['metadata', 'audio', 'piano_roll', 'left_hand', 'right_hand']
    for file_type in required_files:
        if not cache_paths[file_type].exists():
            return []

    try:
        print(f"Loading {split_name} windows from cache...")

        # Load arrays we can either do it by loading into ram or reading from disk

        # audio_features = np.load(cache_paths['audio'])
        # piano_rolls = np.load(cache_paths['piano_roll'])
        # left_hands = np.load(cache_paths['left_hand'])
        # right_hands = np.load(cache_paths['right_hand'])

        # Load arrays
        audio_features = np.load(cache_paths['audio'], mmap_mode='r')
        piano_rolls = np.load(cache_paths['piano_roll'], mmap_mode='r')
        left_hands = np.load(cache_paths['left_hand'], mmap_mode='r')
        right_hands = np.load(cache_paths['right_hand'], mmap_mode='r')

        # Load metadata
        with open(cache_paths['metadata'], 'r') as f:
            metadata_list = json.load(f)

        # Reconstruct windows
        windows = []
        for i, window_info in enumerate(metadata_list):
            metadata = window_info['metadata'].copy()
            metadata['audio_path'] = Path(metadata['audio_path'])
            metadata['midi_path'] = Path(metadata['midi_path'])

            window = {
                'audio': audio_features[i],
                'piano_roll': piano_rolls[i],
                'left_hand': left_hands[i],
                'right_hand': right_hands[i],
                'metadata': metadata,
                'window_start_time': window_info['window_start_time'],
                'window_id': window_info['window_id']
            }
            windows.append(window)

        print(f"Loaded {len(windows):,} {split_name} windows from cache")
        return windows
    except Exception as e:
        print(f"Error loading cache for {split_name}: {e}")
        return []

class PianoTranscriptionDataset(Dataset):
    """PyTorch Dataset for piano transcription (MAESTRO + MUSDB augmentation)"""

    def __init__(self, metadata_list: List[Dict], split_name: str = 'train', musdb_stems: List[Dict] = None, cache_dir: Path = None):
        self.metadata_list = metadata_list
        self.split_name = split_name
        self.is_training = (split_name == 'train')
        self.musdb_stems = musdb_stems or []

        print(f"Creating {split_name} dataset with {len(metadata_list)} MAESTRO recordings...")

        # Try cache first
        self.windows = []
        if CACHE_PROCESSED_DATA and cache_dir and check_cache_compatibility(cache_dir):
            self.windows = load_windows_from_cache(split_name, cache_dir)

        if not self.windows:
            # Process from scratch
            print(f"Processing {len(metadata_list)} recordings...")
            self._process_from_scratch()

            if CACHE_PROCESSED_DATA and cache_dir:
                save_windows_to_cache(self.windows, split_name, cache_dir)
                save_processing_config(cache_dir)

        print(f"{split_name} dataset: {len(self.windows):,} windows")

    def _process_from_scratch(self):
        """Process recordings from scratch with error handling"""
        failed_recordings = 0

        for metadata_item in tqdm(self.metadata_list, desc=f"Processing {self.split_name}"):
            try:
                windows = process_audio_midi_pair(metadata_item, self.musdb_stems, self.is_training)

                if len(windows) == 0:
                    failed_recordings += 1
                else:
                    self.windows.extend(windows)
            except Exception as e:
                print(f"Error processing {metadata_item['title']}: {e}")
                failed_recordings += 1

        if failed_recordings > 0:
            success_rate = (len(self.metadata_list) - failed_recordings) / len(self.metadata_list) * 100
            print(f"{failed_recordings} recordings failed, {success_rate:.1f}% success rate")

    def __len__(self) -> int:
        return len(self.windows)

    def __getitem__(self, idx: int) -> Dict:
        """Get preprocessed window"""
        window = self.windows[idx]

        return {
            'audio': torch.FloatTensor(window['audio']),
            'piano_roll': torch.FloatTensor(window['piano_roll']),
            'left_hand': torch.FloatTensor(window['left_hand']),
            'right_hand': torch.FloatTensor(window['right_hand']),
            'metadata': window['metadata'],
            'window_info': {
                'window_id': window['window_id'],
                'start_time': window['window_start_time']
            }
        }

def collate_batch(batch: List[Dict]) -> Dict:
    """Collate function for DataLoader"""
    if not batch:
        return None

    return {
        'audio': torch.stack([item['audio'] for item in batch]),
        'piano_roll': torch.stack([item['piano_roll'] for item in batch]),
        'left_hand': torch.stack([item['left_hand'] for item in batch]),
        'right_hand': torch.stack([item['right_hand'] for item in batch]),
        'metadata': [item['metadata'] for item in batch],
        'window_info': [item['window_info'] for item in batch]
    }

def create_data_loaders(dataset_splits: Dict[str, List[Dict]], cache_dir: Path = None, musdb_stems: List[Dict] = None) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create PyTorch DataLoaders for MAESTRO dataset with MUSDB augmentation"""
    print(f"Creating datasets...")
    print(f"Batch size: {BATCH_SIZE}, Workers: {NUM_WORKERS}")

    # Create datasets
    train_dataset = PianoTranscriptionDataset(dataset_splits['train'], split_name='train', musdb_stems=musdb_stems, cache_dir=cache_dir)
    val_dataset = PianoTranscriptionDataset(dataset_splits['val'], split_name='val', musdb_stems=musdb_stems, cache_dir=cache_dir)
    test_dataset = PianoTranscriptionDataset(dataset_splits['test'], split_name='test', musdb_stems=musdb_stems, cache_dir=cache_dir)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

    test_loader = DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
    )

    print(f"Data loaders created:")
    print(f"Train:{len(train_loader)} batches ({len(train_dataset):,} windows)")
    print(f"Val:{len(val_loader)} batches ({len(val_dataset):,} windows)")
    print(f"Test:{len(test_loader)} batches ({len(test_dataset):,} windows)")

    return train_loader, val_loader, test_loader

## 6. Execute Data Pipeline
This runs the complete pipeline and creates everything thats needed

In [None]:
# Initialize directories
directories = setup_project_directories()
print("Data loading functions ready")

# Load datasets
print("\nLoading MAESTRO and MUSDB datasets")
maestro_metadata, musdb_stems = load_combined_metadata(directories['data_raw'])

if len(maestro_metadata) == 0:
    print("No datasets loaded successfully!")
    raise RuntimeError("No valid datasets found")

# Create splits
print("\nCreating MAESTRO dataset splits")
dataset_splits = create_maestro_dataset_splits(maestro_metadata)

# Create data loaders
print("\nCreating data loaders")
cache_dir = directories['data_cached'] if CACHE_PROCESSED_DATA else None

train_loader, val_loader, test_loader = create_data_loaders(dataset_splits, cache_dir=cache_dir, musdb_stems=musdb_stems)

print("\n")
print("Data loaded successfully!")

# Count datasets in results
dataset_counts = defaultdict(int)
for item in maestro_metadata:
    dataset_counts[item['dataset']] += 1

print(f"\nFinal Summary:")
for dataset, count in dataset_counts.items():
    print(f"{dataset.upper()}: {count:,} recordings")
print(f"MUSDB stems available: {len(musdb_stems)} tracks for augmentation")

print(f"\nVariables: train_loader, val_loader, test_loader")

Project directories created at: /content/drive/MyDrive/APS360_Team_2_Project
Data loading functions ready

Loading MAESTRO and MUSDB datasets

Loading MAESTRO dataset...
MAESTRO dataset already exists
Loading MAESTRO metadata from: maestro-v3.0.0.csv
Loaded 1276 MAESTRO recordings

Loading MUSDB18-HQ stems...
MUSDB18-HQ dataset found
Loading MUSDB18-HQ stems from: MUSDB18-HQ
Found 100 track folders
Loaded 100 MUSDB tracks for augmentation

Combined dataset summary:
   MAESTRO: 1,276 recordings
   MUSDB stems: 100 tracks

Creating MAESTRO dataset splits
Splitting MAESTRO by composer...
MAESTRO dataset splits:
   TRAIN: 927 recordings from 42 composers
   VAL: 211 recordings from 9 composers
   TEST: 138 recordings from 9 composers

Creating data loaders
Creating datasets...
Batch size: 8, Workers: 0
Creating train dataset with 927 MAESTRO recordings...
Loading train windows from cache...
Loaded 49,316 train windows from cache
train dataset: 49,316 windows
Creating val dataset with 211 M

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CRNN_OnsetsAndFrames(nn.Module):
    """
    Input  (batch):  audio features shaped [B, 128, T]  (128 mel bins x T frames)
    Output (batch):  logits shaped      [B, 88,  T]     (88 piano keys x T frames)

    Notes
    -----
    - We keep time resolution by pooling only along the frequency axis (2,1).
    - We return *logits* (NO sigmoid) so that BCEWithLogitsLoss can be used correctly.
    """
    def __init__(self, num_pitches: int = 88, lstm_hidden_size: int = 256, cnn_out_channels: int = 128):
        super().__init__()
        self.name = "crnn_onsets_frames"
        self.num_pitches = num_pitches

        # Convolutional feature extractor (freq pooling only)
        self.conv_block = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),  # 128 -> 64

            nn.Conv2d(32, 64, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),  # 64 -> 32

            nn.Conv2d(64, cnn_out_channels, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(cnn_out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),  # 32 -> 16
        )

        # After 3x (2,1) pools, frequency dim: 128 -> 64 -> 32 -> 16
        freq_out = 128 // 8  # = 16
        lstm_input_size = cnn_out_channels * freq_out

        self.lstm = nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=False,
        )

        # Residual projection to match BiLSTM output size (2 * hidden)
        self.res_fc = nn.Linear(lstm_input_size, 2 * lstm_hidden_size)

        # Frame-wise classifier to 88 piano keys
        self.fc_frame = nn.Linear(2 * lstm_hidden_size, num_pitches)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, 128, T]  -> returns logits [B, 88, T]
        """
        if x.dim() != 3:
            raise ValueError(f"Expected input [B, 128, T], got {tuple(x.shape)}")

        # Add channel dim for convs: [B, 1, 128, T]
        x = x.unsqueeze(1)

        # [B, C, F', T] with F' = 16
        x = self.conv_block(x)
        B, C, Freq, T = x.shape

        # Prepare sequence for LSTM: [T, B, C*F']
        x_seq = x.permute(3, 0, 1, 2).contiguous().view(T, B, C * Freq)

        # BiLSTM
        lstm_out, _ = self.lstm(x_seq)  # [T, B, 2*hidden]

        # Residual connection (project x_seq to BiLSTM size and add)
        lstm_out = lstm_out + self.res_fc(x_seq)

        # Frame-wise logits -> [T, B, 88] then -> [B, 88, T]
        frame_logits = self.fc_frame(lstm_out)
        return frame_logits.permute(1, 2, 0)

In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

def get_model_name(name, batch_size, learning_rate, epoch):
    return f"model_{name}_bs{batch_size}_lr{learning_rate}_epoch{epoch}"

def evaluate(model, data_loader, criterion):
    device = next(model.parameters()).device
    model.eval()
    total_loss = 0.0
    total_incorrect = 0
    total_elements = 0
    batch_count = 0

    with torch.no_grad():
        for batch in data_loader:
            # Expect a dict from collate_fn
            audio = batch["audio"].to(device)            # [B, T, 128]
            target = batch["piano_roll"].to(device)      # [B, T, 88]

            # DEBUG: Print shapes before processing
            print(f"🔍 DEBUG - Input shapes: audio={audio.shape}, target={target.shape}")

            # Transpose audio for model: [B, T, 128] -> [B, 128, T]
            audio = audio.transpose(1, 2)

            # Align target to logits shape [B, 88, T]
            target = target.permute(0, 2, 1).contiguous()

            print(f"🔍 DEBUG - After transpose: audio={audio.shape}, target={target.shape}")

            logits = model(audio)                        # [B, 88, T]
            print(f"🔍 DEBUG - Model output: {logits.shape}")

            # Ensure shapes match before loss calculation
            if logits.shape != target.shape:
                print(f"❌ SHAPE MISMATCH: logits={logits.shape}, target={target.shape}")
                # Try to fix by truncating to min size
                min_time = min(logits.shape[2], target.shape[2])
                logits = logits[:, :, :min_time]
                target = target[:, :, :min_time]
                print(f"🔧 FIXED: logits={logits.shape}, target={target.shape}")

            loss = criterion(logits, target)
            total_loss += loss.item()
            batch_count += 1

            # Metrics: thresholded predictions vs targets
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5)
            incorrect = (preds != target.bool())
            total_incorrect += incorrect.sum().item()
            total_elements += target.numel()

    error = total_incorrect / max(1, total_elements)
    avg_loss = total_loss / max(1, batch_count)
    return avg_loss, error

def plot_training_curves(save_path, num_epochs):
    # Load metrics
    train_loss = np.loadtxt(f"{save_path}_train_loss.csv")
    train_err  = np.loadtxt(f"{save_path}_train_err.csv")
    val_loss   = np.loadtxt(f"{save_path}_val_loss.csv")
    val_err    = np.loadtxt(f"{save_path}_val_err.csv")

    epochs = np.arange(1, num_epochs + 1)

    # Loss curve
    plt.figure()
    plt.plot(epochs, train_loss, label="Training")
    plt.plot(epochs, val_loss, label="Validation")
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Error curve
    plt.figure()
    plt.plot(epochs, train_err, label="Training")
    plt.plot(epochs, val_err, label="Validation")
    plt.title("Training and Validation Error")
    plt.xlabel("Epoch")
    plt.ylabel("Error")
    plt.legend()
    plt.grid(True)
    plt.show()

def train(model, train_loader, val_loader, batch_size=4, learning_rate=1e-3, num_epochs=30, pos_weight_value=5.0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # BCEWithLogitsLoss expects LOGITS; model.forward returns logits.
    # Use a per-class pos_weight to handle class imbalance across 88 keys.
    pos_weight = torch.full((model.num_pitches,), float(pos_weight_value), device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Tracking
    train_losses = np.zeros(num_epochs, dtype=np.float32)
    train_errs   = np.zeros(num_epochs, dtype=np.float32)
    val_losses   = np.zeros(num_epochs, dtype=np.float32)
    val_errs     = np.zeros(num_epochs, dtype=np.float32)

    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_incorrect = 0
        running_total = 0
        batch_count = 0

        for batch_idx, batch in enumerate(train_loader):
            audio = batch["audio"].to(device)           # [B, T, 128]
            target = batch["piano_roll"].to(device)     # [B, T, 88]

            # DEBUG: Print shapes for first batch of first epoch
            if epoch == 0 and batch_idx == 0:
                print(f"🔍 DEBUG - Training batch shapes: audio={audio.shape}, target={target.shape}")

            # Transpose for model: [B, T, 128] -> [B, 128, T]
            audio = audio.transpose(1, 2)
            target = target.permute(0, 2, 1).contiguous()  # -> [B, 88, T]

            if epoch == 0 and batch_idx == 0:
                print(f"🔍 DEBUG - After transpose: audio={audio.shape}, target={target.shape}")

            optimizer.zero_grad()
            logits = model(audio)                        # [B, 88, T]

            if epoch == 0 and batch_idx == 0:
                print(f"🔍 DEBUG - Model output: {logits.shape}")

            # Ensure shapes match before loss calculation
            if logits.shape != target.shape:
                print(f"❌ SHAPE MISMATCH in batch {batch_idx}: logits={logits.shape}, target={target.shape}")
                # Fix by truncating to min size
                min_time = min(logits.shape[2], target.shape[2])
                logits = logits[:, :, :min_time]
                target = target[:, :, :min_time]
                print(f"🔧 FIXED to: logits={logits.shape}, target={target.shape}")

            loss = criterion(logits, target.float())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_count += 1

            with torch.no_grad():
                probs = torch.sigmoid(logits)
                preds = (probs > 0.5)
                incorrect = (preds != target.bool())
                running_incorrect += incorrect.sum().item()
                running_total += target.numel()

        # Epoch metrics
        train_losses[epoch] = running_loss / max(1, batch_count)
        train_errs[epoch]   = (running_incorrect / max(1, running_total))

        # Validation
        val_loss, val_err = evaluate(model, val_loader, criterion)
        val_losses[epoch] = val_loss
        val_errs[epoch]   = val_err

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_losses[epoch]:.4f}  Err: {train_errs[epoch]:.4f} | "
              f"Val Loss: {val_losses[epoch]:.4f}  Err: {val_errs[epoch]:.4f}")

        # Save model checkpoint each epoch
        model_path = get_model_name(model.name, batch_size, learning_rate, epoch)
        torch.save(model.state_dict(), model_path)

    elapsed = time.time() - start_time
    print(f"\nTraining completed in {elapsed:.1f}s")

    # Save metrics and plot
    save_path = f"model_{model.name}_bs{batch_size}_lr{learning_rate}"
    np.savetxt(f"{save_path}_train_loss.csv", train_losses)
    np.savetxt(f"{save_path}_train_err.csv",  train_errs)
    np.savetxt(f"{save_path}_val_loss.csv",   val_losses)
    np.savetxt(f"{save_path}_val_err.csv",    val_errs)

    plot_training_curves(save_path, num_epochs)

In [None]:
# --- Overfit-on-a-tiny-slice demo ---
# This cell creates a very small, deterministic subset of the training set,
# sets the batch size equal to that subset (single-batch epochs),
# and trains long enough to intentionally overfit. A training curve is plotted at the end.

import torch
from torch.utils.data import DataLoader, Subset

# Reproducibility & deterministic ops
torch.manual_seed(42)
try:
    torch.use_deterministic_algorithms(True)
except Exception:
    pass
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# =============================================================================
# STEP 1: SETUP OVERFIT DATASET
# =============================================================================
print("="*80)
print("STEP 1: SETTING UP OVERFIT DATASET")
print("="*80)

# Build the training dataset
train_dataset = PianoTranscriptionDataset(
    dataset_splits['train'],
    split_name='train',
    musdb_stems=musdb_stems,
    cache_dir=cache_dir
)

# Pick a tiny, contiguous slice from the start for stability
OVERFIT_FRACTION = 0.05     # 5% of train
MIN_EXAMPLES = 8            # ensure we have enough examples to learn
MAX_EXAMPLES = 16           # Reduced for memory - was 64
overfit_size = max(MIN_EXAMPLES, int(len(train_dataset) * OVERFIT_FRACTION))
overfit_size = min(overfit_size, MAX_EXAMPLES, len(train_dataset))

overfit_indices = list(range(overfit_size))
overfit_subset = Subset(train_dataset, overfit_indices)

# Use smaller batch size to prevent memory crashes
BATCH_SIZE = 4  # Much smaller batch size for CPU
overfit_loader = DataLoader(
    overfit_subset,
    batch_size=BATCH_SIZE,
    shuffle=False,      # deterministic order
    num_workers=0,
    drop_last=False,
    collate_fn=collate_batch  # Use the custom collate function
)

print(f"✅ Dataset: {len(train_dataset):,} total windows")
print(f"✅ Overfit subset size: {len(overfit_subset)}")
print(f"✅ Batch size: {BATCH_SIZE}")
print(f"✅ DataLoader using custom collate: {overfit_loader.collate_fn == collate_batch}")
print("="*80)

Creating train dataset with 927 MAESTRO recordings...
Loading train windows from cache...
Loaded 49,316 train windows from cache
train dataset: 49,316 windows
Overfit subset size: 64
Batch size set to:   64 (1 step per epoch)


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pathlib.PosixPath'>

In [None]:
# =============================================================================
# STEP 2: VERIFY DATALOADER AND SHAPES
# =============================================================================
print("="*80)
print("STEP 2: VERIFYING DATALOADER AND TENSOR SHAPES")
print("="*80)

# Step 2.1: Test DataLoader
print("2.1 Testing DataLoader...")
try:
    test_batch = next(iter(overfit_loader))
    print(f"✅ DataLoader works! Batch keys: {test_batch.keys()}")
    
    # DEBUG: Print actual shapes
    audio_shape = test_batch['audio'].shape
    piano_shape = test_batch['piano_roll'].shape
    print(f"🔍 DEBUG - Audio shape from DataLoader: {audio_shape}")
    print(f"🔍 DEBUG - Piano roll shape from DataLoader: {piano_shape}")
    print(f"🔍 DEBUG - Expected: Audio [4, 312, 128], Piano [4, 312, 88]")
    
    # Check if shapes match in time dimension
    if audio_shape[1] != piano_shape[1]:
        print(f"❌ DIMENSION MISMATCH: Audio time={audio_shape[1]}, Piano time={piano_shape[1]}")
        print(f"❌ This is likely the root cause of the tensor size error!")
    else:
        print(f"✅ Time dimensions match: {audio_shape[1]}")
        
except Exception as e:
    print(f"❌ DataLoader failed: {e}")
    print("STOP: DataLoader issue - check collate function")
    raise

# Step 2.2: Test model with corrected input
print("\n2.2 Testing model with corrected input...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CRNN_OnsetsAndFrames().to(device)

# Apply the transpose fixes
test_audio = test_batch["audio"].to(device)           # [B, T, 128]
test_target = test_batch["piano_roll"].to(device)     # [B, T, 88]

print(f"🔍 DEBUG - Before transpose - Audio: {test_audio.shape}, Target: {test_target.shape}")

# CRITICAL FIX: Transpose to match model expectation
test_audio_transposed = test_audio.transpose(1, 2)    # [B, 128, T]
test_target_transposed = test_target.permute(0, 2, 1).contiguous()  # [B, 88, T]

print(f"🔍 DEBUG - After transpose - Audio: {test_audio_transposed.shape}, Target: {test_target_transposed.shape}")
print(f"✅ Expected: Audio [4, 128, 312], Target [4, 88, 312]")

# Check if transposed shapes have matching time dimension
if test_audio_transposed.shape[2] != test_target_transposed.shape[2]:
    print(f"❌ TRANSPOSED DIMENSION MISMATCH: Audio time={test_audio_transposed.shape[2]}, Target time={test_target_transposed.shape[2]}")
else:
    print(f"✅ Transposed time dimensions match: {test_audio_transposed.shape[2]}")

# Step 2.3: Test model forward pass
print("\n2.3 Testing model forward pass...")
try:
    with torch.no_grad():
        test_logits = model(test_audio_transposed)
        print(f"🔍 DEBUG - Model output shape: {test_logits.shape}")
        print(f"🔍 DEBUG - Expected output shape: [4, 88, 312]")
        
        # Check if model output matches target shape
        if test_logits.shape != test_target_transposed.shape:
            print(f"❌ MODEL OUTPUT MISMATCH: Model={test_logits.shape}, Target={test_target_transposed.shape}")
        else:
            print(f"✅ Model output matches target shape!")
            
except Exception as e:
    print(f"❌ Model forward pass failed: {e}")
    print("STOP: Model architecture issue")
    raise

# Step 2.4: Test loss calculation
print("\n2.4 Testing loss calculation...")
try:
    criterion = torch.nn.BCEWithLogitsLoss()
    test_loss = criterion(test_logits, test_target_transposed.float())
    print(f"✅ Loss calculation works! Loss: {test_loss.item():.4f}")
except Exception as e:
    print(f"❌ Loss calculation failed: {e}")
    print(f"🔍 DEBUG - Logits shape: {test_logits.shape}")
    print(f"🔍 DEBUG - Target shape: {test_target_transposed.shape}")
    print("STOP: Loss calculation issue")
    raise

print("\n✅ ALL VERIFICATION TESTS PASSED!")
print("="*80)

In [None]:
# =============================================================================
# STEP 3: RUN TRAINING WITH DIMENSION FIX
# =============================================================================
print("="*80)
print("STEP 3: STARTING TRAINING WITH DIMENSION FIX")
print("="*80)

# Training parameters
NUM_EPOCHS = 5  # Short training for testing
LR = 1e-3

print(f"Training parameters:")
print(f"  • Epochs: {NUM_EPOCHS}")
print(f"  • Learning rate: {LR}")
print(f"  • Batch size: {BATCH_SIZE}")
print(f"  • Device: {device}")
print("="*80)

# CUSTOM TRAINING FUNCTION WITH DIMENSION FIX
def train_with_fix(model, train_loader, val_loader, num_epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        batch_count = 0
        
        for batch_idx, batch in enumerate(train_loader):
            audio = batch["audio"].to(device)           # [B, T, 128]
            target = batch["piano_roll"].to(device)     # [B, T, 88]
            
            # Print shapes for first batch
            if epoch == 0 and batch_idx == 0:
                print(f"🔍 Original shapes: audio={audio.shape}, target={target.shape}")
            
            # Transpose to model format
            audio = audio.transpose(1, 2)               # [B, 128, T]  
            target = target.permute(0, 2, 1).contiguous()  # [B, 88, T]
            
            if epoch == 0 and batch_idx == 0:
                print(f"🔍 After transpose: audio={audio.shape}, target={target.shape}")
            
            optimizer.zero_grad()
            logits = model(audio)                       # [B, 88, T]
            
            if epoch == 0 and batch_idx == 0:
                print(f"🔍 Model output: {logits.shape}")
            
            # CRITICAL FIX: Ensure matching time dimensions
            min_time = min(logits.shape[2], target.shape[2])
            if logits.shape[2] != target.shape[2]:
                print(f"🔧 Fixing dimension mismatch: {logits.shape[2]} vs {target.shape[2]} → {min_time}")
                
            logits = logits[:, :, :min_time]
            target = target[:, :, :min_time]
            
            if epoch == 0 and batch_idx == 0:
                print(f"🔍 Final shapes: logits={logits.shape}, target={target.shape}")
            
            loss = criterion(logits, target.float())
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
            
            # Print progress every 100 batches
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / max(1, batch_count)
        print(f"Epoch {epoch+1}/{num_epochs} completed - Average Loss: {avg_loss:.4f}")
    
    print("🎉 Training completed successfully!")

# Start training with the fixed function
train_with_fix(model, overfit_loader, overfit_loader, num_epochs=NUM_EPOCHS)

print("="*80)
print("🎉 TRAINING COMPLETED SUCCESSFULLY!")
print("="*80)

In [None]:
# =============================================================================
# STEP 4: SAVE THE TRAINED MODEL FOR WEBSITE INTEGRATION
# =============================================================================
print("="*80)
print("STEP 4: SAVING MODEL FOR WEBSITE")
print("="*80)

# Save the model automatically after training
try:
    print("💾 Saving model for website integration...")
    
    # Use the save function we defined earlier
    success = save_model_for_website(model, "piano_transcription_overfitted.pth")
    
    if success:
        print("✅ Model saved successfully!")
        print("")
        print("🚀 NEXT STEPS:")
        print("1. Download 'piano_transcription_overfitted.pth' from Colab Files")
        print("2. Place it in your website's backend/weights/ directory")  
        print("3. Rename it to 'piano_transcription_weights.pth'")
        print("4. Restart your Flask server")
        print("5. Test with a piano audio file!")
        print("")
        print("📝 Note: This model is overfitted on a small dataset for testing.")
        print("   For production, train on the full dataset with more epochs.")
    else:
        print("❌ Model saving failed!")
        
except Exception as e:
    print(f"❌ Error saving model: {e}")

print("="*80)

In [None]:
# =============================================================================
# OPTIONAL: QUICK VERIFICATION CHECK  
# =============================================================================
# Run this cell if you want to double-check that everything is set up correctly
# You can skip this and go straight to the training cells above

import time
print("="*80)
print("OPTIONAL VERIFICATION CHECK")
print("="*80)
print(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"File has all fixes: BATCH_SIZE=4, transpose fixes, custom collate")
print(f"Settings: BATCH_SIZE={BATCH_SIZE}, MAX_EXAMPLES={MAX_EXAMPLES}")
print(f"Overfit subset: {len(overfit_subset)} samples")
print("="*80)
