# MT3 PyTorch - Test Notebook

Complete testing notebook for MT3 audio-to-MIDI transcription on Brev Nvidia instances.

## Features
- ✅ Dependency verification
- ✅ Model loading and validation
- ✅ Audio transcription examples
- ✅ Multiple generation strategies
- ✅ Result visualization

## 1. Setup and Installation

First, let's verify all dependencies are installed.

In [None]:
# Install dependencies if needed
!pip install -q torch librosa soundfile note-seq pretty-midi absl-py tqdm numpy scikit-learn

In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("Running on CPU")

## 2. Import MT3 Modules

Test that all MT3 modules can be imported correctly.

In [None]:
# Import MT3 modules
try:
    from models import MT3Model, MT3Config
    from models.checkpoint_utils import load_mt3_checkpoint
    print("✅ Models module imported")
except Exception as e:
    print(f"❌ Failed to import models: {e}")

try:
    from preprocessing import AudioPreprocessor, AudioPreprocessingConfig
    print("✅ Preprocessing module imported")
except Exception as e:
    print(f"❌ Failed to import preprocessing: {e}")

try:
    from decoder.decoder import MT3TokenDecoder
    print("✅ Decoder module imported")
except Exception as e:
    print(f"❌ Failed to import decoder: {e}")

try:
    from inference import MT3Inference
    print("✅ Inference module imported")
except Exception as e:
    print(f"❌ Failed to import inference: {e}")

print("\n🎉 All modules imported successfully!")

## 3. Configuration

Set up paths and device configuration.

In [None]:
import os

# Configuration
CHECKPOINT_PATH = "mt3_converted.pth"  # Update this path if needed
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = "output_midi"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Checkpoint path: {CHECKPOINT_PATH}")
print(f"Device: {DEVICE}")
print(f"Output directory: {OUTPUT_DIR}")

# Check if checkpoint exists
if os.path.exists(CHECKPOINT_PATH):
    print(f"✅ Checkpoint found ({os.path.getsize(CHECKPOINT_PATH) / 1e6:.1f} MB)")
else:
    print(f"⚠️ Checkpoint not found at {CHECKPOINT_PATH}")
    print("   Please download or specify correct path")

## 4. Initialize MT3 Inference

Load the model and prepare for inference.

In [None]:
# Initialize inference handler
print("Loading MT3 model...")

inference = MT3Inference(
    checkpoint_path=CHECKPOINT_PATH,
    device=DEVICE,
    num_velocity_bins=1  # Use 1 for simple velocity, 127 for full range
)

print("\n✅ Model loaded successfully!")

# Display model information
info = inference.get_model_info()
print(f"\nModel Information:")
print(f"  Parameters: {info['model']['parameters']['total']:,}")
print(f"  Vocab size: {info['decoder']['num_classes']}")
print(f"  Sample rate: {info['preprocessor']['sample_rate']} Hz")
print(f"  Mel bins: {info['preprocessor']['n_mels']}")
print(f"  Device: {info['device']}")

## 5. Test with Audio File

Upload an audio file and transcribe it to MIDI.

### Option A: Upload audio file

In [None]:
# For Jupyter environments with file upload widget
from IPython.display import display
import ipywidgets as widgets

uploader = widgets.FileUpload(
    accept='.wav,.mp3,.flac,.m4a',
    multiple=False,
    description='Upload Audio'
)

display(uploader)
print("Upload an audio file (.wav, .mp3, .flac, .m4a)")

### Option B: Specify audio file path

In [None]:
# Or specify path directly
AUDIO_FILE = "path/to/your/audio.wav"  # Update this path

# Check if file exists
if os.path.exists(AUDIO_FILE):
    print(f"✅ Audio file found: {AUDIO_FILE}")
    print(f"   Size: {os.path.getsize(AUDIO_FILE) / 1e6:.2f} MB")
else:
    print(f"⚠️ Audio file not found: {AUDIO_FILE}")
    print("   Please update AUDIO_FILE path or use upload widget above")

## 6. Transcribe Audio → MIDI

### Basic Transcription (Greedy Decoding)

In [None]:
import time

# Basic transcription with greedy decoding
output_path = os.path.join(OUTPUT_DIR, "transcription_greedy.mid")

print("🎵 Transcribing audio (greedy decoding)...")
start_time = time.time()

result = inference.transcribe(
    audio_path=AUDIO_FILE,
    output_path=output_path,
    max_length=1024,
    do_sample=False  # Greedy decoding
)

elapsed = time.time() - start_time

print(f"\n✅ Transcription complete!")
print(f"   Output: {result['midi_path']}")
print(f"   Notes: {result['num_notes']}")
print(f"   Duration: {result['duration']:.2f}s")
print(f"   Processing time: {elapsed:.2f}s")
print(f"   Speed: {result['duration']/elapsed:.2f}x realtime")

### Transcription with Sampling (More Creative)

In [None]:
# Transcription with temperature sampling
output_path = os.path.join(OUTPUT_DIR, "transcription_sampling.mid")

print("🎵 Transcribing audio (temperature sampling)...")
start_time = time.time()

result = inference.transcribe(
    audio_path=AUDIO_FILE,
    output_path=output_path,
    max_length=1024,
    do_sample=True,
    temperature=0.8,
    top_p=0.9
)

elapsed = time.time() - start_time

print(f"\n✅ Transcription complete!")
print(f"   Output: {result['midi_path']}")
print(f"   Notes: {result['num_notes']}")
print(f"   Processing time: {elapsed:.2f}s")

### Long Audio Transcription (with Chunking)

In [None]:
# For audio files longer than 30 seconds
output_path = os.path.join(OUTPUT_DIR, "transcription_long.mid")

print("🎵 Transcribing long audio (with chunking)...")
start_time = time.time()

result = inference.transcribe_long_audio(
    audio_path=AUDIO_FILE,
    output_path=output_path,
    chunk_length=256,  # ~30 seconds per chunk
    max_length=1024
)

elapsed = time.time() - start_time

print(f"\n✅ Transcription complete!")
print(f"   Output: {result['midi_path']}")
print(f"   Notes: {result['num_notes']}")
print(f"   Chunks: {result['num_chunks']}")
print(f"   Processing time: {elapsed:.2f}s")

## 7. Analyze Results

Inspect the generated MIDI file.

In [None]:
import note_seq
import pretty_midi

# Load MIDI file
midi_path = result['midi_path']
ns = result['note_sequence']

print(f"📊 MIDI Analysis: {os.path.basename(midi_path)}")
print(f"\nBasic Statistics:")
print(f"  Total notes: {len(ns.notes)}")
print(f"  Duration: {ns.total_time:.2f}s")

# Analyze instruments
programs = {}
drums = []
for note in ns.notes:
    if note.is_drum:
        drums.append(note.pitch)
    else:
        programs[note.program] = programs.get(note.program, 0) + 1

print(f"\nInstruments detected:")
if programs:
    for program, count in sorted(programs.items()):
        print(f"  Program {program}: {count} notes")
if drums:
    print(f"  Drums: {len(drums)} hits")

# Pitch distribution
pitches = [note.pitch for note in ns.notes if not note.is_drum]
if pitches:
    print(f"\nPitch range:")
    print(f"  Lowest: {min(pitches)} ({pretty_midi.note_number_to_name(min(pitches))})")
    print(f"  Highest: {max(pitches)} ({pretty_midi.note_number_to_name(max(pitches))})")

# Note duration statistics
durations = [note.end_time - note.start_time for note in ns.notes]
if durations:
    import numpy as np
    print(f"\nNote durations:")
    print(f"  Mean: {np.mean(durations):.3f}s")
    print(f"  Min: {min(durations):.3f}s")
    print(f"  Max: {max(durations):.3f}s")

## 8. Batch Processing

Process multiple audio files at once.

In [None]:
# Example: Process multiple files
audio_files = [
    "path/to/audio1.wav",
    "path/to/audio2.wav",
    "path/to/audio3.wav"
]

# Filter existing files
audio_files = [f for f in audio_files if os.path.exists(f)]

if audio_files:
    print(f"Processing {len(audio_files)} files...\n")
    
    results = inference.transcribe_batch(
        audio_files=audio_files,
        output_dir=OUTPUT_DIR
    )
    
    # Summary
    print(f"\n📊 Batch Processing Summary:")
    successful = [r for r in results if 'num_notes' in r]
    failed = [r for r in results if 'error' in r]
    
    print(f"  Successful: {len(successful)}/{len(audio_files)}")
    print(f"  Failed: {len(failed)}/{len(audio_files)}")
    
    if successful:
        total_notes = sum(r['num_notes'] for r in successful)
        print(f"  Total notes: {total_notes}")
else:
    print("⚠️ No audio files found. Update the audio_files list.")

## 9. Compare Generation Strategies

Test different decoding strategies on the same audio.

In [None]:
strategies = [
    {"name": "Greedy", "do_sample": False},
    {"name": "Temperature 0.5", "do_sample": True, "temperature": 0.5},
    {"name": "Temperature 0.8", "do_sample": True, "temperature": 0.8},
    {"name": "Top-k 50", "do_sample": True, "top_k": 50},
    {"name": "Top-p 0.9", "do_sample": True, "top_p": 0.9},
]

print("🎯 Testing different generation strategies...\n")

comparison = []
for strategy in strategies:
    name = strategy.pop("name")
    output_path = os.path.join(OUTPUT_DIR, f"strategy_{name.replace(' ', '_').lower()}.mid")
    
    print(f"Testing: {name}...")
    start = time.time()
    
    result = inference.transcribe(
        audio_path=AUDIO_FILE,
        output_path=output_path,
        max_length=512,  # Shorter for comparison
        **strategy
    )
    
    elapsed = time.time() - start
    comparison.append({
        "strategy": name,
        "notes": result['num_notes'],
        "time": elapsed
    })
    
print(f"\n📊 Strategy Comparison:")
print(f"{'Strategy':<20} {'Notes':<10} {'Time (s)':<10}")
print("-" * 40)
for c in comparison:
    print(f"{c['strategy']:<20} {c['notes']:<10} {c['time']:<10.2f}")

## 10. Download Results

List all generated MIDI files.

In [None]:
# List all MIDI files in output directory
import glob

midi_files = glob.glob(os.path.join(OUTPUT_DIR, "*.mid"))

print(f"📁 Generated MIDI files ({len(midi_files)}):")
for midi_file in midi_files:
    size = os.path.getsize(midi_file) / 1024
    print(f"  • {os.path.basename(midi_file)} ({size:.1f} KB)")

print(f"\n💾 Files saved in: {OUTPUT_DIR}/")
print("   You can download them from the file browser or use:")
print(f"   !zip -r output_midi.zip {OUTPUT_DIR}")

## 11. Performance Benchmarking

Test model performance on different audio lengths.

In [None]:
# Benchmark inference speed
import librosa

if os.path.exists(AUDIO_FILE):
    # Load audio to get duration
    audio, sr = librosa.load(AUDIO_FILE, sr=16000)
    audio_duration = len(audio) / sr
    
    print(f"📊 Performance Benchmark")
    print(f"Audio duration: {audio_duration:.2f}s\n")
    
    # Run multiple times for average
    times = []
    for i in range(3):
        start = time.time()
        result = inference.transcribe(
            audio_path=AUDIO_FILE,
            output_path=os.path.join(OUTPUT_DIR, f"bench_{i}.mid"),
            max_length=512,
            do_sample=False
        )
        elapsed = time.time() - start
        times.append(elapsed)
        print(f"  Run {i+1}: {elapsed:.2f}s ({audio_duration/elapsed:.2f}x realtime)")
    
    avg_time = np.mean(times)
    print(f"\nAverage: {avg_time:.2f}s ({audio_duration/avg_time:.2f}x realtime)")
    print(f"Throughput: {audio_duration/avg_time:.2f}x faster than realtime")
else:
    print("⚠️ Audio file not found for benchmarking")

## 12. Cleanup

Optional: Clean up test files.

In [None]:
# Uncomment to clean up test files
# import shutil
# shutil.rmtree(OUTPUT_DIR)
# print(f"✅ Cleaned up {OUTPUT_DIR}/")

## Summary

This notebook tested:
- ✅ MT3 model loading and initialization
- ✅ Audio transcription with multiple strategies
- ✅ Long audio processing with chunking
- ✅ Batch processing capabilities
- ✅ Performance benchmarking

### Next Steps
1. Try with your own audio files
2. Experiment with different generation parameters
3. Compare results with different velocity bins (1 vs 127)
4. Fine-tune for your specific use case

### Resources
- GitHub: https://github.com/Pyzeur-ColonyLab/MT3_2025
- Documentation: See README.md and module-specific READMEs
- MT3 Paper: https://arxiv.org/abs/2111.03017