# MusicXML Tokenizer v0 — Baseline Scaffold

**Goal:** Minimal working tokenizer + detokenizer + evaluation harness for SymbTr v3 Turkish makam MusicXML files.

**Token types:** `<BOS>`, `<EOS>`, `PART_`, `TIME_SIG_`, `KEY_`, `CLEF_`, `BAR_`, `POS_BAR_`, `POS_ABS_`, `PITCH_`, `DUR_`, `REST`

**v0 limitations (to be fixed in later iterations):**
- Pitch: uses `<step>` + `<octave>` only — microtonal `<alter>` values are ignored
- No tie, grace note, repeat, fermata, lyric, or dynamics tokens
- Duration computed from raw `<duration>/<divisions>` only

---

## 1. Setup and Dependencies

In [None]:
# ── Install dependencies (for Google Colab) ──
# Uncomment the following lines if running on Colab:
# !pip install lxml music21 python-Levenshtein pandas matplotlib

import os
import glob
import random
import warnings
from pathlib import Path
from collections import Counter, defaultdict
from fractions import Fraction

from lxml import etree
import pandas as pd
import matplotlib.pyplot as plt

# Suppress music21 warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

# We import music21 for the detokenizer and evaluation only
import music21

# Reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)

print(f"lxml version: {etree.LXML_VERSION}")
print(f"music21 version: {music21.VERSION_STR}")
print("Setup complete.")

## 2. Load Dataset File Paths

In [None]:
# ── Load all XML file paths from xml_v3/ ──
XML_DIR = Path("xml_v3")

# Also support Colab path structure
if not XML_DIR.exists():
    # Try relative to notebook location
    XML_DIR = Path("/content/musicxml_tokenizer/xml_v3")

all_xml_files = sorted(XML_DIR.glob("*.xml"))
print(f"Total XML files found: {len(all_xml_files)}")
print(f"\nFirst 5 files:")
for f in all_xml_files[:5]:
    print(f"  {f.name}")
print(f"\nLast 5 files:")
for f in all_xml_files[-5:]:
    print(f"  {f.name}")

# Select a random 100-file sample for evaluation
EVAL_SAMPLE_SIZE = 100
eval_sample = random.sample(all_xml_files, min(EVAL_SAMPLE_SIZE, len(all_xml_files)))
print(f"\nEvaluation sample size: {len(eval_sample)}")

## 3. Parse MusicXML Structure with lxml

Explore the XML tree structure of a sample file to confirm the parsing approach.

In [None]:
# ── Explore MusicXML structure of a sample file ──
sample_file = all_xml_files[0]
print(f"Exploring: {sample_file.name}\n")

tree = etree.parse(str(sample_file))
root = tree.getroot()

# Show top-level structure
print("Root tag:", root.tag)
print("Top-level children:")
for child in root:
    print(f"  <{child.tag}>")

# Show part-list → part names
part_list = root.find("part-list")
if part_list is not None:
    for sp in part_list.findall("score-part"):
        pid = sp.get("id")
        pname = sp.findtext("part-name", default="?")
        print(f"\n  Part ID={pid}, Name='{pname}'")

# Show first measure structure
parts = root.findall("part")
print(f"\nNumber of parts: {len(parts)}")
first_part = parts[0]
first_measure = first_part.find("measure")
print(f"\nFirst measure (number={first_measure.get('number')}) children:")
for child in first_measure:
    print(f"  <{child.tag}>", {k: v for k, v in child.attrib.items()} if child.attrib else "")
    for sub in child:
        txt = sub.text.strip() if sub.text and sub.text.strip() else ""
        print(f"    <{sub.tag}> {txt}")

# Show divisions
attrs = first_measure.find("attributes")
if attrs is not None:
    div = attrs.findtext("divisions")
    print(f"\nDivisions per quarter note: {div}")

## 4. Helper Functions for Pitch, Duration, Key, and Time Signature Extraction

In [None]:
# ── Helper functions for tokenization ──

def get_part_name(root, part_id):
    """Look up part name from <part-list> given a part ID."""
    part_list = root.find("part-list")
    if part_list is not None:
        for sp in part_list.findall("score-part"):
            if sp.get("id") == part_id:
                name = sp.findtext("part-name", default="")
                # Also try instrument-name as fallback
                if not name or name.strip() == "":
                    instr = sp.find(".//instrument-name")
                    if instr is not None and instr.text:
                        name = instr.text
                return name.strip() if name else "Unknown"
    return "Unknown"


def extract_pitch(note_elem):
    """Extract pitch from a <note> element. Returns 'PITCH_<step><octave>' string.
    v0: ignores <alter> (microtonal inflection)."""
    pitch_elem = note_elem.find("pitch")
    if pitch_elem is None:
        return None
    step = pitch_elem.findtext("step", default="C")
    octave = pitch_elem.findtext("octave", default="4")
    return f"PITCH_{step}{octave}"


def extract_duration_ql(note_elem, divisions):
    """Convert <duration> to quarter-note length using divisions.
    Returns a float representing quarter-note duration."""
    dur_text = note_elem.findtext("duration")
    if dur_text is None:
        return 0.0
    duration = int(dur_text)
    if divisions == 0:
        return 0.0
    ql = duration / divisions
    return round(ql, 6)  # Avoid floating point noise


def is_rest(note_elem):
    """Check if a <note> element is a rest."""
    return note_elem.find("rest") is not None


def is_grace_note(note_elem):
    """Check if a <note> element is a grace note (no duration contribution)."""
    return note_elem.find("grace") is not None


def extract_key_signature(key_elem):
    """Parse a <key> element and return a KEY_ token string.

    Handles two formats:
    1. Standard: <fifths> + <mode>
    2. Non-standard (SymbTr): <key-step> + <key-alter> + <key-accidental>

    For v0, non-standard keys are represented as KEY_<step>_<accidental>.
    """
    if key_elem is None:
        return None

    # Try standard format first
    fifths = key_elem.findtext("fifths")
    if fifths is not None:
        fifths_int = int(fifths)
        mode = key_elem.findtext("mode", default="major")
        # Map fifths to tonic
        major_keys = ["C", "G", "D", "A", "E", "B", "F#",
                       "Gb", "Db", "Ab", "Eb", "Bb", "F"]
        minor_keys = ["A", "E", "B", "F#", "C#", "G#", "D#",
                       "Eb", "Bb", "F", "C", "G", "D"]
        if mode == "minor":
            tonic = minor_keys[fifths_int] if -6 <= fifths_int <= 6 else "A"
        else:
            tonic = major_keys[fifths_int] if -6 <= fifths_int <= 6 else "C"
        return f"KEY_{tonic}_{mode}"

    # Non-standard SymbTr format: one or more <key-step> + <key-alter> pairs
    key_steps = key_elem.findall("key-step")
    key_alters = key_elem.findall("key-alter")
    key_accidentals = key_elem.findall("key-accidental")

    if key_steps:
        # Build a composite key signature description
        parts = []
        for i, step_elem in enumerate(key_steps):
            step = step_elem.text
            alter = key_alters[i].text if i < len(key_alters) else "0"
            accidental = key_accidentals[i].text if i < len(key_accidentals) else ""
            if accidental:
                parts.append(f"{step}-{accidental}")
            else:
                parts.append(f"{step}({alter})")
        # Return composite key signature
        return "KEY_" + "_".join(parts)

    return None


def extract_time_signature(time_elem):
    """Parse a <time> element and return 'TIME_SIG_<beats>/<beat-type>'."""
    if time_elem is None:
        return None
    beats = time_elem.findtext("beats", default="4")
    beat_type = time_elem.findtext("beat-type", default="4")
    return f"TIME_SIG_{beats}/{beat_type}"


def format_position(pos):
    """Format a position value to a clean string.
    Use integer representation if it's a whole number, otherwise float."""
    if pos == int(pos):
        return str(int(pos)) + ".0"
    else:
        return f"{pos:.4f}".rstrip('0').rstrip('.')
        # Ensure at least one decimal place
        result = f"{pos:.4f}".rstrip('0')
        if result.endswith('.'):
            result += '0'
        return result


# Quick test of helpers
print("Helper functions defined successfully.")
print(f"format_position(0) = '{format_position(0)}'")
print(f"format_position(1.5) = '{format_position(1.5)}'")
print(f"format_position(0.25) = '{format_position(0.25)}'")

## 5. Implement `tokenize(xml_path) → list[str]`

The main tokenizer function. Parses MusicXML with `lxml`, walks the tree, and emits a flat list of string tokens.

**Partwise tokenization:** Each part's bars are emitted contiguously under a `PART_` header before moving to the next part.

In [None]:
def tokenize(xml_path):
    """Tokenize a MusicXML file into a list of string tokens.

    Args:
        xml_path: Path to a MusicXML (.xml) file.

    Returns:
        List of string tokens representing the score.
    """
    tree = etree.parse(str(xml_path))
    root = tree.getroot()

    tokens = ["<BOS>"]

    # Iterate over each part (partwise tokenization)
    for part_elem in root.findall("part"):
        part_id = part_elem.get("id")
        part_name = get_part_name(root, part_id)
        tokens.append(f"PART_{part_name}")

        # State tracking for this part
        current_time_sig = None
        current_key = None
        clef_emitted = False
        divisions = 1  # default, will be set from <attributes>
        pos_abs = 0.0  # absolute position in quarter notes

        for measure in part_elem.findall("measure"):
            measure_number = measure.get("number", "0")

            # ── Process <attributes> for key, time, clef, divisions ──
            for attrs in measure.findall("attributes"):
                # Update divisions if present
                div_text = attrs.findtext("divisions")
                if div_text is not None:
                    divisions = int(div_text)

                # Time signature
                time_elem = attrs.find("time")
                if time_elem is not None:
                    new_time_sig = extract_time_signature(time_elem)
                    if new_time_sig and new_time_sig != current_time_sig:
                        tokens.append(new_time_sig)
                        current_time_sig = new_time_sig

                # Key signature
                key_elem = attrs.find("key")
                if key_elem is not None:
                    new_key = extract_key_signature(key_elem)
                    if new_key and new_key != current_key:
                        tokens.append(new_key)
                        current_key = new_key

                # Clef (emit default if not yet emitted)
                clef_elem = attrs.find("clef")
                if clef_elem is not None:
                    clef_sign = clef_elem.findtext("sign", default="G")
                    clef_line = clef_elem.findtext("line", default="2")
                    tokens.append(f"CLEF_{clef_sign}_{clef_line}")
                    clef_emitted = True

            # Emit default clef if none found (dataset typically lacks <clef>)
            if not clef_emitted:
                tokens.append("CLEF_G_2")
                clef_emitted = True

            # ── BAR token ──
            tokens.append(f"BAR_{measure_number}")

            # ── Process notes and rests ──
            pos_bar = 0.0  # position within current bar

            for elem in measure:
                if elem.tag == "note":
                    # Skip grace notes in v0 (they have no <duration>)
                    if is_grace_note(elem):
                        continue

                    # Get duration in quarter-note length
                    dur_ql = extract_duration_ql(elem, divisions)

                    # Emit position tokens
                    tokens.append(f"POS_BAR_{format_position(pos_bar)}")
                    tokens.append(f"POS_ABS_{format_position(pos_abs)}")

                    if is_rest(elem):
                        # Rest token
                        tokens.append("REST")
                        tokens.append(f"DUR_{format_position(dur_ql)}")
                    else:
                        # Pitch token
                        pitch_token = extract_pitch(elem)
                        if pitch_token:
                            tokens.append(pitch_token)
                            tokens.append(f"DUR_{format_position(dur_ql)}")

                    # Advance positions
                    pos_bar += dur_ql
                    pos_abs += dur_ql

                elif elem.tag == "forward":
                    # <forward> advances the position without a note
                    dur_text = elem.findtext("duration")
                    if dur_text:
                        fwd_ql = int(dur_text) / divisions
                        pos_bar += fwd_ql
                        pos_abs += fwd_ql

                elif elem.tag == "backup":
                    # <backup> moves the position backwards
                    dur_text = elem.findtext("duration")
                    if dur_text:
                        bak_ql = int(dur_text) / divisions
                        pos_bar -= bak_ql
                        pos_abs -= bak_ql

    tokens.append("<EOS>")
    return tokens


# ── Quick test ──
test_file = all_xml_files[0]
print(f"Tokenizing: {test_file.name}\n")
test_tokens = tokenize(test_file)
print(f"Total tokens: {len(test_tokens)}")
print(f"\nFirst 40 tokens:")
for i, t in enumerate(test_tokens[:40]):
    print(f"  [{i:3d}] {t}")

## 6. Implement `detokenize(tokens) → music21.Score`

Reconstructs a `music21.stream.Score` from a list of tokens. Used for round-trip evaluation.

In [None]:
def detokenize(tokens):
    """Reconstruct a music21.Score from a list of tokens.

    Args:
        tokens: List of string tokens (as produced by tokenize()).

    Returns:
        music21.stream.Score object.
    """
    score = music21.stream.Score()
    current_part = None
    current_measure = None
    current_measure_num = 0
    current_time_sig = None
    current_key_sig = None
    pending_pitch = None  # Holds a PITCH_ token waiting for its DUR_

    i = 0
    while i < len(tokens):
        token = tokens[i]

        if token == "<BOS>" or token == "<EOS>":
            i += 1
            continue

        elif token.startswith("PART_"):
            # Finalize previous part
            if current_measure is not None and current_part is not None:
                current_part.append(current_measure)
                current_measure = None
            if current_part is not None:
                score.append(current_part)

            part_name = token[5:]  # Strip "PART_"
            current_part = music21.stream.Part()
            current_part.partName = part_name
            current_measure = None
            current_time_sig = None
            current_key_sig = None
            i += 1

        elif token.startswith("TIME_SIG_"):
            ts_str = token[9:]  # e.g., "4/4" or "9/8"
            num, denom = ts_str.split("/")
            ts = music21.meter.TimeSignature(f"{num}/{denom}")
            current_time_sig = ts
            # Append to current measure or part
            if current_measure is not None:
                current_measure.append(ts)
            elif current_part is not None:
                # Time sig before first bar — will be added to first measure
                current_time_sig = ts
            i += 1

        elif token.startswith("KEY_"):
            key_str = token[4:]  # e.g., "C_major" or "B-quarter-flat"
            # For v0 standard keys, try to parse; for non-standard, store as-is
            try:
                parts = key_str.split("_")
                if len(parts) == 2 and parts[1] in ("major", "minor"):
                    ks = music21.key.Key(parts[0], parts[1])
                else:
                    # Non-standard key — store as KeySignature with 0 sharps
                    ks = music21.key.KeySignature(0)
                current_key_sig = ks
                if current_measure is not None:
                    current_measure.append(ks)
            except Exception:
                pass
            i += 1

        elif token.startswith("CLEF_"):
            parts = token[5:].split("_")
            sign = parts[0] if len(parts) >= 1 else "G"
            line = int(parts[1]) if len(parts) >= 2 else 2
            if sign == "G" and line == 2:
                clef = music21.clef.TrebleClef()
            elif sign == "F" and line == 4:
                clef = music21.clef.BassClef()
            elif sign == "C" and line == 3:
                clef = music21.clef.AltoClef()
            else:
                clef = music21.clef.Clef()
                clef.sign = sign
                clef.line = line
            if current_measure is not None:
                current_measure.append(clef)
            elif current_part is not None:
                # Will add to first measure
                pass
            i += 1

        elif token.startswith("BAR_"):
            # Finalize previous measure
            if current_measure is not None and current_part is not None:
                current_part.append(current_measure)

            bar_num = int(token[4:])
            current_measure = music21.stream.Measure(number=bar_num)

            # Add time signature and key to first measure if pending
            if bar_num == 1 or (current_time_sig and bar_num == 1):
                if current_time_sig:
                    current_measure.append(current_time_sig)
                    current_time_sig = None
                if current_key_sig:
                    current_measure.append(current_key_sig)
                    current_key_sig = None
            i += 1

        elif token.startswith("POS_BAR_") or token.startswith("POS_ABS_"):
            # Position tokens — skip for now (reconstruction uses sequential append)
            i += 1

        elif token.startswith("PITCH_"):
            # Store pitch, wait for DUR_ token
            pending_pitch = token[6:]  # e.g., "C5" or "A4"
            i += 1

        elif token == "REST":
            pending_pitch = "REST"
            i += 1

        elif token.startswith("DUR_"):
            dur_str = token[4:]
            try:
                dur_ql = float(dur_str)
            except ValueError:
                dur_ql = 1.0

            if pending_pitch == "REST":
                r = music21.note.Rest()
                r.quarterLength = dur_ql
                if current_measure is not None:
                    current_measure.append(r)
            elif pending_pitch is not None:
                # Parse pitch string like "C5", "Bb4", etc.
                try:
                    # Extract step and octave
                    step = pending_pitch[0]
                    octave = int(pending_pitch[-1])
                    n = music21.note.Note()
                    n.pitch.step = step
                    n.pitch.octave = octave
                    n.quarterLength = dur_ql
                    if current_measure is not None:
                        current_measure.append(n)
                except Exception:
                    pass
            pending_pitch = None
            i += 1

        else:
            # Unknown token — skip
            i += 1

    # Finalize last measure and part
    if current_measure is not None and current_part is not None:
        current_part.append(current_measure)
    if current_part is not None:
        score.append(current_part)

    return score


# ── Quick round-trip test ──
print("Round-trip test:")
print(f"  Tokenizing {test_file.name}...")
test_tokens_rt = tokenize(test_file)
print(f"  Got {len(test_tokens_rt)} tokens")
print(f"  Detokenizing...")
reconstructed_score = detokenize(test_tokens_rt)
print(f"  Reconstructed score: {len(reconstructed_score.parts)} part(s)")
for p in reconstructed_score.parts:
    measures = list(p.getElementsByClass(music21.stream.Measure))
    notes = list(p.recurse().notes)
    print(f"    Part '{p.partName}': {len(measures)} measures, {len(notes)} notes/rests")

## 7. Evaluation Harness

Three evaluation axes:
1. **Note-level accuracy** — compare (pitch, duration, onset) tuples from original XML vs round-trip reconstruction
2. **Measure-duration integrity** — do measures sum to the correct total duration?
3. **Sequence edit distance** — Levenshtein distance between original and round-trip token sequences

In [None]:
# ── Evaluation Functions ──

def extract_note_events_from_xml(xml_path):
    """Parse original MusicXML with lxml and extract note events.

    Returns list of dicts: {'pitch': str, 'duration': float, 'onset': float, 'is_rest': bool}
    Each event is (pitch_or_rest, quarterLength, onset_in_quarter_notes).
    """
    tree = etree.parse(str(xml_path))
    root = tree.getroot()
    events = []

    for part_elem in root.findall("part"):
        divisions = 1
        pos_abs = 0.0

        for measure in part_elem.findall("measure"):
            # Check for divisions update
            for attrs in measure.findall("attributes"):
                div_text = attrs.findtext("divisions")
                if div_text:
                    divisions = int(div_text)

            for elem in measure:
                if elem.tag == "note":
                    if is_grace_note(elem):
                        continue

                    dur_ql = extract_duration_ql(elem, divisions)

                    if is_rest(elem):
                        events.append({
                            'pitch': 'REST',
                            'duration': dur_ql,
                            'onset': round(pos_abs, 6),
                            'is_rest': True,
                        })
                    else:
                        pitch_elem = elem.find("pitch")
                        if pitch_elem is not None:
                            step = pitch_elem.findtext("step", "C")
                            octave = pitch_elem.findtext("octave", "4")
                            pitch_str = f"{step}{octave}"
                            events.append({
                                'pitch': pitch_str,
                                'duration': dur_ql,
                                'onset': round(pos_abs, 6),
                                'is_rest': False,
                            })

                    pos_abs += dur_ql

                elif elem.tag == "forward":
                    dur_text = elem.findtext("duration")
                    if dur_text:
                        pos_abs += int(dur_text) / divisions

                elif elem.tag == "backup":
                    dur_text = elem.findtext("duration")
                    if dur_text:
                        pos_abs -= int(dur_text) / divisions

    return events


def extract_note_events_from_score(score):
    """Extract note events from a music21.Score.

    Returns list of dicts matching extract_note_events_from_xml format.
    """
    events = []
    for part in score.parts:
        for note_or_rest in part.recurse().notesAndRests:
            if isinstance(note_or_rest, music21.note.Rest):
                events.append({
                    'pitch': 'REST',
                    'duration': note_or_rest.quarterLength,
                    'onset': round(float(note_or_rest.offset), 6),
                    'is_rest': True,
                })
            elif isinstance(note_or_rest, music21.note.Note):
                step = note_or_rest.pitch.step
                octave = note_or_rest.pitch.octave
                events.append({
                    'pitch': f"{step}{octave}",
                    'duration': note_or_rest.quarterLength,
                    'onset': round(float(note_or_rest.offset), 6),
                    'is_rest': False,
                })
    return events


def compute_note_accuracy(original_events, reconstructed_events):
    """Compute note-level accuracy between original and reconstructed events.

    Returns dict with pitch_accuracy, duration_accuracy, onset_accuracy, combined_accuracy.
    """
    n_orig = len(original_events)
    n_recon = len(reconstructed_events)
    n_compare = min(n_orig, n_recon)

    if n_compare == 0:
        return {
            'pitch_accuracy': 0.0,
            'duration_accuracy': 0.0,
            'combined_accuracy': 0.0,
            'n_original': n_orig,
            'n_reconstructed': n_recon,
            'length_match': n_orig == n_recon,
        }

    pitch_match = 0
    dur_match = 0
    combined_match = 0

    for orig, recon in zip(original_events, reconstructed_events):
        p_match = orig['pitch'] == recon['pitch']
        d_match = abs(orig['duration'] - recon['duration']) < 0.01
        if p_match:
            pitch_match += 1
        if d_match:
            dur_match += 1
        if p_match and d_match:
            combined_match += 1

    return {
        'pitch_accuracy': pitch_match / n_compare,
        'duration_accuracy': dur_match / n_compare,
        'combined_accuracy': combined_match / n_compare,
        'n_original': n_orig,
        'n_reconstructed': n_recon,
        'length_match': n_orig == n_recon,
    }


def check_measure_durations(tokens):
    """Check if each measure's notes sum to the expected duration from the time signature.

    Returns dict with total_measures, correct_measures, integrity_pct.
    """
    # Parse time signature and measure durations from tokens
    current_bar_dur = 4.0  # default 4/4
    measure_results = []
    current_measure_dur = 0.0
    in_measure = False
    measure_num = None

    for token in tokens:
        if token.startswith("TIME_SIG_"):
            ts_str = token[9:]
            try:
                num, denom = ts_str.split("/")
                current_bar_dur = float(num) * (4.0 / float(denom))
            except ValueError:
                pass

        elif token.startswith("BAR_"):
            # Save previous measure
            if in_measure and measure_num is not None:
                measure_results.append({
                    'measure': measure_num,
                    'expected': current_bar_dur,
                    'actual': round(current_measure_dur, 6),
                    'correct': abs(current_measure_dur - current_bar_dur) < 0.01,
                })
            measure_num = token[4:]
            current_measure_dur = 0.0
            in_measure = True

        elif token.startswith("DUR_"):
            try:
                dur = float(token[4:])
                current_measure_dur += dur
            except ValueError:
                pass

    # Save last measure
    if in_measure and measure_num is not None:
        measure_results.append({
            'measure': measure_num,
            'expected': current_bar_dur,
            'actual': round(current_measure_dur, 6),
            'correct': abs(current_measure_dur - current_bar_dur) < 0.01,
        })

    total = len(measure_results)
    correct = sum(1 for m in measure_results if m['correct'])

    return {
        'total_measures': total,
        'correct_measures': correct,
        'integrity_pct': (correct / total * 100) if total > 0 else 0.0,
        'details': measure_results,
    }


def levenshtein_distance(seq1, seq2):
    """Compute Levenshtein edit distance between two sequences."""
    n, m = len(seq1), len(seq2)
    if n == 0:
        return m
    if m == 0:
        return n

    # Use two-row optimization for memory
    prev = list(range(m + 1))
    curr = [0] * (m + 1)

    for i in range(1, n + 1):
        curr[0] = i
        for j in range(1, m + 1):
            cost = 0 if seq1[i - 1] == seq2[j - 1] else 1
            curr[j] = min(
                curr[j - 1] + 1,       # insertion
                prev[j] + 1,           # deletion
                prev[j - 1] + cost     # substitution
            )
        prev, curr = curr, prev

    return prev[m]


def evaluate_file(xml_path):
    """Run full evaluation pipeline on a single file.

    Returns dict with all metrics.
    """
    result = {'file': str(xml_path), 'filename': Path(xml_path).name}

    try:
        # 1. Tokenize
        tokens = tokenize(xml_path)
        result['n_tokens'] = len(tokens)

        # 2. Detokenize
        recon_score = detokenize(tokens)

        # 3. Extract note events from original XML
        orig_events = extract_note_events_from_xml(xml_path)

        # 4. Extract note events from reconstructed score
        recon_events = extract_note_events_from_score(recon_score)

        # 5. Note-level accuracy
        accuracy = compute_note_accuracy(orig_events, recon_events)
        result.update(accuracy)

        # 6. Measure duration integrity
        measure_check = check_measure_durations(tokens)
        result['measure_integrity_pct'] = measure_check['integrity_pct']
        result['total_measures'] = measure_check['total_measures']
        result['correct_measures'] = measure_check['correct_measures']

        # 7. Round-trip token edit distance
        # Re-tokenize the reconstructed score by writing to temp XML and re-parsing
        # For v0, we skip this (expensive) — we'll use note-level accuracy instead
        result['error'] = None

    except Exception as e:
        result['error'] = str(e)
        result['n_tokens'] = 0
        result['pitch_accuracy'] = 0.0
        result['duration_accuracy'] = 0.0
        result['combined_accuracy'] = 0.0
        result['measure_integrity_pct'] = 0.0
        result['n_original'] = 0
        result['n_reconstructed'] = 0
        result['length_match'] = False
        result['total_measures'] = 0
        result['correct_measures'] = 0

    return result


def evaluate_batch(file_paths, verbose=True):
    """Run evaluation on a batch of files.

    Returns a pandas DataFrame with per-file results.
    """
    results = []
    errors = []

    for i, fp in enumerate(file_paths):
        if verbose and (i + 1) % 20 == 0:
            print(f"  Evaluating file {i + 1}/{len(file_paths)}...")
        r = evaluate_file(fp)
        results.append(r)
        if r['error']:
            errors.append((fp, r['error']))

    df = pd.DataFrame(results)

    if verbose:
        print(f"\nEvaluation complete: {len(results)} files")
        if errors:
            print(f"  Errors: {len(errors)} files failed")
            for fp, err in errors[:5]:
                print(f"    {Path(fp).name}: {err}")

    return df


print("Evaluation functions defined successfully.")

## 8. Run Evaluation on 100-File Sample

In [None]:
# ── Run evaluation on 100-file sample ──
print(f"Evaluating {len(eval_sample)} files...\n")
eval_df = evaluate_batch(eval_sample)

# Filter to successful evaluations
eval_ok = eval_df[eval_df['error'].isna()].copy()
print(f"\nSuccessful evaluations: {len(eval_ok)} / {len(eval_df)}")

# ── Summary statistics ──
metrics = ['pitch_accuracy', 'duration_accuracy', 'combined_accuracy', 'measure_integrity_pct']
summary = eval_ok[metrics].describe().T
summary['median'] = eval_ok[metrics].median()
print("\n╔══════════════════════════════════════════════════════════╗")
print("║          v0 EVALUATION SUMMARY (100-file sample)        ║")
print("╠══════════════════════════════════════════════════════════╣")
print(summary[['mean', 'median', 'min', 'max', 'std']].to_string())
print("╚══════════════════════════════════════════════════════════╝")

# Perfect reconstruction count
perfect = len(eval_ok[eval_ok['combined_accuracy'] == 1.0])
print(f"\nPerfect reconstruction (combined=1.0): {perfect}/{len(eval_ok)} ({perfect/len(eval_ok)*100:.1f}%)")

# Length match stats
length_match = eval_ok['length_match'].sum()
print(f"Length match (same # events): {length_match}/{len(eval_ok)} ({length_match/len(eval_ok)*100:.1f}%)")

# Token stats
print(f"\nTotal tokens across sample: {eval_ok['n_tokens'].sum():,}")
print(f"Mean tokens per file: {eval_ok['n_tokens'].mean():.0f}")

# Collect unique tokens for vocabulary size
all_tokens_sample = []
for fp in eval_sample[:20]:  # subset for speed
    try:
        all_tokens_sample.extend(tokenize(fp))
    except:
        pass
vocab = set(all_tokens_sample)
print(f"Vocabulary size (from 20-file subset): {len(vocab)}")

# Save v0 results for later comparison
v0_results = {
    'pitch_accuracy_mean': eval_ok['pitch_accuracy'].mean(),
    'duration_accuracy_mean': eval_ok['duration_accuracy'].mean(),
    'combined_accuracy_mean': eval_ok['combined_accuracy'].mean(),
    'measure_integrity_mean': eval_ok['measure_integrity_pct'].mean(),
    'perfect_reconstruction_pct': perfect / len(eval_ok) * 100 if len(eval_ok) > 0 else 0,
}
print(f"\nv0 results saved for comparison with future versions.")

In [None]:
# ── Visualizations ──
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Pitch accuracy histogram
axes[0].hist(eval_ok['pitch_accuracy'], bins=20, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].set_title('Pitch Accuracy Distribution')
axes[0].set_xlabel('Pitch Accuracy')
axes[0].set_ylabel('Count')
axes[0].axvline(eval_ok['pitch_accuracy'].mean(), color='red', linestyle='--', label=f"mean={eval_ok['pitch_accuracy'].mean():.3f}")
axes[0].legend()

# Duration accuracy histogram
axes[1].hist(eval_ok['duration_accuracy'], bins=20, edgecolor='black', alpha=0.7, color='seagreen')
axes[1].set_title('Duration Accuracy Distribution')
axes[1].set_xlabel('Duration Accuracy')
axes[1].axvline(eval_ok['duration_accuracy'].mean(), color='red', linestyle='--', label=f"mean={eval_ok['duration_accuracy'].mean():.3f}")
axes[1].legend()

# Measure integrity histogram
axes[2].hist(eval_ok['measure_integrity_pct'], bins=20, edgecolor='black', alpha=0.7, color='goldenrod')
axes[2].set_title('Measure Duration Integrity (%)')
axes[2].set_xlabel('% Correct Measures')
axes[2].axvline(eval_ok['measure_integrity_pct'].mean(), color='red', linestyle='--', label=f"mean={eval_ok['measure_integrity_pct'].mean():.1f}%")
axes[2].legend()

plt.tight_layout()
plt.suptitle('v0 Baseline Evaluation', y=1.02, fontsize=14, fontweight='bold')
plt.show()

## 9. Inspect Worst-Scoring Files and Error Analysis

In [None]:
# ── Error analysis: inspect worst-scoring files ──
worst_files = eval_ok.nsmallest(5, 'combined_accuracy')

print("=" * 70)
print("TOP 5 WORST-SCORING FILES (by combined accuracy)")
print("=" * 70)

for _, row in worst_files.iterrows():
    fp = row['file']
    print(f"\n{'─' * 60}")
    print(f"File: {row['filename']}")
    print(f"  Pitch accuracy:    {row['pitch_accuracy']:.4f}")
    print(f"  Duration accuracy: {row['duration_accuracy']:.4f}")
    print(f"  Combined accuracy: {row['combined_accuracy']:.4f}")
    print(f"  Measure integrity: {row['measure_integrity_pct']:.1f}%")
    print(f"  Notes: {row['n_original']} original → {row['n_reconstructed']} reconstructed")
    print(f"  Length match: {row['length_match']}")

    # Show first 50 tokens
    try:
        tokens = tokenize(fp)
        print(f"\n  First 30 tokens:")
        for i, t in enumerate(tokens[:30]):
            print(f"    [{i:3d}] {t}")
    except Exception as e:
        print(f"  Error re-tokenizing: {e}")

# ── Categorize error patterns across all evaluated files ──
print("\n\n" + "=" * 70)
print("ERROR PATTERN ANALYSIS")
print("=" * 70)

error_categories = Counter()
for _, row in eval_ok.iterrows():
    if not row['length_match']:
        error_categories['length_mismatch'] += 1
    if row['pitch_accuracy'] < 1.0:
        error_categories['pitch_errors'] += 1
    if row['duration_accuracy'] < 1.0:
        error_categories['duration_errors'] += 1
    if row['measure_integrity_pct'] < 100.0:
        error_categories['measure_integrity_errors'] += 1

print(f"\n  Total files evaluated: {len(eval_ok)}")
for cat, count in error_categories.most_common():
    print(f"  {cat}: {count}/{len(eval_ok)} ({count/len(eval_ok)*100:.1f}%)")

# ── Analyze which features cause issues ──
print("\n\nKnown v0 limitations causing errors:")
print("  1. Microtonal <alter> values ignored → pitch tokens lose inflection")
print("     (pitch accuracy measures step+octave only, so this won't show yet)")
print("  2. Grace notes skipped → may cause position/count mismatches")
print("  3. No tie tokens → tied notes counted as separate events")
print("  4. Pickup measures (anacrusis) → measure duration != time sig")
print("  5. Repeats and endings → not tokenized, may affect measure counting")

# Check for grace notes in the sample
grace_note_files = 0
for fp in eval_sample:
    try:
        tree = etree.parse(str(fp))
        root = tree.getroot()
        graces = root.findall(".//grace")
        if len(graces) > 0:
            grace_note_files += 1
    except:
        pass
print(f"\n  Files with grace notes in sample: {grace_note_files}/{len(eval_sample)}")

# Check for ties
tie_files = 0
for fp in eval_sample:
    try:
        tree = etree.parse(str(fp))
        root = tree.getroot()
        ties = root.findall(".//tie")
        if len(ties) > 0:
            tie_files += 1
    except:
        pass
print(f"  Files with ties in sample: {tie_files}/{len(eval_sample)}")