# Model A: Curriculum Learning - Other/Piano Extraction

**Curriculum Learning Strategy:**
1. **Stage 1:** Extract "other" from simplified mixture (vocals + other only)
2. **Stage 2:** Extract "other" from full mixture (drums + bass + vocals + other)

**MUSDB18 Dataset:** 4 stems per track (drums, bass, other, vocals)

**Workflow:**
- Load MUSDB18 ‚Üí prepare curriculum batches
- Train Stage 1 on simpler 2-source task
- Train Stage 2 on full 4-source mixture using Stage 1 weights
- Test on uploaded song (10 seconds from 1:00-1:10)

In [1]:
import sys
import os
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import librosa
from tqdm import tqdm
from IPython.display import Audio, display

# Setup paths
project_root = Path(os.getcwd()).resolve()
if project_root.name.lower() == "notebooks":
    project_root = project_root.parent

sys.path.insert(0, str(project_root))

checkpoints_dir = project_root / "checkpoints"
data_dir = project_root / "data"

checkpoints_dir.mkdir(exist_ok=True, parents=True)
data_dir.mkdir(exist_ok=True, parents=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úì Setup complete | Device: {device} | Project: {project_root}")

# Import model components
from models.model_a_unet_freq import (
    STFTProcessor, FrequencyDomainUNet, 
    SourceSeparationDataset, ModelATrainer, ModelAInference
)

‚úì Setup complete | Device: cpu | Project: C:\Users\amita\source\repos\Deep learning on computational accelerators\Final_Project_Deep_Learning


In [2]:
# Load MUSDB18 dataset - Auto download if not present
print("\n" + "="*70)
print("LOADING MUSDB18 DATASET")
print("="*70)

try:
    import musdb
    print("‚úì musdb library found")
except ImportError:
    print("Installing musdb...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "musdb", "-q"])
    import musdb

print("\nLoading MUSDB18 (auto-downloading if needed)...")
print("Note: First run may take time. Dataset will be cached for future use.\n")

try:
    mus = musdb.DB(download=True)
    tracks = mus.tracks
    print(f"‚úì MUSDB18 loaded successfully!")
    print(f"‚úì Available tracks: {len(tracks)}")
    use_real_musdb = True
except Exception as e:
    print(f"‚ö†Ô∏è Could not load MUSDB18: {e}")
    print("Will use synthetic data instead")
    mus = None
    use_real_musdb = False

# MUSDB18 stems: [0]=drums, [1]=bass, [2]=other, [3]=vocals
STEM_NAMES = {0: 'drums', 1: 'bass', 2: 'other', 3: 'vocals'}


LOADING MUSDB18 DATASET
‚úì musdb library found

Loading MUSDB18 (auto-downloading if needed)...
Note: First run may take time. Dataset will be cached for future use.

‚úì MUSDB18 loaded successfully!
‚úì Available tracks: 144


In [3]:
# Manual MUSDB18 Download Instructions
if not use_real_musdb:
    print("\n" + "="*70)
    print("MUSDB18 DATASET REQUIRED")
    print("="*70)
    print("\nThe musdb library doesn't support automatic downloads.")
    print("Please follow these steps to download MUSDB18:\n")
    print("1. Visit: https://sigsep.github.io/datasets/musdb.html")
    print("2. Download the MUSDB18-HQ dataset (~23GB)")
    print("3. Extract the ZIP file to:")
    print(f"   {musdb_root}")
    print("\n4. After extraction, the structure should be:")
    print(f"   {musdb_root}/")
    print(f"     ‚îú‚îÄ‚îÄ train/")
    print(f"     ‚îÇ   ‚îú‚îÄ‚îÄ A Classic Education - NightOwl/")
    print(f"     ‚îÇ   ‚îú‚îÄ‚îÄ ...")
    print(f"     ‚îî‚îÄ‚îÄ test/")
    print(f"         ‚îú‚îÄ‚îÄ ...")
    print("\n5. Then re-run cells 2-3 to detect the dataset")
    print("\n" + "="*70)
else:
    print("‚úì MUSDB18 already available")

‚úì MUSDB18 already available


In [4]:
# Prepare curriculum learning data
print("\n" + "="*70)
print("CURRICULUM LEARNING DATA PREPARATION")
print("="*70)

def prepare_curriculum_data(num_tracks=50):
    """Prepare data for curriculum learning using MUSDB18"""
    
    if not use_real_musdb or mus is None:
        raise ValueError(
            "MUSDB18 dataset could not be loaded.\n"
            "Please check your internet connection and try again.\n"
            "The musdb library will attempt to download automatically."
        )
    
    tracks = mus.tracks[:num_tracks]
    print(f"\nProcessing {len(tracks)} MUSDB18 tracks for curriculum learning...")
    
    stage1_mixture_paths = []
    stage1_target_paths = []
    stage2_mixture_paths = []
    stage2_target_paths = []
    
    cache_dir = data_dir / "curriculum_cache"
    cache_dir.mkdir(exist_ok=True, parents=True)
    
    for idx, track in enumerate(tracks):
        try:
            # Extract stems
            drums = track.targets['drums'].audio
            bass = track.targets['bass'].audio
            other = track.targets['other'].audio
            vocals = track.targets['vocals'].audio
            
            # Create mixtures
            # Stage 1: vocals + other (simplified)
            mixture_s1 = vocals + other
            # Stage 2: drums + bass + other + vocals (full)
            mixture_s2 = drums + bass + other + vocals
            
            # Resample to 22050 Hz if needed
            # Use MUSDB18 default sample rate (44100 Hz)
            sr = getattr(track, 'sample_rate', None) or 44100
            
            if sr != 22050:
                from scipy import signal
                n_samples = int(len(other) * 22050 / sr)
                other = signal.resample(other, n_samples)
                mixture_s1 = signal.resample(mixture_s1, n_samples)
                mixture_s2 = signal.resample(mixture_s2, n_samples)
            
            # Convert stereo to mono
            if other.ndim > 1:
                other = np.mean(other, axis=1)
            if mixture_s1.ndim > 1:
                mixture_s1 = np.mean(mixture_s1, axis=1)
            if mixture_s2.ndim > 1:
                mixture_s2 = np.mean(mixture_s2, axis=1)
            
            # Normalize
            other = other / (np.max(np.abs(other)) + 1e-8)
            mixture_s1 = mixture_s1 / (np.max(np.abs(mixture_s1)) + 1e-8)
            mixture_s2 = mixture_s2 / (np.max(np.abs(mixture_s2)) + 1e-8)
            
            # Save files
            s1_mix_path = cache_dir / f"stage1_mixture_{idx:03d}.npy"
            s1_tgt_path = cache_dir / f"stage1_target_{idx:03d}.npy"
            s2_mix_path = cache_dir / f"stage2_mixture_{idx:03d}.npy"
            s2_tgt_path = cache_dir / f"stage2_target_{idx:03d}.npy"
            
            np.save(s1_mix_path, mixture_s1.astype(np.float32))
            np.save(s1_tgt_path, other.astype(np.float32))
            np.save(s2_mix_path, mixture_s2.astype(np.float32))
            np.save(s2_tgt_path, other.astype(np.float32))
            
            stage1_mixture_paths.append(str(s1_mix_path))
            stage1_target_paths.append(str(s1_tgt_path))
            stage2_mixture_paths.append(str(s2_mix_path))
            stage2_target_paths.append(str(s2_tgt_path))
            
            if (idx + 1) % 10 == 0:
                print(f"  Processed {idx + 1}/{len(tracks)} tracks...")
                
        except Exception as e:
            print(f"  ‚ö†Ô∏è Skipping track {idx} ({track.name}): {str(e)[:50]}")
            continue
    
    if not stage1_mixture_paths:
        raise ValueError("No tracks could be processed from MUSDB18")
    
    return (stage1_mixture_paths, stage1_target_paths,
            stage2_mixture_paths, stage2_target_paths)

# Check if cache already exists
cache_dir = data_dir / "curriculum_cache"
cached_files = sorted(list(cache_dir.glob("stage1_mixture_*.npy"))) if cache_dir.exists() else []

if cached_files and len(cached_files) > 0:
    print("‚úì Loading cached curriculum data (skipping MUSDB18 processing)...")
    s1_mix = sorted([str(p) for p in cache_dir.glob("stage1_mixture_*.npy")])
    s1_tgt = sorted([str(p) for p in cache_dir.glob("stage1_target_*.npy")])
    s2_mix = sorted([str(p) for p in cache_dir.glob("stage2_mixture_*.npy")])
    s2_tgt = sorted([str(p) for p in cache_dir.glob("stage2_target_*.npy")])
    print(f"‚úì Loaded {len(s1_mix)} cached samples per stage")
else:
    print("‚è≥ No cache found. Processing 50 MUSDB18 tracks (first time only)...")
    s1_mix, s1_tgt, s2_mix, s2_tgt = prepare_curriculum_data(num_tracks=50)

print(f"\n‚úì Stage 1 (Vocals + Other ‚Üí Other): {len(s1_mix)} samples")
print(f"‚úì Stage 2 (Full Mixture ‚Üí Other): {len(s2_mix)} samples")


CURRICULUM LEARNING DATA PREPARATION
‚úì Loading cached curriculum data (skipping MUSDB18 processing)...
‚úì Loaded 50 cached samples per stage

‚úì Stage 1 (Vocals + Other ‚Üí Other): 50 samples
‚úì Stage 2 (Full Mixture ‚Üí Other): 50 samples


In [5]:
# Create dataloaders for both stages
print("\n" + "="*70)
print("CREATING DATALOADERS FOR CURRICULUM LEARNING")
print("="*70)

stft_processor = STFTProcessor(n_fft=2048, hop_length=512)

# Stage 1: Vocals extraction
print("\nStage 1: Vocals Extraction")
stage1_dataset = SourceSeparationDataset(
    mixture_paths=s1_mix,
    target_paths=s1_tgt,
    stft_processor=stft_processor,
    normalize=False  # Disable normalization to preserve log->linear consistency
)

s1_train_size = int(0.8 * len(stage1_dataset))
s1_val_size = len(stage1_dataset) - s1_train_size
s1_train_data, s1_val_data = random_split(stage1_dataset, [s1_train_size, s1_val_size])

s1_train_loader = DataLoader(s1_train_data, batch_size=4, shuffle=True, num_workers=0)
s1_val_loader = DataLoader(s1_val_data, batch_size=4, shuffle=False, num_workers=0)

print(f"  Train: {len(s1_train_data)} | Val: {len(s1_val_data)}")

# Stage 2: Other (piano) extraction
print("\nStage 2: Other/Piano Extraction")
stage2_dataset = SourceSeparationDataset(
    mixture_paths=s2_mix,
    target_paths=s2_tgt,
    stft_processor=stft_processor,
    normalize=False  # Disable normalization to preserve log->linear consistency
)

s2_train_size = int(0.8 * len(stage2_dataset))
s2_val_size = len(stage2_dataset) - s2_train_size
s2_train_data, s2_val_data = random_split(stage2_dataset, [s2_train_size, s2_val_size])

s2_train_loader = DataLoader(s2_train_data, batch_size=4, shuffle=True, num_workers=0)
s2_val_loader = DataLoader(s2_val_data, batch_size=4, shuffle=False, num_workers=0)

print(f"  Train: {len(s2_train_data)} | Val: {len(s2_val_data)}")

print("\n‚úì All dataloaders created")


CREATING DATALOADERS FOR CURRICULUM LEARNING

Stage 1: Vocals Extraction
  Train: 40 | Val: 10

Stage 2: Other/Piano Extraction
  Train: 40 | Val: 10

‚úì All dataloaders created


In [6]:
# Initialize model and trainer
print("\n" + "="*70)
print("MODEL INITIALIZATION & CHECKPOINT MANAGEMENT")
print("="*70)

import matplotlib.pyplot as plt

model_config = {
    'in_channels': 1,
    'base_channels': 32,
    'depth': 4,
    'use_batch_norm': True
}

model = FrequencyDomainUNet(**model_config).to(device)
print(f"\n‚úì Model created: {sum(p.numel() for p in model.parameters()):,} parameters")

# Checkpoint management
def train_stage(stage_num, train_loader, val_loader, num_epochs=20):
    """Train a curriculum stage with checkpoint management"""
    
    checkpoint_path = checkpoints_dir / f"stage{stage_num}_modelA.pt"
    
    print(f"\n{'='*70}")
    print(f"STAGE {stage_num}: CURRICULUM LEARNING")
    print(f"{'='*70}\n")
    
    # Check if checkpoint exists
    if checkpoint_path.exists():
        print(f"‚úì Checkpoint weights loaded: {checkpoint_path.name}")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"  Epoch: {checkpoint.get('epoch', '?')} | Val Loss: {checkpoint.get('val_loss', '?'):.6f}")
        # Try to recover history for plotting/printing
        train_hist = checkpoint.get('train_loss_history') or checkpoint.get('train_loss')
        val_hist = checkpoint.get('val_loss_history') or checkpoint.get('val_loss')
        history = None
        if isinstance(train_hist, (list, tuple)) and isinstance(val_hist, (list, tuple)) and len(train_hist) == len(val_hist):
            history = {
                'train_loss': list(train_hist),
                'val_loss': list(val_hist)
            }
            print("\nEpoch losses (train | val) from checkpoint:")
            for epoch_idx, (tr, va) in enumerate(zip(history['train_loss'], history['val_loss']), start=1):
                print(f"  Epoch {epoch_idx:02d}: {tr:.6f} | {va:.6f}")
        return history
    
    # Initialize trainer
    trainer = ModelATrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        learning_rate=1e-4,
        device=device,
        use_energy_weighted_loss=False  # Use standard MSE loss
    )
    
    scheduler = torch.optim.lr_scheduler.StepLR(
        trainer.optimizer,
        step_size=5,
        gamma=0.5
    )
    
    # Train
    print(f"Starting training... (no checkpoint found)")
    history = trainer.train(num_epochs=num_epochs, save_dir=str(checkpoints_dir))
    
    # Print per-epoch losses
    print("\nEpoch losses (train | val):")
    for epoch_idx, (tr, va) in enumerate(zip(history['train_loss'], history['val_loss']), start=1):
        print(f"  Epoch {epoch_idx:02d}/{num_epochs}: {tr:.6f} | {va:.6f}")
    
    # Save checkpoint
    best_epoch = np.argmin(history['val_loss']) + 1
    torch.save({
        'epoch': best_epoch,
        'model_state_dict': model.state_dict(),
        'val_loss': float(np.min(history['val_loss'])),
        'train_loss': [float(x) for x in history['train_loss']],
        'train_loss_history': [float(x) for x in history['train_loss']],
        'val_loss_history': [float(x) for x in history['val_loss']]
    }, checkpoint_path)
    
    print(f"\n‚úì Checkpoint saved: {checkpoint_path.name}")
    print(f"  Best epoch: {best_epoch} | Val Loss: {np.min(history['val_loss']):.6f}")
    return history


MODEL INITIALIZATION & CHECKPOINT MANAGEMENT

‚úì Model created: 7,765,409 parameters


In [7]:
# =============================================================================
# SANITY CHECK: Overfit on Small Dataset
# =============================================================================
# This cell tests if the model can memorize a tiny dataset (1-2 samples)

print("\n" + "="*70)
print("SANITY CHECK: OVERFITTING ON SMALL DATASET (LR=5e-3 with Grad Clipping)")
print("="*70)

# Create tiny dataset (just 1 sample)
tiny_train_data = torch.utils.data.Subset(stage1_dataset, [0])
tiny_val_data = torch.utils.data.Subset(stage1_dataset, [0])

tiny_train_loader = DataLoader(tiny_train_data, batch_size=1, shuffle=False)
tiny_val_loader = DataLoader(tiny_val_data, batch_size=1, shuffle=False)

print(f"\n‚ö†Ô∏è  Testing overfitting on 1 sample...")
print("Previous attempts:")
print("  LR=1e-2:  Epoch 1 loss=0.130 ‚Üí Epoch 2 plateaus at 0.228 (stable but slow)")
print("  LR=1e-1:  Epoch 1 loss=0.300 ‚Üí Explodes to NaN by epoch 10 (diverges)")
print("\nNew strategy: Use intermediate LR (5e-3) with modest gradient clipping (5.0)")
print("This prevents divergence while allowing faster convergence than LR=1e-2")
print("BatchNorm: disabled | Loss: MSE | Grad clipping: 5.0")

# Instantiate model
tiny_model_config = {
    'in_channels': 1,
    'base_channels': 32,
    'depth': 4,
    'use_batch_norm': False
}

tiny_model = FrequencyDomainUNet(**tiny_model_config).to(device)

# Train with INTERMEDIATE learning rate + gradient clipping
tiny_trainer = ModelATrainer(
    model=tiny_model,
    train_loader=tiny_train_loader,
    val_loader=tiny_val_loader,
    learning_rate=5e-3,  # Between 1e-2 (stable) and 1e-1 (diverges)
    device=device,
    use_energy_weighted_loss=False,
    grad_clip_max_norm=5.0  # Prevent gradient explosion
)

print("\n[Overfitting run: 50 epochs with LR=5e-3 + grad_clip=5.0]")
tiny_history = tiny_trainer.train(num_epochs=50, save_dir=None)

# Print losses
final_train = tiny_history['train_loss'][-1]
final_val = tiny_history['val_loss'][-1]

print(f"\nFirst 5 epochs:")
for i in range(min(5, len(tiny_history['train_loss']))):
    print(f"  Epoch {i+1}: Train {tiny_history['train_loss'][i]:.6f} | Val {tiny_history['val_loss'][i]:.6f}")

print(f"\nEpochs 10, 20, 30, 40, 50:")
for ep in [9, 19, 29, 39, 49]:
    if ep < len(tiny_history['train_loss']):
        print(f"  Epoch {ep+1:2d}: Train {tiny_history['train_loss'][ep]:.6f} | Val {tiny_history['val_loss'][ep]:.6f}")

print(f"\nFinal loss (epoch 50):")
print(f"  Train: {final_train:.6f}")
print(f"  Val:   {final_val:.6f}")

# Plot with enhanced diagnostics
fig = plt.figure(figsize=(14, 5))

# Linear scale
ax1 = plt.subplot(1, 3, 1)
ax1.plot(tiny_history['train_loss'], marker='o', markersize=3, label='Train', alpha=0.7)
ax1.plot(tiny_history['val_loss'], marker='s', markersize=3, label='Val', alpha=0.7)
ax1.set_title('Loss vs Epoch (Linear scale)', fontsize=11, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('MSE Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.axhline(y=1e-4, color='g', linestyle='--', alpha=0.5, label='Excellent (<1e-4)')
ax1.axhline(y=1e-2, color='b', linestyle='--', alpha=0.5, label='Good (<1e-2)')
ax1.axhline(y=0.1, color='orange', linestyle='--', alpha=0.5, label='Fair (<0.1)')

# Log scale
ax2 = plt.subplot(1, 3, 2)
ax2.semilogy(tiny_history['train_loss'], marker='o', markersize=3, label='Train', alpha=0.7)
ax2.semilogy(tiny_history['val_loss'], marker='s', markersize=3, label='Val', alpha=0.7)
ax2.set_title('Loss vs Epoch (Log scale)', fontsize=11, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('MSE Loss (log)')
ax2.legend()
ax2.grid(True, alpha=0.3, which='both')

# Convergence speed (last 20 epochs)
ax3 = plt.subplot(1, 3, 3)
if len(tiny_history['train_loss']) > 20:
    ax3.plot(range(30, 50), tiny_history['train_loss'][30:50], marker='o', markersize=4, label='Train (last 20)', color='C0')
    ax3.plot(range(30, 50), tiny_history['val_loss'][30:50], marker='s', markersize=4, label='Val (last 20)', color='C1')
    ax3.set_title('Convergence (Epochs 30-50)', fontsize=11, fontweight='bold')
else:
    ax3.plot(tiny_history['train_loss'], marker='o', markersize=3, label='Train', alpha=0.7)
    ax3.plot(tiny_history['val_loss'], marker='s', markersize=3, label='Val', alpha=0.7)
    ax3.set_title('Full Training Curve', fontsize=11, fontweight='bold')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('MSE Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*70)

if final_train < 1e-4:
    print("‚úÖ EXCELLENT: Model memorized (<1e-4). Ready for Stage 1/2 training.")
elif final_train < 1e-2:
    print("‚úÖ GOOD: Model learning (<1e-2). Proceed to Stage 1/2 training.")
elif final_train < 0.1:
    print("‚ö†Ô∏è  FAIR: Model improving but slow (<0.1). May still be usable.")
else:
    print("‚ùå FAILED: Loss > 0.1. Training pipeline has fundamental issue.")
    print("\nDiagnosis needed:")
    print("- Check if mask computation (target_mask = expm1(tgt)/expm1(mix)) is numerically stable")
    print("- Verify loss landscape doesn't have sharp discontinuities")
    print("- Consider alternative mask formulation or loss function")


SANITY CHECK: OVERFITTING ON SMALL DATASET (LR=5e-3 with Grad Clipping)

‚ö†Ô∏è  Testing overfitting on 1 sample...
Previous attempts:
  LR=1e-2:  Epoch 1 loss=0.130 ‚Üí Epoch 2 plateaus at 0.228 (stable but slow)
  LR=1e-1:  Epoch 1 loss=0.300 ‚Üí Explodes to NaN by epoch 10 (diverges)

New strategy: Use intermediate LR (5e-3) with modest gradient clipping (5.0)
This prevents divergence while allowing faster convergence than LR=1e-2
BatchNorm: disabled | Loss: MSE | Grad clipping: 5.0

[Overfitting run: 50 epochs with LR=5e-3 + grad_clip=5.0]


Epochs:   0%|          | 0/50 [00:00<?, ?it/s]


[Epoch 0 Diagnostics]
  Mix mag range: [0.0000, 4.7291]
  Tgt mag range: [0.0000, 4.8964]
  Mix lin range: [0.0000, 112.1943]
  Tgt lin range: [0.0000, 132.8131]
  Target mask range: [0.0000, 1.0000]
  Pred mask range: [0.0419, 0.9994]




  Grad max (pre-clip): 0.143332
  Grad max (post-clip): 0.143332


Epochs:   2%|‚ñè         | 1/50 [00:03<02:59,  3.66s/it, train_loss=0.149693, val_loss=0.228314]

  Epoch 01/50: Train Loss 0.149693 | Val Loss 0.228314


Epochs:   4%|‚ñç         | 2/50 [00:05<02:04,  2.59s/it, train_loss=0.228314, val_loss=0.228318]

  Epoch 02/50: Train Loss 0.228314 | Val Loss 0.228318


Epochs:   6%|‚ñå         | 3/50 [00:07<01:45,  2.24s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 03/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:   8%|‚ñä         | 4/50 [00:09<01:36,  2.11s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 04/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  10%|‚ñà         | 5/50 [00:11<01:30,  2.00s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 05/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  12%|‚ñà‚ñè        | 6/50 [00:12<01:26,  1.96s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 06/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  14%|‚ñà‚ñç        | 7/50 [00:14<01:21,  1.89s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 07/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  16%|‚ñà‚ñå        | 8/50 [00:16<01:17,  1.84s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 08/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  18%|‚ñà‚ñä        | 9/50 [00:18<01:14,  1.82s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 09/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  20%|‚ñà‚ñà        | 10/50 [00:19<01:12,  1.80s/it, train_loss=0.228318, val_loss=0.228318]

  Epoch 10/50: Train Loss 0.228318 | Val Loss 0.228318


Epochs:  20%|‚ñà‚ñà        | 10/50 [00:21<01:24,  2.12s/it, train_loss=0.228318, val_loss=0.228318]


KeyboardInterrupt: 

In [None]:
stage1_history = train_stage(
    stage_num=1,
    train_loader=s1_train_loader,
    val_loader=s1_val_loader,
    num_epochs=20
)

# Plot Stage 1 loss curves
if stage1_history is not None:
    plt.figure(figsize=(6, 4))
    plt.plot(stage1_history['train_loss'], label='Train Loss')
    plt.plot(stage1_history['val_loss'], label='Val Loss')
    plt.title('Stage 1 Loss')
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Evaluate Stage 1 Performance
print("\n" + "="*70)
print("STAGE 1 EVALUATION")
print("="*70)

# Load Stage 1 checkpoint
stage1_checkpoint = checkpoints_dir / 'stage1_modelA.pt'
if stage1_checkpoint.exists():
    checkpoint = torch.load(stage1_checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n‚úì Loaded Stage 1 checkpoint")
    print(f"  Training epoch: {checkpoint['epoch']}")
    print(f"  Validation loss: {checkpoint['val_loss']:.6f}")
    
    # Evaluate on validation set
    model.eval()
    val_loss_total = 0
    num_batches = 0
    
    print("\nEvaluating on validation set...")
    with torch.no_grad():
        for batch_data in s1_val_loader:
            # Extract mixture magnitude and target magnitude from batch
            mixture = batch_data['mixture_mag'].to(device)
            target = batch_data['target_mag'].to(device)
            
            # Forward pass
            output = model(mixture)
            
            # Compute loss
            loss_fn = torch.nn.L1Loss()
            loss = loss_fn(output, target)
            val_loss_total += loss.item()
            num_batches += 1
    
    avg_val_loss = val_loss_total / num_batches
    print(f"\n‚úì Average validation loss: {avg_val_loss:.6f}")
    
    # Test on a sample
    print("\nTesting on sample audio...")
    test_idx = 0
    test_mix = np.load(s1_mix[test_idx])
    test_tgt = np.load(s1_tgt[test_idx])
    
    # Create inference engine
    inference_engine = ModelAInference(
        model=model,
        stft_processor=stft_processor,
        device=device
    )
    
    # Separate
    separated = inference_engine.separate(test_mix)
    
    # Compute metrics
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    mse = mean_squared_error(test_tgt, separated)
    mae = mean_absolute_error(test_tgt, separated)
    
    print(f"  MSE: {mse:.6f}")
    print(f"  MAE: {mae:.6f}")
    
    # Audio playback
    print("\nüìä Listen to Stage 1 results:")
    sr = 22050
    
    def norm_audio(x):
        return x / (np.max(np.abs(x)) + 1e-8) * 0.95
    
    print("\n1. Input (Vocals + Other):")
    display(Audio(norm_audio(test_mix), rate=sr))
    
    print("\n2. Target (Other/Piano):")
    display(Audio(norm_audio(test_tgt), rate=sr))
    
    print("\n3. Separated (Stage 1 Output):")
    display(Audio(norm_audio(separated), rate=sr))
    
    print("\n" + "="*70)
    print("Stage 1 evaluation complete. Ready for Stage 2 training.")
    print("="*70)
    
else:
    print("\n‚ö†Ô∏è Stage 1 checkpoint not found. Please run Stage 1 training first.")

In [None]:
stage2_history = train_stage(
    stage_num=2,
    train_loader=s2_train_loader,
    val_loader=s2_val_loader,
    num_epochs=20
)

# Plot Stage 2 loss curves
if stage2_history is not None:
    plt.figure(figsize=(6, 4))
    plt.plot(stage2_history['train_loss'], label='Train Loss')
    plt.plot(stage2_history['val_loss'], label='Val Loss')
    plt.title('Stage 2 Loss')
    plt.xlabel('Epoch')
    plt.ylabel('L1 Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Load best model (Stage 2)
print("\n" + "="*70)
print("LOADING TRAINED MODEL FOR INFERENCE")
print("="*70)

best_checkpoint = checkpoints_dir / 'stage2_modelA.pt'
if best_checkpoint.exists():
    checkpoint = torch.load(best_checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n‚úì Loaded: {best_checkpoint.name}")
    print(f"  Best epoch: {checkpoint['epoch']} | Val Loss: {checkpoint['val_loss']:.6f}")
else:
    print("‚úì Using current model (no checkpoint)")

inference_engine = ModelAInference(
    model=model,
    stft_processor=stft_processor,
    device=device
)

print("‚úì Inference engine ready")

In [None]:
%matplotlib inline

import numpy as np
import librosa
from IPython.display import Audio, display
from pathlib import Path
import glob
import gc

# Clear matplotlib cache BEFORE importing
import os
import shutil
cache_dir = os.path.expanduser('~/.matplotlib')
if os.path.exists(cache_dir):
    try:
        shutil.rmtree(cache_dir)
        print("‚úì Cleared matplotlib cache")
    except:
        pass

# NOW import matplotlib
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Force matplotlib to rebuild font cache
try:
    matplotlib.font_manager._rebuild()
    print("‚úì Rebuilt font cache")
except:
    pass

# Use simple, safe backend and minimal text
matplotlib.use('agg')
matplotlib.rcParams.update({
    'font.size': 9,
    'font.family': 'sans-serif',
    'figure.dpi': 80,
    'savefig.dpi': 80,
    'text.usetex': False,
    'axes.unicode_minus': False
})

print("\n" + "="*70)
print("TESTING ON UPLOADED SONG & DATABASE SAMPLES WITH STFT VISUALIZATION")
print("="*70)

def visualize_stft_masking(mixture, separated, stft_processor, sr, title_prefix=""):
    """Visualize STFT magnitude before and after masking with titles"""
    
    try:
        # Compute STFT
        mix_mag, mix_phase = stft_processor.waveform_to_magnitude_phase(mixture)
        sep_mag, sep_phase = stft_processor.waveform_to_magnitude_phase(separated)
        
        # Create simple figure WITH titles
        fig, axes = plt.subplots(1, 3, figsize=(15, 4), dpi=80)
        
        # 1. Mixture
        mix_db = 20 * np.log10(mix_mag + 1e-8)
        im1 = axes[0].imshow(mix_db, aspect='auto', origin='lower', cmap='viridis')
        axes[0].set_title(f'{title_prefix}Mixture', fontsize=10, pad=5)
        axes[0].axis('off')
        
        # 2. Separated
        sep_db = 20 * np.log10(sep_mag + 1e-8)
        im2 = axes[1].imshow(sep_db, aspect='auto', origin='lower', cmap='viridis')
        axes[1].set_title(f'{title_prefix}Separated', fontsize=10, pad=5)
        axes[1].axis('off')
        
        # 3. Mask (purple-orange plasma)
        mask = sep_mag / (mix_mag + 1e-8)
        mask = np.clip(mask, 0, 1)
        im3 = axes[2].imshow(mask, aspect='auto', origin='lower', cmap='plasma', vmin=0, vmax=1)
        axes[2].set_title(f'{title_prefix}Mask', fontsize=10, pad=5)
        axes[2].axis('off')
        
        plt.subplots_adjust(left=0.02, right=0.98, top=0.88, bottom=0.02, wspace=0.05)
        
        # Display directly
        from IPython.display import display as ipy_display
        ipy_display(fig)
        plt.close(fig)
        
        print(f"‚úì Spectrograms: Mixture (blue-green) | Separated (blue-green) | Mask (purple-orange)")
        
        return mix_mag, sep_mag, mask
    except Exception as e:
        print(f"‚ö†Ô∏è Visualization failed: {e}")
        plt.close('all')
        return None, None, None

def norm_audio(x):
    """Normalize audio for playback"""
    return x / (np.max(np.abs(x)) + 1e-8) * 0.95

def process_long_audio(audio_path, inference_engine, max_chunk_duration=30, sr=22050):
    """Process long audio files in chunks to avoid memory issues"""
    print(f"\n‚ö†Ô∏è Long audio detected. Processing in chunks ({max_chunk_duration}s each)...")
    
    duration = librosa.get_duration(filename=str(audio_path))
    print(f"Total duration: {duration:.1f}s")
    
    separated_chunks = []
    num_chunks = int(np.ceil(duration / max_chunk_duration))
    
    for chunk_idx in range(num_chunks):
        offset = chunk_idx * max_chunk_duration
        y_chunk, _ = librosa.load(str(audio_path), sr=sr, mono=True, offset=offset, duration=max_chunk_duration)
        y_chunk = y_chunk / (np.max(np.abs(y_chunk)) + 1e-8)
        
        print(f"  Processing chunk {chunk_idx + 1}/{num_chunks} ({offset:.0f}s - {offset + max_chunk_duration:.0f}s)...")
        separated_chunk = inference_engine.separate(y_chunk)
        separated_chunks.append(separated_chunk)
        
        del y_chunk
        gc.collect()
    
    separated = np.concatenate(separated_chunks)
    print("‚úì Chunked processing complete")
    return separated

# ============================================================================
# Test 1: Uploaded Song
# ============================================================================
print("\n" + "-"*70)
print("TEST 1: UPLOADED SONG")
print("-"*70)

audio_files = []
search_dirs = [project_root / "data", project_root, Path(".")]
for search_dir in search_dirs:
    if search_dir.exists():
        for ext in ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']:
            audio_files.extend(glob.glob(str(search_dir / '**' / ext), recursive=True))

if audio_files:
    test_audio_path = audio_files[0]
    print(f"\n‚úì Found audio: {Path(test_audio_path).name}")
    
    file_size_mb = Path(test_audio_path).stat().st_size / (1024 * 1024)
    duration = librosa.get_duration(filename=str(test_audio_path))
    print(f"File size: {file_size_mb:.1f}MB | Duration: {duration:.1f}s")
    
    if file_size_mb > 30 or duration > 120:
        print("‚Üí Using chunked processing (memory-efficient)")
        separated = process_long_audio(test_audio_path, inference_engine, max_chunk_duration=30, sr=22050)
        y, sr = librosa.load(test_audio_path, sr=22050, mono=True, duration=30)
        test_segment = y
    else:
        y, sr = librosa.load(test_audio_path, sr=22050, mono=True)
        test_segment = y
        print("‚Üí Processing full audio")
        test_segment = test_segment / (np.max(np.abs(test_segment)) + 1e-8)
        
        print("\nRunning source separation...")
        separated = inference_engine.separate(test_segment)
    
    duration_sec = len(test_segment) / sr
    print(f"‚úì Processing song ({duration_sec:.1f}s)")
    
    mix_norm = norm_audio(test_segment)
    sep_norm = norm_audio(separated[:len(test_segment)])
    
    print("\nüìä STFT Visualization - Uploaded Song (First 30s):")
    visualize_stft_masking(test_segment, separated[:len(test_segment)], stft_processor, sr, "Song - ")
    
    print("\nüìä ORIGINAL MIXTURE (First 30s):")
    display(Audio(mix_norm, rate=sr))
    
    print("\n‚ú® SEPARATED SOURCE (First 30s):")
    display(Audio(sep_norm, rate=sr))
    
    del y, test_segment, separated
    gc.collect()
    
else:
    print("\n‚ö†Ô∏è No audio files found in data/ or current directory")

# ============================================================================
# Test 2: Database Sample (from curriculum cache)
# ============================================================================
print("\n\n" + "-"*70)
print("TEST 2: DATABASE SAMPLE (FROM CURRICULUM CACHE)")
print("-"*70)

if 's2_mix' in locals() and s2_mix and s2_tgt:
    sample_idx = np.random.randint(0, min(10, len(s2_mix)))
    
    db_mixture_path = s2_mix[sample_idx]
    db_target_path = s2_tgt[sample_idx]
    
    print(f"\n‚úì Selected sample: {Path(db_mixture_path).name}")
    
    db_mixture = np.load(db_mixture_path).astype(np.float32)
    db_target = np.load(db_target_path).astype(np.float32)
    
    print(f"‚úì Sample duration: {len(db_mixture) / 22050:.2f}s")
    
    db_mixture = db_mixture / (np.max(np.abs(db_mixture)) + 1e-8)
    db_target = db_target / (np.max(np.abs(db_target)) + 1e-8)
    
    print("\nRunning source separation...")
    db_separated = inference_engine.separate(db_mixture)
    
    mix_norm_db = norm_audio(db_mixture)
    tgt_norm_db = norm_audio(db_target)
    sep_norm_db = norm_audio(db_separated)
    
    print("\nüìä STFT Visualization - Database Sample:")
    visualize_stft_masking(db_mixture, db_separated, stft_processor, sr=22050, title_prefix="DB - ")
    
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    db_mse = mean_squared_error(db_target, db_separated)
    db_mae = mean_absolute_error(db_target, db_separated)
    
    print(f"\nüìà Performance Metrics:")
    print(f"  MSE: {db_mse:.6f}")
    print(f"  MAE: {db_mae:.6f}")
    
    print("\nüìä INPUT MIXTURE (Full 4-source):")
    display(Audio(mix_norm_db, rate=22050))
    
    print("\n‚úì GROUND TRUTH TARGET (Other/Piano):")
    display(Audio(tgt_norm_db, rate=22050))
    
    print("\n‚ú® MODEL OUTPUT (Separated):")
    display(Audio(sep_norm_db, rate=22050))
    
else:
    print("\n‚ö†Ô∏è No database samples available. Please run curriculum data preparation first.")