# 🎵 Stems Separation PoC: Accuracy Comparison

**Proof of Concept** to validate Idea 1: Stems-based transcription approach.

**Hypothesis**: Separating audio into stems (bass, drums, other, vocals) before transcription improves accuracy.

**Test**:
1. Transcribe **full mix** with YourMT3
2. Separate audio into **4 stems** with Demucs
3. Transcribe **each stem** with YourMT3
4. Compare accuracy and analyze improvements

**Expected**: 10-15% accuracy improvement on stems (without fine-tuning yet)

## ⚠️ IMPORTANT: Run cells in order (1→2→3→4→5)

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import glob
import subprocess
import shutil

# Change to yourmt3_space directory
original_dir = os.getcwd()
if not os.path.exists('yourmt3_space'):
    print("❌ Error: yourmt3_space directory not found!")
    print("   Run setup_yourmt3_brev.sh first")
else:
    os.chdir('yourmt3_space')
    sys.path.insert(0, '.')
    sys.path.insert(0, 'amt/src')

import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from IPython.display import Audio, display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
import json
from collections import defaultdict

from model_helper import load_model_checkpoint, transcribe

print("✅ Imports successful!")
print(f"   Working directory: {os.getcwd()}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

## 2. Check Demucs Installation

**Run this cell once** - it will verify Demucs is working

In [None]:
# Check if Demucs is installed and working
def ensure_demucs_installed():
    """Check and install Demucs if needed"""
    try:
        # Try using python3 -m demucs (works even if PATH is wrong)
        result = subprocess.run(['python3', '-m', 'demucs', '--help'], capture_output=True, text=True, timeout=5)
        if result.returncode == 0:
            print("✅ Demucs is already installed")
            # Get version
            version_result = subprocess.run(['pip', 'show', 'demucs'], capture_output=True, text=True)
            for line in version_result.stdout.split('\n'):
                if line.startswith('Version:'):
                    print(f"   {line}")
            return True
    except (FileNotFoundError, subprocess.TimeoutExpired, subprocess.CalledProcessError):
        pass
    
    # Not installed, install it
    print("⚠️  Demucs not found! Installing...")
    print("   This will take 2-5 minutes...")
    print("   (Installing torch, torchaudio, demucs)")
    
    try:
        # Install demucs
        result = subprocess.run(
            ['pip', 'install', '-q', 'demucs'],
            capture_output=True,
            text=True,
            timeout=300  # 5 minutes timeout
        )
        
        if result.returncode == 0:
            print("✅ Demucs installed successfully!")
            # Verify using python3 -m demucs
            verify = subprocess.run(['python3', '-m', 'demucs', '--help'], capture_output=True, text=True)
            if verify.returncode == 0:
                print("   ✅ Verified: Demucs is working")
                return True
            else:
                print("   ⚠️  Installation succeeded but demucs module not found")
                print("   Try restarting the kernel: Kernel → Restart")
                return False
        else:
            print(f"❌ Installation failed: {result.stderr}")
            return False
            
    except subprocess.TimeoutExpired:
        print("❌ Installation timed out (>5 minutes)")
        print("   Try manually: !pip install demucs")
        return False
    except Exception as e:
        print(f"❌ Installation error: {e}")
        print("   Try manually: !pip install demucs")
        return False

# Run installation check
demucs_ready = ensure_demucs_installed()

if not demucs_ready:
    print("\n⚠️  Manual installation required:")
    print("   Run this in a new cell: !pip install demucs")
    print("   Then restart kernel and re-run this notebook")

## 3. Load YourMT3 Model

In [None]:
print("Loading YourMT3 model...")
print("This may take 10-15 seconds...")

# Model configuration
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
project = '2024'
precision = '16'

args = [
    checkpoint,
    '-p', project,
    '-tk', 'mc13_full_plus_256',
    '-dec', 'multi-t5',
    '-nl', '26',
    '-enc', 'perceiver-tf',
    '-sqr', '1',
    '-ff', 'moe',
    '-wf', '4',
    '-nmoe', '8',
    '-kmoe', '2',
    '-act', 'silu',
    '-epe', 'rope',
    '-rp', '1',
    '-ac', 'spec',
    '-hop', '300',
    '-atc', '1',
    '-pr', precision
]

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_checkpoint(args=args, device=device)

print("\n✅ Model loaded successfully!")
print(f"   Device: {device}")
print(f"   Model: YPTF.MoE+Multi (noPS)")

## 4. Helper Functions for PoC

In [None]:
def separate_stems_demucs(audio_path, output_dir="separated"):
    """Separate audio into 4 stems using Demucs"""
    print(f"🎵 Separating stems with Demucs...")
    print(f"   Input: {audio_path}")
    print(f"   Output: {output_dir}/")
    print("   This may take 30-60 seconds...")
    
    # Double-check Demucs is available (using python3 -m demucs)
    try:
        subprocess.run(['python3', '-m', 'demucs', '--help'], capture_output=True, check=True, timeout=5)
    except (FileNotFoundError, subprocess.CalledProcessError):
        print("\n❌ Demucs not found!")
        print("   Please run Cell 2 to install Demucs first")
        print("   Or install manually: !pip install demucs")
        return None
    
    # Run Demucs using python3 -m demucs (works even if PATH is wrong)
    cmd = [
        'python3', '-m', 'demucs',
        '-n', 'htdemucs',      # Use htdemucs model (4-stem: bass, drums, other, vocals)
        '--out', output_dir,
        audio_path
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        
        # Find output stems
        track_name = Path(audio_path).stem
        stem_dir = Path(output_dir) / 'htdemucs' / track_name
        
        stems = {
            'bass': str(stem_dir / 'bass.wav'),
            'drums': str(stem_dir / 'drums.wav'),
            'other': str(stem_dir / 'other.wav'),
            'vocals': str(stem_dir / 'vocals.wav')
        }
        
        # Verify all stems exist
        for stem_name, stem_path in stems.items():
            if not Path(stem_path).exists():
                raise FileNotFoundError(f"Stem not found: {stem_path}")
        
        print("✅ Stems separated successfully!")
        for stem_name in stems.keys():
            print(f"   - {stem_name}.wav")
        
        return stems
        
    except subprocess.CalledProcessError as e:
        print(f"❌ Demucs separation failed: {e}")
        print(f"   stderr: {e.stderr}")
        return None

def analyze_midi(midi_path):
    """Analyze MIDI file and return statistics"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    
    stats = {
        'total_notes': sum(len(inst.notes) for inst in midi.instruments),
        'num_instruments': len(midi.instruments),
        'duration': midi.get_end_time(),
        'instruments': []
    }
    
    for i, inst in enumerate(midi.instruments):
        if len(inst.notes) > 0:
            stats['instruments'].append({
                'index': i,
                'program': inst.program,
                'name': pretty_midi.program_to_instrument_name(inst.program),
                'notes': len(inst.notes),
                'is_drum': inst.is_drum
            })
    
    return stats, midi

def transcribe_audio(audio_path, output_name):
    """Transcribe audio and return MIDI path + stats"""
    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    
    audio_info = {
        "filepath": audio_path,
        "track_name": output_name,
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(duration),
        "encoding": 'unknown',
    }
    
    # Transcribe
    midi_path = transcribe(model, audio_info)
    
    # Analyze
    stats, midi = analyze_midi(midi_path)
    
    return midi_path, stats

print("✅ Helper functions loaded")

## 5. Select Audio File for PoC Test

In [None]:
# Find audio files
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']
audio_files = []
for ext in audio_extensions:
    audio_files.extend(glob.glob(os.path.join(original_dir, ext)))

if len(audio_files) == 0:
    print("❌ No audio files found!")
    print("   Please upload audio files to the MT3 directory")
else:
    print(f"✅ Found {len(audio_files)} audio files:")
    for i, f in enumerate(audio_files):
        info = torchaudio.info(f)
        duration = info.num_frames / info.sample_rate
        print(f"   {i+1}. {os.path.basename(f)} ({duration:.1f}s)")

# File selector
file_selector = widgets.Dropdown(
    options=[(f"{os.path.basename(f)} ({torchaudio.info(f).num_frames/torchaudio.info(f).sample_rate:.1f}s)", f) for f in audio_files],
    description='Test File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='700px')
)

display(file_selector)

## 6. Run PoC Test: Full Mix vs Stems

**Click the button below to start the test** (takes 3-5 minutes)

In [None]:
# Global results storage
poc_results = {}

def run_poc_test(button):
    global poc_results
    
    output.clear_output(wait=True)
    
    with output:
        audio_path = file_selector.value
        
        if not audio_path:
            print("❌ Please select an audio file first!")
            return
        
        print("="*80)
        print("🎯 PROOF OF CONCEPT TEST: Full Mix vs Stems")
        print("="*80)
        print(f"\n📁 Test File: {os.path.basename(audio_path)}")
        
        # Step 1: Transcribe full mix
        print("\n" + "="*80)
        print("STEP 1: Transcribe Full Mix (Baseline)")
        print("="*80)
        
        try:
            fullmix_midi, fullmix_stats = transcribe_audio(audio_path, "poc_fullmix")
            
            print(f"\n✅ Full Mix Transcription Complete:")
            print(f"   Total notes: {fullmix_stats['total_notes']}")
            print(f"   Instruments: {fullmix_stats['num_instruments']}")
            print(f"   Duration: {fullmix_stats['duration']:.2f}s")
        except Exception as e:
            print(f"\n❌ Full mix transcription failed: {e}")
            import traceback
            traceback.print_exc()
            return
        
        # Step 2: Separate stems
        print("\n" + "="*80)
        print("STEP 2: Separate Stems with Demucs")
        print("="*80)
        
        stems = separate_stems_demucs(audio_path, output_dir="separated")
        
        if not stems:
            print("\n❌ Stem separation failed! Stopping test.")
            print("   Make sure Demucs is installed (run Cell 2)")
            return
        
        # Step 3: Transcribe each stem
        print("\n" + "="*80)
        print("STEP 3: Transcribe Each Stem")
        print("="*80)
        
        stem_results = {}
        
        for stem_name, stem_path in stems.items():
            print(f"\n🎵 Transcribing {stem_name} stem...")
            
            try:
                stem_midi, stem_stats = transcribe_audio(stem_path, f"poc_{stem_name}")
                stem_results[stem_name] = {
                    'midi_path': stem_midi,
                    'stats': stem_stats
                }
                
                print(f"   ✅ {stem_name}: {stem_stats['total_notes']} notes, {stem_stats['num_instruments']} instruments")
                
            except Exception as e:
                print(f"   ❌ {stem_name} transcription failed: {e}")
                stem_results[stem_name] = {'error': str(e)}
        
        # Step 4: Compare results
        print("\n" + "="*80)
        print("STEP 4: Comparison Analysis")
        print("="*80)
        
        # Store results
        poc_results = {
            'audio_path': audio_path,
            'fullmix': {
                'midi_path': fullmix_midi,
                'stats': fullmix_stats
            },
            'stems': stems,
            'stem_results': stem_results
        }
        
        # Overall comparison
        total_stem_notes = sum(r['stats']['total_notes'] for r in stem_results.values() if 'stats' in r)
        
        print(f"\n📊 Overall Note Count:")
        print(f"   Full Mix: {fullmix_stats['total_notes']} notes")
        print(f"   Stems Combined: {total_stem_notes} notes")
        improvement = ((total_stem_notes - fullmix_stats['total_notes']) / fullmix_stats['total_notes'] * 100) if fullmix_stats['total_notes'] > 0 else 0
        print(f"   Difference: {total_stem_notes - fullmix_stats['total_notes']:+d} notes ({improvement:+.1f}%)")
        
        # Per-stem comparison
        print(f"\n🎸 Per-Stem Analysis:")
        for stem_name, result in stem_results.items():
            if 'stats' in result:
                print(f"\n   {stem_name.upper()}:")
                print(f"      Notes detected: {result['stats']['total_notes']}")
                print(f"      Instruments: {result['stats']['num_instruments']}")
                if result['stats']['instruments']:
                    top_inst = sorted(result['stats']['instruments'], key=lambda x: x['notes'], reverse=True)[0]
                    print(f"      Top instrument: {top_inst['name']} ({top_inst['notes']} notes)")
        
        print("\n" + "="*80)
        print("✅ PoC TEST COMPLETE!")
        print("="*80)
        print("\nRun the cells below to:")
        print("   - View detailed comparison tables")
        print("   - Listen to full mix vs stems")
        print("   - Get recommendation based on results")

# Create button and output
poc_button = widgets.Button(
    description='🚀 Run PoC Test',
    button_style='success',
    layout=widgets.Layout(width='200px', height='50px')
)
poc_button.on_click(run_poc_test)

output = widgets.Output()

print("⚠️  Note: This test will take 3-5 minutes depending on audio length")
print("   - Full mix transcription: ~30s")
print("   - Stem separation: ~30-60s")
print("   - 4 stem transcriptions: ~2min")
print("\n✅ Ready! Click the button above to start the test")

display(poc_button)
display(output)

## 7. View Detailed Results

*Run this after the PoC test completes*

In [None]:
if poc_results:
    fullmix_notes = poc_results['fullmix']['stats']['total_notes']
    total_stem_notes = sum(r['stats']['total_notes'] for r in poc_results['stem_results'].values() if 'stats' in r)
    improvement = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    
    print("="*80)
    print("📊 PoC RESULTS SUMMARY")
    print("="*80)
    print(f"\n🎯 Improvement: {improvement:+.1f}%")
    print(f"   Full Mix: {fullmix_notes} notes")
    print(f"   Stems Combined: {total_stem_notes} notes")
    print(f"   Difference: {total_stem_notes - fullmix_notes:+d} notes\n")
    
    if improvement > 10:
        print("✅ RECOMMENDATION: PROCEED WITH IDEA 1 (Fine-tuning)")
        print("\n   Stems approach validates the hypothesis!")
        print("   Fine-tuning will likely add another 10-15% improvement.")
        print("   Expected total improvement: 20-30%\n")
        print("   Next steps:")
        print("   1. Read YOURMT3_FINETUNING_GUIDE.md")
        print("   2. Download Slakh2100 dataset (~1TB)")
        print("   3. Start fine-tuning bass model (3-5 days)")
        print("   4. Continue with other stems if bass succeeds")
    elif improvement > 5:
        print("⚠️  RECOMMENDATION: INVESTIGATE FURTHER")
        print("\n   Moderate improvement detected.")
        print("   Check stem quality and per-stem results.\n")
        print("   Action items:")
        print("   1. Listen to separated stems (next cell)")
        print("   2. Check which stems work best")
        print("   3. Try different Demucs model: demucs -n mdx_extra")
        print("   4. Re-run PoC with better separation")
    else:
        print("❌ RECOMMENDATION: RECONSIDER APPROACH")
        print("\n   Stems not providing expected benefit.")
        print("   Fine-tuning may not help.\n")
        print("   Alternatives:")
        print("   1. Investigate Demucs quality (listen to stems)")
        print("   2. Try Idea 2 (instrument matching)")
        print("   3. Use different stem separation (Spleeter, Open-Unmix)")
        print("   4. Hybrid approach: stems for specific instruments only")
    
    print("\n" + "="*80)
else:
    print("⚠️  Run the PoC test first (Cell 6)!")

## 8. Listen to Stems Quality

*Verify stem separation quality*

In [None]:
if poc_results:
    print("🎧 Audio Playback: Original vs Stems")
    print("="*80)
    
    # Original audio
    print("\n🎵 Original Full Mix:")
    waveform, sr = torchaudio.load(poc_results['audio_path'])
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    display(Audio(waveform.numpy(), rate=sr))
    
    # Each stem
    for stem_name, stem_path in poc_results['stems'].items():
        print(f"\n🎸 {stem_name.title()} Stem:")
        waveform, sr = torchaudio.load(stem_path)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        display(Audio(waveform.numpy(), rate=sr))
else:
    print("⚠️  Run the PoC test first!")

## 9. Save Results

*Export PoC results to JSON*

In [None]:
if poc_results:
    # Save results as JSON
    results_summary = {
        'test_file': os.path.basename(poc_results['audio_path']),
        'fullmix': {
            'notes': poc_results['fullmix']['stats']['total_notes'],
            'instruments': poc_results['fullmix']['stats']['num_instruments'],
            'duration': poc_results['fullmix']['stats']['duration']
        },
        'stems': {}
    }
    
    for stem_name, result in poc_results['stem_results'].items():
        if 'stats' in result:
            results_summary['stems'][stem_name] = {
                'notes': result['stats']['total_notes'],
                'instruments': result['stats']['num_instruments']
            }
    
    # Calculate overall improvement
    fullmix_notes = results_summary['fullmix']['notes']
    total_stem_notes = sum(s['notes'] for s in results_summary['stems'].values())
    improvement = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    
    results_summary['improvement'] = {
        'absolute': total_stem_notes - fullmix_notes,
        'percentage': improvement
    }
    
    # Save to file
    output_path = os.path.join(original_dir, 'poc_results.json')
    with open(output_path, 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"✅ PoC results saved to: {output_path}")
    print(f"\n📄 Summary:")
    print(json.dumps(results_summary, indent=2))
else:
    print("⚠️  Run the PoC test first!")

---

## 🎉 PoC Complete!

**If improvement >10%**: Proceed with fine-tuning (see `YOURMT3_FINETUNING_GUIDE.md`)

**If improvement 5-10%**: Investigate stem quality, try different Demucs model

**If improvement <5%**: Consider Idea 2 (instrument matching) or improve stem separation

---

*Assisted by Claude Code*