# Piano MIDI Generation - Data Preprocessing

This notebook handles preprocessing of the ARIA MIDI dataset for training a transformer model.

## Overview

1. **Load and analyze metadata** - Understand field coverage
2. **Load MIDI files** - Read and parse MIDI data
3. **Convert MIDI to tokens** - Tokenize MIDI events
4. **Combine with metadata** - Create full training sequences
5. **Save processed data** - Prepare for training

## Dataset Structure

- Metadata: `aria-midi-v1-deduped-ext/metadata.json`
- MIDI files: `aria-midi-v1-deduped-ext/data/{aa-zz}/*.mid`
- File naming: `{ID}_{audio_index}.mid` (e.g., `000002_0.mid`)


In [None]:
# Install required packages

%pip install mido pretty_midi tqdm


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [13]:
# Import libraries
import json
import mido
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import defaultdict, Counter
import re

# Optional: tqdm for progress bars (fallback if not installed)
try:
    from tqdm import tqdm
except ImportError:
    # Simple fallback progress indicator
    def tqdm(iterable, desc=None, **kwargs):
        if desc:
            print(f"{desc}...")
        return iterable

print("‚úÖ Libraries imported successfully")


‚úÖ Libraries imported successfully


## Step 1: Metadata Analysis

First, we analyze metadata fields to understand which fields have sufficient coverage.


In [14]:
# Load metadata
metadata_path = Path("aria-midi-v1-deduped-ext/metadata.json")

print(f"Loading metadata from: {metadata_path}")
with open(metadata_path, 'r', encoding='utf-8') as f:
    metadata = json.load(f)

print(f"‚úÖ Loaded {len(metadata):,} entries")
print(f"Sample entry keys: {list(metadata.keys())[:5]}")

# Show sample entry
sample_id = list(metadata.keys())[0]
print(f"\nSample entry (ID: {sample_id}):")
print(json.dumps(metadata[sample_id], indent=2))


Loading metadata from: aria-midi-v1-deduped-ext\metadata.json
‚úÖ Loaded 371,053 entries
Sample entry keys: ['2', '3', '4', '6', '8']

Sample entry (ID: 2):
{
  "metadata": {
    "composer": "strauss",
    "form": "waltz",
    "performer": "cziffra",
    "genre": "classical",
    "music_period": "classical"
  },
  "audio_scores": {
    "0": 0.9902
  }
}


### Field Inclusion Decision

Based on coverage threshold (‚â•30%):


In [15]:
# Determine which fields to include/exclude based on coverage
COVERAGE_THRESHOLD = 30.0

included_fields = []
excluded_fields = []

for field in sorted(analysis['fields']):
    count = analysis['presence'][field]
    coverage = (count / analysis['total']) * 100
    
    if coverage >= COVERAGE_THRESHOLD:
        included_fields.append((field, coverage, count))
    else:
        excluded_fields.append((field, coverage, count))

print("=" * 60)
print("FIELD INCLUSION DECISION (Threshold: ‚â•30% coverage)")
print("=" * 60)

print(f"\n‚úÖ INCLUDED FIELDS ({len(included_fields)} fields):")
print(f"{'Field':<20} {'Coverage':<12} {'Count':<15} {'Decision'}")
print("-" * 60)
for field, coverage, count in sorted(included_fields, key=lambda x: x[1], reverse=True):
    decision = "REQUIRED" if coverage >= 70 else "RECOMMENDED" if coverage >= 50 else "OPTIONAL"
    print(f"{field:<20} {coverage:>6.2f}%     {count:>12,}     {decision}")

print(f"\n‚ùå EXCLUDED FIELDS ({len(excluded_fields)} fields - too sparse):")
print(f"{'Field':<20} {'Coverage':<12} {'Count':<15}")
print("-" * 60)
for field, coverage, count in sorted(excluded_fields, key=lambda x: x[1], reverse=True):
    print(f"{field:<20} {coverage:>6.2f}%     {count:>12,}")

print(f"\nüìä Summary:")
print(f"  Total fields analyzed: {len(analysis['fields'])}")
print(f"  Included: {len(included_fields)} ({len(included_fields)/len(analysis['fields'])*100:.1f}%)")
print(f"  Excluded: {len(excluded_fields)} ({len(excluded_fields)/len(analysis['fields'])*100:.1f}%)")


FIELD INCLUSION DECISION (Threshold: ‚â•30% coverage)

‚úÖ INCLUDED FIELDS (3 fields):
Field                Coverage     Count           Decision
------------------------------------------------------------
genre                 74.64%          276,948     REQUIRED
composer              39.13%          145,186     OPTIONAL
music_period          38.99%          144,691     OPTIONAL

‚ùå EXCLUDED FIELDS (6 fields - too sparse):
Field                Coverage     Count          
------------------------------------------------------------
form                  16.40%           60,848
difficulty            11.21%           41,587
performer              6.85%           25,401
opus                   6.58%           24,399
key_signature          6.04%           22,403
piece_number           4.70%           17,457

üìä Summary:
  Total fields analyzed: 9
  Included: 3 (33.3%)
  Excluded: 6 (66.7%)


In [16]:
# Analyze metadata field coverage
def analyze_metadata_fields(metadata_dict, sample_size=None):
    """Analyze which metadata fields exist and their coverage"""
    all_fields = set()
    field_presence = defaultdict(int)
    field_values = defaultdict(Counter)
    total_entries = 0
    empty_count = 0
    
    entries = list(metadata_dict.items())
    if sample_size:
        entries = entries[:sample_size]
    
    for entry_id, entry_data in entries:
        total_entries += 1
        metadata_fields = entry_data.get('metadata', {})
        
        if not metadata_fields:
            empty_count += 1
            continue
        
        for field, value in metadata_fields.items():
            all_fields.add(field)
            field_presence[field] += 1
            
            if isinstance(value, (int, float)):
                field_values[field][str(value)] += 1
            elif isinstance(value, str):
                field_values[field][value.lower()] += 1
    
    return {
        'total': total_entries,
        'empty': empty_count,
        'fields': all_fields,
        'presence': dict(field_presence),
        'values': dict(field_values)
    }

# Run analysis
print("Analyzing metadata fields...")
analysis = analyze_metadata_fields(metadata)

print(f"\nüìä Metadata Analysis Results:")
print(f"Total entries: {analysis['total']:,}")
print(f"Empty metadata: {analysis['empty']:,} ({analysis['empty']/analysis['total']*100:.2f}%)")
print(f"Unique fields: {len(analysis['fields'])}")

# Print coverage for each field
print(f"\n{'Field':<20} {'Count':<12} {'Coverage':<12}")
print("-" * 44)
for field in sorted(analysis['fields']):
    count = analysis['presence'][field]
    coverage = (count / analysis['total']) * 100
    print(f"{field:<20} {count:<12,} {coverage:>6.2f}%")


Analyzing metadata fields...

üìä Metadata Analysis Results:
Total entries: 371,053
Empty metadata: 52,060 (14.03%)
Unique fields: 9

Field                Count        Coverage    
--------------------------------------------
composer             145,186       39.13%
difficulty           41,587        11.21%
form                 60,848        16.40%
genre                276,948       74.64%
key_signature        22,403         6.04%
music_period         144,691       38.99%
opus                 24,399         6.58%
performer            25,401         6.85%
piece_number         17,457         4.70%


## Step 2: Metadata Tokenizer

Create a tokenizer that converts metadata dictionaries to tokens, handling missing fields gracefully.


In [17]:
# Metadata Tokenizer Class
# Based on coverage analysis - only fields with ‚â•30% coverage are included

FIELD_COVERAGE_THRESHOLD = 30.0  # Minimum coverage % to include a field

# Field Decision Matrix (based on analysis):
# ‚úÖ INCLUDE (>30% coverage):
#   - genre: 74.64% - REQUIRED (highest coverage)
#   - music_period: 38.99% - RECOMMENDED
#   - composer: 39.13% - OPTIONAL (requires normalization, high cardinality)
#
# ‚ùå EXCLUDE (<30% coverage - too sparse):
#   - form: 16.40%
#   - difficulty: 11.21%
#   - key_signature: 6.04%
#   - opus: 6.58%
#   - performer: 6.85%
#   - piece_number: 4.70%

class MetadataTokenizer:
    """
    Converts metadata to tokens based on field coverage analysis.
    
    Only includes fields with ‚â•30% coverage to ensure sufficient training signal.
    Fields below threshold are automatically excluded.
    
    Included fields:
    - genre (74.64%) - Always included if present
    - music_period (38.99%) - Included if present
    - composer (39.13%) - Optional, only top-N composers (requires normalization)
    """
    
    def __init__(self, include_composer=True, top_n_composers=100):
        self.include_composer = include_composer
        
        # Valid genres (74.64% coverage)
        self.valid_genres = {
            'classical', 'pop', 'soundtrack', 'jazz', 'rock', 
            'folk', 'ambient', 'ragtime', 'blues', 'atonal'
        }
        
        # Valid music periods (38.99% coverage)
        self.valid_periods = {
            'contemporary', 'modern', 'romantic', 'classical', 
            'baroque', 'impressionist'
        }
        
        # Top composers (load from actual analysis)
        self.top_composers = self._load_top_composers(top_n_composers)
    
    def _load_top_composers(self, n):
        """Load top N composers"""
        # Based on dataset analysis - you can extend this
        top = {
            'hisaishi', 'satie', 'yiruma', 'einaudi', 'joplin',
            'chopin', 'beethoven', 'bach', 'mozart', 'debussy',
            'schubert', 'schumann', 'liszt', 'rachmaninoff', 'tchaikovsky',
            'ravel', 'poulenc', 'faure', 'bartok'
        }
        return {self._normalize_composer(c) for c in top}
    
    def _normalize_composer(self, composer):
        """Normalize composer name"""
        if not composer:
            return ""
        normalized = composer.lower().strip()
        # Remove accents
        normalized = normalized.replace('√©', 'e').replace('√®', 'e')
        normalized = normalized.replace('√°', 'a').replace('√†', 'a')
        normalized = normalized.replace('√≠', 'i').replace('√¨', 'i')
        normalized = normalized.replace('√≥', 'o').replace('√≤', 'o')
        normalized = normalized.replace('√∫', 'u').replace('√π', 'u')
        normalized = normalized.replace('√±', 'n')
        # Remove special chars
        normalized = re.sub(r'[^a-z0-9\s-]', '', normalized)
        normalized = re.sub(r'\s+', ' ', normalized).strip()
        return normalized
    
    def metadata_to_tokens(self, metadata, include_start=True):
        """Convert metadata dict to token list"""
        tokens = []
        if include_start:
            tokens.append("START")
        
        # Genre (74.64% coverage)
        if metadata.get('genre'):
            genre = metadata['genre'].lower().strip()
            if genre in self.valid_genres:
                tokens.append(f"GENRE:{genre}")
        
        # Music period (38.99% coverage)
        if metadata.get('music_period'):
            period = metadata['music_period'].lower().strip()
            if period in self.valid_periods:
                tokens.append(f"PERIOD:{period}")
        
        # Composer (39.13% coverage) - optional
        if self.include_composer and metadata.get('composer'):
            composer = self._normalize_composer(metadata['composer'])
            if composer in self.top_composers:
                tokens.append(f"COMPOSER:{composer}")
        
        return tokens

# Test the tokenizer
tokenizer = MetadataTokenizer(include_composer=True)

test_metadata = [
    {"genre": "classical", "music_period": "romantic", "composer": "Chopin"},
    {"genre": "jazz", "music_period": "modern"},  # No composer
    {"genre": "pop"},  # Minimal
    {}  # Empty
]

print("Testing MetadataTokenizer:")
for i, meta in enumerate(test_metadata, 1):
    tokens = tokenizer.metadata_to_tokens(meta)
    print(f"  Test {i}: {meta}")
    print(f"    ‚Üí Tokens: {tokens}")
    print()


Testing MetadataTokenizer:
  Test 1: {'genre': 'classical', 'music_period': 'romantic', 'composer': 'Chopin'}
    ‚Üí Tokens: ['START', 'GENRE:classical', 'PERIOD:romantic', 'COMPOSER:chopin']

  Test 2: {'genre': 'jazz', 'music_period': 'modern'}
    ‚Üí Tokens: ['START', 'GENRE:jazz', 'PERIOD:modern']

  Test 3: {'genre': 'pop'}
    ‚Üí Tokens: ['START', 'GENRE:pop']

  Test 4: {}
    ‚Üí Tokens: ['START']



## Step 3: MIDI to Tokens Conversion

Convert MIDI files to token sequences using event-based representation.


In [7]:
# MIDI to Tokens Converter
class MIDITokenizer:
    """
    Converts MIDI files to token sequences.
    Uses event-based representation: TIME_SHIFT, NOTE_ON, NOTE_OFF, VELOCITY
    """
    
    def __init__(self, time_quantization=10):
        """
        Args:
            time_quantization: Quantize time to this many ticks (smaller = finer resolution)
        """
        self.time_quantization = time_quantization
    
    def midi_to_tokens(self, midi_path: Path) -> List[str]:
        """
        Convert MIDI file to token sequence
        
        Returns list of tokens like: ["TIME_SHIFT:0", "NOTE_ON:60", "VELOCITY:80", ...]
        """
        try:
            mid = mido.MidiFile(midi_path)
            tokens = []
            current_time = 0
            
            # Process all tracks
            for track in mid.tracks:
                for msg in track:
                    # Accumulate time
                    current_time += int(msg.time)
                    
                    # Quantize time
                    quantized_time = (current_time // self.time_quantization) * self.time_quantization
                    
                    # Note On (velocity > 0)
                    if msg.type == 'note_on' and msg.velocity > 0:
                        if quantized_time > 0:
                            tokens.append(f"TIME_SHIFT:{quantized_time}")
                        tokens.append(f"NOTE_ON:{msg.note}")
                        tokens.append(f"VELOCITY:{msg.velocity}")
                        current_time = 0
                    
                    # Note Off (or Note On with velocity 0)
                    elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0):
                        if quantized_time > 0:
                            tokens.append(f"TIME_SHIFT:{quantized_time}")
                        tokens.append(f"NOTE_OFF:{msg.note}")
                        current_time = 0
            
            return tokens
        
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            return []
    
    def tokens_to_vocab_mapping(self, tokens_list: List[List[str]]) -> Dict[str, int]:
        """Create vocabulary mapping from all tokens"""
        all_tokens = set()
        for tokens in tokens_list:
            all_tokens.update(tokens)
        
        # Special tokens
        vocab = {
            "<PAD>": 0,
            "<UNK>": 1,
            "<START>": 2,
            "<END>": 3,
        }
        
        # Add all unique tokens
        for token in sorted(all_tokens):
            if token not in vocab:
                vocab[token] = len(vocab)
        
        return vocab

# Test MIDI tokenizer
midi_tokenizer = MIDITokenizer(time_quantization=10)

# Test with a sample MIDI file (if available)
data_path = Path("aria-midi-v1-deduped-ext/data")
if data_path.exists():
    # Find first MIDI file
    midi_files = list(data_path.glob("**/*.mid"))
    if midi_files:
        test_midi = midi_files[0]
        print(f"Testing with: {test_midi.name}")
        tokens = midi_tokenizer.midi_to_tokens(test_midi)
        print(f"  Generated {len(tokens)} tokens")
        print(f"  First 20 tokens: {tokens[:20]}")
    else:
        print("No MIDI files found for testing")
else:
    print("Data directory not found - will process during full preprocessing")


Testing with: 000002_0.mid
  Generated 37741 tokens
  First 20 tokens: ['TIME_SHIFT:1280', 'NOTE_ON:56', 'VELOCITY:60', 'TIME_SHIFT:30', 'NOTE_ON:58', 'VELOCITY:70', 'TIME_SHIFT:20', 'NOTE_OFF:56', 'TIME_SHIFT:10', 'NOTE_ON:32', 'VELOCITY:85', 'NOTE_ON:44', 'VELOCITY:100', 'TIME_SHIFT:10', 'NOTE_ON:39', 'VELOCITY:65', 'TIME_SHIFT:10', 'NOTE_ON:56', 'VELOCITY:90', 'TIME_SHIFT:800']


## Step 5: Balanced Dataset Sampling

**Problem:** Using full dataset creates bias:
1. **Composer bias**: Only top-N composers get tokens, rest get no composer ‚Üí most files have no composer token
2. **Empty metadata bias**: 14% of files have empty metadata ‚Üí too many unconditional examples
3. **Genre bias**: Some genres may dominate

**Solution:** Create a balanced subset that ensures:
- Representative composer distribution
- Limited empty metadata samples
- Balanced genre distribution
- Similar number of examples per category


In [19]:
# Balanced sampling configuration
SAMPLING_CONFIG = {
    'target_samples_per_category': 1000,  # Target samples per composer/genre combination
    'max_empty_metadata_ratio': 0.05,  # Max 5% of dataset with empty metadata
    'max_per_composer': 500,  # Max samples per composer (top-N)
    'max_per_genre': None,  # None = no limit, or set number to balance
    'composer_strategy': 'balanced',  # 'balanced' or 'exclude'
    # If 'exclude': Don't use composer field at all (avoids bias completely)
    # If 'balanced': Include top-N composers with balanced samples
}

def analyze_metadata_distribution(metadata_dict):
    """Analyze distribution of metadata fields for balancing"""
    stats = {
        'by_composer': defaultdict(int),
        'by_genre': defaultdict(int),
        'by_period': defaultdict(int),
        'by_composer_genre': defaultdict(int),
        'empty_metadata': [],
        'with_composer': [],
        'no_composer': [],
    }
    
    for entry_id, entry_data in metadata_dict.items():
        metadata = entry_data.get('metadata', {})
        audio_scores = entry_data.get('audio_scores', {})
        
        # Check quality
        if not audio_scores:
            continue
        best_score = max(audio_scores.values())
        if best_score < 0.97:  # Quality threshold
            continue
        
        genre = metadata.get('genre', '').lower() if metadata.get('genre') else None
        composer = metadata.get('composer', '').lower() if metadata.get('composer') else None
        period = metadata.get('music_period', '').lower() if metadata.get('music_period') else None
        
        if not metadata:
            stats['empty_metadata'].append(entry_id)
        else:
            if genre:
                stats['by_genre'][genre] += 1
            if composer:
                stats['by_composer'][composer] += 1
                stats['with_composer'].append(entry_id)
            else:
                stats['no_composer'].append(entry_id)
            if period:
                stats['by_period'][period] += 1
            if composer and genre:
                stats['by_composer_genre'][(composer, genre)] += 1
    
    return stats

# Analyze distribution
print("Analyzing metadata distribution for balancing...")
distribution = analyze_metadata_distribution(metadata)

print(f"\nüìä Distribution Analysis:")
print(f"  Files with composer: {len(distribution['with_composer']):,}")
print(f"  Files without composer: {len(distribution['no_composer']):,}")
print(f"  Empty metadata: {len(distribution['empty_metadata']):,}")

print(f"\n  Top 10 composers:")
for composer, count in sorted(distribution['by_composer'].items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"    {composer}: {count:,}")

print(f"\n  Genre distribution:")
for genre, count in sorted(distribution['by_genre'].items(), key=lambda x: x[1], reverse=True):
    print(f"    {genre}: {count:,}")


Analyzing metadata distribution for balancing...

üìä Distribution Analysis:
  Files with composer: 91,518
  Files without composer: 103,338
  Empty metadata: 28,507

  Top 10 composers:
    hisaishi: 2,013
    satie: 1,410
    yiruma: 1,040
    bach: 1,018
    joplin: 922
    handel: 910
    einaudi: 902
    uematsu: 788
    gershwin: 691
    sakamoto: 682

  Genre distribution:
    classical: 70,933
    pop: 44,616
    soundtrack: 32,726
    jazz: 10,937
    rock: 3,636
    folk: 3,120
    ragtime: 2,233
    ambient: 1,974
    blues: 627
    atonal: 69


In [21]:
# Balanced sampling function
import random

def create_balanced_sample(metadata_dict, sampling_config, tokenizer):
    """
    Create a balanced, unbiased subset of the dataset
    
    Strategy:
    1. If composer_strategy == 'exclude': Don't use composer at all (eliminates bias)
    2. If composer_strategy == 'balanced': Sample equally from top-N composers + no-composer
    3. Limit empty metadata to max ratio
    4. Balance genres proportionally
    """
    random.seed(42)  # For reproducibility
    
    # Analyze distribution
    distribution = analyze_metadata_distribution(metadata_dict)
    
    # Strategy decision
    if sampling_config['composer_strategy'] == 'exclude':
        print("‚ö†Ô∏è  Composer field will be EXCLUDED to avoid bias")
        print("   All composer tokens will be skipped during tokenization")
        # Don't filter by composer - just balance other fields
        composer_sampling = None
    else:
        print("‚úÖ Using BALANCED composer sampling")
        
        # Get top composers (based on tokenizer's top_composers)
        top_composers = set(tokenizer.top_composers)
        
        # Group by composer category
        composer_groups = {
            'top_composer': defaultdict(list),  # Files with top-N composers
            'other_composer': [],  # Files with other composers
            'no_composer': []  # Files without composer
        }
        
        for entry_id, entry_data in metadata_dict.items():
            metadata = entry_data.get('metadata', {})
            audio_scores = entry_data.get('audio_scores', {})
            
            if not audio_scores:
                continue
            if max(audio_scores.values()) < 0.97:
                continue
            
            composer = metadata.get('composer', '').lower() if metadata.get('composer') else None
            normalized_composer = tokenizer._normalize_composer(composer) if composer else None
            
            if normalized_composer and normalized_composer in top_composers:
                composer_groups['top_composer'][normalized_composer].append(entry_id)
            elif composer:
                composer_groups['other_composer'].append(entry_id)
            else:
                composer_groups['no_composer'].append(entry_id)
        
        # Sample strategy: equal number from each top composer + no-composer group
        max_per_category = sampling_config['max_per_composer']
        samples_per_composer = min(
            max_per_category,
            min(len(files) for files in composer_groups['top_composer'].values()) if composer_groups['top_composer'] else 0
        )
        
        # Ensure no-composer group has similar size
        no_composer_limit = samples_per_composer * len(composer_groups['top_composer'])
        
        sampled_ids = set()
        
        # Sample from each top composer
        for composer, file_ids in composer_groups['top_composer'].items():
            sampled = random.sample(file_ids, min(samples_per_composer, len(file_ids)))
            sampled_ids.update(sampled)
            print(f"  Composer '{composer}': {len(sampled)}/{len(file_ids)} samples")
        
        # Sample from no-composer group
        no_composer_sample = random.sample(
            composer_groups['no_composer'], 
            min(no_composer_limit, len(composer_groups['no_composer']))
        )
        sampled_ids.update(no_composer_sample)
        print(f"  No composer: {len(no_composer_sample)} samples")
        
        # Skip other_composer (would create bias)
        print(f"  Other composers (excluded to avoid bias): {len(composer_groups['other_composer']):,}")
        
        composer_sampling = sampled_ids
    
    # Handle empty metadata limitation
    max_empty = int(sampling_config['max_empty_metadata_ratio'] * len(distribution['empty_metadata']))
    empty_sample = random.sample(distribution['empty_metadata'], min(max_empty, len(distribution['empty_metadata'])))
    
    if composer_sampling:
        final_ids = list(composer_sampling) + empty_sample
    else:
        # If excluding composer, sample from all (with empty metadata limit)
        all_ids = distribution['with_composer'] + distribution['no_composer']
        final_ids = random.sample(all_ids, min(50000, len(all_ids))) + empty_sample
    
    # Remove duplicates
    final_ids = list(set(final_ids))
    
    print(f"\n‚úÖ Balanced sample created:")
    print(f"  Total samples: {len(final_ids):,}")
    print(f"  Empty metadata: {len(empty_sample):,} ({len(empty_sample)/len(final_ids)*100:.1f}%)")
    
    return final_ids

# Create balanced sample
print("Creating balanced dataset sample...")
print("=" * 60)

# Option 1: Balanced composer sampling (recommended)
# - Equal samples from top-N composers
# - Equal samples from no-composer group
# - Excludes "other" composers to avoid bias

# Option 2: Exclude composer entirely (alternative - eliminates composer bias completely)
# SAMPLING_CONFIG['composer_strategy'] = 'exclude'

balanced_ids = create_balanced_sample(metadata, SAMPLING_CONFIG, tokenizer)

# Create filtered metadata dict
balanced_metadata = {entry_id: metadata[entry_id] for entry_id in balanced_ids if entry_id in metadata}

print(f"\nüìä Final balanced dataset:")
print(f"  Original size: {len(metadata):,}")
print(f"  Balanced size: {len(balanced_metadata):,}")
print(f"  Reduction: {(1 - len(balanced_metadata)/len(metadata))*100:.1f}%")

# Analyze final distribution
print("\nüìà Final dataset distribution:")
empty_count = sum(1 for entry in balanced_metadata.values() if not entry.get('metadata', {}))
print(f"  Empty metadata: {empty_count:,} ({empty_count/len(balanced_metadata)*100:.1f}%)")
print(f"  Target was: <{SAMPLING_CONFIG['max_empty_metadata_ratio']*100:.0f}%")

# Show composer distribution
composer_counts = defaultdict(int)
for entry in balanced_metadata.values():
    composer = entry.get('metadata', {}).get('composer', '').lower() if entry.get('metadata', {}).get('composer') else None
    if composer:
        normalized = tokenizer._normalize_composer(composer)
        if normalized in tokenizer.top_composers:
            composer_counts[normalized] += 1

print(f"\n  Composer distribution in balanced set:")
for composer, count in sorted(composer_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"    {composer}: {count:,}")

no_composer_count = sum(1 for entry in balanced_metadata.values() 
                       if not entry.get('metadata', {}).get('composer'))
print(f"    No composer: {no_composer_count:,}")


Creating balanced dataset sample...
‚úÖ Using BALANCED composer sampling
  Composer 'chopin': 64/388 samples
  Composer 'rachmaninoff': 64/175 samples
  Composer 'satie': 64/1410 samples
  Composer 'bach': 64/1018 samples
  Composer 'beethoven': 64/314 samples
  Composer 'yiruma': 64/1040 samples
  Composer 'hisaishi': 64/2013 samples
  Composer 'poulenc': 64/577 samples
  Composer 'debussy': 64/170 samples
  Composer 'joplin': 64/922 samples
  Composer 'mozart': 64/398 samples
  Composer 'schubert': 64/495 samples
  Composer 'bartok': 64/167 samples
  Composer 'einaudi': 64/902 samples
  Composer 'schumann': 64/363 samples
  Composer 'liszt': 64/567 samples
  Composer 'ravel': 64/64 samples
  Composer 'faure': 64/208 samples
  Composer 'tchaikovsky': 64/168 samples
  No composer: 1216 samples
  Other composers (excluded to avoid bias): 80,159

‚úÖ Balanced sample created:
  Total samples: 3,843
  Empty metadata: 1,425 (37.1%)

üìä Final balanced dataset:
  Original size: 371,053
  Ba

## Step 4: Find MIDI Files from Metadata

Given a metadata entry ID and audio index, find the corresponding MIDI file.


In [22]:
# Function to find MIDI file path
def find_midi_file(file_id: str, audio_index: str, data_root: Path) -> Optional[Path]:
    """
    Find MIDI file given its ID and audio index
    
    Args:
        file_id: Numeric ID as string (e.g., "2", "647")
        audio_index: Audio score index (e.g., "0", "1")
        data_root: Root directory containing data/ folder
    
    Returns:
        Path to MIDI file if found, None otherwise
    """
    # Format filename: 000002_0.mid
    padded_id = file_id.zfill(6)
    filename = f"{padded_id}_{audio_index}.mid"
    
    # Search in all subdirectories (aa, ab, ac, etc.)
    for subfolder in data_root.iterdir():
        if subfolder.is_dir() and len(subfolder.name) == 2:
            filepath = subfolder / filename
            if filepath.exists():
                return filepath
    
    return None

# Test file finding
data_root = Path("aria-midi-v1-deduped-ext/data")
if data_root.exists():
    # Test with first few entries
    test_ids = list(metadata.keys())[:5]
    print("Testing file finding:")
    for entry_id in test_ids:
        entry = metadata[entry_id]
        audio_scores = entry.get('audio_scores', {})
        if audio_scores:
            audio_idx = list(audio_scores.keys())[0]
            filepath = find_midi_file(entry_id, audio_idx, data_root)
            if filepath:
                print(f"  ‚úÖ ID {entry_id}: {filepath.name}")
            else:
                print(f"  ‚ùå ID {entry_id}: Not found")
else:
    print("Data directory not found")


Testing file finding:
  ‚úÖ ID 2: 000002_0.mid
  ‚úÖ ID 3: 000003_0.mid
  ‚úÖ ID 4: 000004_0.mid
  ‚úÖ ID 6: 000006_0.mid
  ‚úÖ ID 8: 000008_0.mid


## Step 5: Full Preprocessing Pipeline

Process all entries: combine metadata tokens + MIDI tokens into training sequences.


In [23]:
# Preprocessing configuration
CONFIG = {
    'min_quality_score': 0.97,  # Only use high-quality transcriptions
    'max_sequence_length': 2048,  # Maximum tokens per sequence
    'time_quantization': 10,  # MIDI time quantization
    'data_root': Path("aria-midi-v1-deduped-ext/data"),
    'output_dir': Path("processed_data"),
    'sample_size': None,  # Set to number for testing, None for full dataset
}

# Create output directory
CONFIG['output_dir'].mkdir(exist_ok=True)

print("Preprocessing Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")


Preprocessing Configuration:
  min_quality_score: 0.97
  max_sequence_length: 2048
  time_quantization: 10
  data_root: aria-midi-v1-deduped-ext\data
  output_dir: processed_data
  sample_size: None


In [24]:
# Initialize tokenizers
meta_tokenizer = MetadataTokenizer(include_composer=True)
midi_tokenizer = MIDITokenizer(time_quantization=CONFIG['time_quantization'])

def process_entry(entry_id: str, entry_data: Dict, data_root: Path) -> Optional[List[str]]:
    """
    Process a single entry: combine metadata + MIDI tokens
    
    Returns:
        Full token sequence or None if processing fails
    """
    # 1. Filter by quality
    audio_scores = entry_data.get('audio_scores', {})
    if not audio_scores:
        return None
    
    # Get best quality audio index
    best_idx = max(audio_scores.items(), key=lambda x: x[1])[0]
    score = audio_scores[best_idx]
    
    if score < CONFIG['min_quality_score']:
        return None
    
    # 2. Get metadata tokens
    metadata_dict = entry_data.get('metadata', {})
    metadata_tokens = meta_tokenizer.metadata_to_tokens(metadata_dict, include_start=True)
    
    # 3. Find and load MIDI file
    midi_path = find_midi_file(entry_id, best_idx, data_root)
    if not midi_path or not midi_path.exists():
        return None
    
    # 4. Convert MIDI to tokens
    midi_tokens = midi_tokenizer.midi_to_tokens(midi_path)
    if not midi_tokens:
        return None
    
    # 5. Combine: metadata + MIDI + END token
    full_sequence = metadata_tokens + midi_tokens + ["<END>"]
    
    # 6. Truncate if too long
    if len(full_sequence) > CONFIG['max_sequence_length']:
        # Keep all metadata, truncate MIDI tokens
        metadata_len = len(metadata_tokens)
        max_midi_len = CONFIG['max_sequence_length'] - metadata_len - 1  # -1 for END
        full_sequence = metadata_tokens + midi_tokens[:max_midi_len] + ["<END>"]
    
    return full_sequence

# Test processing on a few entries
print("Testing preprocessing pipeline on sample entries...")
test_entries = list(metadata.items())[:10]
processed_count = 0

for entry_id, entry_data in test_entries:
    sequence = process_entry(entry_id, entry_data, CONFIG['data_root'])
    if sequence:
        processed_count += 1
        print(f"  ‚úÖ ID {entry_id}: {len(sequence)} tokens")
        print(f"     Metadata tokens: {len([t for t in sequence if t.startswith(('START', 'GENRE', 'PERIOD', 'COMPOSER'))])}")
        print(f"     MIDI tokens: {len([t for t in sequence if not t.startswith(('START', 'GENRE', 'PERIOD', 'COMPOSER')) and t != '<END>'])}")
    else:
        print(f"  ‚ùå ID {entry_id}: Failed (quality too low or file not found)")

print(f"\nProcessed {processed_count}/{len(test_entries)} test entries successfully")


Testing preprocessing pipeline on sample entries...
  ‚úÖ ID 2: 2048 tokens
     Metadata tokens: 3
     MIDI tokens: 2044
  ‚úÖ ID 3: 2048 tokens
     Metadata tokens: 3
     MIDI tokens: 2044
  ‚ùå ID 4: Failed (quality too low or file not found)
  ‚úÖ ID 6: 2048 tokens
     Metadata tokens: 1
     MIDI tokens: 2046
  ‚úÖ ID 8: 2048 tokens
     Metadata tokens: 3
     MIDI tokens: 2044
  ‚úÖ ID 9: 2048 tokens
     Metadata tokens: 2
     MIDI tokens: 2045
  ‚úÖ ID 10: 2048 tokens
     Metadata tokens: 3
     MIDI tokens: 2044
  ‚úÖ ID 11: 2048 tokens
     Metadata tokens: 2
     MIDI tokens: 2045
  ‚úÖ ID 12: 2048 tokens
     Metadata tokens: 1
     MIDI tokens: 2046
  ‚ùå ID 13: Failed (quality too low or file not found)

Processed 8/10 test entries successfully


## Step 6: Process Full Dataset

Now process all entries and build vocabulary.


In [25]:
# Process full dataset
def process_dataset(metadata_dict, data_root, sample_size=None):
    """Process entire dataset and build vocabulary"""
    all_sequences = []
    failed_count = 0
    quality_filtered = 0
    file_not_found = 0
    
    entries = list(metadata_dict.items())
    if sample_size:
        entries = entries[:sample_size]
    
    print(f"Processing {len(entries):,} entries...")
    
    for entry_id, entry_data in tqdm(entries, desc="Processing"):
        sequence = process_entry(entry_id, entry_data, data_root)
        
        if sequence is None:
            failed_count += 1
            # Track failure reasons (simplified)
            audio_scores = entry_data.get('audio_scores', {})
            if audio_scores:
                best_score = max(audio_scores.values())
                if best_score < CONFIG['min_quality_score']:
                    quality_filtered += 1
                else:
                    file_not_found += 1
            else:
                quality_filtered += 1
        else:
            all_sequences.append({
                'entry_id': entry_id,
                'sequence': sequence,
                'length': len(sequence)
            })
    
    print(f"\n‚úÖ Processing complete!")
    print(f"  Successful: {len(all_sequences):,}")
    print(f"  Failed: {failed_count:,}")
    print(f"    - Quality filtered: {quality_filtered:,}")
    print(f"    - File not found: {file_not_found:,}")
    
    return all_sequences

# Run processing on BALANCED dataset (not full dataset!)
print("Starting dataset processing on BALANCED sample...")
print(f"Using balanced subset: {len(balanced_metadata):,} entries")
print(f"(Original dataset: {len(metadata):,} entries)")
print()

# Use balanced_metadata instead of full metadata to avoid bias
all_sequences = process_dataset(balanced_metadata, CONFIG['data_root'], sample_size=None)


Starting dataset processing on BALANCED sample...
Using balanced subset: 3,843 entries
(Original dataset: 371,053 entries)

Processing 3,843 entries...


Processing: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3843/3843 [02:38<00:00, 24.17it/s]



‚úÖ Processing complete!
  Successful: 3,843
  Failed: 0
    - Quality filtered: 0
    - File not found: 0


## Step 7: Build Vocabulary

Create vocabulary mapping from all processed tokens.


In [27]:
# Build vocabulary
def build_vocabulary(sequences):
    """Build vocabulary from all token sequences"""
    all_tokens = set()
    
    for seq_data in sequences:
        all_tokens.update(seq_data['sequence'])
    
    # Special tokens first
    vocab = {
        "<PAD>": 0,
        "<UNK>": 1,
        "<START>": 2,
        "<END>": 3,
    }
    
    # Add all unique tokens
    for token in sorted(all_tokens):
        if token not in vocab:
            vocab[token] = len(vocab)
    
    # Reverse mapping (id -> token)
    id_to_token = {v: k for k, v in vocab.items()}
    
    return vocab, id_to_token

vocab, id_to_token = build_vocabulary(all_sequences)

print(f"Vocabulary size: {len(vocab):,} tokens")
print(f"\nSpecial tokens:")
for token, idx in sorted(vocab.items(), key=lambda x: x[1])[:10]:
    print(f"  {idx:4d}: {token}")

print(f"\nSample tokens:")
for token, idx in sorted(vocab.items(), key=lambda x: x[1])[10:25]:
    print(f"  {idx:4d}: {token}")


Vocabulary size: 746 tokens

Special tokens:
     0: <PAD>
     1: <UNK>
     2: <START>
     3: <END>
     4: COMPOSER:bach
     5: COMPOSER:bartok
     6: COMPOSER:beethoven
     7: COMPOSER:chopin
     8: COMPOSER:debussy
     9: COMPOSER:einaudi

Sample tokens:
    10: COMPOSER:faure
    11: COMPOSER:hisaishi
    12: COMPOSER:joplin
    13: COMPOSER:liszt
    14: COMPOSER:mozart
    15: COMPOSER:poulenc
    16: COMPOSER:rachmaninoff
    17: COMPOSER:ravel
    18: COMPOSER:satie
    19: COMPOSER:schubert
    20: COMPOSER:schumann
    21: COMPOSER:tchaikovsky
    22: COMPOSER:yiruma
    23: GENRE:ambient
    24: GENRE:blues


## Step 8: Convert Sequences to Token IDs

Convert text tokens to numerical IDs for model training.


In [28]:
# Convert sequences to token IDs
def tokenize_sequences(sequences, vocab):
    """Convert token sequences to ID sequences"""
    tokenized = []
    
    for seq_data in sequences:
        token_ids = [vocab.get(token, vocab["<UNK>"]) for token in seq_data['sequence']]
        tokenized.append({
            'entry_id': seq_data['entry_id'],
            'token_ids': token_ids,
            'length': len(token_ids)
        })
    
    return tokenized

# Convert to token IDs
tokenized_sequences = tokenize_sequences(all_sequences, vocab)

print(f"Converted {len(tokenized_sequences):,} sequences to token IDs")
print(f"\nSample sequence (first 20 tokens):")
sample = tokenized_sequences[0]
token_ids = sample['token_ids'][:20]
tokens = [id_to_token[tid] for tid in token_ids]
print(f"  Token IDs: {token_ids}")
print(f"  Tokens:    {tokens}")

# Statistics
lengths = [s['length'] for s in tokenized_sequences]
print(f"\nSequence length statistics:")
print(f"  Min: {min(lengths)}")
print(f"  Max: {max(lengths)}")
print(f"  Mean: {sum(lengths)/len(lengths):.1f}")
print(f"  Median: {sorted(lengths)[len(lengths)//2]}")


Converted 3,843 sequences to token IDs

Sample sequence (first 20 tokens):
  Token IDs: [220, 28, 362, 169, 734, 334, 204, 739, 209, 740, 221, 200, 738, 446, 176, 736, 222, 181, 737, 312]
  Tokens:    ['START', 'GENRE:pop', 'TIME_SHIFT:2240', 'NOTE_ON:55', 'VELOCITY:45', 'TIME_SHIFT:20', 'NOTE_ON:90', 'VELOCITY:65', 'NOTE_ON:95', 'VELOCITY:70', 'TIME_SHIFT:10', 'NOTE_ON:86', 'VELOCITY:60', 'TIME_SHIFT:300', 'NOTE_ON:62', 'VELOCITY:50', 'TIME_SHIFT:100', 'NOTE_ON:67', 'VELOCITY:55', 'TIME_SHIFT:180']

Sequence length statistics:
  Min: 722
  Max: 2048
  Mean: 2021.1
  Median: 2048


## Step 9: Save Processed Data

Save processed sequences, vocabulary, and tokenizer config for training.


In [29]:
# Save processed data
output_dir = CONFIG['output_dir']

# 1. Save vocabulary
vocab_path = output_dir / "vocab.json"
with open(vocab_path, 'w') as f:
    json.dump(vocab, f, indent=2)
print(f"‚úÖ Saved vocabulary to: {vocab_path}")
print(f"   Size: {len(vocab):,} tokens")

# 2. Save ID to token mapping
id_to_token_path = output_dir / "id_to_token.json"
with open(id_to_token_path, 'w') as f:
    json.dump(id_to_token, f, indent=2)
print(f"‚úÖ Saved ID mapping to: {id_to_token_path}")

# 3. Save tokenized sequences
sequences_path = output_dir / "sequences.json"
sequences_to_save = [
    {
        'entry_id': s['entry_id'],
        'token_ids': s['token_ids'],
        'length': s['length']
    }
    for s in tokenized_sequences
]
with open(sequences_path, 'w') as f:
    json.dump(sequences_to_save, f)
print(f"‚úÖ Saved {len(sequences_to_save):,} sequences to: {sequences_path}")

# 4. Save preprocessing config
config_to_save = {
    'min_quality_score': CONFIG['min_quality_score'],
    'max_sequence_length': CONFIG['max_sequence_length'],
    'time_quantization': CONFIG['time_quantization'],
    'vocab_size': len(vocab),
    'num_sequences': len(sequences_to_save),
    'total_tokens': sum(s['length'] for s in sequences_to_save)
}
config_path = output_dir / "preprocessing_config.json"
with open(config_path, 'w') as f:
    json.dump(config_to_save, f, indent=2)
print(f"‚úÖ Saved config to: {config_path}")

print(f"\nüìä Summary:")
print(f"  Total sequences: {len(sequences_to_save):,}")
print(f"  Total tokens: {sum(s['length'] for s in sequences_to_save):,}")
print(f"  Vocabulary size: {len(vocab):,}")
print(f"  Output directory: {output_dir}")


‚úÖ Saved vocabulary to: processed_data\vocab.json
   Size: 746 tokens
‚úÖ Saved ID mapping to: processed_data\id_to_token.json
‚úÖ Saved 3,843 sequences to: processed_data\sequences.json
‚úÖ Saved config to: processed_data\preprocessing_config.json

üìä Summary:
  Total sequences: 3,843
  Total tokens: 7,766,992
  Vocabulary size: 746
  Output directory: processed_data


## Step 10: Verify Saved Data

Load and verify the saved processed data.


In [30]:
# Verify saved data
print("Verifying saved data...")

# Load vocabulary
with open(vocab_path, 'r') as f:
    loaded_vocab = json.load(f)
print(f"‚úÖ Loaded vocabulary: {len(loaded_vocab):,} tokens")

# Load sequences (sample)
with open(sequences_path, 'r') as f:
    loaded_sequences = json.load(f)
print(f"‚úÖ Loaded sequences: {len(loaded_sequences):,}")

# Load config
with open(config_path, 'r') as f:
    loaded_config = json.load(f)
print(f"‚úÖ Loaded config:")
for key, value in loaded_config.items():
    print(f"   {key}: {value}")

# Show sample sequence
print(f"\nSample sequence (ID: {loaded_sequences[0]['entry_id']}):")
sample_ids = loaded_sequences[0]['token_ids'][:30]
sample_tokens = [id_to_token[tid] for tid in sample_ids]
print(f"  First 30 tokens: {sample_tokens}")

print("\n‚úÖ Preprocessing complete! Data ready for training.")


Verifying saved data...
‚úÖ Loaded vocabulary: 746 tokens
‚úÖ Loaded sequences: 3,843
‚úÖ Loaded config:
   min_quality_score: 0.97
   max_sequence_length: 2048
   time_quantization: 10
   vocab_size: 746
   num_sequences: 3843
   total_tokens: 7766992

Sample sequence (ID: 79863):
  First 30 tokens: ['START', 'GENRE:pop', 'TIME_SHIFT:2240', 'NOTE_ON:55', 'VELOCITY:45', 'TIME_SHIFT:20', 'NOTE_ON:90', 'VELOCITY:65', 'NOTE_ON:95', 'VELOCITY:70', 'TIME_SHIFT:10', 'NOTE_ON:86', 'VELOCITY:60', 'TIME_SHIFT:300', 'NOTE_ON:62', 'VELOCITY:50', 'TIME_SHIFT:100', 'NOTE_ON:67', 'VELOCITY:55', 'TIME_SHIFT:180', 'NOTE_ON:91', 'VELOCITY:70', 'TIME_SHIFT:10', 'NOTE_ON:83', 'VELOCITY:55', 'NOTE_ON:79', 'VELOCITY:65', 'TIME_SHIFT:320', 'NOTE_ON:71', 'VELOCITY:60']

‚úÖ Preprocessing complete! Data ready for training.
