# 🎵 Stems Separation PoC: Accuracy Comparison

**Proof of Concept** to validate Idea 1: Stems-based transcription approach.

**Hypothesis**: Separating audio into stems (bass, drums, other, vocals) before transcription improves accuracy.

**Test**:
1. Transcribe **full mix** with YourMT3
2. Separate audio into **4 stems** with Demucs
3. Transcribe **each stem** with YourMT3
4. Compare accuracy and analyze improvements

**Expected**: 10-15% accuracy improvement on stems (without fine-tuning yet)

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import glob
import subprocess
import shutil

# Change to yourmt3_space directory
original_dir = os.getcwd()
if not os.path.exists('yourmt3_space'):
    print("❌ Error: yourmt3_space directory not found!")
    print("   Run setup_yourmt3_brev.sh first")
else:
    os.chdir('yourmt3_space')
    sys.path.insert(0, '.')
    sys.path.insert(0, 'amt/src')

import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from IPython.display import Audio, display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
import json
from collections import defaultdict

from model_helper import load_model_checkpoint, transcribe

print("✅ Imports successful!")
print(f"   Working directory: {os.getcwd()}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

## 2. Check Demucs Installation

In [None]:
# Check if Demucs is installed
try:
    result = subprocess.run(['demucs', '--help'], capture_output=True, text=True)
    if result.returncode == 0:
        print("✅ Demucs is installed and ready")
        # Get version
        version_result = subprocess.run(['pip', 'show', 'demucs'], capture_output=True, text=True)
        for line in version_result.stdout.split('\n'):
            if line.startswith('Version:'):
                print(f"   {line}")
    else:
        raise Exception("Demucs command failed")
except FileNotFoundError:
    print("⚠️  Demucs not found! Installing...")
    print("   This will take 2-5 minutes...")
    subprocess.run(['pip', 'install', 'demucs'], check=True)
    print("✅ Demucs installed successfully!")
except Exception as e:
    print(f"❌ Demucs check failed: {e}")
    print("   Install manually: pip install demucs")

## 3. Load YourMT3 Model

In [None]:
print("Loading YourMT3 model...")
print("This may take 10-15 seconds...")

# Model configuration
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
project = '2024'
precision = '16'

args = [
    checkpoint,
    '-p', project,
    '-tk', 'mc13_full_plus_256',
    '-dec', 'multi-t5',
    '-nl', '26',
    '-enc', 'perceiver-tf',
    '-sqr', '1',
    '-ff', 'moe',
    '-wf', '4',
    '-nmoe', '8',
    '-kmoe', '2',
    '-act', 'silu',
    '-epe', 'rope',
    '-rp', '1',
    '-ac', 'spec',
    '-hop', '300',
    '-atc', '1',
    '-pr', precision
]

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_checkpoint(args=args, device=device)

print("\n✅ Model loaded successfully!")
print(f"   Device: {device}")
print(f"   Model: YPTF.MoE+Multi (noPS)")

## 4. Helper Functions for PoC

In [None]:
def separate_stems_demucs(audio_path, output_dir="separated"):
    """Separate audio into 4 stems using Demucs"""
    print(f"🎵 Separating stems with Demucs...")
    print(f"   Input: {audio_path}")
    print(f"   Output: {output_dir}/")
    print("   This may take 30-60 seconds...")
    
    # Run Demucs (htdemucs model - best quality, 4 stems)
    cmd = [
        'demucs',
        '--two-stems=vocals',  # First separate vocals
        '-n', 'htdemucs',      # Use htdemucs model
        '--out', output_dir,
        audio_path
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, check=True)
        
        # Find output stems
        track_name = Path(audio_path).stem
        stem_dir = Path(output_dir) / 'htdemucs' / track_name
        
        stems = {
            'bass': str(stem_dir / 'bass.wav'),
            'drums': str(stem_dir / 'drums.wav'),
            'other': str(stem_dir / 'other.wav'),
            'vocals': str(stem_dir / 'vocals.wav')
        }
        
        # Verify all stems exist
        for stem_name, stem_path in stems.items():
            if not Path(stem_path).exists():
                raise FileNotFoundError(f"Stem not found: {stem_path}")
        
        print("✅ Stems separated successfully!")
        for stem_name in stems.keys():
            print(f"   - {stem_name}.wav")
        
        return stems
        
    except subprocess.CalledProcessError as e:
        print(f"❌ Demucs separation failed: {e}")
        print(f"   stderr: {e.stderr}")
        return None

def analyze_midi(midi_path):
    """Analyze MIDI file and return statistics"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    
    stats = {
        'total_notes': sum(len(inst.notes) for inst in midi.instruments),
        'num_instruments': len(midi.instruments),
        'duration': midi.get_end_time(),
        'instruments': []
    }
    
    for i, inst in enumerate(midi.instruments):
        if len(inst.notes) > 0:
            stats['instruments'].append({
                'index': i,
                'program': inst.program,
                'name': pretty_midi.program_to_instrument_name(inst.program),
                'notes': len(inst.notes),
                'is_drum': inst.is_drum
            })
    
    return stats, midi

def transcribe_audio(audio_path, output_name):
    """Transcribe audio and return MIDI path + stats"""
    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    
    audio_info = {
        "filepath": audio_path,
        "track_name": output_name,
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(duration),
        "encoding": 'unknown',
    }
    
    # Transcribe
    midi_path = transcribe(model, audio_info)
    
    # Analyze
    stats, midi = analyze_midi(midi_path)
    
    return midi_path, stats

def compare_instruments(fullmix_stats, stem_stats):
    """Compare instrument detection between full mix and stem"""
    fullmix_programs = {inst['program']: inst['notes'] for inst in fullmix_stats['instruments']}
    stem_programs = {inst['program']: inst['notes'] for inst in stem_stats['instruments']}
    
    comparison = {
        'programs_in_both': [],
        'programs_only_fullmix': [],
        'programs_only_stem': [],
        'note_differences': {}
    }
    
    all_programs = set(fullmix_programs.keys()) | set(stem_programs.keys())
    
    for program in all_programs:
        fullmix_notes = fullmix_programs.get(program, 0)
        stem_notes = stem_programs.get(program, 0)
        
        if fullmix_notes > 0 and stem_notes > 0:
            comparison['programs_in_both'].append(program)
            comparison['note_differences'][program] = {
                'fullmix': fullmix_notes,
                'stem': stem_notes,
                'diff': stem_notes - fullmix_notes,
                'improvement': ((stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
            }
        elif fullmix_notes > 0:
            comparison['programs_only_fullmix'].append(program)
        else:
            comparison['programs_only_stem'].append(program)
    
    return comparison

print("✅ Helper functions loaded")

## 5. Select Audio File for PoC Test

In [None]:
# Find audio files
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']
audio_files = []
for ext in audio_extensions:
    audio_files.extend(glob.glob(os.path.join(original_dir, ext)))

if len(audio_files) == 0:
    print("❌ No audio files found!")
    print("   Please upload audio files to the MT3 directory")
else:
    print(f"✅ Found {len(audio_files)} audio files:")
    for i, f in enumerate(audio_files):
        info = torchaudio.info(f)
        duration = info.num_frames / info.sample_rate
        print(f"   {i+1}. {os.path.basename(f)} ({duration:.1f}s)")

# File selector
file_selector = widgets.Dropdown(
    options=[(f"{os.path.basename(f)} ({torchaudio.info(f).num_frames/torchaudio.info(f).sample_rate:.1f}s)", f) for f in audio_files],
    description='Test File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='700px')
)

display(file_selector)

## 6. Run PoC Test: Full Mix vs Stems

**This will**:
1. Transcribe the full mix
2. Separate into 4 stems with Demucs
3. Transcribe each stem separately
4. Compare results

In [None]:
# Global results storage
poc_results = {}

def run_poc_test(button):
    global poc_results
    
    output.clear_output(wait=True)
    
    with output:
        audio_path = file_selector.value
        
        if not audio_path:
            print("❌ Please select an audio file first!")
            return
        
        print("="*80)
        print("🎯 PROOF OF CONCEPT TEST: Full Mix vs Stems")
        print("="*80)
        print(f"\n📁 Test File: {os.path.basename(audio_path)}")
        
        # Step 1: Transcribe full mix
        print("\n" + "="*80)
        print("STEP 1: Transcribe Full Mix (Baseline)")
        print("="*80)
        
        fullmix_midi, fullmix_stats = transcribe_audio(audio_path, "poc_fullmix")
        
        print(f"\n✅ Full Mix Transcription Complete:")
        print(f"   Total notes: {fullmix_stats['total_notes']}")
        print(f"   Instruments: {fullmix_stats['num_instruments']}")
        print(f"   Duration: {fullmix_stats['duration']:.2f}s")
        
        # Step 2: Separate stems
        print("\n" + "="*80)
        print("STEP 2: Separate Stems with Demucs")
        print("="*80)
        
        stems = separate_stems_demucs(audio_path, output_dir="separated")
        
        if not stems:
            print("❌ Stem separation failed! Stopping test.")
            return
        
        # Step 3: Transcribe each stem
        print("\n" + "="*80)
        print("STEP 3: Transcribe Each Stem")
        print("="*80)
        
        stem_results = {}
        
        for stem_name, stem_path in stems.items():
            print(f"\n🎵 Transcribing {stem_name} stem...")
            
            try:
                stem_midi, stem_stats = transcribe_audio(stem_path, f"poc_{stem_name}")
                stem_results[stem_name] = {
                    'midi_path': stem_midi,
                    'stats': stem_stats
                }
                
                print(f"   ✅ {stem_name}: {stem_stats['total_notes']} notes, {stem_stats['num_instruments']} instruments")
                
            except Exception as e:
                print(f"   ❌ {stem_name} transcription failed: {e}")
                stem_results[stem_name] = {'error': str(e)}
        
        # Step 4: Compare results
        print("\n" + "="*80)
        print("STEP 4: Comparison Analysis")
        print("="*80)
        
        # Store results
        poc_results = {
            'audio_path': audio_path,
            'fullmix': {
                'midi_path': fullmix_midi,
                'stats': fullmix_stats
            },
            'stems': stems,
            'stem_results': stem_results
        }
        
        # Overall comparison
        total_stem_notes = sum(r['stats']['total_notes'] for r in stem_results.values() if 'stats' in r)
        
        print(f"\n📊 Overall Note Count:")
        print(f"   Full Mix: {fullmix_stats['total_notes']} notes")
        print(f"   Stems Combined: {total_stem_notes} notes")
        improvement = ((total_stem_notes - fullmix_stats['total_notes']) / fullmix_stats['total_notes'] * 100) if fullmix_stats['total_notes'] > 0 else 0
        print(f"   Difference: {total_stem_notes - fullmix_stats['total_notes']:+d} notes ({improvement:+.1f}%)")
        
        # Per-stem comparison
        print(f"\n🎸 Per-Stem Analysis:")
        for stem_name, result in stem_results.items():
            if 'stats' in result:
                print(f"\n   {stem_name.upper()}:")
                print(f"      Notes detected: {result['stats']['total_notes']}")
                print(f"      Instruments: {result['stats']['num_instruments']}")
                if result['stats']['instruments']:
                    top_inst = sorted(result['stats']['instruments'], key=lambda x: x['notes'], reverse=True)[0]
                    print(f"      Top instrument: {top_inst['name']} ({top_inst['notes']} notes)")
        
        print("\n" + "="*80)
        print("✅ PoC TEST COMPLETE!")
        print("="*80)
        print("\nRun the cells below to:")
        print("   - View detailed comparison tables")
        print("   - Listen to full mix vs stems")
        print("   - Analyze accuracy improvements")

# Create button and output
poc_button = widgets.Button(
    description='🚀 Run PoC Test',
    button_style='success',
    layout=widgets.Layout(width='200px', height='50px')
)
poc_button.on_click(run_poc_test)

output = widgets.Output()

print("⚠️  Note: This test will take 2-5 minutes depending on audio length")
print("   - Full mix transcription: ~30s")
print("   - Stem separation: ~30-60s")
print("   - 4 stem transcriptions: ~2min")

display(poc_button)
display(output)

## 7. Detailed Comparison Table

In [None]:
if poc_results:
    print("📊 Detailed Comparison: Full Mix vs Stems")
    print("="*80)
    
    fullmix_stats = poc_results['fullmix']['stats']
    
    # Create comparison table
    comparison_data = []
    
    # Add full mix
    comparison_data.append({
        'Source': 'Full Mix (Baseline)',
        'Notes': fullmix_stats['total_notes'],
        'Instruments': fullmix_stats['num_instruments'],
        'Density': f"{fullmix_stats['total_notes']/fullmix_stats['duration']:.1f}"
    })
    
    # Add each stem
    for stem_name, result in poc_results['stem_results'].items():
        if 'stats' in result:
            stats = result['stats']
            comparison_data.append({
                'Source': f"{stem_name.title()} Stem",
                'Notes': stats['total_notes'],
                'Instruments': stats['num_instruments'],
                'Density': f"{stats['total_notes']/stats['duration']:.1f}"
            })
    
    # Add combined
    total_stem_notes = sum(r['stats']['total_notes'] for r in poc_results['stem_results'].values() if 'stats' in r)
    total_stem_insts = sum(r['stats']['num_instruments'] for r in poc_results['stem_results'].values() if 'stats' in r)
    
    comparison_data.append({
        'Source': 'Stems Combined',
        'Notes': total_stem_notes,
        'Instruments': total_stem_insts,
        'Density': f"{total_stem_notes/fullmix_stats['duration']:.1f}"
    })
    
    # Display as HTML table
    html = "<table style='border-collapse: collapse; width: 100%;'>\n"
    html += "<tr style='background-color: #f0f0f0;'>\n"
    html += "<th style='border: 1px solid #ddd; padding: 12px;'>Source</th>"
    html += "<th style='border: 1px solid #ddd; padding: 12px;'>Notes</th>"
    html += "<th style='border: 1px solid #ddd; padding: 12px;'>Instruments</th>"
    html += "<th style='border: 1px solid #ddd; padding: 12px;'>Density (notes/s)</th>"
    html += "</tr>\n"
    
    for row in comparison_data:
        html += "<tr>\n"
        html += f"<td style='border: 1px solid #ddd; padding: 10px;'>{row['Source']}</td>"
        html += f"<td style='border: 1px solid #ddd; padding: 10px;'>{row['Notes']}</td>"
        html += f"<td style='border: 1px solid #ddd; padding: 10px;'>{row['Instruments']}</td>"
        html += f"<td style='border: 1px solid #ddd; padding: 10px;'>{row['Density']}</td>"
        html += "</tr>\n"
    
    html += "</table>"
    
    display(HTML(html))
    
    # Calculate improvement
    improvement = ((total_stem_notes - fullmix_stats['total_notes']) / fullmix_stats['total_notes'] * 100) if fullmix_stats['total_notes'] > 0 else 0
    
    print(f"\n🎯 Key Finding:")
    if improvement > 10:
        print(f"   ✅ Stems approach shows {improvement:.1f}% more notes detected!")
        print(f"   This validates the hypothesis: stem separation improves transcription.")
        print(f"   ✨ Recommendation: Proceed with fine-tuning (Idea 1)")
    elif improvement > 5:
        print(f"   ⚠️  Stems approach shows {improvement:.1f}% more notes detected.")
        print(f"   Moderate improvement - fine-tuning may add 10-15% more.")
        print(f"   📊 Recommendation: Review per-stem quality before deciding")
    else:
        print(f"   ⚠️  Stems approach shows only {improvement:.1f}% improvement.")
        print(f"   This is lower than expected (10-15% target).")
        print(f"   🔍 Recommendation: Investigate Demucs quality or try Idea 2")
        
else:
    print("⚠️  Run the PoC test first (cell above)!")

## 8. Play Audio Comparison

Listen to the original vs each separated stem

In [None]:
if poc_results:
    print("🎧 Audio Playback: Original vs Stems")
    print("="*80)
    
    # Original audio
    print("\n🎵 Original Full Mix:")
    waveform, sr = torchaudio.load(poc_results['audio_path'])
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    display(Audio(waveform.numpy(), rate=sr))
    
    # Each stem
    for stem_name, stem_path in poc_results['stems'].items():
        print(f"\n🎸 {stem_name.title()} Stem:")
        waveform, sr = torchaudio.load(stem_path)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        display(Audio(waveform.numpy(), rate=sr))
        
else:
    print("⚠️  Run the PoC test first!")

## 9. Decision Matrix

Based on PoC results, should you proceed with Idea 1 (fine-tuning)?

In [None]:
if poc_results:
    fullmix_notes = poc_results['fullmix']['stats']['total_notes']
    total_stem_notes = sum(r['stats']['total_notes'] for r in poc_results['stem_results'].values() if 'stats' in r)
    improvement = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    
    print("🎯 DECISION MATRIX")
    print("="*80)
    print(f"\nPoC Result: {improvement:+.1f}% note count improvement with stems\n")
    
    decision_html = """
    <div style='font-family: Arial, sans-serif;'>
        <h3>Interpretation Guide:</h3>
        
        <div style='background-color: #d4edda; padding: 15px; margin: 10px 0; border-left: 4px solid #28a745;'>
            <strong>✅ >10% Improvement → PROCEED WITH IDEA 1</strong>
            <ul>
                <li>Stems approach validates the hypothesis</li>
                <li>Fine-tuning likely to add another 10-15% improvement</li>
                <li>Expected final improvement: 20-30%</li>
                <li><strong>Action</strong>: Download Slakh2100, start bass model fine-tuning</li>
            </ul>
        </div>
        
        <div style='background-color: #fff3cd; padding: 15px; margin: 10px 0; border-left: 4px solid #ffc107;'>
            <strong>⚠️ 5-10% Improvement → INVESTIGATE FURTHER</strong>
            <ul>
                <li>Moderate improvement, below 10-15% target</li>
                <li>Check Demucs stem quality (listen to stems)</li>
                <li>Verify per-stem transcription quality</li>
                <li><strong>Action</strong>: Analyze per-stem results, try different Demucs model</li>
            </ul>
        </div>
        
        <div style='background-color: #f8d7da; padding: 15px; margin: 10px 0; border-left: 4px solid #dc3545;'>
            <strong>❌ <5% Improvement → RECONSIDER APPROACH</strong>
            <ul>
                <li>Stems not providing expected benefit</li>
                <li>Possible issues: poor stem separation, model limitations</li>
                <li>Fine-tuning may not help if stems themselves don't improve accuracy</li>
                <li><strong>Action</strong>: Try Idea 2 (instrument matching), or investigate Demucs alternatives</li>
            </ul>
        </div>
        
        <h3>Next Steps if Proceeding:</h3>
        <ol>
            <li><strong>Download Slakh2100</strong>: ~1TB, contains stem-level training data</li>
            <li><strong>Prepare fine-tuning environment</strong>: See YOURMT3_FINETUNING_GUIDE.md</li>
            <li><strong>Start with bass model</strong>: 3-5 days fine-tuning on A10G</li>
            <li><strong>Evaluate bass results</strong>: Measure accuracy improvement</li>
            <li><strong>Continue with other stems</strong>: Drums, other, vocals (if bass succeeds)</li>
        </ol>
    </div>
    """
    
    display(HTML(decision_html))
    
    # Specific recommendation based on actual result
    print("\n" + "="*80)
    print("📋 YOUR SPECIFIC RECOMMENDATION:")
    print("="*80)
    
    if improvement > 10:
        print(f"\n✅ Result: {improvement:.1f}% improvement")
        print(f"\n🎯 RECOMMENDATION: PROCEED WITH IDEA 1 (Fine-tuning)")
        print(f"\nNext steps:")
        print(f"   1. Review YOURMT3_FINETUNING_GUIDE.md for setup instructions")
        print(f"   2. Download Slakh2100 dataset (~1TB): wget https://zenodo.org/record/4599666/files/slakh2100_flac_16k.tar.gz")
        print(f"   3. Start fine-tuning bass model (3-5 days on A10G GPU)")
        print(f"   4. Expected final accuracy improvement: 20-30%")
    elif improvement > 5:
        print(f"\n⚠️  Result: {improvement:.1f}% improvement")
        print(f"\n📊 RECOMMENDATION: INVESTIGATE BEFORE PROCEEDING")
        print(f"\nAction items:")
        print(f"   1. Listen to separated stems quality (cell above)")
        print(f"   2. Check if specific stems (bass/drums) show better improvement")
        print(f"   3. Try different Demucs model: demucs -n mdx_extra")
        print(f"   4. Re-run PoC with better stem separation")
    else:
        print(f"\n❌ Result: {improvement:.1f}% improvement (below target)")
        print(f"\n🔍 RECOMMENDATION: RECONSIDER APPROACH")
        print(f"\nAlternatives to explore:")
        print(f"   1. Investigate why stems don't help (check Demucs output quality)")
        print(f"   2. Try Idea 2 (audio-based instrument matching) as alternative")
        print(f"   3. Use different stem separation model (Spleeter, Open-Unmix)")
        print(f"   4. Consider hybrid approach: stems for specific instruments only")
        
else:
    print("⚠️  Run the PoC test first!")

## 10. Save PoC Results

In [None]:
if poc_results:
    # Save results as JSON
    results_summary = {
        'test_file': os.path.basename(poc_results['audio_path']),
        'fullmix': {
            'notes': poc_results['fullmix']['stats']['total_notes'],
            'instruments': poc_results['fullmix']['stats']['num_instruments'],
            'duration': poc_results['fullmix']['stats']['duration']
        },
        'stems': {}
    }
    
    for stem_name, result in poc_results['stem_results'].items():
        if 'stats' in result:
            results_summary['stems'][stem_name] = {
                'notes': result['stats']['total_notes'],
                'instruments': result['stats']['num_instruments']
            }
    
    # Calculate overall improvement
    fullmix_notes = results_summary['fullmix']['notes']
    total_stem_notes = sum(s['notes'] for s in results_summary['stems'].values())
    improvement = ((total_stem_notes - fullmix_notes) / fullmix_notes * 100) if fullmix_notes > 0 else 0
    
    results_summary['improvement'] = {
        'absolute': total_stem_notes - fullmix_notes,
        'percentage': improvement
    }
    
    # Save to file
    output_path = os.path.join(original_dir, 'poc_results.json')
    with open(output_path, 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"✅ PoC results saved to: {output_path}")
    print(f"\n📄 Summary:")
    print(json.dumps(results_summary, indent=2))
    
else:
    print("⚠️  Run the PoC test first!")

---

## 🎉 PoC Complete!

**What we tested**:
- ✅ Full mix transcription (baseline)
- ✅ Stem separation with Demucs
- ✅ Individual stem transcription
- ✅ Accuracy comparison

**What this tells us**:
- If improvement >10%: **Proceed with fine-tuning** (Idea 1)
- If improvement 5-10%: **Investigate further** before committing
- If improvement <5%: **Reconsider approach** or try Idea 2

**Next steps if proceeding**:
1. Read `YOURMT3_FINETUNING_GUIDE.md`
2. Download Slakh2100 dataset
3. Fine-tune bass model first
4. Measure improvements
5. Continue with other stems

---

*Assisted by Claude Code*