# 🎵 YourMT3 Interactive Music Transcription Test

Interactive notebook to test YourMT3 model with audio files and compare results.

**Features:**
- Upload or select audio files
- Transcribe with YourMT3
- Play original audio vs generated MIDI
- Visualize transcription results
- Compare quality

## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path
import glob

# Change to yourmt3_space directory
original_dir = os.getcwd()
if not os.path.exists('yourmt3_space'):
    print("❌ Error: yourmt3_space directory not found!")
    print("   Run setup_yourmt3_brev.sh first")
else:
    os.chdir('yourmt3_space')
    sys.path.insert(0, '.')
    sys.path.insert(0, 'amt/src')

import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import pretty_midi
from IPython.display import Audio, display, HTML, Markdown
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual

from model_helper import load_model_checkpoint, transcribe

print("✅ Imports successful!")
print(f"   Working directory: {os.getcwd()}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

## 2. Load YourMT3 Model

Load the model once - this takes ~10 seconds

In [None]:
print("Loading YourMT3 model...")
print("This may take 10-15 seconds...")

# Model configuration (same as demo)
checkpoint = "mc13_256_g4_all_v7_mt3f_sqr_rms_moe_wf4_n8k2_silu_rope_rp_b36_nops@last.ckpt"
project = '2024'
precision = '16'

args = [
    checkpoint,
    '-p', project,
    '-tk', 'mc13_full_plus_256',
    '-dec', 'multi-t5',
    '-nl', '26',
    '-enc', 'perceiver-tf',
    '-sqr', '1',
    '-ff', 'moe',
    '-wf', '4',
    '-nmoe', '8',
    '-kmoe', '2',
    '-act', 'silu',
    '-epe', 'rope',
    '-rp', '1',
    '-ac', 'spec',
    '-hop', '300',
    '-atc', '1',
    '-pr', precision
]

device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model_checkpoint(args=args, device=device)

print("\n✅ Model loaded successfully!")
print(f"   Device: {device}")
print(f"   Model: YPTF.MoE+Multi (noPS)")

## 3. Helper Functions

In [None]:
def get_audio_files(directory=None):
    """Find all audio files in directory"""
    if directory is None:
        directory = original_dir
    
    audio_extensions = ['*.mp3', '*.wav', '*.flac', '*.m4a', '*.ogg']
    audio_files = []
    
    for ext in audio_extensions:
        audio_files.extend(glob.glob(os.path.join(directory, ext)))
    
    return sorted(audio_files)

def analyze_midi(midi_path):
    """Analyze MIDI file and return statistics"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    
    stats = {
        'total_notes': sum(len(inst.notes) for inst in midi.instruments),
        'num_instruments': len(midi.instruments),
        'duration': midi.get_end_time(),
        'instruments': []
    }
    
    for i, inst in enumerate(midi.instruments):
        if len(inst.notes) > 0:
            stats['instruments'].append({
                'index': i,
                'program': inst.program,
                'name': pretty_midi.program_to_instrument_name(inst.program),
                'notes': len(inst.notes),
                'is_drum': inst.is_drum
            })
    
    return stats, midi

def plot_piano_roll(midi, max_time=None):
    """Plot piano roll visualization"""
    fig, ax = plt.subplots(figsize=(14, 6))
    
    if max_time is None:
        max_time = midi.get_end_time()
    
    # Plot each instrument with different color
    colors = plt.cm.tab10(np.linspace(0, 1, len(midi.instruments)))
    
    for idx, (inst, color) in enumerate(zip(midi.instruments, colors)):
        if len(inst.notes) == 0:
            continue
            
        inst_name = pretty_midi.program_to_instrument_name(inst.program)
        
        for note in inst.notes:
            if note.start <= max_time:
                duration = min(note.end, max_time) - note.start
                ax.barh(note.pitch, duration, left=note.start, 
                       height=0.8, color=color, alpha=0.7,
                       label=inst_name if note == inst.notes[0] else "")
    
    ax.set_xlabel('Time (seconds)', fontsize=12)
    ax.set_ylabel('MIDI Pitch', fontsize=12)
    ax.set_title('Piano Roll Visualization', fontsize=14, fontweight='bold')
    ax.set_xlim(0, max_time)
    ax.grid(True, alpha=0.3)
    
    # Remove duplicate labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), 
             loc='upper right', fontsize=10)
    
    plt.tight_layout()
    return fig

def midi_to_audio(midi_path, output_path=None, sample_rate=16000):
    """Convert MIDI to audio for playback"""
    midi = pretty_midi.PrettyMIDI(midi_path)
    audio = midi.fluidsynth(fs=sample_rate)
    
    if output_path:
        from scipy.io import wavfile
        wavfile.write(output_path, sample_rate, audio)
    
    return audio, sample_rate

print("✅ Helper functions loaded")

## 4. Interactive Audio File Selector

In [None]:
# Find available audio files
audio_files = get_audio_files(original_dir)

if len(audio_files) == 0:
    print("❌ No audio files found in directory!")
    print("   Please upload audio files (.mp3, .wav, .flac, etc.)")
else:
    print(f"✅ Found {len(audio_files)} audio files:")
    for i, f in enumerate(audio_files):
        print(f"   {i+1}. {os.path.basename(f)}")

# Create file selector widget
file_selector = widgets.Dropdown(
    options=[(os.path.basename(f), f) for f in audio_files],
    description='Audio File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='600px')
)

display(file_selector)

## 5. Transcribe Audio

Click the button to transcribe the selected audio file

In [None]:
# Global variable to store results
transcription_results = {}

def transcribe_audio(button):
    global transcription_results
    
    # Clear output
    output.clear_output(wait=True)
    
    with output:
        audio_path = file_selector.value
        
        if not audio_path:
            print("❌ Please select an audio file first!")
            return
        
        print(f"🎵 Transcribing: {os.path.basename(audio_path)}")
        print("=" * 60)
        
        # Get audio info
        info = torchaudio.info(audio_path)
        duration = info.num_frames / info.sample_rate
        
        print(f"Audio info:")
        print(f"  Duration: {duration:.2f}s")
        print(f"  Sample rate: {info.sample_rate} Hz")
        print(f"  Channels: {info.num_channels}")
        print()
        
        # Prepare audio info dict
        audio_info = {
            "filepath": audio_path,
            "track_name": os.path.splitext(os.path.basename(audio_path))[0],
            "sample_rate": int(info.sample_rate),
            "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
            "num_channels": int(info.num_channels),
            "num_frames": int(info.num_frames),
            "duration": int(duration),
            "encoding": 'unknown',
        }
        
        # Estimate time
        estimated_time = duration * 0.12  # ~0.12s per second of audio on A10G
        print(f"⏱️  Estimated time: ~{estimated_time:.1f}s")
        print("Transcribing...")
        
        # Transcribe
        try:
            midi_path = transcribe(model, audio_info)
            
            print()
            print("✅ Transcription complete!")
            print(f"   MIDI file: {midi_path}")
            
            # Analyze MIDI
            stats, midi = analyze_midi(midi_path)
            
            print()
            print("📊 Transcription Results:")
            print(f"   Total notes: {stats['total_notes']}")
            print(f"   Instruments detected: {stats['num_instruments']}")
            print(f"   Duration: {stats['duration']:.2f}s")
            print()
            print("🎼 Instruments breakdown:")
            
            for inst in sorted(stats['instruments'], key=lambda x: x['notes'], reverse=True):
                drum_mark = "🥁" if inst['is_drum'] else "🎹"
                print(f"   {drum_mark} {inst['name']}: {inst['notes']} notes (program {inst['program']})")
            
            # Store results
            transcription_results = {
                'audio_path': audio_path,
                'midi_path': midi_path,
                'stats': stats,
                'midi': midi,
                'duration': duration
            }
            
            print()
            print("=" * 60)
            print("✅ Ready to play and compare!")
            print("   Run the cells below to visualize and play the results")
            
        except Exception as e:
            print(f"❌ Transcription failed: {e}")
            import traceback
            traceback.print_exc()

# Create button and output
transcribe_button = widgets.Button(
    description='🎵 Transcribe Audio',
    button_style='success',
    layout=widgets.Layout(width='200px', height='40px')
)
transcribe_button.on_click(transcribe_audio)

output = widgets.Output()

display(transcribe_button)
display(output)

## 6. Visualize Piano Roll

Visual representation of the transcribed notes

In [None]:
if transcription_results:
    print("🎼 Piano Roll Visualization")
    print("=" * 60)
    
    # Create interactive time range selector
    max_duration = transcription_results['stats']['duration']
    
    @interact(max_time=widgets.FloatSlider(
        value=min(30, max_duration),
        min=5,
        max=max_duration,
        step=5,
        description='Display (s):',
        style={'description_width': 'initial'}
    ))
    def plot_roll(max_time):
        fig = plot_piano_roll(transcription_results['midi'], max_time)
        plt.show()
else:
    print("⚠️  Transcribe an audio file first!")

## 7. Play Original Audio

Listen to the original audio file

In [None]:
if transcription_results:
    print("🎧 Original Audio")
    print("=" * 60)
    
    audio_path = transcription_results['audio_path']
    
    # Load audio
    waveform, sample_rate = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    display(Audio(waveform.numpy(), rate=sample_rate))
    
else:
    print("⚠️  Transcribe an audio file first!")

## 8. Play Generated MIDI

Listen to the transcribed MIDI (synthesized to audio)

In [None]:
if transcription_results:
    print("🎹 Generated MIDI (Synthesized)")
    print("=" * 60)
    print("Converting MIDI to audio...")
    
    try:
        midi_path = transcription_results['midi_path']
        audio, sample_rate = midi_to_audio(midi_path)
        
        print("✅ Synthesis complete!")
        display(Audio(audio, rate=sample_rate))
        
    except Exception as e:
        print(f"❌ MIDI playback failed: {e}")
        print("   Note: FluidSynth may not be installed")
        print("   Install with: apt-get install fluidsynth")
        print(f"   MIDI file saved at: {midi_path}")
        
else:
    print("⚠️  Transcribe an audio file first!")

## 9. Side-by-Side Comparison

Compare statistics and quality

In [None]:
if transcription_results:
    print("📊 Comparison Summary")
    print("=" * 60)
    
    stats = transcription_results['stats']
    duration = transcription_results['duration']
    
    # Create comparison table
    comparison_html = f"""
    <table style="border-collapse: collapse; width: 100%; margin: 20px 0;">
        <tr style="background-color: #f0f0f0;">
            <th style="border: 1px solid #ddd; padding: 12px; text-align: left;">Metric</th>
            <th style="border: 1px solid #ddd; padding: 12px; text-align: left;">Original Audio</th>
            <th style="border: 1px solid #ddd; padding: 12px; text-align: left;">Transcribed MIDI</th>
        </tr>
        <tr>
            <td style="border: 1px solid #ddd; padding: 10px;">Duration</td>
            <td style="border: 1px solid #ddd; padding: 10px;">{duration:.2f}s</td>
            <td style="border: 1px solid #ddd; padding: 10px;">{stats['duration']:.2f}s</td>
        </tr>
        <tr>
            <td style="border: 1px solid #ddd; padding: 10px;">Notes</td>
            <td style="border: 1px solid #ddd; padding: 10px;">-</td>
            <td style="border: 1px solid #ddd; padding: 10px;"><strong>{stats['total_notes']}</strong></td>
        </tr>
        <tr>
            <td style="border: 1px solid #ddd; padding: 10px;">Instruments</td>
            <td style="border: 1px solid #ddd; padding: 10px;">-</td>
            <td style="border: 1px solid #ddd; padding: 10px;"><strong>{stats['num_instruments']}</strong></td>
        </tr>
        <tr>
            <td style="border: 1px solid #ddd; padding: 10px;">Note Density</td>
            <td style="border: 1px solid #ddd; padding: 10px;">-</td>
            <td style="border: 1px solid #ddd; padding: 10px;"><strong>{stats['total_notes']/stats['duration']:.1f}</strong> notes/sec</td>
        </tr>
    </table>
    """
    
    display(HTML(comparison_html))
    
    # Quality assessment questions
    print("\n❓ Quality Assessment Questions:")
    print("   1. Do the main melodies match?")
    print("   2. Are the rhythms accurate?")
    print("   3. Are the detected instruments correct?")
    print("   4. Are there missing or extra notes?")
    print("   5. Overall quality score (1-10): ___")
    
else:
    print("⚠️  Transcribe an audio file first!")

## 10. Download MIDI File

Download the generated MIDI for use in your DAW

In [None]:
if transcription_results:
    midi_path = transcription_results['midi_path']
    
    print("💾 MIDI File Location:")
    print("=" * 60)
    print(f"   {midi_path}")
    print()
    print("To download from Brev:")
    print(f"   scp ubuntu@brev-xxx:~/{midi_path} .")
    
else:
    print("⚠️  Transcribe an audio file first!")

## 11. Batch Testing (Optional)

Test multiple files and compare results

In [None]:
def batch_transcribe(audio_files_list):
    """Transcribe multiple audio files"""
    results = []
    
    for audio_path in audio_files_list:
        print(f"\n🎵 Processing: {os.path.basename(audio_path)}")
        print("=" * 60)
        
        try:
            # Get audio info
            info = torchaudio.info(audio_path)
            duration = info.num_frames / info.sample_rate
            
            audio_info = {
                "filepath": audio_path,
                "track_name": os.path.splitext(os.path.basename(audio_path))[0],
                "sample_rate": int(info.sample_rate),
                "bits_per_sample": int(info.bits_per_sample) if info.bits_per_sample else 16,
                "num_channels": int(info.num_channels),
                "num_frames": int(info.num_frames),
                "duration": int(duration),
                "encoding": 'unknown',
            }
            
            # Transcribe
            midi_path = transcribe(model, audio_info)
            
            # Analyze
            stats, _ = analyze_midi(midi_path)
            
            results.append({
                'file': os.path.basename(audio_path),
                'duration': duration,
                'notes': stats['total_notes'],
                'instruments': stats['num_instruments'],
                'midi_path': midi_path
            })
            
            print(f"   ✅ Complete: {stats['total_notes']} notes, {stats['num_instruments']} instruments")
            
        except Exception as e:
            print(f"   ❌ Failed: {e}")
            results.append({
                'file': os.path.basename(audio_path),
                'error': str(e)
            })
    
    return results

# Uncomment to run batch test
# batch_results = batch_transcribe(audio_files[:3])  # Test first 3 files
# print("\n📊 Batch Results:")
# for r in batch_results:
#     if 'error' not in r:
#         print(f"   {r['file']}: {r['notes']} notes, {r['instruments']} inst")

print("✅ Batch testing function ready")
print("   Uncomment the code above to run batch tests")