# Model A: Curriculum Learning - Other/Piano Extraction

**Curriculum Learning Strategy:**
1. **Stage 1:** Extract "other" from simplified mixture (vocals + other only)
2. **Stage 2:** Extract "other" from full mixture (drums + bass + vocals + other)

**MUSDB18 Dataset:** 4 stems per track (drums, bass, other, vocals)

**Workflow:**
- Load MUSDB18 ‚Üí prepare curriculum batches
- Train Stage 1 on simpler 2-source task
- Train Stage 2 on full 4-source mixture using Stage 1 weights
- Test on uploaded song (10 seconds from 1:00-1:10)

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
import librosa
from tqdm import tqdm
from IPython.display import Audio, display

# Setup paths
project_root = Path(os.getcwd()).resolve()
if project_root.name.lower() == "notebooks":
    project_root = project_root.parent

sys.path.insert(0, str(project_root))

checkpoints_dir = project_root / "checkpoints"
data_dir = project_root / "data"

checkpoints_dir.mkdir(exist_ok=True, parents=True)
data_dir.mkdir(exist_ok=True, parents=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úì Setup complete | Device: {device} | Project: {project_root}")

# Import model components
from models.model_a_unet_freq import (
    STFTProcessor, FrequencyDomainUNet, 
    SourceSeparationDataset, ModelATrainer, ModelAInference
)

In [None]:
# Load MUSDB18 dataset - Auto download if not present
print("\n" + "="*70)
print("LOADING MUSDB18 DATASET")
print("="*70)

try:
    import musdb
    print("‚úì musdb library found")
except ImportError:
    print("Installing musdb...")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "musdb", "-q"])
    import musdb

print("\nLoading MUSDB18 (auto-downloading if needed)...")
print("Note: First run may take time. Dataset will be cached for future use.\n")

try:
    mus = musdb.DB(download=True)
    tracks = mus.tracks
    print(f"‚úì MUSDB18 loaded successfully!")
    print(f"‚úì Available tracks: {len(tracks)}")
    use_real_musdb = True
except Exception as e:
    print(f"‚ö†Ô∏è Could not load MUSDB18: {e}")
    print("Will use synthetic data instead")
    mus = None
    use_real_musdb = False

# MUSDB18 stems: [0]=drums, [1]=bass, [2]=other, [3]=vocals
STEM_NAMES = {0: 'drums', 1: 'bass', 2: 'other', 3: 'vocals'}

In [None]:
# Manual MUSDB18 Download Instructions
if not use_real_musdb:
    print("\n" + "="*70)
    print("MUSDB18 DATASET REQUIRED")
    print("="*70)
    print("\nThe musdb library doesn't support automatic downloads.")
    print("Please follow these steps to download MUSDB18:\n")
    print("1. Visit: https://sigsep.github.io/datasets/musdb.html")
    print("2. Download the MUSDB18-HQ dataset (~23GB)")
    print("3. Extract the ZIP file to:")
    print(f"   {musdb_root}")
    print("\n4. After extraction, the structure should be:")
    print(f"   {musdb_root}/")
    print(f"     ‚îú‚îÄ‚îÄ train/")
    print(f"     ‚îÇ   ‚îú‚îÄ‚îÄ A Classic Education - NightOwl/")
    print(f"     ‚îÇ   ‚îú‚îÄ‚îÄ ...")
    print(f"     ‚îî‚îÄ‚îÄ test/")
    print(f"         ‚îú‚îÄ‚îÄ ...")
    print("\n5. Then re-run cells 2-3 to detect the dataset")
    print("\n" + "="*70)
else:
    print("‚úì MUSDB18 already available")

In [None]:
# Prepare curriculum learning data
print("\n" + "="*70)
print("CURRICULUM LEARNING DATA PREPARATION")
print("="*70)

def prepare_curriculum_data(num_tracks=50):
    """Prepare data for curriculum learning using MUSDB18"""
    
    if not use_real_musdb or mus is None:
        raise ValueError(
            "MUSDB18 dataset could not be loaded.\n"
            "Please check your internet connection and try again.\n"
            "The musdb library will attempt to download automatically."
        )
    
    tracks = mus.tracks[:num_tracks]
    print(f"\nProcessing {len(tracks)} MUSDB18 tracks for curriculum learning...")
    
    stage1_mixture_paths = []
    stage1_target_paths = []
    stage2_mixture_paths = []
    stage2_target_paths = []
    
    cache_dir = data_dir / "curriculum_cache"
    cache_dir.mkdir(exist_ok=True, parents=True)
    
    for idx, track in enumerate(tracks):
        try:
            # Extract stems
            drums = track.targets['drums'].audio
            bass = track.targets['bass'].audio
            other = track.targets['other'].audio
            vocals = track.targets['vocals'].audio
            
            # Create mixtures
            # Stage 1: vocals + other (simplified)
            mixture_s1 = vocals + other
            # Stage 2: drums + bass + other + vocals (full)
            mixture_s2 = drums + bass + other + vocals
            
            # Resample to 22050 Hz if needed
            # Use MUSDB18 default sample rate (44100 Hz)
            sr = getattr(track, 'sample_rate', None) or 44100
            
            if sr != 22050:
                from scipy import signal
                n_samples = int(len(other) * 22050 / sr)
                other = signal.resample(other, n_samples)
                mixture_s1 = signal.resample(mixture_s1, n_samples)
                mixture_s2 = signal.resample(mixture_s2, n_samples)
            
            # Convert stereo to mono
            if other.ndim > 1:
                other = np.mean(other, axis=1)
            if mixture_s1.ndim > 1:
                mixture_s1 = np.mean(mixture_s1, axis=1)
            if mixture_s2.ndim > 1:
                mixture_s2 = np.mean(mixture_s2, axis=1)
            
            # Normalize
            other = other / (np.max(np.abs(other)) + 1e-8)
            mixture_s1 = mixture_s1 / (np.max(np.abs(mixture_s1)) + 1e-8)
            mixture_s2 = mixture_s2 / (np.max(np.abs(mixture_s2)) + 1e-8)
            
            # Save files
            s1_mix_path = cache_dir / f"stage1_mixture_{idx:03d}.npy"
            s1_tgt_path = cache_dir / f"stage1_target_{idx:03d}.npy"
            s2_mix_path = cache_dir / f"stage2_mixture_{idx:03d}.npy"
            s2_tgt_path = cache_dir / f"stage2_target_{idx:03d}.npy"
            
            np.save(s1_mix_path, mixture_s1.astype(np.float32))
            np.save(s1_tgt_path, other.astype(np.float32))
            np.save(s2_mix_path, mixture_s2.astype(np.float32))
            np.save(s2_tgt_path, other.astype(np.float32))
            
            stage1_mixture_paths.append(str(s1_mix_path))
            stage1_target_paths.append(str(s1_tgt_path))
            stage2_mixture_paths.append(str(s2_mix_path))
            stage2_target_paths.append(str(s2_tgt_path))
            
            if (idx + 1) % 10 == 0:
                print(f"  Processed {idx + 1}/{len(tracks)} tracks...")
                
        except Exception as e:
            print(f"  ‚ö†Ô∏è Skipping track {idx} ({track.name}): {str(e)[:50]}")
            continue
    
    if not stage1_mixture_paths:
        raise ValueError("No tracks could be processed from MUSDB18")
    
    return (stage1_mixture_paths, stage1_target_paths,
            stage2_mixture_paths, stage2_target_paths)

# Prepare data from MUSDB18
s1_mix, s1_tgt, s2_mix, s2_tgt = prepare_curriculum_data(num_tracks=50)

print(f"\n‚úì Stage 1 (Vocals + Other ‚Üí Other): {len(s1_mix)} samples")
print(f"‚úì Stage 2 (Full Mixture ‚Üí Other): {len(s2_mix)} samples")

In [None]:
# Create dataloaders for both stages
print("\n" + "="*70)
print("CREATING DATALOADERS FOR CURRICULUM LEARNING")
print("="*70)

stft_processor = STFTProcessor(n_fft=2048, hop_length=512)

# Stage 1: Vocals extraction
print("\nStage 1: Vocals Extraction")
stage1_dataset = SourceSeparationDataset(
    mixture_paths=s1_mix,
    target_paths=s1_tgt,
    stft_processor=stft_processor,
    normalize=True
)

s1_train_size = int(0.8 * len(stage1_dataset))
s1_val_size = len(stage1_dataset) - s1_train_size
s1_train_data, s1_val_data = random_split(stage1_dataset, [s1_train_size, s1_val_size])

s1_train_loader = DataLoader(s1_train_data, batch_size=4, shuffle=True, num_workers=0)
s1_val_loader = DataLoader(s1_val_data, batch_size=4, shuffle=False, num_workers=0)

print(f"  Train: {len(s1_train_data)} | Val: {len(s1_val_data)}")

# Stage 2: Other (piano) extraction
print("\nStage 2: Other/Piano Extraction")
stage2_dataset = SourceSeparationDataset(
    mixture_paths=s2_mix,
    target_paths=s2_tgt,
    stft_processor=stft_processor,
    normalize=True
)

s2_train_size = int(0.8 * len(stage2_dataset))
s2_val_size = len(stage2_dataset) - s2_train_size
s2_train_data, s2_val_data = random_split(stage2_dataset, [s2_train_size, s2_val_size])

s2_train_loader = DataLoader(s2_train_data, batch_size=4, shuffle=True, num_workers=0)
s2_val_loader = DataLoader(s2_val_data, batch_size=4, shuffle=False, num_workers=0)

print(f"  Train: {len(s2_train_data)} | Val: {len(s2_val_data)}")

print("\n‚úì All dataloaders created")

In [None]:
# Initialize model and trainer
print("\n" + "="*70)
print("MODEL INITIALIZATION & CHECKPOINT MANAGEMENT")
print("="*70)

model_config = {
    'in_channels': 1,
    'base_channels': 32,
    'depth': 4,
    'use_batch_norm': True
}

model = FrequencyDomainUNet(**model_config).to(device)
print(f"\n‚úì Model created: {sum(p.numel() for p in model.parameters()):,} parameters")

# Checkpoint management
def train_stage(stage_num, train_loader, val_loader, num_epochs=20):
    """Train a curriculum stage with checkpoint management"""
    
    checkpoint_path = checkpoints_dir / f"stage{stage_num}_modelA.pt"
    
    print(f"\n{'='*70}")
    print(f"STAGE {stage_num}: CURRICULUM LEARNING")
    print(f"{'='*70}\n")
    
    # Check if checkpoint exists
    if checkpoint_path.exists():
        print(f"‚úì Checkpoint weights loaded: {checkpoint_path.name}")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"  Epoch: {checkpoint.get('epoch', '?')} | Val Loss: {checkpoint.get('val_loss', '?'):.6f}")
        return
    
    # Initialize trainer
    trainer = ModelATrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        learning_rate=1e-3,
        device=device
    )
    
    scheduler = torch.optim.lr_scheduler.StepLR(
        trainer.optimizer,
        step_size=5,
        gamma=0.5
    )
    
    # Train
    print(f"Starting training... (no checkpoint found)")
    history = trainer.train(num_epochs=num_epochs, save_dir=str(checkpoints_dir))
    
    # Save checkpoint
    best_epoch = np.argmin(history['val_loss']) + 1
    torch.save({
        'epoch': best_epoch,
        'model_state_dict': model.state_dict(),
        'val_loss': float(np.min(history['val_loss'])),
        'train_loss': [float(x) for x in history['train_loss']],
        'val_loss_history': [float(x) for x in history['val_loss']]
    }, checkpoint_path)
    
    print(f"\n‚úì Checkpoint saved: {checkpoint_path.name}")
    print(f"  Best epoch: {best_epoch} | Val Loss: {np.min(history['val_loss']):.6f}")

In [None]:
# Train Stage 1: Vocals Extraction
train_stage(
    stage_num=1,
    train_loader=s1_train_loader,
    val_loader=s1_val_loader,
    num_epochs=20
)

In [None]:
# Evaluate Stage 1 Performance
print("\n" + "="*70)
print("STAGE 1 EVALUATION")
print("="*70)

# Load Stage 1 checkpoint
stage1_checkpoint = checkpoints_dir / 'stage1_modelA.pt'
if stage1_checkpoint.exists():
    checkpoint = torch.load(stage1_checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n‚úì Loaded Stage 1 checkpoint")
    print(f"  Training epoch: {checkpoint['epoch']}")
    print(f"  Validation loss: {checkpoint['val_loss']:.6f}")
    
    # Evaluate on validation set
    model.eval()
    val_loss_total = 0
    num_batches = 0
    
    print("\nEvaluating on validation set...")
    with torch.no_grad():
        for batch_data in s1_val_loader:
            # Extract mixture magnitude and target magnitude from batch
            mixture = batch_data['mixture_mag'].to(device)
            target = batch_data['target_mag'].to(device)
            
            # Forward pass
            output = model(mixture)
            
            # Compute loss
            loss_fn = torch.nn.L1Loss()
            loss = loss_fn(output, target)
            val_loss_total += loss.item()
            num_batches += 1
    
    avg_val_loss = val_loss_total / num_batches
    print(f"\n‚úì Average validation loss: {avg_val_loss:.6f}")
    
    # Test on a sample
    print("\nTesting on sample audio...")
    test_idx = 0
    test_mix = np.load(s1_mix[test_idx])
    test_tgt = np.load(s1_tgt[test_idx])
    
    # Create inference engine
    inference_engine = ModelAInference(
        model=model,
        stft_processor=stft_processor,
        device=device
    )
    
    # Separate
    separated = inference_engine.separate(test_mix)
    
    # Compute metrics
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    mse = mean_squared_error(test_tgt, separated)
    mae = mean_absolute_error(test_tgt, separated)
    
    print(f"  MSE: {mse:.6f}")
    print(f"  MAE: {mae:.6f}")
    
    # Audio playback
    print("\nüìä Listen to Stage 1 results:")
    sr = 22050
    
    def norm_audio(x):
        return x / (np.max(np.abs(x)) + 1e-8) * 0.95
    
    print("\n1. Input (Vocals + Other):")
    display(Audio(norm_audio(test_mix), rate=sr))
    
    print("\n2. Target (Other/Piano):")
    display(Audio(norm_audio(test_tgt), rate=sr))
    
    print("\n3. Separated (Stage 1 Output):")
    display(Audio(norm_audio(separated), rate=sr))
    
    print("\n" + "="*70)
    print("Stage 1 evaluation complete. Ready for Stage 2 training.")
    print("="*70)
    
else:
    print("\n‚ö†Ô∏è Stage 1 checkpoint not found. Please run Stage 1 training first.")

In [None]:
# Train Stage 2: Other/Piano Extraction
train_stage(
    stage_num=2,
    train_loader=s2_train_loader,
    val_loader=s2_val_loader,
    num_epochs=20
)

In [None]:
# Load best model (Stage 2)
print("\n" + "="*70)
print("LOADING TRAINED MODEL FOR INFERENCE")
print("="*70)

best_checkpoint = checkpoints_dir / 'stage2_modelA.pt'
if best_checkpoint.exists():
    checkpoint = torch.load(best_checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n‚úì Loaded: {best_checkpoint.name}")
    print(f"  Best epoch: {checkpoint['epoch']} | Val Loss: {checkpoint['val_loss']:.6f}")
else:
    print("‚úì Using current model (no checkpoint)")

inference_engine = ModelAInference(
    model=model,
    stft_processor=stft_processor,
    device=device
)

print("‚úì Inference engine ready")

In [None]:
# Test on uploaded song & database samples - with STFT visualization
%matplotlib inline

import numpy as np
import librosa
from IPython.display import Audio, display
from pathlib import Path
import glob
import gc

# Clear matplotlib cache BEFORE importing
import os
import shutil
cache_dir = os.path.expanduser('~/.matplotlib')
if os.path.exists(cache_dir):
    try:
        shutil.rmtree(cache_dir)
        print("‚úì Cleared matplotlib cache")
    except:
        pass

# NOW import matplotlib
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Force matplotlib to rebuild font cache
try:
    matplotlib.font_manager._rebuild()
    print("‚úì Rebuilt font cache")
except:
    pass

# Use simple, safe backend and minimal text
matplotlib.use('agg')
matplotlib.rcParams.update({
    'font.size': 9,
    'font.family': 'sans-serif',
    'figure.dpi': 80,
    'savefig.dpi': 80,
    'text.usetex': False,
    'axes.unicode_minus': False
})

print("\n" + "="*70)
print("TESTING ON UPLOADED SONG & DATABASE SAMPLES WITH STFT VISUALIZATION")
print("="*70)

def visualize_stft_masking(mixture, separated, stft_processor, sr, title_prefix=""):
    """Visualize STFT magnitude before and after masking with titles"""
    
    try:
        # Compute STFT
        mix_mag, mix_phase = stft_processor.waveform_to_magnitude_phase(mixture)
        sep_mag, sep_phase = stft_processor.waveform_to_magnitude_phase(separated)
        
        # Create simple figure WITH titles
        fig, axes = plt.subplots(1, 3, figsize=(15, 4), dpi=80)
        
        # 1. Mixture
        mix_db = 20 * np.log10(mix_mag + 1e-8)
        im1 = axes[0].imshow(mix_db, aspect='auto', origin='lower', cmap='viridis')
        axes[0].set_title(f'{title_prefix}Mixture', fontsize=10, pad=5)
        axes[0].axis('off')
        
        # 2. Separated
        sep_db = 20 * np.log10(sep_mag + 1e-8)
        im2 = axes[1].imshow(sep_db, aspect='auto', origin='lower', cmap='viridis')
        axes[1].set_title(f'{title_prefix}Separated', fontsize=10, pad=5)
        axes[1].axis('off')
        
        # 3. Mask (purple-orange plasma)
        mask = sep_mag / (mix_mag + 1e-8)
        mask = np.clip(mask, 0, 1)
        im3 = axes[2].imshow(mask, aspect='auto', origin='lower', cmap='plasma', vmin=0, vmax=1)
        axes[2].set_title(f'{title_prefix}Mask', fontsize=10, pad=5)
        axes[2].axis('off')
        
        plt.subplots_adjust(left=0.02, right=0.98, top=0.88, bottom=0.02, wspace=0.05)
        
        # Display directly
        from IPython.display import display as ipy_display
        ipy_display(fig)
        plt.close(fig)
        
        print(f"‚úì Spectrograms: Mixture (blue-green) | Separated (blue-green) | Mask (purple-orange)")
        
        return mix_mag, sep_mag, mask
    except Exception as e:
        print(f"‚ö†Ô∏è Visualization failed: {e}")
        plt.close('all')
        return None, None, None

def norm_audio(x):
    """Normalize audio for playback"""
    return x / (np.max(np.abs(x)) + 1e-8) * 0.95

def process_long_audio(audio_path, inference_engine, max_chunk_duration=30, sr=22050):
    """Process long audio files in chunks to avoid memory issues"""
    print(f"\n‚ö†Ô∏è Long audio detected. Processing in chunks ({max_chunk_duration}s each)...")
    
    duration = librosa.get_duration(filename=str(audio_path))
    print(f"Total duration: {duration:.1f}s")
    
    separated_chunks = []
    num_chunks = int(np.ceil(duration / max_chunk_duration))
    
    for chunk_idx in range(num_chunks):
        offset = chunk_idx * max_chunk_duration
        y_chunk, _ = librosa.load(str(audio_path), sr=sr, mono=True, offset=offset, duration=max_chunk_duration)
        y_chunk = y_chunk / (np.max(np.abs(y_chunk)) + 1e-8)
        
        print(f"  Processing chunk {chunk_idx + 1}/{num_chunks} ({offset:.0f}s - {offset + max_chunk_duration:.0f}s)...")
        separated_chunk = inference_engine.separate(y_chunk)
        separated_chunks.append(separated_chunk)
        
        del y_chunk
        gc.collect()
    
    separated = np.concatenate(separated_chunks)
    print("‚úì Chunked processing complete")
    return separated

# ============================================================================
# Test 1: Uploaded Song
# ============================================================================
print("\n" + "-"*70)
print("TEST 1: UPLOADED SONG")
print("-"*70)

audio_files = []
search_dirs = [project_root / "data", project_root, Path(".")]
for search_dir in search_dirs:
    if search_dir.exists():
        for ext in ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']:
            audio_files.extend(glob.glob(str(search_dir / '**' / ext), recursive=True))

if audio_files:
    test_audio_path = audio_files[0]
    print(f"\n‚úì Found audio: {Path(test_audio_path).name}")
    
    file_size_mb = Path(test_audio_path).stat().st_size / (1024 * 1024)
    duration = librosa.get_duration(filename=str(test_audio_path))
    print(f"File size: {file_size_mb:.1f}MB | Duration: {duration:.1f}s")
    
    if file_size_mb > 30 or duration > 120:
        print("‚Üí Using chunked processing (memory-efficient)")
        separated = process_long_audio(test_audio_path, inference_engine, max_chunk_duration=30, sr=22050)
        y, sr = librosa.load(test_audio_path, sr=22050, mono=True, duration=30)
        test_segment = y
    else:
        y, sr = librosa.load(test_audio_path, sr=22050, mono=True)
        test_segment = y
        print("‚Üí Processing full audio")
        test_segment = test_segment / (np.max(np.abs(test_segment)) + 1e-8)
        
        print("\nRunning source separation...")
        separated = inference_engine.separate(test_segment)
    
    duration_sec = len(test_segment) / sr
    print(f"‚úì Processing song ({duration_sec:.1f}s)")
    
    mix_norm = norm_audio(test_segment)
    sep_norm = norm_audio(separated[:len(test_segment)])
    
    print("\nüìä STFT Visualization - Uploaded Song (First 30s):")
    visualize_stft_masking(test_segment, separated[:len(test_segment)], stft_processor, sr, "Song - ")
    
    print("\nüìä ORIGINAL MIXTURE (First 30s):")
    display(Audio(mix_norm, rate=sr))
    
    print("\n‚ú® SEPARATED SOURCE (First 30s):")
    display(Audio(sep_norm, rate=sr))
    
    del y, test_segment, separated
    gc.collect()
    
else:
    print("\n‚ö†Ô∏è No audio files found in data/ or current directory")

# ============================================================================
# Test 2: Database Sample (from curriculum cache)
# ============================================================================
print("\n\n" + "-"*70)
print("TEST 2: DATABASE SAMPLE (FROM CURRICULUM CACHE)")
print("-"*70)

if 's2_mix' in locals() and s2_mix and s2_tgt:
    sample_idx = np.random.randint(0, min(10, len(s2_mix)))
    
    db_mixture_path = s2_mix[sample_idx]
    db_target_path = s2_tgt[sample_idx]
    
    print(f"\n‚úì Selected sample: {Path(db_mixture_path).name}")
    
    db_mixture = np.load(db_mixture_path).astype(np.float32)
    db_target = np.load(db_target_path).astype(np.float32)
    
    print(f"‚úì Sample duration: {len(db_mixture) / 22050:.2f}s")
    
    db_mixture = db_mixture / (np.max(np.abs(db_mixture)) + 1e-8)
    db_target = db_target / (np.max(np.abs(db_target)) + 1e-8)
    
    print("\nRunning source separation...")
    db_separated = inference_engine.separate(db_mixture)
    
    mix_norm_db = norm_audio(db_mixture)
    tgt_norm_db = norm_audio(db_target)
    sep_norm_db = norm_audio(db_separated)
    
    print("\nüìä STFT Visualization - Database Sample:")
    visualize_stft_masking(db_mixture, db_separated, stft_processor, sr=22050, title_prefix="DB - ")
    
    from sklearn.metrics import mean_squared_error, mean_absolute_error
    db_mse = mean_squared_error(db_target, db_separated)
    db_mae = mean_absolute_error(db_target, db_separated)
    
    print(f"\nüìà Performance Metrics:")
    print(f"  MSE: {db_mse:.6f}")
    print(f"  MAE: {db_mae:.6f}")
    
    print("\nüìä INPUT MIXTURE (Full 4-source):")
    display(Audio(mix_norm_db, rate=22050))
    
    print("\n‚úì GROUND TRUTH TARGET (Other/Piano):")
    display(Audio(tgt_norm_db, rate=22050))
    
    print("\n‚ú® MODEL OUTPUT (Separated):")
    display(Audio(sep_norm_db, rate=22050))
    
else:
    print("\n‚ö†Ô∏è No database samples available. Please run curriculum data preparation first.")