# 🎵 Instrument Recognition PoC

**Goal**: Identify specific instruments present in audio using MIDI transcription + timbre matching

**Pipeline**:
- **Phase 0** (This notebook): MIDI transcription with YourMT3 ✅
- **Phase 1**: YAMNet setup and instrument mapping (521 → 25 categories) 🚧
- **Phase 2**: Note isolation from MIDI piano roll 🚧
- **Phase 3**: Timbre matching with YAMNet 🚧
- **Phase 4**: Output generation (aggregated % + timeline) 🚧
- **Phase 5**: Accuracy evaluation 🚧

**Reference**: See `instrument_recognition/SPECIFICATION.md` for full technical details

## ⚠️ IMPORTANT: Run cells in order (1→2→3→4→5)

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import glob

# Change to yourmt3_space directory
original_dir = os.getcwd()
if not os.path.exists('yourmt3_space'):
    print("❌ Error: yourmt3_space directory not found!")
    print("   Run setup_yourmt3_brev.sh first")
else:
    os.chdir('yourmt3_space')
    sys.path.insert(0, '.')
    sys.path.insert(0, 'amt/src')

import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from IPython.display import Audio, display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
import json
from collections import defaultdict

from model_helper import load_model_checkpoint, transcribe

print("✅ Imports successful!")
print(f"   Working directory: {os.getcwd()}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

## 2. Load YourMT3 Model

**Model**: YPTF.MoE+Multi (noPS) - 536M parameters

In [None]:
print("Loading YourMT3 model...")
print("This may take 10-15 seconds...")

# Model configuration
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
project = '2024'
precision = '16'

args = [
    checkpoint,
    '-p', project,
    '-tk', 'mc13_full_plus_256',
    '-dec', 'multi-t5',
    '-nl', '26',
    '-enc', 'perceiver-tf',
    '-sqr', '1',
    '-ff', 'moe',
    '-wf', '4',
    '-nmoe', '8',
    '-kmoe', '2',
    '-act', 'silu',
    '-epe', 'rope',
    '-rp', '1',
    '-ac', 'spec',
    '-hop', '300',
    '-atc', '1',
    '-pr', precision
]

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_checkpoint(args=args, device=device)

print("\n✅ Model loaded successfully!")
print(f"   Device: {device}")
print(f"   Model: YPTF.MoE+Multi (noPS)")

## 3. Helper Functions - Phase 0 (MIDI Transcription)

In [None]:
def analyze_midi(midi_path):
    """Analyze MIDI file and return statistics"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    
    stats = {
        'total_notes': sum(len(inst.notes) for inst in midi.instruments),
        'num_instruments': len(midi.instruments),
        'duration': midi.get_end_time(),
        'instruments': []
    }
    
    for i, inst in enumerate(midi.instruments):
        if len(inst.notes) > 0:
            stats['instruments'].append({
                'index': i,
                'program': inst.program,
                'name': pretty_midi.program_to_instrument_name(inst.program),
                'notes': len(inst.notes),
                'is_drum': inst.is_drum
            })
    
    return stats, midi

def transcribe_audio(audio_path, output_name):
    """Transcribe audio and return MIDI path + stats"""
    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    
    audio_info = {
        "filepath": audio_path,
        "track_name": output_name,
        "sample_rate": int(info.sample_rate),
        "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
        "num_channels": int(info.num_channels),
        "num_frames": int(info.num_frames),
        "duration": int(duration),
        "encoding": 'unknown',
    }
    
    # Transcribe
    midi_path = transcribe(model, audio_info)
    
    # Analyze
    stats, midi = analyze_midi(midi_path)
    
    return midi_path, stats, midi

def midi_to_audio(midi_path, sample_rate=16000):
    """Convert MIDI to audio for playback using FluidSynth"""
    try:
        midi = pretty_midi.PrettyMIDI(midi_path)
        audio = midi.fluidsynth(fs=sample_rate)
        return audio, sample_rate
    except Exception as e:
        print(f"⚠️  MIDI synthesis failed: {e}")
        print("   Note: FluidSynth may not be installed")
        return None, None

def plot_piano_roll(midi, title, ax=None):
    """Plot piano roll from MIDI object"""
    if ax is None:
        fig, ax = plt.subplots(figsize=(16, 6))
    
    # Get all notes
    notes_to_plot = []
    colors = plt.cm.tab20.colors
    
    for inst_idx, inst in enumerate(midi.instruments):
        color = colors[inst_idx % len(colors)]
        for note in inst.notes:
            notes_to_plot.append({
                'start': note.start,
                'end': note.end,
                'pitch': note.pitch,
                'velocity': note.velocity,
                'color': color,
                'instrument': inst.program,
                'is_drum': inst.is_drum
            })
    
    # Plot notes
    for note_info in notes_to_plot:
        ax.add_patch(
            plt.Rectangle(
                (note_info['start'], note_info['pitch']),
                note_info['end'] - note_info['start'],
                1,
                facecolor=note_info['color'],
                edgecolor='black',
                linewidth=0.5,
                alpha=0.7
            )
        )
    
    ax.set_xlim(0, midi.get_end_time())
    ax.set_ylim(20, 108)  # Piano range
    ax.set_xlabel('Time (seconds)', fontsize=12)
    ax.set_ylabel('MIDI Pitch', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add note count
    total_notes = len(notes_to_plot)
    ax.text(0.98, 0.02, f'{total_notes} notes', 
            transform=ax.transAxes,
            ha='right', va='bottom',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
            fontsize=10)
    
    return ax

print("✅ Helper functions loaded")

## 4. Select Audio File

In [None]:
# Find audio files
audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']
audio_files = []
for ext in audio_extensions:
    audio_files.extend(glob.glob(os.path.join(original_dir, ext)))

if len(audio_files) == 0:
    print("❌ No audio files found!")
    print("   Please upload audio files to the MT3 directory")
else:
    print(f"✅ Found {len(audio_files)} audio files:")
    for i, f in enumerate(audio_files):
        info = torchaudio.info(f)
        duration = info.num_frames / info.sample_rate
        print(f"   {i+1}. {os.path.basename(f)} ({duration:.1f}s)")

# File selector
file_selector = widgets.Dropdown(
    options=[(f"{os.path.basename(f)} ({torchaudio.info(f).num_frames/torchaudio.info(f).sample_rate:.1f}s)", f) for f in audio_files],
    description='Test File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='700px')
)

display(file_selector)

## 5. Phase 0: MIDI Transcription

**Goal**: Generate MIDI transcription from audio using YourMT3

**Output**: MIDI file with multi-instrument transcription

In [None]:
# Global results storage
phase0_results = {}

def run_phase0(button):
    global phase0_results
    
    output.clear_output(wait=True)
    
    with output:
        audio_path = file_selector.value
        
        if not audio_path:
            print("❌ Please select an audio file first!")
            return
        
        print("="*80)
        print("🎯 PHASE 0: MIDI Transcription")
        print("="*80)
        print(f"\n📁 Input File: {os.path.basename(audio_path)}")
        
        try:
            # Transcribe
            print("\n🎵 Running YourMT3 transcription...")
            print("   This may take 30-60 seconds...")
            
            midi_path, stats, midi = transcribe_audio(audio_path, "instrument_recognition")
            
            # Store results
            phase0_results = {
                'audio_path': audio_path,
                'midi_path': midi_path,
                'stats': stats,
                'midi': midi
            }
            
            # Display results
            print(f"\n✅ Transcription Complete:")
            print(f"   MIDI file: {midi_path}")
            print(f"   Total notes: {stats['total_notes']}")
            print(f"   Instruments detected: {stats['num_instruments']}")
            print(f"   Duration: {stats['duration']:.2f}s")
            
            # Show instrument breakdown
            print(f"\n🎸 Detected Instruments:")
            for inst in stats['instruments']:
                drum_label = "(Drums)" if inst['is_drum'] else ""
                print(f"   - {inst['name']} {drum_label}: {inst['notes']} notes (Program {inst['program']})")
            
            print("\n" + "="*80)
            print("✅ PHASE 0 COMPLETE!")
            print("="*80)
            print("\nRun the cells below to:")
            print("   - View piano roll visualization")
            print("   - Listen to MIDI playback")
            print("   - Proceed to Phase 1 (YAMNet setup)")
            
        except Exception as e:
            print(f"\n❌ Transcription failed: {e}")
            import traceback
            traceback.print_exc()

# Create button and output
phase0_button = widgets.Button(
    description='🚀 Run Phase 0',
    button_style='success',
    layout=widgets.Layout(width='200px', height='50px')
)
phase0_button.on_click(run_phase0)

output = widgets.Output()

print("⚠️  Note: Transcription takes ~30-60 seconds depending on audio length")
print("\n✅ Ready! Click the button below to start Phase 0")

display(phase0_button)
display(output)

## 6. View MIDI Results

*Run this after Phase 0 completes*

In [None]:
if phase0_results:
    print("🎹 MIDI Transcription Results")
    print("="*80)
    
    stats = phase0_results['stats']
    midi = phase0_results['midi']
    
    # Summary
    print(f"\n📊 Summary:")
    print(f"   Audio: {os.path.basename(phase0_results['audio_path'])}")
    print(f"   MIDI: {phase0_results['midi_path']}")
    print(f"   Total notes: {stats['total_notes']}")
    print(f"   Instruments: {stats['num_instruments']}")
    print(f"   Duration: {stats['duration']:.2f}s")
    
    # Piano roll visualization
    print(f"\n🎹 Piano Roll:")
    fig, ax = plt.subplots(figsize=(18, 6))
    plot_piano_roll(midi, '🎵 MIDI Transcription - Phase 0', ax)
    plt.tight_layout()
    plt.show()
    
    # MIDI playback
    print(f"\n🔊 MIDI Playback:")
    audio, sr = midi_to_audio(phase0_results['midi_path'])
    if audio is not None:
        display(Audio(audio, rate=sr))
    
    print("\n" + "="*80)
    print("✅ Phase 0 complete!")
    print("\n💡 Next Steps:")
    print("   → Phase 1: Set up YAMNet for timbre matching")
    print("   → Phase 2: Extract isolated notes from MIDI piano roll")
    print("   → Phase 3: Match notes to instrument timbres")
    print("   → Phase 4: Generate aggregated % and timeline outputs")
    
else:
    print("⚠️  Run Phase 0 first (Cell 5)!")

---

## 🎯 Phase 1: YAMNet Setup & Instrument Mapping

**Goal**: Load YAMNet model and create instrument mapping

**Tasks**:
1. Install TensorFlow and YAMNet dependencies ✅
2. Load pretrained YAMNet model ✅
3. Create mapping: 521 AudioSet classes → 25 Level 2 instrument categories
4. Test YAMNet on sample audio segments ✅

### 7.1. Import Phase 1 Dependencies

⚠️ **Prerequisite**: Run `pip install -r instrument_recognition/requirements.txt` on the instance first!

### 7.5. Create Instrument Mapping File

**User Task**: Create `yamnet_to_level2_mapping.json` to map 521 YAMNet classes → 25 Level 2 categories.

**Instructions**:
1. Review YAMNet classes above and SPECIFICATION.md for Level 2 categories
2. Create mapping file: `instrument_recognition/yamnet_to_level2_mapping.json`
3. Format: `{"YAMNet_Class_Name": "Level2_Category", ...}`

**Example mapping**:
```json
{
  "Bass drum": "Kick Drum",
  "Snare drum": "Snare Drum",
  "Hi-hat": "Hi-hat (Closed)",
  "Cymbal": "Crash Cymbal",
  "Electric guitar": "Electric Guitar (Clean)",
  "Distortion": "Electric Guitar (Distorted)",
  "Acoustic guitar": "Acoustic Guitar",
  "Bass guitar": "Electric Bass",
  "Synthesizer": "Synthesizer (Lead)",
  "Piano": "Piano (Acoustic)",
  "Electric piano": "Electric Piano",
  ...
}
```

**Level 2 Categories (25 total)**:
- **Drums**: Kick Drum, Snare Drum, Hi-hat (Closed), Hi-hat (Open), Crash Cymbal, Ride Cymbal, Tom Drum, Electronic Drum
- **Bass**: Electric Bass, Synth Bass, Acoustic Bass  
- **Guitar**: Electric Guitar (Clean), Electric Guitar (Distorted), Acoustic Guitar
- **Keys**: Piano (Acoustic), Electric Piano, Synthesizer (Lead), Synthesizer (Pad)
- **Other**: Orchestral Strings, Brass, Woodwinds, Vocals, Vocal Sample, Sound Effects, Unknown

Once created, the mapping will be loaded in Phase 3 for timbre matching.

In [None]:
if phase0_results:
    print("🧪 Testing YAMNet on audio from Phase 0...")
    
    # Load audio
    audio_path = phase0_results['audio_path']
    waveform, sr = librosa.load(audio_path, sr=16000, mono=True, duration=5.0)  # First 5 seconds
    
    print(f"   Audio: {os.path.basename(audio_path)}")
    print(f"   Sample rate: {sr} Hz")
    print(f"   Duration: {len(waveform)/sr:.2f}s")
    
    # Run YAMNet inference
    print("\n🔍 Running YAMNet classification...")
    scores, embeddings, spectrogram = yamnet_model(waveform)
    
    # Average scores across time
    mean_scores = np.mean(scores.numpy(), axis=0)
    
    # Get top-10 predictions
    top_k = 10
    top_indices = np.argsort(mean_scores)[-top_k:][::-1]
    
    print(f"\n✅ Top {top_k} predicted classes:")
    for i, idx in enumerate(top_indices, 1):
        print(f"   {i:2d}. {class_names[idx]:30s} (confidence: {mean_scores[idx]:.3f})")
    
    print("\n💡 These raw YAMNet predictions will be mapped to Level 2 categories")
    
else:
    print("⚠️  Run Phase 0 first to get audio for testing!")

### 7.4. Test YAMNet on Sample Audio

Test YAMNet classification on a short audio segment from Phase 0.

In [None]:
# Browse instrument-related classes
print("🎸 Instrument-Related AudioSet Classes (Sample):\n")

# Common instrument keywords
instrument_keywords = ['guitar', 'piano', 'drum', 'bass', 'synth', 'kick', 'snare', 'cymbal', 
                       'violin', 'trumpet', 'saxophone', 'flute', 'vocal', 'voice']

instrument_classes = []
for idx, name in enumerate(class_names):
    for keyword in instrument_keywords:
        if keyword.lower() in name.lower():
            instrument_classes.append((idx, name))
            break

print(f"Found {len(instrument_classes)} instrument-related classes out of {len(class_names)} total:\n")

# Show first 30
for idx, name in instrument_classes[:30]:
    print(f"   [{idx:3d}] {name}")

if len(instrument_classes) > 30:
    print(f"\n   ... and {len(instrument_classes) - 30} more")
    
print(f"\n💡 These classes will be mapped to 25 Level 2 categories in the next step")

### 7.3. Browse YAMNet AudioSet Classes

Explore the 521 AudioSet classes to understand what YAMNet can detect.

In [None]:
# Fix: Import pandas for class name loading
import pandas as pd

print("📥 Loading YAMNet model from TensorFlow Hub...")
print("   This will download ~4MB on first run...")

try:
    # Load YAMNet model
    yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
    
    # Load class names
    class_map_path = yamnet_model.class_map_path().numpy()
    class_names = list(pd.read_csv(class_map_path)['display_name'])
    
    print(f"\n✅ YAMNet loaded successfully!")
    print(f"   Total AudioSet classes: {len(class_names)}")
    print(f"   Model inputs: 16kHz mono audio (variable length)")
    print(f"   Model outputs: 521-dim probability vector + 1024-dim embeddings")
    
except Exception as e:
    print(f"\n❌ YAMNet loading failed: {e}")
    import traceback
    traceback.print_exc()

In [None]:
print("📥 Loading YAMNet model from TensorFlow Hub...")
print("   This will download ~4MB on first run...")

try:
    # Load YAMNet model
    yamnet_model = hub.load('https://tfhub.dev/google/yamnet/1')
    
    # Load class names
    class_map_path = yamnet_model.class_map_path().numpy()
    class_names = list(pd.read_csv(class_map_path)['display_name'])
    
    print(f"\n✅ YAMNet loaded successfully!")
    print(f"   Total AudioSet classes: {len(class_names)}")
    print(f"   Model inputs: 16kHz mono audio (variable length)")
    print(f"   Model outputs: 521-dim probability vector + 1024-dim embeddings")
    
except Exception as e:
    print(f"\n❌ YAMNet loading failed: {e}")
    import traceback
    traceback.print_exc()

### 7.2. Load YAMNet Model

Load pretrained YAMNet model from TensorFlow Hub (auto-downloads ~4MB)

In [None]:
# Import TensorFlow and YAMNet dependencies
try:
    import tensorflow as tf
    import tensorflow_hub as hub
    import librosa
    
    print("✅ TensorFlow imports successful!")
    print(f"   TensorFlow version: {tf.__version__}")
    print(f"   TensorFlow Hub version: {hub.__version__}")
    print(f"   Librosa version: {librosa.__version__}")
    
    # Check for GPU
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print(f"   TensorFlow GPU: {len(gpus)} GPU(s) available")
    else:
        print("   TensorFlow GPU: Running on CPU")
        
except ImportError as e:
    print(f"❌ Import failed: {e}")
    print("\n📦 Please install dependencies first:")
    print("   Run on instance terminal:")
    print("   pip install -r instrument_recognition/requirements.txt")

## 🚧 Phase 2: Note Isolation (Coming Next)

**Goal**: Extract isolated notes from MIDI for timbre matching

**Tasks**:
1. Parse MIDI piano roll to identify individual notes
2. Extract corresponding audio segments from original audio
3. Score isolation quality (temporal, pitch, energy)
4. Select top N isolated notes per MIDI instrument

**Status**: Not implemented yet

## 🚧 Phase 3: Timbre Matching (Coming Next)

**Goal**: Match isolated notes to real instrument timbres using YAMNet

**Tasks**:
1. Run YAMNet inference on isolated audio segments
2. Aggregate predictions per MIDI instrument
3. Vote on most likely real instrument match
4. Generate confidence scores

**Status**: Not implemented yet

## 🚧 Phase 4: Output Generation (Coming Next)

**Goal**: Create final outputs - aggregated % and timeline

**Tasks**:
1. **Output B**: Aggregated instrument percentages (e.g., "Electric Guitar: 35%, Piano: 25%, ...")
2. **Output C**: Timeline of instrument presence (e.g., "0-30s: Piano, 30-60s: Guitar + Drums")
3. Visualization: Piano roll with identified instruments
4. Save results to JSON/CSV

**Status**: Not implemented yet

## 🚧 Phase 5: Evaluation (Coming Next)

**Goal**: Assess accuracy and tune thresholds

**Tasks**:
1. Manual evaluation against ground truth
2. Calculate accuracy metrics
3. Error analysis (common misclassifications)
4. Threshold tuning for optimal performance

**Target**: 85% accuracy, 80% instrument coverage

**Status**: Not implemented yet

---

## 📋 Implementation Checklist

**Phase 0** (Current):
- ✅ YourMT3 model loading
- ✅ Audio file selection
- ✅ MIDI transcription
- ✅ MIDI analysis and visualization

**Phase 1** (Next):
- ⬜ Install TensorFlow + YAMNet
- ⬜ Load pretrained YAMNet model
- ⬜ Create 521→25 instrument mapping
- ⬜ Test YAMNet inference

**Phase 2-5**:
- ⬜ Note isolation algorithm
- ⬜ Isolation quality scoring
- ⬜ YAMNet timbre matching
- ⬜ Aggregated output generation
- ⬜ Timeline output generation
- ⬜ Visualization
- ⬜ Accuracy evaluation

**Reference**: See `instrument_recognition/SPECIFICATION.md` for detailed technical specifications

---

*Assisted by Claude Code*