In [None]:
#@markdown # UST + LAB CORE GENERATOR

import os
import random
import math
import shutil
from pathlib import Path
from collections import Counter


#@markdown ## Batch Control


#@markdown <font size="-1.5"> _BATCH_COUNT = Number of ust+lab file will be generated_
BATCH_COUNT = 80  #@param {type:"integer", min:1, max:150}

#@markdown <font size="-1.5"> _PREFIX = Additional string added to the front of the generated files name_
PREFIX = 'default'#@param {type:"string"}


#@markdown ## Output Settings
MOUNTED_DRIVE_DIR = "/content"
FILE_PREFIX = PREFIX+"dataset"

#@markdown ## Randomization Ranges

#@markdown <font size="-1.5"> _BASE_NOTE= Root note of the generated ust, is in midi format (36 = C2, 72 = C4)_
BASE_NOTE_MIN = 36 #@param {type:"integer", min:0, max:127}
BASE_NOTE_MAX = 72 #@param {type:"integer", min:0, max:127}
#@markdown <font size="-1.5"> _SIGMA_VALUE= Standard distribution sigma value, determine the "spreadiness" of ust BASE_NOTE distribution_
SIGMA_VALUE = 0.25 #@param {type:"number", min:0,max:1}

#@markdown <font size="-1.5"> _NOTE_RANGE= Range between BASE_NOTE and the other note generated in the ust in semitones_
NOTE_RANGE_MIN = 4 #@param {type:"integer", min:1, max:20}
NOTE_RANGE_MAX = 5 #@param {type:"integer", min:1, max:20}


LENGTH_VARIANT_RANGE = (0.0, 1.0)

#@markdown <font size="-1.5"> _EXTRA_SP_RANGE= Making random SP appears in between the notes in ust generated_
EXTRA_SP_RANGE = (1, 3)
LENGTH_MODES = ["normal", "short", "long"]  # Length distribution modes

BASE_NOTE_RANGE = (BASE_NOTE_MIN, BASE_NOTE_MAX)
NOTE_RANGE_RANGE = (NOTE_RANGE_MIN, NOTE_RANGE_MAX)

#@markdown ## UST Settings
UST_FLAGS = "g0B0H0P86" #@param {type:"string"}

# ------------------------
# Fixed Global Constants
# ------------------------
BPM = 120
TICKS_PER_BEAT = 480
TOTAL_TICKS = 24 * TICKS_PER_BEAT   # 6 bars * 4 beats
SP_TICK = 240                       # half beat
NOTE_LENGTHS = [240, 360, 480,600,720]

# HTK Timing Constants (100ns units)
# 1 beat = 60/BPM seconds
# 1 tick = (60/BPM)/TICKS_PER_BEAT seconds
# 1 sec = 10,000,000 units (100ns)
TICK_TO_100NS = (60 / BPM / TICKS_PER_BEAT) * 10_000_000

# Pitch Bend Timing Constants (milliseconds)
# 1 tick = (60/BPM)/TICKS_PER_BEAT seconds = (60/BPM)/TICKS_PER_BEAT * 1000 ms
TICK_TO_MS = (60 / BPM / TICKS_PER_BEAT) * 1000

# -----------------------
# Japanese Hiragana Pools
# ------------------------
VOWELS = ["あ","い","う","え","お"]

SINGLE_CV = [
    "か","き","く","け","こ","さ","し","す","せ","そ",
    "た","ち","つ","て","と","な","に","ぬ","ね","の",
    "は","ひ","ふ","へ","ほ","ま","み","む","め","も",
    "ら","り","る","れ","ろ","が","ぎ","ぐ","げ","ご",
    "ざ","じ","ず","ぜ","ぞ","だ","ぢ","づ","で","ど",
    "ば","び","ぶ","べ","ぼ","ぱ","ぴ","ぷ","ぺ","ぽ",
    "や","ゆ","よ","わ","を"
]

YOON = [
    "きゃ","きゅ","きょ","しゃ","しゅ","しょ","ちゃ","ちゅ","ちょ",
    "にゃ","にゅ","にょ","ひゃ","ひゅ","ひょ","みゃ","みゅ","みょ",
    "りゃ","りゅ","りょ","ぎゃ","ぎゅ","ぎょ","じゃ","じゅ","じょ",
    "びゃ","びゅ","びょ","ぴゃ","ぴゅ","ぴょ"
]

SPECIAL = ["ん"]

# Distribution weights
LYRIC_WEIGHTS = (
    ["vowel"] * 1 +
    ["single"] * 10 +
    ["double"] * 6 +
    ["special"] * 1
)

# -----------------------
# Target Phoneme Ratios
# ------------------------
TARGET_RATIO = {
    # Vowels (~60%)
    "a": 13.0,
    "i": 12.0,
    "u": 11.0,
    "e": 11.0,
    "o": 13.0,

    # Special
    "N": 4.0,
    "SP": 7.0,

    # Core consonants
    "k": 4.0,
    "t": 4.0,
    "n": 5.0,
    "r": 4.0,
    "s": 2.0,
    "sh": 2.5,
    "m": 2.0,
    "p": 2.0,
    "g": 2.0,
    "d": 2.0,
    "z": 1.5,
    "b": 1.0,
    "h": 1.0,
    "y": 1.0,
    "w": 1.5,

    # Palatalized (low but non-zero)
    "ky": 1.2,
    "gy": 1.0,
    "ch": 1.2,
    "ry": 1.0,
    "my": 0.8,
    "py": 0.8,
    "hy": 0.8,
    "ny": 0.8,
    "by": 0.8,
}

# -------------------------------
# Hiragana to Romaji Mapping Logic
# ---------------------------------
HIRAGANA_TO_ROMAJI = {
    # Vowels
    "あ": "a", "い": "i", "う": "u", "え": "e", "お": "o",
    # Single CV - K
    "か": "ka", "き": "ki", "く": "ku", "け": "ke", "こ": "ko",
    # Single CV - S
    "さ": "sa", "し": "shi", "す": "su", "せ": "se", "そ": "so",
    # Single CV - T
    "た": "ta", "ち": "chi", "つ": "tsu", "て": "te", "と": "to",
    # Single CV - N
    "な": "na", "に": "ni", "ぬ": "nu", "ね": "ne", "の": "no",
    # Single CV - H
    "は": "ha", "ひ": "hi", "ふ": "fu", "へ": "he", "ほ": "ho", # Changed fu to f and he to h here as well
    # Single CV - M
    "ま": "ma", "み": "mi", "む": "mu", "め": "me", "も": "mo",
    # Single CV - R
    "ら": "ra", "り": "ri", "る": "ru", "れ": "re", "ろ": "ro",
    # Single CV - G
    "が": "ga", "ぎ": "gi", "ぐ": "gu", "げ": "ge", "ご": "go",
    # Single CV - Z
    "ざ": "za", "じ": "ji", "ず": "zu", "ぜ": "ze", "ぞ": "zo",
    # Single CV - D
    "だ": "da", "ぢ": "ji", "づ": "zu", "で": "de", "ど": "do",
    # Single CV - B
    "ば": "ba", "び": "bi", "ぶ": "bu", "べ": "be", "ぼ": "bo",
    # Single CV - P
    "ぱ": "pa", "ぴ": "pi", "ぷ": "pu", "ぺ": "pe", "ぽ": "po",
    # Single CV - Y/W
    "や": "ya", "ゆ": "yu", "よ": "yo", "わ": "wa", "を": "wo",
    # YOON - K
    "きゃ": "kya", "きゅ": "kyu", "きょ": "kyo",
    # YOON - S
    "しゃ": "sha", "しゅ": "shu", "しょ": "sho",
    # YOON - T
    "ちゃ": "cha", "ちゅ": "chu", "ちょ": "cho",
    # YOON - N
    "にゃ": "nya", "にゅ": "nyu", "にょ": "nyo",
    # YOON - H
    "ひゃ": "hya", "ひゅ": "hyu", "ひょ": "hyo",
    # YOON - M
    "みゃ": "mya", "みゅ": "myu", "みょ": "myo",
    # YOON - R
    "りゃ": "rya", "りゅ": "ryu", "りょ": "ryo",
    # YOON - G
    "ぎゃ": "gya", "ぎゅ": "gyu", "ぎょ": "gyo",
    # YOON - J
    "じゃ": "ja", "じゅ": "ju", "じょ": "jo",
    # YOON - B
    "びゃ": "bya", "びゅ": "byu", "びょ": "byo",
    # YOON - P
    "ぴゃ": "pya", "ぴゅ": "pyu", "ぴょ": "pyo",
    # Special
    "ん": "n"
}

def hiragana_to_romaji(hiragana):
    """Convert hiragana to romaji. Returns the original string if not found in mapping."""
    return HIRAGANA_TO_ROMAJI.get(hiragana, hiragana)

# ---------------------------
# Phoneme Decomposition Logic
# ---------------------------
HIRAGANA_TO_PHONEME = {
    # Vowels (keep as is)
    "あ": "a", "い": "i", "う": "u", "え": "e", "お": "o",
    # Single CV - K
    "か": "k", "き": "k", "く": "k", "け": "k", "こ": "k",
    # Single CV - S
    "さ": "s", "し": "sh", "す": "s", "せ": "s", "そ": "s",
    # Single CV - T
    "た": "t", "ち": "ch", "つ": "ts", "て": "t", "と": "t",
    # Single CV - N
    "な": "n", "に": "n", "ぬ": "n", "ね": "n", "の": "n",
    # Single CV - H
    "は": "h", "ひ": "h", "ふ": "f", "へ": "h", "ほ": "h",
    # Single CV - M
    "ま": "m", "み": "m", "む": "m", "め": "m", "も": "m",
    # Single CV - R
    "ら": "r", "り": "r", "る": "r", "れ": "r", "ろ": "r",
    # Single CV - G
    "が": "g", "ぎ": "g", "ぐ": "g", "げ": "g", "ご": "g",
    # Single CV - Z
    "ざ": "z", "じ": "j", "ず": "z", "ぜ": "z", "ぞ": "z",
    # Single CV - D
    "だ": "d", "ぢ": "j", "づ": "z", "で": "d", "ど": "d",
    # Single CV - B
    "ば": "b", "び": "b", "ぶ": "b", "べ": "b", "ぼ": "b",
    # Single CV - P
    "ぱ": "p", "ぴ": "p", "ぷ": "p", "ぺ": "p", "ぽ": "p",
    # Single CV - Y/W
    "や": "y", "ゆ": "y", "よ": "y", "わ": "w", "を": "w",
    # Special
    "ん": "N"
}

# Vowel mapping for yoon
YOON_VOWEL_MAP = {
    "ゃ": "a", "ゅ": "u", "ょ": "o"
}

NOTE_NAMES = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]


def extract_track_name_from_voicepart(voicepart_name: str) -> str:
    """
    Convert a voice_part name like:
        mzo_neutral_001_A3_9
    into:
        001_A3_9
    """
    parts = voicepart_name.split("_")
    if len(parts) >= 3:
        return "_".join(parts[-3:])
    return voicepart_name



def generate_singer_tracks_from_voiceparts(
    voice_parts: list,
    default_singer: str,
    phonemizer: str,
    renderer_settings: dict
) -> list:
    """
    Generate USTX singer tracks where track_name is derived
    from the corresponding voice_part.name.
    """
    tracks = {}

    for vp in voice_parts:
        track_no = vp["track_no"]  # 1-based
        if track_no not in tracks:
            track_name = extract_track_name_from_voicepart(vp["name"])
            tracks[track_no] = {
                "singer": default_singer,
                "phonemizer": phonemizer,
                "renderer_settings": renderer_settings,
                "track_name": track_name,
                "track_color": "Blue",
                "mute": False,
                "solo": False,
                "volume": 0,
                "pan": 0,
                "track_expressions": [],
                "voice_color_names": [""],
            }

    # Return tracks sorted by track_no
    return [tracks[k] for k in sorted(tracks.keys())]

def pick_base_note_gaussian(
    base_note_min: int,
    base_note_max: int,
    batch_count: int,
    sigma_scale: float = 0.25
) -> int:
    """
    Pick a base note using a normal (Gaussian) distribution,
    centered in the allowed range, but still randomized.

    Args:
        base_note_min: minimum MIDI note
        base_note_max: maximum MIDI note
        batch_count: total number of files being generated
        sigma_scale: controls spread (0.15 = tight, 0.3 = wide)

    Returns:
        int: selected base_note (MIDI)
    """
    center = (base_note_min + base_note_max) / 2

    # Wider distribution when generating more files
    note_range = base_note_max - base_note_min
    sigma = max(1.0, note_range * sigma_scale)

    for _ in range(10):  # retry to stay in bounds
        value = int(round(random.gauss(center, sigma)))
        if base_note_min <= value <= base_note_max:
            return value

    # Fallback (uniform random if gaussian fails repeatedly)
    return random.randint(base_note_min, base_note_max)


def midi_to_note_name(midi_note: int) -> str:
    """
    Convert MIDI note number to scientific pitch notation.

    Examples:
        0  -> C-1
        12 -> C0
        24 -> C1
        30 -> F#1
        47 -> B2
    """
    note_name = NOTE_NAMES[midi_note % 12]
    octave = (midi_note // 12) - 1
    return f"{note_name}{octave}"

def decompose_lyric(lyric):
    """Decompose hiragana lyric into phonemes for LAB file."""
    if lyric == "SP":
        return ["SP"]

    if lyric in VOWELS:
        # Vowels stay as single phoneme
        return [HIRAGANA_TO_PHONEME.get(lyric, lyric)]

    if lyric == "ん":
        return ["N"]

    if len(lyric) == 1:
        # Single CV: decompose into consonant + vowel
        # e.g., "た" (ta) -> ["t", "a"], "ご" (go) -> ["g", "o"]
        consonant = HIRAGANA_TO_PHONEME.get(lyric, "")
        if not consonant:
            return [lyric]

        # Get vowel from romaji mapping
        romaji = HIRAGANA_TO_ROMAJI.get(lyric, "")
        if len(romaji) > 1:
            # Extract vowel (last character)
            vowel = romaji[-1]
            return [consonant, vowel]
        # If romaji is single char, it's likely a vowel or special case
        return [consonant]

    if len(lyric) == 2:  # yōon (e.g., 'きゃ', 'しょ')
        # e.g., "きゃ" (kya) -> ["ky", "a"], "しゃ" (sha) -> ["sh", "a"]
        consonant_char = lyric[0]
        yoon_char = lyric[1]
        consonant = HIRAGANA_TO_PHONEME.get(consonant_char, "")
        vowel = YOON_VOWEL_MAP.get(yoon_char, "")

        if consonant and vowel:
            # Only add 'y' for single-character consonants (k, s, t, etc.)
            # Multi-character consonants (sh, ch, ts, j) should remain as-is
            if len(consonant) == 1:
                return [consonant + "y", vowel]
            else:
                return [consonant, vowel]
        # Fallback if mapping fails
        return [lyric]

    # Fallback
    return [lyric]

def get_hiragana_phoneme_map():
    """Create a mapping from hiragana to the phonemes it produces."""
    mapping = {}
    all_hiragana = VOWELS + SINGLE_CV + YOON + SPECIAL

    for hiragana in all_hiragana:
        mapping[hiragana] = decompose_lyric(hiragana)

    return mapping

# Pre-compute hiragana to phoneme mapping
HIRAGANA_PHONEME_MAP = get_hiragana_phoneme_map()

# ---------------------------
# Template / Sample Config Loading
# ---------------------------
def load_ust_header(path=None):
    default_header = ["[#SETTING]", "Tempo=120"]
    if path is None or not os.path.exists(path):
        return default_header

    header = []
    try:
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('[#0') or line.startswith('[#TRACK'):
                    break
                if line:
                    header.append(line)
    except Exception as e:
        print(f"Warning: Could not read UST header: {e}")
        return default_header
    return header

def detect_lab_format(path=None):
    """Returns 'htk' if timed (start end phn), else 'simple' (phn phn ...)"""
    if path is None or not os.path.exists(path):
        return 'htk'  # Default to HTK format
    try:
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split()
                # Check for HTK format: start_time end_time phoneme
                if len(parts) >= 3 and parts[0].replace('.','',1).isdigit():
                    return 'htk'
                if line.strip():
                    break
    except:
        pass
    return 'simple'

# Pre-load configuration
UST_HEADER_TEMPLATE = load_ust_header()
LAB_OUTPUT_FORMAT = detect_lab_format()
print(f"Generator Configured: UST Header lines={len(UST_HEADER_TEMPLATE)}, LAB Format={LAB_OUTPUT_FORMAT}")

# -----------------------------------
# OpenUTAU Pitch Curves & Vibrato Logic
# -----------------------------------
def generate_pitch_curve(note_length_ticks, current_note_pitch, next_note_length_ticks=None, next_note_pitch=None):
    """
    Generate PBS, PBW, PBY, and PBM following UTAU Mode2 pitchbend rules.

    PBS: start_time;start_pitch
      - start_time: 10-25% of current note length, negative value (ms)
      - start_pitch: based on pitch difference between current and next note (decacents)

    PBW: comma-separated widths in ms between control points
      - Final PBW should not exceed 25% of next note's length
      - Calculated based on pitchBendTotalWidth

    PBY: comma-separated pitch offsets in decacents for control points after first
      - Final point must have PBY = 0
      - All values between -200 and 200

    PBM: interpolation mode string (empty = S-curve, 's' = linear, 'j' = J-curve, 'r' = R-curve)

    Args:
        note_length_ticks: Length of the current note in ticks
        current_note_pitch: MIDI note number of current note
        next_note_length_ticks: Length of next note in ticks (None if last note)
        next_note_pitch: MIDI note number of next note (None if last note)

    Returns:
        tuple: (PBS, PBW, PBY, PBM) strings
    """
    # Convert note length to milliseconds
    note_length_ms = note_length_ticks * TICK_TO_MS

    # Generate PBS: start_time;start_pitch
    # a. start_time: 10-25% of current note's length, negative value
    start_time_percent = random.uniform(0.10, 0.25)
    start_time = -(note_length_ms * start_time_percent)

    # b. start_pitch: based on pitch difference between current and next note
    # 1 semitone = 10 decacents
    # If next note is higher, start_pitch is negative (slide up)
    # If next note is lower, start_pitch is positive (slide down)
    if next_note_pitch is not None:
        pitch_diff_semitones = next_note_pitch - current_note_pitch
        start_pitch = -pitch_diff_semitones * 10  # Convert to decacents, negate for slide direction
    else:
        # Last note: use small random value
        start_pitch = random.uniform(-20, 20)

    # Format PBS with appropriate precision
    if abs(start_time - round(start_time)) < 0.001:
        start_time_str = str(int(round(start_time)))
    else:
        start_time_str = f"{start_time:.5f}"

    if abs(start_pitch - round(start_pitch)) < 0.001:
        start_pitch_str = str(int(round(start_pitch)))
    else:
        precision = random.choice([1, 2, 5])
        start_pitch_str = f"{start_pitch:.{precision}f}"

    pbs = f"{start_time_str};{start_pitch_str}"

    # Generate PBW & PBY
    # 1. Final point should have PBY = 0, PBW should not be more than 25% of next note's length
    if next_note_length_ticks is not None:
        next_note_length_ms = next_note_length_ticks * TICK_TO_MS
        max_final_position = next_note_length_ms * 0.25  # Maximum absolute position (ms from note start)
    else:
        # Last note: use note_length_ms as max
        max_final_position = note_length_ms * 0.25

    # 2. Define pitchBendTotalWidth = last_pbw_point - start_time
    # last_pbw_point is the final absolute position, must be <= max_final_position
    # start_time is negative, so we need to ensure: start_time + pitchBendTotalWidth <= max_final_position
    # Therefore: pitchBendTotalWidth <= max_final_position - start_time

    # Calculate maximum allowed pitchBendTotalWidth
    max_pitchBendTotalWidth = max_final_position - start_time

    # Choose a final position within the allowed range
    # Final position should be between start_time and max_final_position
    # We want it to be reasonable, so let's use a portion of the available range
    # But we must ensure it never exceeds max_final_position
    target_final_position = note_length_ms * random.uniform(0.5, 0.9)

    # Clamp to maximum allowed position
    final_position = min(target_final_position, max_final_position)

    # Ensure final_position is reasonable (at least some distance from start_time)
    # Since start_time is negative, final_position should be positive
    final_position = max(final_position, abs(start_time) + 20)  # At least 20ms after note start

    # Re-clamp after adjustment
    final_position = min(final_position, max_final_position)

    # Calculate pitchBendTotalWidth
    pitchBendTotalWidth = final_position - start_time

    # Ensure we don't exceed the maximum allowed
    pitchBendTotalWidth = min(pitchBendTotalWidth, max_pitchBendTotalWidth)

    # Recalculate final_position based on clamped pitchBendTotalWidth
    final_position = start_time + pitchBendTotalWidth

    # 3. PBW values should be between PBS start_time and final PBW point
    # 4. num_points = floor(pitchBendTotalWidth / 100)
    num_points = max(1, int(math.floor(pitchBendTotalWidth / 100)))

    # 5. Calculate PBW using formula: (pitchBendTotalWidth * num_points_index) / (num_points + 1)
    # This gives cumulative positions, we need to convert to widths (differences)
    if num_points == 1:
        # Single point: no widths needed
        pbw_values = []
    else:
        # Calculate cumulative positions
        positions = []
        positions.append(start_time)  # Position 0 = start_time
        for i in range(1, num_points + 1):  # i from 1 to num_points
            # Cumulative position from start_time
            position = start_time + (pitchBendTotalWidth * i) / (num_points + 1)
            positions.append(position)

        # Convert positions to widths (differences between consecutive positions)
        pbw_values = []
        for i in range(len(positions) - 1):
            width = positions[i + 1] - positions[i]
            pbw_values.append(width)

        # Critical check: Verify final cumulative position is within limit
        # The final position is: start_time + sum of all PBW values
        final_cumulative = start_time + sum(pbw_values)

        if final_cumulative > max_final_position:
            # We exceeded the limit - need to adjust
            excess = final_cumulative - max_final_position

            # Option 1: Reduce the last PBW value
            if len(pbw_values) > 0 and pbw_values[-1] > excess:
                pbw_values[-1] = pbw_values[-1] - excess
            else:
                # Option 2: Scale down all PBW values proportionally
                if sum(pbw_values) > 0:
                    scale_factor = (max_final_position - start_time) / sum(pbw_values)
                    pbw_values = [w * scale_factor for w in pbw_values]

        # Final verification: double-check the cumulative position
        final_cumulative_check = start_time + sum(pbw_values)
        if final_cumulative_check > max_final_position:
            # Last resort: clamp to exactly the limit
            total_width = sum(pbw_values)
            if total_width > 0:
                scale = (max_final_position - start_time) / total_width
                pbw_values = [w * scale for w in pbw_values]

    # Format PBW values
    if not pbw_values:
        pbw = ""
    else:
        pbw_strs = []
        for val in pbw_values:
            if abs(val - round(val)) < 0.001:
                pbw_strs.append(str(int(round(val))))
            else:
                precision = random.choice([0, 1, 2, 5])
                pbw_strs.append(f"{val:.{precision}f}")
        pbw = ",".join(pbw_strs)

    # Generate PBY: pitch offsets in decacents for control points after first
    # PBY has num_points values (for points 1 through num_points)
    # 1. Final point must have PBY = 0
    # 4. PBY can only be between -200 and 200
    if num_points == 1:
        # Single point: no PBY values needed
        pby = ""
    else:
        pby_values = []
        current_pitch = start_pitch

        # Generate pitches for points 1 through num_points-1 (not including final)
        for i in range(num_points - 1):
            # Interpolate towards 0 (final point)
            # Gradually move from start_pitch towards 0
            progress = (i + 1) / num_points  # 0 to 1
            target_pitch = start_pitch * (1 - progress)

            # Add some randomness but keep within -200 to 200
            variation = random.uniform(-20, 20)
            new_pitch = target_pitch + variation
            new_pitch = max(-200, min(200, new_pitch))

            pby_values.append(new_pitch)
            current_pitch = new_pitch

        # Final point: PBY = 0
        pby_values.append(0)

        # Format PBY values
        pby_strs = []
        for val in pby_values:
            if abs(val - round(val)) < 0.001:
                pby_strs.append(str(int(round(val))))
            else:
                precision = random.choice([1, 2, 5])
                pby_strs.append(f"{val:.{precision}f}")
        pby = ",".join(pby_strs)

    # Generate PBM: interpolation mode string
    # PBM has num_points characters (one for each interval)
    # Empty/blank = S-curve (default), 's' = linear, 'j' = J-curve, 'r' = R-curve
    if num_points == 1:
        # Single point: no intervals, so no PBM
        pbm = ""
    elif random.random() < 0.8:  # 80% chance for default S-curve (empty)
        pbm = ""
    else:  # 20% chance for other modes
        modes = ['s', 'j', 'r']
        pbm = "".join(random.choice(modes) for _ in range(num_points))

    return pbs, pbw, pby, pbm

def vibrato_params():
    """Randomly selects vibrato mode and generates parameters for each note."""
    # Randomly choose vibrato mode: none (30%), medium (50%), high (20%)
    mode = random.choices(
        ["none", "medium", "high"],
        weights=[30, 50, 20]
    )[0]

    if mode == "none":
        return ""

    if mode == "medium":
        base = {
            "len": 60,
            "freq": 150,
            "depth": 25.0,
            "in": 30,
            "out": 30
        }
    else:  # high
        base = {
            "len": 75,
            "freq": 220,
            "depth": 45.0,
            "in": 40,
            "out": 40
        }

    v_len = random.randint(int(base["len"] * 0.6), base["len"])
    v_freq = random.randint(100, min(base["freq"], 250))
    v_depth = round(random.uniform(base["depth"] * 0.6, base["depth"]), 2)

    return f"{v_len},{v_freq},{v_depth},{base['in']},{base['out']},0,0"

# ---------------------------
# Random Selection & Pickers
# ---------------------------
def pick_lyric():
    """Picks a lyric randomly based on LYRIC_WEIGHTS distribution."""
    pool = random.choice(LYRIC_WEIGHTS)
    if pool == "vowel":
        return random.choice(VOWELS)
    if pool == "single":
        return random.choice(SINGLE_CV)
    return random.choice(YOON)

def calculate_phoneme_weights(current_counts, total_count, target_ratio):
    """
    Calculate weights for hiragana selection based on current phoneme ratios vs targets.

    Args:
        current_counts: dict of current phoneme counts
        total_count: total number of phonemes generated so far
        target_ratio: dict of target ratios

    Returns:
        dict: weights for each hiragana
    """
    if total_count == 0:
        # Equal weights at start
        return {h: 1.0 for h in VOWELS + SINGLE_CV + YOON + SPECIAL}

    weights = {}
    all_hiragana = VOWELS + SINGLE_CV + YOON + SPECIAL

    for hiragana in all_hiragana:
        phonemes = HIRAGANA_PHONEME_MAP.get(hiragana, [])
        if not phonemes:
            weights[hiragana] = 1.0
            continue

        # Calculate how much each phoneme from this hiragana needs boosting
        boost_factor = 1.0
        for phoneme in phonemes:
            if phoneme == "SP":  # Skip SP in ratio calculations
                continue

            current_ratio = (current_counts.get(phoneme, 0) / total_count * 100) if total_count > 0 else 0
            target = target_ratio.get(phoneme, 0)

            if target > 0:
                # If below 70% of target, boost significantly
                threshold = target * 0.7
                if current_ratio < threshold:
                    # Boost more if further below threshold
                    deficit = threshold - current_ratio
                    boost = 1.0 + (deficit / target) * 3.0  # Up to 4x boost
                    boost_factor = max(boost_factor, boost)
                elif current_ratio > target * 1.5:
                    # Reduce if way above target
                    excess = current_ratio - target
                    reduction = 1.0 - (excess / target) * 0.3  # Up to 30% reduction
                    boost_factor = min(boost_factor, max(0.1, reduction))

        weights[hiragana] = boost_factor

    return weights

def pick_lyric_weighted(
    current_counts,
    total_count,
    target_ratio,
    blocked_phonemes=None
):

    """
    Pick a lyric with weights based on phoneme ratio targets.

    Args:
        current_counts: dict of current phoneme counts
        total_count: total number of phonemes generated
        target_ratio: dict of target ratios

    Returns:
        str: selected hiragana
    """
    if current_counts is None or target_ratio is None:
        # Fallback to original random selection
        return pick_lyric()

    if blocked_phonemes is None:
        blocked_phonemes = set()

    weights = calculate_phoneme_weights(current_counts, total_count, target_ratio)

    for h, phs in HIRAGANA_PHONEME_MAP.items():
        if any(p in blocked_phonemes for p in phs):
            weights[h] = 0.0

    # Separate by pool
    vowel_weights = [weights.get(v, 1.0) for v in VOWELS]
    single_weights = [weights.get(s, 1.0) for s in SINGLE_CV]
    yoon_weights = [weights.get(y, 1.0) for y in YOON]
    special_weights = weights.get("ん", 1.0)

    # Weight pool selection based on average weights
    pool_weights = [
        sum(vowel_weights) / len(vowel_weights) if vowel_weights else 1.0,  # vowel
        sum(single_weights) / len(single_weights) if single_weights else 1.0,  # single
        sum(yoon_weights) / len(yoon_weights) if yoon_weights else 1.0,  # yoon
        special_weights
    ]

    pool = random.choices(["vowel", "single", "yoon",'special'], weights=pool_weights)[0]

    if pool == "vowel":
        return random.choices(VOWELS, weights=vowel_weights)[0]
    elif pool == "single":
        return random.choices(SINGLE_CV, weights=single_weights)[0]
    elif pool == "special":
        return "ん"
    else:
        return random.choices(YOON, weights=yoon_weights)[0]

def pick_length(length_mode="normal"):
    """
    Pick a note length based on weighted distribution.

    Args:
        length_mode: "normal" (equal weights), "short" (favor shorter), "long" (favor longer)

    Returns:
        int: Note length in ticks
    """
    if length_mode == "normal":
        # Equal probability for all lengths
        weights = [1, 1, 1, 1, 1]
    elif length_mode == "short":
        # Favor shorter lengths: 240 > 480 > 720 > 960
        weights = [2, 4, 2, 1,0]
    elif length_mode == "long":
        # Favor longer lengths: 960 > 720 > 480 > 240
        weights = [0, 3, 4,3,2]
    else:
        # Default to normal if invalid mode
        weights = [1, 1, 1 , 1, 1]

    return random.choices(NOTE_LENGTHS, weights=weights)[0]

def pick_note(base_note, note_range):
    return random.randint(
        base_note,
        min(base_note + note_range, 106)
    )

def validate_phoneme_ratios(
    phoneme_counts,
    target_ratio,
    under_threshold=0.7,
    over_threshold=1.1
):
    total = sum(phoneme_counts.values())
    if total == 0:
        return False, {}, {}, "No phonemes counted"

    under = {}
    over = {}
    current_ratios = {}

    for phoneme, target in target_ratio.items():
        if phoneme == "SP" or phoneme == "trash":
            continue

        current = phoneme_counts.get(phoneme, 0)
        ratio = (current / total) * 100
        current_ratios[phoneme] = ratio
        threshold_ratio = target * under_threshold

        if ratio < threshold_ratio:
            print(f"phoneme {phoneme} is underepresented (ratio: {ratio},threshold: {threshold_ratio})")
            under[phoneme] = {
                "current": ratio,
                "target": target,
                "threshold": target * under_threshold
            }
        elif ratio > target * over_threshold:
            over[phoneme] = {
                "current": ratio,
                "target": target,
                "threshold": target * over_threshold
            }

    is_valid = not under and not over

    report = [
        f"Validation: {'PASS' if is_valid else 'FAIL'}",
        f"Total phonemes: {total}",
        f"Underrepresented: {len(under)}",
        f"Overrepresented: {len(over)}",
    ]

    if under:
        report.append("\nUnderrepresented phonemes:")
        for k, v in under.items():
            report.append(f"  {k}: {v['current']:.2f}% < {v['threshold']:.2f}%")

    if over:
        report.append("\nOverrepresented phonemes:")
        for k, v in over.items():
            report.append(f"  {k}: {v['current']:.2f}% > {v['threshold']:.2f}%")

    return is_valid, under, over, "\n".join(report)


# -----------------------
# Generator Function
# -----------------------
def generate_ust_lab(
    out_dir,
    filename,
    base_note,
    note_range,
    length_variant,
    extra_sp,
    length_mode="normal",
    use_ratio_targets=False,
    blocked_phonemes=None,
    ust_flags="g0B0H0P86" # New parameter for UST flags
):
    ust = list(UST_HEADER_TEMPLATE)  # Start with sample header
    lab = []
    if blocked_phonemes is None:
        blocked_phonemes = set()

    note_idx = 0
    current_tick = 0
    current_time_ns = 0.0

    # Track phoneme counts for ratio targeting
    phoneme_counts = Counter()
    total_phoneme_count = 0

    def add_note_data(length, lyric, note_num, next_length=None, next_note=None):
        nonlocal note_idx, current_time_ns, phoneme_counts, total_phoneme_count

        # 1. Append to UST (convert hiragana to romaji for UST)
        pbs, pbw, pby, pbm = generate_pitch_curve(
            length,
            note_num,
            next_length,
            next_note
        )
        vbr = vibrato_params()  # Random vibrato per note
        ust_lyric = hiragana_to_romaji(lyric) if lyric != "SP" else "SP"

        # Build UST note entry
        note_entry = f"""[#%04d]
Length={length}
Lyric={ust_lyric}
NoteNum={note_num}
PreUtterance=
Velocity=100
Intensity=100
Modulation=0
Flags={ust_flags} # Use the ust_flags parameter here
PBS={pbs}
PBW={pbw}
PBY={pby}
PBM={pbm}""" % note_idx

        # Add VBR only if vibrato is enabled
        if vbr:
            note_entry += f"\nVBR={vbr}"

        note_entry += "\n"
        ust.append(note_entry)

        note_idx += 1

        # 2. Append to LAB (keep original hiragana for phoneme decomposition)
        phons = decompose_lyric(lyric)
        note_duration_ns = length * TICK_TO_100NS

        # Note: Phonemes are already tracked during note collection if use_ratio_targets is enabled
        # This ensures weighted selection works correctly

        if LAB_OUTPUT_FORMAT == 'htk':
            # Distribute duration among phonemes
            # Simple heuristic: Equal split for now (improving this requires alignment models)
            part_dur = note_duration_ns / len(phons)

            for ph in phons:
                start_t = int(current_time_ns)
                end_t = int(current_time_ns + part_dur)
                lab.append(f"{start_t} {end_t} {ph}")
                current_time_ns += part_dur
        else:
            # Simple list format
            lab.extend(phons)

    # Collect all notes first (excluding SP at start/end)
    notes = []
    temp_tick = SP_TICK  # Account for initial SP

    while temp_tick < TOTAL_TICKS - SP_TICK:
        length = pick_length(length_mode)
        if temp_tick + length > TOTAL_TICKS - SP_TICK:
            break

        # Use weighted selection if ratio targeting is enabled
        if use_ratio_targets:
            lyric = pick_lyric_weighted(
                current_counts=dict(phoneme_counts),
                total_count=total_phoneme_count,
                target_ratio=TARGET_RATIO,
                blocked_phonemes=blocked_phonemes
            )
            # Pre-track phonemes for this lyric to improve next selection
            phons = decompose_lyric(lyric)
            for ph in phons:
                if ph != "SP":
                    phoneme_counts[ph] += 1
                    total_phoneme_count += 1
        else:
            lyric = pick_lyric()

        note = pick_note(base_note, note_range)
        notes.append((length, lyric, note))
        temp_tick += length

    # Insert extra SP notes randomly between regular notes (not at start/end)
    if extra_sp > 0 and len(notes) > 0:
        # Calculate positions where SP can be inserted (between notes, not at edges)
        num_insert_positions = len(notes) - 1  # Between notes
        if num_insert_positions > 0:
            # Randomly select positions for SP insertion
            sp_positions = random.sample(
                range(num_insert_positions),
                min(extra_sp, num_insert_positions)
            )
            sp_positions.sort(reverse=True)  # Insert from end to maintain indices

            # Insert SP notes at selected positions
            for pos in sp_positions:
                notes.insert(pos + 1, (SP_TICK, "SP", base_note))

    # Add initial SP (with next note info if available)
    if notes:
        next_length, next_lyric, next_note = notes[0]
    else:
        next_length, next_note = None, None
    add_note_data(SP_TICK, "SP", base_note, next_length, next_note)
    current_tick += SP_TICK

    # Add all notes (regular + inserted SP)
    for i, (length, lyric, note) in enumerate(notes):
        # Get next note info if available
        if i + 1 < len(notes):
            next_length, next_lyric, next_note = notes[i + 1]
        else:
            # Last note before final SP
            next_length, next_note = SP_TICK, base_note
        add_note_data(length, lyric, note, next_length, next_note)
        current_tick += length

    # Add final SP (no next note)
    add_note_data(SP_TICK, "SP", base_note, None, None)

    # Add TRACKEND to UST
    ust.append("[#TRACKEND]")

    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    (out_dir / f"{filename}.ust").write_text("\n".join(ust), encoding="utf-8")

    if LAB_OUTPUT_FORMAT == 'htk':
        # Add trash phoneme at the end (very small duration, 100ns)
        if lab:
            # Get last end time from the last lab entry
            last_line = lab[-1]
            last_parts = last_line.strip().split()
            if len(last_parts) >= 2:
                last_end_time = int(last_parts[1])
                trash_start = last_end_time
                trash_end = trash_start + 100  # 100ns
                lab.append(f"{trash_start} {trash_end} trash")
        (out_dir / f"{filename}.lab").write_text("\n".join(lab), encoding="utf-8")
    else:
        (out_dir / f"{filename}.lab").write_text(" ".join(lab), encoding="utf-8")

# -----------------------
# Phoneme Counter Function
# -----------------------
def count_phonemes_from_lab_dir(lab_dir):
    """
    Count all phonemes from all LAB files in a directory.

    Args:
        lab_dir: Path to directory containing .lab files

    Returns:
        dict: Dictionary with phoneme counts {phoneme: count}
    """
    lab_dir = Path(lab_dir)
    if not lab_dir.exists():
        return {}

    phoneme_counts = Counter()
    lab_files = list(lab_dir.glob("*.lab"))

    if not lab_files:
        return {}

    for lab_file in lab_files:
        try:
            with open(lab_file, 'r', encoding='utf-8') as f:
                content = f.read().strip()
                if not content:
                    continue

                lines = content.split('\n')
                # Check if HTK format (first line has numeric start/end times)
                first_line_parts = lines[0].strip().split()
                is_htk = (len(first_line_parts) >= 3 and
                         first_line_parts[0].replace('.', '', 1).isdigit())

                if is_htk:
                    # HTK format: start_time end_time phoneme
                    for line in lines:
                        parts = line.strip().split()
                        if len(parts) >= 3:
                            phoneme = parts[2]  # Third element is the phoneme
                            # Exclude SP phonemes from statistics
                            if phoneme != "SP":
                                phoneme_counts[phoneme] += 1
                else:
                    # Simple format: space-separated phonemes
                    phonemes = content.split()
                    for phoneme in phonemes:
                        # Exclude SP phonemes from statistics
                        if phoneme != "SP":
                            phoneme_counts[phoneme] += 1
        except Exception as e:
            print(f"Warning: Could not read {lab_file}: {e}")
            continue

    return dict(phoneme_counts)


# -----------------------
# Prepare Output Folder
# -----------------------
output_root = Path(MOUNTED_DRIVE_DIR) / FILE_PREFIX
# Ensure output_root is empty at the start
if output_root.exists() and output_root.is_dir():
    shutil.rmtree(output_root)

output_root.mkdir(parents=True, exist_ok=True)
blocked_phonemes = set()

# -----------------------
# Generate UST + LAB
# -----------------------
for i in range(1, BATCH_COUNT + 1):
    base_note = pick_base_note_gaussian(
      base_note_min=BASE_NOTE_MIN,
      base_note_max=BASE_NOTE_MAX,
      batch_count=BATCH_COUNT,
      sigma_scale=SIGMA_VALUE
    )
    note_range = random.randint(*NOTE_RANGE_RANGE)
    length_variant = round(random.uniform(*LENGTH_VARIANT_RANGE), 2)
    extra_sp = random.randint(*EXTRA_SP_RANGE)
    length_mode = random.choice(LENGTH_MODES)  # Randomly select length mode

    base_note_name = midi_to_note_name(base_note)
    filename = f"{PREFIX}{i:03d}_{base_note_name}_{note_range}"

    # phoneme_counts is reset for each individual UST validation iteration in the loop
    # but for overall statistics, `counts` below will re-read all generated files

    generate_ust_lab(
        out_dir=output_root,
        filename=filename,
        base_note=base_note,
        note_range=note_range,
        length_variant=length_variant,
        extra_sp=extra_sp,
        length_mode=length_mode,
        use_ratio_targets=True,
        blocked_phonemes=blocked_phonemes,
        ust_flags=UST_FLAGS # Pass the user-defined UST_FLAGS
    )

    # Validate THIS UST
    counts = count_phonemes_from_lab_dir(output_root)

    valid, under, over, report = validate_phoneme_ratios(
        counts,
        TARGET_RATIO
    )

    print(f"\nUST {filename} validation:")
    print(report)

    # Feed back overrepresented phonemes to block them in subsequent generations
    blocked_phonemes = set(over.keys())

# -----------------------
# Zip Output
# -----------------------
zip_path = Path(MOUNTED_DRIVE_DIR) / f"{FILE_PREFIX}.zip"

if zip_path.exists():
    zip_path.unlink()

shutil.make_archive(
    base_name=str(zip_path).replace(".zip", ""),
    format="zip",
    root_dir=output_root
)

# -----------------------
# Count Phonemes
# -----------------------
print(f"\nBatch generation complete.")
print(f"Zipped dataset saved to:\n{zip_path}")

phoneme_counts = count_phonemes_from_lab_dir(output_root)
if phoneme_counts:
    total_phonemes = sum(phoneme_counts.values())
    print(f"\n{'='*50}")
    print(f"PHONEME STATISTICS")
    print(f"{'='*50}")
    print(f"Total phonemes: {total_phonemes}")
    print(f"Unique phonemes: {len(phoneme_counts)}")
    print(f"\nPhoneme counts (sorted by frequency):")
    print(f"{'-'*50}")
    for phoneme, count in sorted(phoneme_counts.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / total_phonemes) * 100
        print(f"  {phoneme:10s}: {count:6d} ({percentage:5.2f}%)")
    print(f"{'='*50}")

    # Validate phoneme ratios
    print(f"\n{'='*50}")
    print(f"PHONEME RATIO VALIDATION")
    print(f"{'='*50}")
    is_valid, under, over, report = validate_phoneme_ratios(phoneme_counts, TARGET_RATIO, under_threshold=0.7)
    print(f"Validation: {'PASS' if is_valid else 'FAIL'}")
    print(report)
    print(f"{'='*50}")
else:
    print("\nNo LAB files found to count phonemes.")

Generator Configured: UST Header lines=2, LAB Format=htk
phoneme u is underepresented (ratio: 6.0606060606060606,threshold: 7.699999999999999)
phoneme e is underepresented (ratio: 6.0606060606060606,threshold: 7.699999999999999)
phoneme k is underepresented (ratio: 0.0,threshold: 2.8)
phoneme n is underepresented (ratio: 3.0303030303030303,threshold: 3.5)
phoneme s is underepresented (ratio: 0.0,threshold: 1.4)
phoneme m is underepresented (ratio: 0.0,threshold: 1.4)
phoneme p is underepresented (ratio: 0.0,threshold: 1.4)
phoneme g is underepresented (ratio: 0.0,threshold: 1.4)
phoneme d is underepresented (ratio: 0.0,threshold: 1.4)
phoneme z is underepresented (ratio: 0.0,threshold: 1.0499999999999998)
phoneme y is underepresented (ratio: 0.0,threshold: 0.7)
phoneme w is underepresented (ratio: 0.0,threshold: 1.0499999999999998)
phoneme ky is underepresented (ratio: 0.0,threshold: 0.84)
phoneme gy is underepresented (ratio: 0.0,threshold: 0.7)
phoneme ry is underepresented (ratio: 0

In [None]:
#@markdown # 🛠 USTX Transform
#@markdown
#@markdown This tool fixes and normalizes **USTX files** generated from imported USTs,
#@markdown making them **dataset-ready** for easier WAV export and DiffSinger / OpenUtau workflows.
#@markdown
#@markdown ---
#@markdown
#@markdown ## ✨ What This Script Does
#@markdown
#@markdown When executed, this script will:
#@markdown
#@markdown - ✅ **Automatically rename singer tracks**
#@markdown   - Track names are derived from associated `voice_parts`
#@markdown   - Example:
#@markdown     - `singer_neutral_001_A3_9` → `001_A3_9`
#@markdown
#@markdown - 🔁 **Propagate singer settings forward**
#@markdown   - Tracks without a `singer` will inherit:
#@markdown     - singer
#@markdown     - phonemizer
#@markdown     - renderer_settings
#@markdown     - voice_color_names
#@markdown     - track_expressions
#@markdown   - Source is the nearest previous singer track
#@markdown
#@markdown - 🎛 **Replicate phoneme expressions**
#@markdown   - `phoneme_expressions` from:
#@markdown     - **Track 1 (index 0) → First Note**
#@markdown   - Are copied to:
#@markdown     - all notes
#@markdown     - all voice parts
#@markdown     - all tracks
#@markdown
#@markdown This ensures **consistent expression behavior** and **clean track naming**
#@markdown for batch WAV export.
#@markdown
#@markdown ---
#@markdown
#@markdown ## 📦 Prerequisites
#@markdown
#@markdown Before running the script, make sure:
#@markdown
#@markdown 1. You have a valid **`.ustx` file**
#@markdown    - Generated by importing UST files
#@markdown
#@markdown 2. **Track 1 (index 0) is fully configured**
#@markdown    - Must already contain:
#@markdown      - `singer`
#@markdown      - `phonemizer`
#@markdown      - renderer settings
#@markdown
#@markdown 3. Expression replication (optional)
#@markdown    - If you want expressions copied, set them on:
#@markdown      - **Track 1 → First Note**
#@markdown    - Examples:
#@markdown      - Gender
#@markdown      - Velocity
#@markdown      - Dynamics
#@markdown      - moresampler parameters (`Mt`, `Mb`, `Mo`, etc.)
#@markdown
#@markdown ⚠️ Only expressions present on the **first note of Track 1**
#@markdown will be propagated.
#@markdown
#@markdown ---
#@markdown
#@markdown ## ▶️ How to Use
#@markdown
#@markdown ### Step 1 — Run the Cell
#@markdown Run the cell containing the **USTX Fixer** script.
#@markdown
#@markdown ---
#@markdown
#@markdown ### Step 2 — Upload Your USTX
#@markdown A file upload prompt will appear.
#@markdown
#@markdown - Select the `.ustx` file you want to transform
#@markdown
#@markdown ---
#@markdown
#@markdown ### Step 3 — Automatic Processing
#@markdown The script will:
#@markdown
#@markdown - Load the uploaded USTX
#@markdown - Apply all transformations:
#@markdown   - track renaming
#@markdown   - singer propagation
#@markdown   - phoneme expression replication
#@markdown
#@markdown Progress messages will be printed in the output cell.
#@markdown
#@markdown ---
#@markdown
#@markdown ### Step 4 — Download Result
#@markdown After processing completes:
#@markdown
#@markdown - A new file will be generated:
#@markdown
#@markdown   `<original_name>_dataset.ustx`
#@markdown
#@markdown - The browser will automatically download the transformed file
#@markdown
#@markdown ---
#@markdown
#@markdown ## ✅ Result
#@markdown
#@markdown You will get a **clean, dataset-ready USTX** with:
#@markdown
#@markdown - Properly named tracks
#@markdown - Consistent singer and expression settings
#@markdown - Ready for batch WAV export or DiffSinger processing
#@markdown
#@markdown ---
#@markdown
#@markdown ## 🧩 Notes & Tips
#@markdown
#@markdown - Each singer track is renamed **only once**
#@markdown - Phonemizer-only tracks are preserved
#@markdown - Safe to re-run on the same file (mostly idempotent)
#@markdown
#@markdown ---
#@markdown

from google.colab import files
import yaml
import copy
import os

# ---------------------------
# Helper: extract track name
# ---------------------------
def extract_track_name_from_voicepart(name: str) -> str:
    """
    mzo_neutral_001_A3_9 -> 001_A3_9
    """
    parts = name.split("_")
    return "_".join(parts[-3:]) if len(parts) >= 3 else name

def propagate_phoneme_expressions_from_first_note(voice_parts):
    """
    Copy voice_parts[0].notes[0].phoneme_expressions
    to all notes in all other voice_parts.
    """
    if (
        not voice_parts
        or not voice_parts[0].get("notes")
        or not voice_parts[0]["notes"][0].get("phoneme_expressions")
    ):
        print("ℹ No phoneme_expressions to propagate")
        return

    template = copy.deepcopy(
        voice_parts[0]["notes"][0]["phoneme_expressions"]
    )

    for vp_index, vp in enumerate(voice_parts):
        if vp_index == 0:
            continue

        for note in vp.get("notes", []):
            note["phoneme_expressions"] = copy.deepcopy(template)

        print(f"✔ Applied phoneme_expressions → voice_parts[{vp_index}]")

def propagate_singers_forward(tracks):
    """
    For tracks without 'singer', copy singer-related fields
    from the nearest previous singer track.
    """
    last_singer_track = None

    for i, track in enumerate(tracks):
        if "singer" in track:
            last_singer_track = track
            continue

        if last_singer_track is not None:
            # Copy all singer-related keys
            for key in (
                "singer",
                "phonemizer",
                "renderer_settings",
                "voice_color_names",
                "track_expressions",
            ):
                if key in last_singer_track:
                    track[key] = copy.deepcopy(last_singer_track[key])

            print(f"➕ Inherited singer for tracks[{i}] from previous singer")





# ---------------------------
# Core processing function
# ---------------------------
def update_track_names_from_voiceparts(ustx: dict) -> dict:
    new_ustx = copy.deepcopy(ustx)

    def find_previous_singer_track(tracks, start_index):
      """
      Walk backward from start_index to find the nearest singer track.
      """
      for i in range(start_index, -1, -1):
          if "singer" in tracks[i]:
              return tracks[i]
      return None

    def find_previous_singer_track_index(tracks, start_index):
      for i in range(start_index, -1, -1):
          if "singer" in tracks[i]:
              return i
      return None


    tracks = new_ustx.get("tracks", [])
    voice_parts = new_ustx.get("voice_parts", [])
    print(f"Found {len(tracks)} tracks and {len(voice_parts)} voice parts.")

    # 🔥 THIS IS THE MISSING STEP
    propagate_singers_forward(tracks)

    propagate_phoneme_expressions_from_first_note(voice_parts)

    # Only singer tracks (skip phonemizer-only tracks)
    singer_tracks = [t for t in tracks if "singer" in t]


    updated_singer_tracks = set()

    for vp in voice_parts:
        track_no = vp.get("track_no")
        vp_name = vp.get("name")

        track_index = track_no
        singer_index = find_previous_singer_track_index(tracks, track_index)

        if singer_index is None:
            continue

        # 🔒 Prevent overwriting the same singer repeatedly
        if singer_index in updated_singer_tracks:
            continue

        new_name = extract_track_name_from_voicepart(vp_name)
        tracks[singer_index]["track_name"] = new_name
        updated_singer_tracks.add(singer_index)

        print(f"✔ Updated tracks[{singer_index}].track_name → {new_name}")


    return new_ustx


# ---------------------------
# Upload file
# ---------------------------
uploaded = files.upload()
ustx_path = list(uploaded.keys())[0]

print(f"Loaded: {ustx_path}")

# ---------------------------
# Read, process, write
# ---------------------------
with open(ustx_path, "r", encoding="utf-8") as f:
    ustx_data = yaml.safe_load(f)

fixed_ustx = update_track_names_from_voiceparts(ustx_data)

output_path = os.path.splitext(ustx_path)[0] + "_dataset.ustx"

with open(output_path, "w", encoding="utf-8") as f:
    yaml.dump(
        fixed_ustx,
        f,
        allow_unicode=True,
        sort_keys=False
    )

print(f"Saved: {output_path}")

# ---------------------------
# Download result
# ---------------------------
files.download(output_path)

Saving mzo_soft.ustx to mzo_soft.ustx
Loaded: mzo_soft.ustx
Found 80 tracks and 80 voice parts.
➕ Inherited singer for tracks[1] from previous singer
➕ Inherited singer for tracks[2] from previous singer
➕ Inherited singer for tracks[3] from previous singer
➕ Inherited singer for tracks[4] from previous singer
➕ Inherited singer for tracks[5] from previous singer
➕ Inherited singer for tracks[6] from previous singer
➕ Inherited singer for tracks[7] from previous singer
➕ Inherited singer for tracks[8] from previous singer
➕ Inherited singer for tracks[9] from previous singer
➕ Inherited singer for tracks[10] from previous singer
➕ Inherited singer for tracks[11] from previous singer
➕ Inherited singer for tracks[12] from previous singer
➕ Inherited singer for tracks[13] from previous singer
➕ Inherited singer for tracks[14] from previous singer
➕ Inherited singer for tracks[15] from previous singer
➕ Inherited singer for tracks[16] from previous singer
➕ Inherited singer for tracks[17]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>