# 🎵 Stems Separation PoC: Accuracy Comparison

**Proof of Concept** to validate Idea 1: Stems-based transcription approach.

**Hypothesis**: Separating audio into stems (bass, drums, other, vocals) before transcription improves accuracy.

**Test**:
1. Transcribe **full mix** with YourMT3
2. Separate audio into **4 stems** with Demucs
3. Transcribe **each stem** with YourMT3
4. Compare accuracy and analyze improvements

**Expected**: 10-15% accuracy improvement on stems (without fine-tuning yet)

## ⚠️ IMPORTANT: Run cells in order (1→2→3→4→5)

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import glob

# Change to yourmt3_space directory
original_dir = os.getcwd()
if not os.path.exists('yourmt3_space'):
    print("❌ Error: yourmt3_space directory not found!")
    print("   Run setup_yourmt3_brev.sh first")
else:
    os.chdir('yourmt3_space')
    sys.path.insert(0, '.')
    sys.path.insert(0, 'amt/src')

import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from IPython.display import Audio, display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
import json
from collections import defaultdict

from model_helper import load_model_checkpoint, transcribe

print("✅ Imports successful!")
print(f"   Working directory: {os.getcwd()}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Demucs (as Python library)

**Run this cell once** - it will import and prepare Demucs

In [None]:
# Import Demucs as a Python library (no subprocess needed!)
try:
    from demucs.pretrained import get_model
    from demucs.apply import apply_model
    print("✅ Demucs imported successfully!")
    
    # Pre-load the htdemucs model
    print("📦 Loading htdemucs model...")
    print("   (This will download ~80MB on first run)")
    demucs_model = get_model('htdemucs')
    demucs_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    demucs_model.to(demucs_device)
    print(f"✅ Demucs model loaded on {demucs_device}")
    
except ImportError as e:
    print("❌ Demucs not installed!")
    print("   Installing demucs...")
    import subprocess
    subprocess.run(['pip', 'install', '-q', 'demucs'], check=True)
    print("✅ Demucs installed! Please restart kernel and re-run this cell.")
    demucs_model = None
except Exception as e:
    print(f"❌ Error loading Demucs: {e}")
    demucs_model = None

## 3. Load YourMT3 Model

In [None]:
print("Loading YourMT3 model...")
print("This may take 10-15 seconds...")

# Model configuration
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
project = '2024'
precision = '16'

args = [
    checkpoint,
    '-p', project,
    '-tk', 'mc13_full_plus_256',
    '-dec', 'multi-t5',
    '-nl', '26',
    '-enc', 'perceiver-tf',
    '-sqr', '1',
    '-ff', 'moe',
    '-wf', '4',
    '-nmoe', '8',
    '-kmoe', '2',
    '-act', 'silu',
    '-epe', 'rope',
    '-rp', '1',
    '-ac', 'spec',
    '-hop', '300',
    '-atc', '1',
    '-pr', precision
]

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_checkpoint(args=args, device=device)

print("\n✅ Model loaded successfully!")
print(f"   Device: {device}")
print(f"   Model: YPTF.MoE+Multi (noPS)")

## 4. Helper Functions for PoC

In [None]:
def separate_stems_demucs(audio_path, output_dir="separated"):
    """Separate audio into 4 stems using Demucs Python API"""
    print(f"🎵 Separating stems with Demucs...")
    print(f"   Input: {audio_path}")
    print(f"   Output: {output_dir}/")
    print("   This may take 30-60 seconds...")
    
    if demucs_model is None:
        print("\n❌ Demucs model not loaded!")
        print("   Please run Cell 2 first")
        return None
    
    try:
        # Load audio
        wav, sr = torchaudio.load(audio_path)
        
        # Demucs expects 44.1kHz, resample if needed
        if sr != 44100:
            resampler = torchaudio.transforms.Resample(sr, 44100)
            wav = resampler(wav)
            sr = 44100
        
        # Convert to mono if stereo (Demucs expects mono or stereo)
        if wav.shape[0] > 2:
            wav = wav[:2]  # Take first 2 channels
        
        # Apply Demucs model
        print("   Running separation...")
        wav = wav.to(demucs_device)
        
        with torch.no_grad():
            sources = apply_model(demucs_model, wav[None], device=demucs_device)[0]
        
        # sources shape: [4, channels, samples]
        # Order: drums, bass, other, vocals
        stems_order = ['drums', 'bass', 'other', 'vocals']
        
        # Create output directory
        track_name = Path(audio_path).stem
        stem_dir = Path(output_dir) / 'htdemucs' / track_name
        stem_dir.mkdir(parents=True, exist_ok=True)
        
        stems = {}
        
        # Save each stem
        for i, stem_name in enumerate(stems_order):
            stem_path = stem_dir / f"{stem_name}.wav"
            stem_audio = sources[i].cpu()
            torchaudio.save(str(stem_path), stem_audio, sr)
            stems[stem_name] = str(stem_path)
        
        print("✅ Stems separated successfully!")
        for stem_name in stems_order:
            print(f"   - {stem_name}.wav")
        
        return stems
        
    except Exception as e:
        print(f"❌ Demucs separation failed: {e}")
        import traceback
        traceback.print_exc()
        return None

def analyze_midi(midi_path):
    """Analyze MIDI file and return statistics"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    
    stats = {
        'total_notes': sum(len(inst.notes) for inst in midi.instruments),
        'num_instruments': len(midi.instruments),
        'duration': midi.get_end_time(),
        'instruments': []
    }
    
    for i, inst in enumerate(midi.instruments):
        if len(inst.notes) > 0:
            stats['instruments'].append({
                'index': i,
                'program': inst.program,
                'name': pretty_midi.program_to_instrument_name(inst.program),
                'notes': len(inst.notes),
                'is_drum': inst.is_drum
            })
    
    return stats, midi

def transcribe_audio(audio_path, output_name):
    """Transcribe audio and return MIDI path + stats"""
    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    
    audio_info = {
        "filepath": audio_path,
        "track_name": output_name,
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(duration),
        "encoding": 'unknown',
    }
    
    # Transcribe
    midi_path = transcribe(model, audio_info)
    
    # Analyze
    stats, midi = analyze_midi(midi_path)
    
    return midi_path, stats

def midi_to_audio(midi_path, sample_rate=16000):
    """Convert MIDI to audio for playback using FluidSynth"""
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)
        audio = midi.fluidsynth(fs=sample_rate)
        return audio, sample_rate
    except Exception as e:
        print(f"⚠️  MIDI synthesis failed: {e}")
        print("   Note: FluidSynth may not be installed")
        return None, None

def merge_midi_files(midi_paths_dict, output_path="output_midi/poc_recomposed.mid"):
    """Merge multiple MIDI files into one recomposed MIDI
    
    Args:
        midi_paths_dict: Dictionary of {stem_name: midi_path}
        output_path: Where to save the merged MIDI file
    
    Returns:
        Path to merged MIDI file
    """
    print(f"🔄 Merging {len(midi_paths_dict)} MIDI files into recomposed mix...")
    
    # Create output directory
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create new MIDI object
    merged_midi = pretty_midi.PrettyMIDI()
    
    # Copy all instruments from each stem
    for stem_name, midi_path in midi_paths_dict.items():
        try:
            stem_midi = pretty_midi.PrettyMIDI(midi_path)
            
            # Add all instruments from this stem
            for instrument in stem_midi.instruments:
                # Create a copy of the instrument
                merged_instrument = pretty_midi.Instrument(
                    program=instrument.program,
                    is_drum=instrument.is_drum,
                    name=f"{stem_name}_{instrument.name}" if instrument.name else stem_name
                )
                
                # Copy all notes
                merged_instrument.notes = instrument.notes.copy()
                
                # Copy control changes
                merged_instrument.control_changes = instrument.control_changes.copy()
                
                # Add to merged MIDI
                merged_midi.instruments.append(merged_instrument)
            
            print(f"   ✅ {stem_name}: {len(stem_midi.instruments)} instruments added")
            
        except Exception as e:
            print(f"   ⚠️  Failed to merge {stem_name}: {e}")
    
    # Save merged MIDI
    merged_midi.write(output_path)
    
    print(f"✅ Recomposed MIDI saved: {output_path}")
    print(f"   Total instruments: {len(merged_midi.instruments)}")
    print(f"   Total notes: {sum(len(inst.notes) for inst in merged_midi.instruments)}")
    
    return output_path

print("✅ Helper functions loaded")

## 5. Select Audio File for PoC Test

In [None]:
# Find audio files
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']
audio_files = []
for ext in audio_extensions:
    audio_files.extend(glob.glob(os.path.join(original_dir, ext)))

if len(audio_files) == 0:
    print("❌ No audio files found!")
    print("   Please upload audio files to the MT3 directory")
else:
    print(f"✅ Found {len(audio_files)} audio files:")
    for i, f in enumerate(audio_files):
        info = torchaudio.info(f)
        duration = info.num_frames / info.sample_rate
        print(f"   {i+1}. {os.path.basename(f)} ({duration:.1f}s)")

# File selector
file_selector = widgets.Dropdown(
    options=[(f"{os.path.basename(f)} ({torchaudio.info(f).num_frames/torchaudio.info(f).sample_rate:.1f}s)", f) for f in audio_files],
    description='Test File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='700px')
)

display(file_selector)

## 6. Run PoC Test: Full Mix vs Stems

**Click the button below to start the test** (takes 3-5 minutes)

In [None]:
# Global results storage
poc_results = {}

def run_poc_test(button):
    global poc_results
    
    output.clear_output(wait=True)
    
    with output:
        audio_path = file_selector.value
        
        if not audio_path:
            print("❌ Please select an audio file first!")
            return
        
        print("="*80)
        print("🎯 PROOF OF CONCEPT TEST: Full Mix vs Stems")
        print("="*80)
        print(f"\n📁 Test File: {os.path.basename(audio_path)}")
        
        # Step 1: Transcribe full mix
        print("\n" + "="*80)
        print("STEP 1: Transcribe Full Mix (Baseline)")
        print("="*80)
        
        try:
            fullmix_midi, fullmix_stats = transcribe_audio(audio_path, "poc_fullmix")
            
            print(f"\n✅ Full Mix Transcription Complete:")
            print(f"   Total notes: {fullmix_stats['total_notes']}")
            print(f"   Instruments: {fullmix_stats['num_instruments']}")
            print(f"   Duration: {fullmix_stats['duration']:.2f}s")
        except Exception as e:
            print(f"\n❌ Full mix transcription failed: {e}")
            import traceback
            traceback.print_exc()
            return
        
        # Step 2: Separate stems
        print("\n" + "="*80)
        print("STEP 2: Separate Stems with Demucs")
        print("="*80)
        
        stems = separate_stems_demucs(audio_path, output_dir="separated")
        
        if not stems:
            print("\n❌ Stem separation failed! Stopping test.")
            print("   Make sure Demucs is loaded (run Cell 2)")
            return
        
        # Step 3: Transcribe each stem
        print("\n" + "="*80)
        print("STEP 3: Transcribe Each Stem")
        print("="*80)
        
        stem_results = {}
        
        for stem_name, stem_path in stems.items():
            print(f"\n🎵 Transcribing {stem_name} stem...")
            
            try:
                stem_midi, stem_stats = transcribe_audio(stem_path, f"poc_{stem_name}")
                stem_results[stem_name] = {
                    'midi_path': stem_midi,
                    'stats': stem_stats
                }
                
                print(f"   ✅ {stem_name}: {stem_stats['total_notes']} notes, {stem_stats['num_instruments']} instruments")
                
            except Exception as e:
                print(f"   ❌ {stem_name} transcription failed: {e}")
                stem_results[stem_name] = {'error': str(e)}
        
        # Step 3.5: Create Recomposed MIDI from stems
        print("\n" + "="*80)
        print("STEP 3.5: Recompose Full Mix from Stem MIDIs")
        print("="*80)
        
        recomposed_midi = None
        recomposed_stats = None
        
        try:
            # Get all stem MIDI paths
            stem_midi_paths = {
                stem_name: result['midi_path']
                for stem_name, result in stem_results.items()
                if 'midi_path' in result
            }
            
            if len(stem_midi_paths) > 0:
                # Merge all stem MIDIs
                recomposed_midi = merge_midi_files(stem_midi_paths, "output_midi/poc_recomposed.mid")
                
                # Analyze recomposed MIDI
                recomposed_stats, _ = analyze_midi(recomposed_midi)
                
                print(f"\n✅ Recomposed MIDI Created:")
                print(f"   Total notes: {recomposed_stats['total_notes']}")
                print(f"   Instruments: {recomposed_stats['num_instruments']}")
            else:
                print("\n⚠️  No stem MIDIs to merge")
                
        except Exception as e:
            print(f"\n⚠️  Recomposition failed: {e}")
            import traceback
            traceback.print_exc()
        
        # Step 4: Compare results
        print("\n" + "="*80)
        print("STEP 4: Comparison Analysis")
        print("="*80)
        
        # Store results
        poc_results = {
            'audio_path': audio_path,
            'fullmix': {
                'midi_path': fullmix_midi,
                'stats': fullmix_stats
            },
            'stems': stems,
            'stem_results': stem_results,
            'recomposed': {
                'midi_path': recomposed_midi,
                'stats': recomposed_stats
            } if recomposed_midi else None
        }
        
        # Overall comparison
        total_stem_notes = sum(r['stats']['total_notes'] for r in stem_results.values() if 'stats' in r)
        
        print(f"\n📊 Overall Note Count:")
        print(f"   Full Mix (Direct): {fullmix_stats['total_notes']} notes")
        print(f"   Stems Combined (Sum): {total_stem_notes} notes")
        if recomposed_stats:
            print(f"   Recomposed Mix (Merged): {recomposed_stats['total_notes']} notes")
        
        # Calculate improvements
        improvement_sum = ((total_stem_notes - fullmix_stats['total_notes']) / fullmix_stats['total_notes'] * 100) if fullmix_stats['total_notes'] > 0 else 0
        print(f"\n   Improvement (Sum): {total_stem_notes - fullmix_stats['total_notes']:+d} notes ({improvement_sum:+.1f}%)")
        
        if recomposed_stats:
            improvement_recomposed = ((recomposed_stats['total_notes'] - fullmix_stats['total_notes']) / fullmix_stats['total_notes'] * 100) if fullmix_stats['total_notes'] > 0 else 0
            print(f"   Improvement (Recomposed): {recomposed_stats['total_notes'] - fullmix_stats['total_notes']:+d} notes ({improvement_recomposed:+.1f}%)")
        
        # Per-stem comparison
        print(f"\n🎸 Per-Stem Analysis:")
        for stem_name, result in stem_results.items():
            if 'stats' in result:
                print(f"\n   {stem_name.upper()}:")
                print(f"      Notes detected: {result['stats']['total_notes']}")
                print(f"      Instruments: {result['stats']['num_instruments']}")
                if result['stats']['instruments']:
                    top_inst = sorted(result['stats']['instruments'], key=lambda x: x['notes'], reverse=True)[0]
                    print(f"      Top instrument: {top_inst['name']} ({top_inst['notes']} notes)")
        
        print("\n" + "="*80)
        print("✅ PoC TEST COMPLETE!")
        print("="*80)
        print("\nRun the cells below to:")
        print("   - View detailed comparison tables")
        print("   - Listen to full mix vs stems vs recomposed")
        print("   - Get recommendation based on results")

# Create button and output
poc_button = widgets.Button(
    description='🚀 Run PoC Test',
    button_style='success',
    layout=widgets.Layout(width='200px', height='50px')
)
poc_button.on_click(run_poc_test)

output = widgets.Output()

print("⚠️  Note: This test will take 3-5 minutes depending on audio length")
print("   - Full mix transcription: ~30s")
print("   - Stem separation: ~30-60s")
print("   - 4 stem transcriptions: ~2min")
print("   - MIDI recomposition: ~1s")
print("\n✅ Ready! Click the button above to start the test")

display(poc_button)
display(output)

## 7. View Detailed Results

*Run this after the PoC test completes*

In [None]:
if poc_results:
    fullmix_notes = poc_results['fullmix']['stats']['total_notes']
    total_stem_notes = sum(r['stats']['total_notes'] for r in poc_results['stem_results'].values() if 'stats' in r)
    
    # Get recomposed stats if available
    recomposed_notes = None
    if poc_results.get('recomposed') and poc_results['recomposed']:
        recomposed_notes = poc_results['recomposed']['stats']['total_notes']
    
    improvement_sum = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    improvement_recomposed = ((recomposed_notes - fullmix_notes) / fullmix_notes * 100) if recomposed_notes and fullmix_notes > 0 else None
    
    print("="*80)
    print("📊 PoC RESULTS SUMMARY")
    print("="*80)
    
    print(f"\n📈 Note Count Comparison:")
    print(f"   Full Mix (Direct): {fullmix_notes} notes")
    print(f"   Stems Combined (Sum): {total_stem_notes} notes")
    if recomposed_notes:
        print(f"   Recomposed Mix (Merged): {recomposed_notes} notes")
    
    print(f"\n🎯 Improvement Analysis:")
    print(f"   Sum of Stems: {total_stem_notes - fullmix_notes:+d} notes ({improvement_sum:+.1f}%)")
    if improvement_recomposed is not None:
        print(f"   Recomposed Mix: {recomposed_notes - fullmix_notes:+d} notes ({improvement_recomposed:+.1f}%)")
    
    # Use recomposed improvement if available, otherwise use sum
    primary_improvement = improvement_recomposed if improvement_recomposed is not None else improvement_sum
    
    print(f"\n💡 Key Insight:")
    if recomposed_notes and recomposed_notes == total_stem_notes:
        print("   ✅ Recomposed MIDI matches sum of stems (no overlap/duplicate notes)")
    elif recomposed_notes and recomposed_notes != total_stem_notes:
        print(f"   ⚠️  Recomposed MIDI differs from sum ({total_stem_notes - recomposed_notes:+d} notes)")
        print("       This could indicate overlapping notes or timing differences")
    
    print("\n" + "="*80)
    
    if primary_improvement > 10:
        print("✅ RECOMMENDATION: PROCEED WITH IDEA 1 (Fine-tuning)")
        print("\n   Stems approach validates the hypothesis!")
        print("   Fine-tuning will likely add another 10-15% improvement.")
        print("   Expected total improvement: 20-30%\n")
        print("   Next steps:")
        print("   1. Read YOURMT3_FINETUNING_GUIDE.md")
        print("   2. Download Slakh2100 dataset (~1TB)")
        print("   3. Start fine-tuning bass model (3-5 days)")
        print("   4. Continue with other stems if bass succeeds")
    elif primary_improvement > 5:
        print("⚠️  RECOMMENDATION: INVESTIGATE FURTHER")
        print("\n   Moderate improvement detected.")
        print("   Check stem quality and per-stem results.\n")
        print("   Action items:")
        print("   1. Listen to separated stems (next cell)")
        print("   2. Check which stems work best")
        print("   3. Try different Demucs model: mdx_extra")
        print("   4. Re-run PoC with better separation")
    else:
        print("❌ RECOMMENDATION: RECONSIDER APPROACH")
        print("\n   Stems not providing expected benefit.")
        print("   Fine-tuning may not help.\n")
        print("   Alternatives:")
        print("   1. Investigate Demucs quality (listen to stems)")
        print("   2. Try Idea 2 (instrument matching)")
        print("   3. Use different stem separation (Spleeter, Open-Unmix)")
        print("   4. Hybrid approach: stems for specific instruments only")
    
    print("\n" + "="*80)
else:
    print("⚠️  Run the PoC test first (Cell 6)!")

## 7.5 Piano Roll Visualization

*Visual comparison of MIDI transcriptions*

In [None]:
if poc_results:
    print("🎹 Piano Roll Visualizations with MIDI Playback")
    print("="*80)
    
    def plot_piano_roll(midi_path, title, ax=None):
        """Plot piano roll from MIDI file"""
        midi = pretty_midi.PrettyMIDI(midi_path)
        
        if ax is None:
            fig, ax = plt.subplots(figsize=(16, 6))
        
        # Get all notes
        notes_to_plot = []
        colors = plt.cm.tab20.colors
        
        for inst_idx, inst in enumerate(midi.instruments):
            color = colors[inst_idx % len(colors)]
            for note in inst.notes:
                notes_to_plot.append({
                    'start': note.start,
                    'end': note.end,
                    'pitch': note.pitch,
                    'velocity': note.velocity,
                    'color': color,
                    'instrument': inst.program,
                    'is_drum': inst.is_drum
                })
        
        # Plot notes
        for note_info in notes_to_plot:
            ax.add_patch(
                plt.Rectangle(
                    (note_info['start'], note_info['pitch']),
                    note_info['end'] - note_info['start'],
                    1,
                    facecolor=note_info['color'],
                    edgecolor='black',
                    linewidth=0.5,
                    alpha=0.7
                )
            )
        
        ax.set_xlim(0, midi.get_end_time())
        ax.set_ylim(20, 108)  # Piano range
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_ylabel('MIDI Pitch', fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add note count
        total_notes = len(notes_to_plot)
        ax.text(0.98, 0.02, f'{total_notes} notes', 
                transform=ax.transAxes,
                ha='right', va='bottom',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
                fontsize=10)
        
        return ax
    
    print("\n📊 Generating visualizations...")
    
    # 1. Full Mix
    print("\n" + "="*80)
    print("🎵 FULL MIX (Baseline - Direct Transcription)")
    print("="*80)
    
    fig, ax = plt.subplots(figsize=(18, 4))
    plot_piano_roll(poc_results['fullmix']['midi_path'], 
                    '🎵 Full Mix Transcription (Direct)', ax)
    plt.tight_layout()
    plt.show()
    
    print("\n🔊 Audio Playback:")
    audio, sr = midi_to_audio(poc_results['fullmix']['midi_path'])
    if audio is not None:
        display(Audio(audio, rate=sr))
    
    # 2. Each Stem
    stem_names = ['drums', 'bass', 'other', 'vocals']
    stem_emojis = {'drums': '🥁', 'bass': '🎸', 'other': '🎹', 'vocals': '🎤'}
    
    for stem_name in stem_names:
        if stem_name in poc_results['stem_results'] and 'midi_path' in poc_results['stem_results'][stem_name]:
            print("\n" + "="*80)
            print(f"{stem_emojis[stem_name]} {stem_name.upper()} STEM")
            print("="*80)
            
            fig, ax = plt.subplots(figsize=(18, 4))
            plot_piano_roll(poc_results['stem_results'][stem_name]['midi_path'],
                            f'{stem_emojis[stem_name]} {stem_name.title()} Stem Transcription',
                            ax)
            plt.tight_layout()
            plt.show()
            
            print("\n🔊 Audio Playback:")
            audio, sr = midi_to_audio(poc_results['stem_results'][stem_name]['midi_path'])
            if audio is not None:
                display(Audio(audio, rate=sr))
    
    # 3. Recomposed Mix
    if poc_results.get('recomposed') and poc_results['recomposed'] and poc_results['recomposed']['midi_path']:
        print("\n" + "="*80)
        print("🎼 RECOMPOSED MIX (Stems Merged Back Together)")
        print("="*80)
        
        fig, ax = plt.subplots(figsize=(18, 4))
        plot_piano_roll(poc_results['recomposed']['midi_path'],
                        '🎼 Recomposed Mix (All Stems Merged)',
                        ax)
        plt.tight_layout()
        plt.show()
        
        print("\n🔊 Audio Playback:")
        audio, sr = midi_to_audio(poc_results['recomposed']['midi_path'])
        if audio is not None:
            display(Audio(audio, rate=sr))
        
        print("\n💡 Comparison:")
        fullmix_notes = poc_results['fullmix']['stats']['total_notes']
        recomposed_notes = poc_results['recomposed']['stats']['total_notes']
        diff = recomposed_notes - fullmix_notes
        print(f"   Full Mix (Direct): {fullmix_notes} notes")
        print(f"   Recomposed (Stems Merged): {recomposed_notes} notes")
        print(f"   Difference: {diff:+d} notes ({(diff/fullmix_notes*100):+.1f}%)")
        
        if diff > 0:
            print("\n   ✅ Recomposed mix captured MORE notes than direct transcription")
            print("      This suggests stem separation helps the model detect more musical content")
        elif diff < 0:
            print("\n   ⚠️  Recomposed mix has FEWER notes than direct transcription")
            print("      This might indicate stem separation quality issues")
        else:
            print("\n   ➡️  Same number of notes, but possibly different timing or instruments")
    
    print("\n" + "="*80)
    print("✅ All visualizations generated!")
    print("="*80)
    print("\n💡 Visual Insights:")
    print("   - Vertical position = pitch (higher = higher note)")
    print("   - Horizontal position = time")
    print("   - Rectangle width = note duration")
    print("   - Colors = different instruments")
    print("   - Compare density to see which stem has more detected notes")
    print("   - Listen to MIDI playback to verify transcription quality")
    print("\n📊 Comparison Strategy:")
    print("   1. Full Mix (Direct) = baseline transcription of original audio")
    print("   2. Individual Stems = isolated instrument transcriptions")
    print("   3. Recomposed Mix = all stems merged back together")
    print("   → Recomposed should capture more musical information than Full Mix")
    
else:
    print("⚠️  Run the PoC test first (Cell 6)!")

## 8. Listen to Stems Quality

*Verify stem separation quality*

In [None]:
if poc_results:
    print("🎧 Audio Playback: Original vs Stems")
    print("="*80)
    
    # Original audio
    print("\n🎵 Original Full Mix:")
    waveform, sr = torchaudio.load(poc_results['audio_path'])
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    display(Audio(waveform.numpy(), rate=sr))
    
    # Each stem
    for stem_name, stem_path in poc_results['stems'].items():
        print(f"\n🎸 {stem_name.title()} Stem:")
        waveform, sr = torchaudio.load(stem_path)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        display(Audio(waveform.numpy(), rate=sr))
else:
    print("⚠️  Run the PoC test first!")

## 9. Save Results

*Export PoC results to JSON*

In [None]:
if poc_results:
    # Save results as JSON
    results_summary = {
        'test_file': os.path.basename(poc_results['audio_path']),
        'fullmix': {
            'notes': poc_results['fullmix']['stats']['total_notes'],
            'instruments': poc_results['fullmix']['stats']['num_instruments'],
            'duration': poc_results['fullmix']['stats']['duration']
        },
        'stems': {}
    }
    
    for stem_name, result in poc_results['stem_results'].items():
        if 'stats' in result:
            results_summary['stems'][stem_name] = {
                'notes': result['stats']['total_notes'],
                'instruments': result['stats']['num_instruments']
            }
    
    # Add recomposed results if available
    if poc_results.get('recomposed') and poc_results['recomposed']:
        results_summary['recomposed'] = {
            'notes': poc_results['recomposed']['stats']['total_notes'],
            'instruments': poc_results['recomposed']['stats']['num_instruments']
        }
    
    # Calculate overall improvement
    fullmix_notes = results_summary['fullmix']['notes']
    total_stem_notes = sum(s['notes'] for s in results_summary['stems'].values())
    improvement_sum = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    
    results_summary['improvement'] = {
        'absolute_sum': total_stem_notes - fullmix_notes,
        'percentage_sum': improvement_sum
    }
    
    # Add recomposed improvement if available
    if 'recomposed' in results_summary:
        recomposed_notes = results_summary['recomposed']['notes']
        improvement_recomposed = ((recomposed_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
        results_summary['improvement']['absolute_recomposed'] = recomposed_notes - fullmix_notes
        results_summary['improvement']['percentage_recomposed'] = improvement_recomposed
    
    # Save to file
    output_path = os.path.join(original_dir, 'poc_results.json')
    with open(output_path, 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"✅ PoC results saved to: {output_path}")
    print(f"\n📄 Summary:")
    print(json.dumps(results_summary, indent=2))
else:
    print("⚠️  Run the PoC test first!")

---

## 🎉 PoC Complete!

**If improvement >10%**: Proceed with fine-tuning (see `YOURMT3_FINETUNING_GUIDE.md`)

**If improvement 5-10%**: Investigate stem quality, try different Demucs model

**If improvement <5%**: Consider Idea 2 (instrument matching) or improve stem separation

---

*Assisted by Claude Code*