In [None]:
"""
MARKOV-LSTM-MARKOV FILTER: MORPHOLOGY PARSER WITH PRIORS
========================================================

This notebook implements a morphology parser for Quechua that combines:
1. BiLSTM neural network for boundary prediction
2. HMM (Hidden Markov Model) priors based on suffix patterns
3. Privileged knowledge (K-teacher) for regularization

The parser segments Quechua words into morphemes by predicting boundary positions
between tokens. It uses:
- Gold standard data (Sue Kalt dataset) as the base training data
- Optional synthetic data augmentation from GPT models (gpt4o or gpt5mini)
- HMM priors trained on suffix patterns to guide segmentation
- K-teacher regularization to improve generalization

Key Features:
- Configurable synthetic data augmentation (none, gpt4o, gpt5mini)
- Model checkpointing: saves/loads models to avoid retraining
- Comprehensive evaluation metrics (precision, recall, F1, exact match)
- Suffix validation to filter invalid segmentations

All data is read from the 'data' folder and models are saved to the 'models' folder.
"""

import pandas as pd
import os
import json
import hashlib
import pickle
import torch
import torch.nn as nn

In [None]:
# =========================
# DATA FOLDER CONFIGURATION
# =========================
# All data files should be read from and saved to the data folder
DATA_FOLDER = "data"
MODEL_NAME = "Markov-LSTM-MarkovFilter"
MODELS_FOLDER = f"models_{MODEL_NAME}"

# Create models folder if it doesn't exist
os.makedirs(MODELS_FOLDER, exist_ok=True)

# =========================
# CONFIGURATION: SYNTHETIC DATA AUGMENTATION
# =========================
# Choose which synthetic data to use for augmentation:
#   "none"     - Use only gold standard data (no augmentation)
#   "gpt4o"    - Augment with GPT-4o synthetic segmentations
#   "gpt5mini" - Augment with GPT-5-mini synthetic segmentations
SYNTHETIC_DATA_CHOICE = "none"  # Change this to "none", "gpt4o", or "gpt5mini"

# =========================
# CONFIGURATION: WORD SELECTION FOR AUGMENTATION
# =========================
# Choose how to select words from common words for augmentation:
#   "all"      - Use all common words (default)
#   "first"    - Use the first n common words (sorted alphabetically)
#   "random"   - Use n randomly selected common words
AUGMENTATION_WORD_SELECTION = "random"  # Change this to "all", "first", or "random"
AUGMENTATION_N_WORDS = 100  # Number of words to use when selection is "first" or "random" (ignored if "all")

# =========================
# LOAD GOLD STANDARD DATA
# =========================
# The gold standard dataset contains high-quality morphological segmentations
# This is the base training data that will always be used
print("Loading gold standard data...")
gold_df = pd.read_parquet(os.path.join(DATA_FOLDER, "Sue_kalt.parquet"))
gold_df['Word'] = gold_df['word']
gold_df['morph'] = gold_df['morph'].str.replace('-', ' ')  # Normalize separators
gold_df['Morph_split_str'] = gold_df['morph']  # String version
gold_df['Morph_split'] = gold_df['morph'].str.split(' ')  # List version
gold_df = gold_df[['Word', 'Morph_split', 'Morph_split_str']]
gold_df.drop_duplicates(subset='Word', keep='first', inplace=True)
gold_df.dropna(subset=['Word'], inplace=True)
print(f"Loaded {len(gold_df):,} gold standard examples")

In [None]:
gold_df.shape

In [None]:
# =========================
# LOAD AND PROCESS SYNTHETIC DATA (if augmentation is enabled)
# =========================
# This cell loads synthetic segmentations from GPT models if augmentation is enabled
# The synthetic data is filtered to remove low-quality segmentations and formatted
# to match the gold standard data structure

def load_synthetic_data(choice):
    """
    Load synthetic segmentation data based on the chosen augmentation method.
    
    Args:
        choice: One of "none", "gpt4o", or "gpt5mini"
    
    Returns:
        DataFrame with synthetic segmentations, or None if choice is "none"
    """
    if choice == "none":
        print("No synthetic data augmentation selected.")
        return None
    
    # Map choice to file name
    file_map = {
        "gpt4o": "gpt4o_synthetic_segmentations.csv",
        "gpt5mini": "gpt5mini_synthetic_segmentations.csv"
    }
    
    if choice not in file_map:
        print(f"Warning: Unknown synthetic data choice '{choice}'. Using 'none' instead.")
        return None
    
    file_path = os.path.join(DATA_FOLDER, file_map[choice])
    
    if not os.path.exists(file_path):
        print(f"Warning: Synthetic data file not found: {file_path}")
        print(f"Falling back to no augmentation.")
        return None
    
    print(f"Loading synthetic data from {file_path}...")
    df = pd.read_csv(file_path)
    
    # Remove duplicates (keep first occurrence)
    df = df.drop_duplicates(subset=['Original_Word']).reset_index(drop=True)
    
    # Filter out low-quality segmentations that contain invalid strings
    # These strings indicate the model failed or produced invalid output
    strings_to_drop = ['can\'t', 'quechua', 'sorry', 'could']
    df = df[~df['Segmented_Morphemes'].str.contains('|'.join(strings_to_drop), case=False, na=False)]
    
    # Rename and format columns to match gold standard structure
    df = df.rename(columns={'Original_Word': 'Word'})
    df['Morph_split_str'] = df['Segmented_Morphemes']
    df['Morph_split'] = df['Segmented_Morphemes'].str.split(' ')
    df = df[['Word', 'Morph_split', 'Morph_split_str']]
    
    print(f"Loaded {len(df):,} synthetic segmentations from {choice}")
    return df

# Load synthetic data based on configuration
synthetic_df = load_synthetic_data(SYNTHETIC_DATA_CHOICE)

In [None]:
# =========================
# LOAD GPT-5-MINI DATA (for comparison/analysis)
# =========================
# This cell loads GPT-5-mini data separately for comparison purposes
# Note: This is separate from the augmentation choice above

gpt_5_mini_df = pd.read_csv(os.path.join(DATA_FOLDER, "gpt5mini_synthetic_segmentations.csv"))
gpt_5_mini_df = gpt_5_mini_df.drop_duplicates(subset=['Original_Word']).reset_index(drop=True)

strings_to_drop = ['can\'t', 'quechua', 'sorry', 'could']
gpt_5_mini_df = gpt_5_mini_df[~gpt_5_mini_df['Segmented_Morphemes'].str.contains('|'.join(strings_to_drop), case=False, na=False)]

# Rename and format columns
gpt_5_mini_df = gpt_5_mini_df.rename(columns={'Original_Word': 'Word'})
gpt_5_mini_df['Morph_split_str'] = gpt_5_mini_df['Segmented_Morphemes']
gpt_5_mini_df['Morph_split'] = gpt_5_mini_df['Segmented_Morphemes'].str.split(' ')
gpt_5_mini_df = gpt_5_mini_df[['Word', 'Morph_split', 'Morph_split_str']]

In [None]:
gpt_5_mini_df.shape

In [None]:
# =========================
# LOAD GPT-4O DATA (for comparison/analysis)
# =========================
# This cell loads GPT-4o data separately for comparison purposes
# Note: This is separate from the augmentation choice above

gpt_4o_df = pd.read_csv(os.path.join(DATA_FOLDER, "gpt4o_synthetic_segmentations.csv"))
gpt_4o_df = gpt_4o_df.drop_duplicates(subset=['Original_Word']).reset_index(drop=True)

strings_to_drop = ['can\'t', 'quechua', 'sorry', 'could']
gpt_4o_df = gpt_4o_df[~gpt_4o_df['Segmented_Morphemes'].str.contains('|'.join(strings_to_drop), case=False, na=False)]

# Rename and format columns
gpt_4o_df = gpt_4o_df.rename(columns={'Original_Word': 'Word'})
gpt_4o_df['Morph_split_str'] = gpt_4o_df['Segmented_Morphemes']
gpt_4o_df['Morph_split'] = gpt_4o_df['Segmented_Morphemes'].str.split(' ')
gpt_4o_df = gpt_4o_df[['Word', 'Morph_split', 'Morph_split_str']]

In [None]:
gpt_4o_df.shape

In [None]:
gpt_5_mini_words = set(gpt_5_mini_df['Word'])
gpt_4o_words = set(gpt_4o_df['Word'])

common_words = gpt_4o_words.intersection(gpt_5_mini_words)

print("Number of common words:", len(common_words))

In [None]:
# =========================
# COMBINE GOLD AND SYNTHETIC DATA
# =========================
# Combine the gold standard data with synthetic data (if augmentation is enabled)
# Only words that appear in both GPT models are used to ensure quality

if synthetic_df is not None:
    # Find common words between GPT-4o and GPT-5-mini for quality control
    gpt_5_mini_words = set(gpt_5_mini_df['Word'])
    gpt_4o_words = set(gpt_4o_df['Word'])
    common_words = gpt_4o_words.intersection(gpt_5_mini_words)
    print(f"Number of common words between GPT models: {len(common_words):,}")
    
    # Select words based on AUGMENTATION_WORD_SELECTION configuration
    if AUGMENTATION_WORD_SELECTION == "all":
        selected_words = common_words
        print(f"Using all {len(selected_words):,} common words for augmentation")
    elif AUGMENTATION_WORD_SELECTION == "first":
        # Sort words alphabetically and take first n
        sorted_words = sorted(common_words)
        n = min(AUGMENTATION_N_WORDS, len(sorted_words))
        selected_words = set(sorted_words[:n])
        print(f"Using first {n:,} common words (alphabetically sorted) for augmentation")
    elif AUGMENTATION_WORD_SELECTION == "random":
        # Randomly sample n words
        import random
        # Use RNG if defined, otherwise use default seed of 42
        seed = RNG if 'RNG' in globals() else 42
        random.seed(seed)  # Use the same RNG seed for reproducibility
        n = min(AUGMENTATION_N_WORDS, len(common_words))
        selected_words = set(random.sample(list(common_words), n))
        print(f"Using {n:,} randomly selected common words for augmentation")
    else:
        print(f"Warning: Unknown AUGMENTATION_WORD_SELECTION '{AUGMENTATION_WORD_SELECTION}'. Using 'all' instead.")
        selected_words = common_words
    
    # Use only selected words from the chosen synthetic data
    if SYNTHETIC_DATA_CHOICE == "gpt5mini":
        df_sampled = synthetic_df[synthetic_df['Word'].isin(selected_words)]
    elif SYNTHETIC_DATA_CHOICE == "gpt4o":
        df_sampled = synthetic_df[synthetic_df['Word'].isin(selected_words)]
    else:
        df_sampled = None
    
    if df_sampled is not None and len(df_sampled) > 0:
        # Combine with gold data
        gold_df = pd.concat([df_sampled, gold_df], ignore_index=True)
        print(f"Combined dataset: {len(gold_df):,} examples ({len(df_sampled):,} synthetic + {len(gold_df) - len(df_sampled):,} gold)")
    else:
        print("No synthetic data to add (no common words found)")
else:
    print("Using only gold standard data (no augmentation)")

In [None]:
# =========================
# SAVE COMMON WORDS (if synthetic data was used)
# =========================
# Save the common words used for augmentation to the data folder for reference

if synthetic_df is not None and 'df_sampled' in locals() and df_sampled is not None:
    df_sampled = df_sampled.sort_values(by="Word")
    output_file = os.path.join(DATA_FOLDER, f"{SYNTHETIC_DATA_CHOICE}_common.parquet")
    df_sampled.to_parquet(output_file, index=False)
    print(f"Saved common words to {output_file}")

In [None]:
gold_df.head(50)

In [None]:
# =========================
# LOAD TEST DATA
# =========================
# Load the test/accuracy evaluation dataset
# This dataset is used for final evaluation of the trained model

acc_df = pd.read_parquet(os.path.join(DATA_FOLDER, "cleaned_data_df.parquet"))

print("="*60)
print("DATASET SUMMARY")
print("="*60)
print(f"Training data shape: {gold_df.shape}")
print(f"Test data shape: {acc_df.shape}")
print(f"Synthetic augmentation: {SYNTHETIC_DATA_CHOICE}")
print("="*60)

In [None]:
acc_df.head(50)

In [None]:
graphemes = [
    "ch","ll","rr","tr","kw","ph",  # digraphs/trigraphs
    "a","b","d","e","f","g","h","i","k","l","m","n","√±","o","p","q",
    "r","s","t","u","v","w","x","y"
]

In [None]:
import re

In [None]:
pattern = re.compile("|".join(sorted(graphemes, key=len, reverse=True)))

def tokenize_morphemes(morphs):
    return [pattern.findall(m.lower()) for m in morphs]

In [None]:
gold_df["Char_split"] = gold_df["Morph_split"].apply(tokenize_morphemes)

In [None]:
vowels = {"a", "i", "e", "o", "u"}

def grapheme_to_cv(grapheme):
    return "V" if grapheme in vowels else "C"

def morphs_to_cv(morphs):
    return [[grapheme_to_cv(g) for g in morph] for morph in morphs]

In [None]:
gold_df["CV_split"] = gold_df["Char_split"].apply(morphs_to_cv)

In [None]:
def cv_to_string(cv_split):
    """Convert nested CV list to dash-separated string."""
    return "-".join("".join(m) for m in cv_split)

In [None]:
str_df = pd.DataFrame()

In [None]:
import numpy as np

In [None]:
str_df["Full_chain"] = gold_df["CV_split"].apply(cv_to_string)

# Create Trimmed_chain, but use NaN if no dash
str_df["Trimmed_chain"] = str_df["Full_chain"].apply(
    lambda x: x.split("-", 1)[1] if "-" in x else np.nan
)

str_df["Word"] = gold_df["Word"]
str_df["Char_split"] = gold_df["Char_split"]
str_df["Morph_split"] = gold_df["Morph_split"]

# Drop rows where Trimmed_chain is NaN
str_df = str_df.dropna(subset=["Trimmed_chain"]).reset_index(drop=True)

In [None]:
# word length
str_df["Word_len"] = str_df["Word"].str.len()

# consonant and vowel count from Full_chain
str_df["Vowel_no"] = str_df["Full_chain"].str.count("V")
str_df["Cons_no"] = str_df["Full_chain"].str.count("C")

# tail consonant and vowel counts (last segment in Full_chain after '-')
str_df["Tail_cons_no"] = str_df["Trimmed_chain"].str.count("C")
str_df["Tail_vowel_no"] = str_df["Trimmed_chain"].str.count("V")

# number of splits from Morph_split
str_df["No_splits"] = str_df["Morph_split"].str.len()

# total y/w count in word
str_df["YW_count"] = str_df["Word"].str.count("[yw]")

# tail y/w count (all morphs except first)
str_df["Tail_YW_count"] = str_df["Morph_split"].apply(
    lambda ms: sum(m.count("y") + m.count("w") for m in ms[1:])
)

In [None]:
str_df.head()

In [None]:
import ast
import re
import math
import numpy as np
import pandas as pd
from collections import Counter, defaultdict
from sklearn.model_selection import GroupShuffleSplit
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score

In [None]:
import ast, re, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.feature_extraction import DictVectorizer
from sklearn.metrics import precision_recall_fscore_support

In [None]:
RNG = 42
torch.manual_seed(RNG)
np.random.seed(RNG)

NEW_NUM_FEATS = [
    "Word_len", "Vowel_no", "Cons_no",
    "Tail_cons_no", "Tail_vowel_no",
    "No_splits", "YW_count", "Tail_YW_count"
]


In [None]:
def safe_list(x):
    if isinstance(x, list): return x
    s = str(x)
    try:
        return ast.literal_eval(s)
    except Exception:
        s2 = s.replace("[[", "[['").replace("]]", "']]").replace("], [", "'],['").replace(", ", "','")
        return ast.literal_eval(s2)

def flatten(list_of_lists):
    out=[]
    for seg in list_of_lists: out.extend(seg)
    return [str(t) for t in out]

def extract_priv_features_from_row(row, feat_names):
    vec=[]
    for k in feat_names:
        val = row[k] if (k in row and pd.notna(row[k])) else 0.0
        try: vec.append(float(val))
        except Exception: vec.append(0.0)
    return vec

In [None]:
# %%
# ===================================================================
# NEW CODE: Suffix HMM Prior Model (Replaces Decision Tree)
# ===================================================================
import math
from collections import Counter

class SuffixHMMPrior:
    """
    Calculates boundary priors using a suffix list and the Forward-Backward algorithm.
    This model assumes segmentation proceeds from right to left.
    """
    def __init__(self, suffix_log_probs, max_suffix_len, unk_penalty=-15.0):
        """
        Args:
            suffix_log_probs (dict): A dictionary mapping a suffix string to its log probability.
            max_suffix_len (int): The maximum length of a suffix to consider.
            unk_penalty (float): The log probability assigned to any substring not in our list
                                 (i.e., the cost of it being part of the root).
        """
        self.log_probs = suffix_log_probs
        self.max_len = max_suffix_len
        self.unk_penalty = unk_penalty
        self.LOG_ZERO = -1e9 # A very small number representing log(0)

    def _get_log_prob(self, segment):
        # The cost of a segment is its suffix probability, or a penalty if it's unknown (part of the root).
        return self.log_probs.get(segment, self.unk_penalty)

    def _forward_pass(self, word):
        """Calculates the log probability of all segmentations for each prefix."""
        n = len(word)
        alpha = [self.LOG_ZERO] * (n + 1)
        alpha[0] = 0.0  # log(1) for the empty prefix

        for i in range(1, n + 1):
            # To calculate alpha[i], we sum probabilities from all previous split points j
            log_sums = []
            for j in range(max(0, i - self.max_len), i):
                segment = word[j:i]
                log_p_segment = self._get_log_prob(segment)
                log_sums.append(alpha[j] + log_p_segment)
            
            if log_sums:
                alpha[i] = torch.logsumexp(torch.tensor(log_sums), dim=0).item()
        return alpha

    def _backward_pass(self, word):
        """Calculates the log probability of all segmentations for each suffix."""
        n = len(word)
        beta = [self.LOG_ZERO] * (n + 1)
        beta[n] = 0.0  # log(1) for the empty suffix

        for i in range(n - 1, -1, -1):
            log_sums = []
            for j in range(i + 1, min(n + 1, i + self.max_len + 1)):
                segment = word[i:j]
                log_p_segment = self._get_log_prob(segment)
                log_sums.append(beta[j] + log_p_segment)

            if log_sums:
                beta[i] = torch.logsumexp(torch.tensor(log_sums), dim=0).item()
        return beta

    def get_boundary_priors(self, word):
        """
        Calculate the posterior probability of a boundary at each position i.
        P(boundary at i | word) is proportional to alpha[i] * beta[i].
        """
        n = len(word)
        if n <= 1:
            return []

        alpha = self._forward_pass(word)
        beta = self._backward_pass(word)
        
        log_total_prob = alpha[n]
        if log_total_prob == self.LOG_ZERO: # No valid segmentation found
             return [0.0] * (n - 1)

        log_priors = []
        for i in range(1, n):
            # Log probability of a boundary at i is log(alpha[i]) + log(beta[i])
            log_p_boundary = alpha[i] + beta[i]
            log_priors.append(log_p_boundary)
        
        # Normalize to get probabilities
        log_priors_tensor = torch.tensor(log_priors)
        # We subtract the log probability of the whole word to normalize
        normalized_log_priors = log_priors_tensor - log_total_prob
        
        return torch.exp(normalized_log_priors).tolist()

def train_hmm_prior(samples):
    """
    Creates the SuffixHMMPrior by calculating suffix frequencies from training data.
    This replaces `train_dt_prior`.
    """
    suffix_counts = Counter()
    max_suffix_len = 0
    
    for s in samples:
        cs = s["tokens"] # This is now a list of grapheme tokens
        morph_lens = [len(seg) for seg in safe_list(s['y_morphs'])] # Assuming y_morphs is available
        
        current_idx = len(cs)
        # Iterate backwards through morphemes (which are the suffixes)
        for morph_len in reversed(morph_lens[1:]): # Skip the root
            start_idx = current_idx - morph_len
            suffix_tokens = cs[start_idx:current_idx]
            suffix_str = "".join(suffix_tokens)
            
            suffix_counts[suffix_str] += 1
            max_suffix_len = max(max_suffix_len, len(suffix_str))
            current_idx = start_idx

    total_suffix_obs = sum(suffix_counts.values())
    
    # Calculate log probabilities with Laplace smoothing
    log_probs = {
        suffix: math.log((count + 1) / (total_suffix_obs + len(suffix_counts)))
        for suffix, count in suffix_counts.items()
    }

    # Heuristic penalty for unknown segments (roots). Should be lower than rare suffixes.
    avg_log_prob = sum(log_probs.values()) / len(log_probs) if log_probs else 0
    unk_penalty = avg_log_prob * 1.5 

    print(f"HMM Prior: Found {len(log_probs)} unique suffixes. Max length: {max_suffix_len}. Unk penalty: {unk_penalty:.2f}")
    
    return SuffixHMMPrior(log_probs, max_suffix_len, unk_penalty=unk_penalty)

# ===================================================================
# REVISED CODE: Create HMM Prior from a user-provided suffix list
# ===================================================================
import math

def create_hmm_prior_from_list(allowed_suffixes: list, unk_penalty: float = -15.0):
    """
    Creates the SuffixHMMPrior using a predefined list of allowed suffixes.
    This replaces `train_hmm_prior`.

    Args:
        allowed_suffixes (list): A list of valid Quechua suffix strings.
    """
    if not allowed_suffixes:
        raise ValueError("The provided suffix list cannot be empty.")

    # Assign a high, uniform log probability (e.g., log(1)=0) to all known suffixes.
    # This expresses a strong preference for using these segments.
    suffix_log_probs = {suffix: 0.0 for suffix in allowed_suffixes}

    # The max length is determined by your list.
    max_suffix_len = len(max(allowed_suffixes, key=len))

    # A penalty for any segment NOT in the list (i.e., part of a root).
    # This should be a reasonably large negative number.

    print(f"HMM Prior: Initialized with {len(allowed_suffixes)} provided suffixes. Max length: {max_suffix_len}.")

    # The SuffixHMMPrior class itself does not need to change.
    return SuffixHMMPrior(suffix_log_probs, max_suffix_len, unk_penalty=unk_penalty)

# We need to add the gold morphemes to the sample builder for training the HMM
def build_samples_with_priv(df, feat_names=NEW_NUM_FEATS):
    rows = []
    for _, r in df.iterrows():
        cs = safe_list(r["Char_split"])
        toks = flatten(cs)
        lens = [len(seg) for seg in cs]
        cut_idxs = set(np.cumsum(lens)[:-1].tolist())
        y = [1 if (i+1) in cut_idxs else 0 for i in range(len(toks)-1)]
        priv = extract_priv_features_from_row(r, feat_names)
        
        # ADD GOLD MORPHEMES (needed for HMM training)
        gold_morphs = ["".join(seg) for seg in cs]

        rows.append({"tokens": toks, "y": y, "priv": priv, "y_morphs": gold_morphs})
    return rows

In [None]:
def featurize_window(tokens, i, k_left=2, k_right=2):
    feats = {}
    for k in range(1, k_left+1):
        idx = i-(k-1); feats[f"L{k}"] = tokens[idx] if idx >= 0 else "<BOS>"
    for k in range(1, k_right+1):
        idx = i+k; feats[f"R{k}"] = tokens[idx] if idx < len(tokens) else "<EOS>"
    def is_vowel(ch): return ch.lower() in "aeiou√°√©√≠√≥√∫"
    L1 = feats["L1"]; R1 = feats["R1"]
    feats["L1_cv"] = 'V' if is_vowel(L1[-1]) else 'C'
    feats["R1_cv"] = 'V' if (R1 != "<EOS>" and is_vowel(R1[0])) else 'C'
    feats["L1_last"] = L1[-1]
    feats["R1_first"] = R1[0] if R1 != "<EOS>" else "<EOS>"
    return feats

# %%
# ===================================================================
# MODIFIED CODE: Use HMM for prior calculation
# ===================================================================

def prior_probs_for_sample(hmm_prior, tokens):
    """
    Generates prior probabilities for a single tokenized sample using the HMM.
    This replaces the DT-based version.
    """
    if hmm_prior is None or len(tokens) <= 1:
        return [0.5] * (max(len(tokens) - 1, 0))

    word = "".join(tokens)
    # HMM gives character-level boundary probabilities
    char_priors = hmm_prior.get_boundary_priors(word)

    # Map character-level priors to token-level boundary priors
    token_boundary_indices = np.cumsum([len(t) for t in tokens[:-1]]) - 1
    
    token_priors = []
    for idx in token_boundary_indices:
        if 0 <= idx < len(char_priors):
            token_priors.append(char_priors[idx])
        else:
            token_priors.append(0.5) # Fallback for any index issue

    return token_priors

In [None]:
def train_k_teacher_priv(samples, feat_dim):
    """
    Train a regressor to predict K (number of cuts) from priv feature vector.
    """
    X = np.array([s["priv"] for s in samples], dtype=float)   # (N, F)
    y = np.array([int(np.sum(s["y"])) for s in samples], dtype=float)
    reg = DecisionTreeRegressor(max_depth=6, min_samples_leaf=10, random_state=RNG)
    reg.fit(X, y)
    return reg

def predict_k_hat_priv(reg, priv_batch):
    # priv_batch: (B, F) float tensor
    with torch.no_grad():
        k = reg.predict(priv_batch.cpu().numpy())
    return torch.tensor(k, dtype=torch.float32, device=priv_batch.device)

In [None]:
# ===================================================================
# DEMONSTRATION: HMM Prior Processing with Actual Model
# ===================================================================
import torch
import math

# Add verbose method to SuffixHMMPrior class for demonstration
def get_boundary_priors_verbose(self, word):
    """
    Calculate boundary priors with detailed verbose output showing
    forward pass, backward pass, and intermediate calculations.
    """
    n = len(word)
    if n <= 1:
        return []

    print(f"\n{'='*70}")
    print(f"HMM PRIOR PROCESSING: '{word}' (length={n})")
    print(f"{'='*70}")
    
    # Show known suffixes (sample of them)
    print(f"\nHMM Prior Configuration:")
    print(f"  Max suffix length: {self.max_len}")
    print(f"  Unknown penalty: {self.unk_penalty:.4f}")
    print(f"  Number of known suffixes: {len(self.log_probs)}")
    if len(self.log_probs) > 0:
        sample_suffixes = list(self.log_probs.items())[:10]
        print(f"  Sample suffixes (first 10):")
        for suffix, log_prob in sample_suffixes:
            print(f"    '{suffix}': log P = {log_prob:.4f} (P = {math.exp(log_prob):.6f})")
        if len(self.log_probs) > 10:
            print(f"    ... and {len(self.log_probs) - 10} more")
    
    # Forward pass
    print(f"\n{'‚îÄ'*70}")
    print("FORWARD PASS (Œ±) - Computing log probabilities of all prefixes")
    print(f"{'‚îÄ'*70}")
    alpha = self._forward_pass(word)
    
    for i in range(n + 1):
        prefix = word[:i] if i > 0 else "<empty>"
        if i == 0:
            print(f"  Œ±[{i}] = {alpha[i]:.4f}  (empty prefix, base case)")
        else:
            # Show which segments were considered
            candidates = []
            for j in range(max(0, i - self.max_len), i):
                segment = word[j:i]
                log_p_seg = self._get_log_prob(segment)
                candidates.append((j, segment, log_p_seg, alpha[j]))
            
            # Show top 3 candidates
            candidates_sorted = sorted(candidates, key=lambda x: x[3] + x[2], reverse=True)[:3]
            print(f"  Œ±[{i}] = {alpha[i]:.4f}  (prefix: '{prefix}')")
            for j, seg, log_p, prev_alpha in candidates_sorted:
                total = prev_alpha + log_p
                print(f"      Candidate: j={j}, segment='{seg}', log P(seg)={log_p:.4f}, "
                      f"Œ±[{j}]={prev_alpha:.4f}, total={total:.4f}")
    
    # Backward pass
    print(f"\n{'‚îÄ'*70}")
    print("BACKWARD PASS (Œ≤) - Computing log probabilities of all suffixes")
    print(f"{'‚îÄ'*70}")
    beta = self._backward_pass(word)
    
    for i in range(n, -1, -1):
        suffix = word[i:] if i < n else "<empty>"
        if i == n:
            print(f"  Œ≤[{i}] = {beta[i]:.4f}  (empty suffix, base case)")
        else:
            # Show which segments were considered
            candidates = []
            for j in range(i + 1, min(n + 1, i + self.max_len + 1)):
                segment = word[i:j]
                log_p_seg = self._get_log_prob(segment)
                candidates.append((j, segment, log_p_seg, beta[j]))
            
            # Show top 3 candidates
            candidates_sorted = sorted(candidates, key=lambda x: x[3] + x[2], reverse=True)[:3]
            print(f"  Œ≤[{i}] = {beta[i]:.4f}  (suffix: '{suffix}')")
            for j, seg, log_p, next_beta in candidates_sorted:
                total = next_beta + log_p
                print(f"      Candidate: j={j}, segment='{seg}', log P(seg)={log_p:.4f}, "
                      f"Œ≤[{j}]={next_beta:.4f}, total={total:.4f}")
    
    # Boundary prior computation
    print(f"\n{'‚îÄ'*70}")
    print("BOUNDARY PRIOR COMPUTATION")
    print(f"{'‚îÄ'*70}")
    log_total_prob = alpha[n]
    print(f"  Total log probability: Œ±[{n}] = {log_total_prob:.4f}")
    print(f"  Total probability: P(word) = {math.exp(log_total_prob):.6f}")
    
    if log_total_prob == self.LOG_ZERO:
        print("  WARNING: No valid segmentation found!")
        return [0.0] * (n - 1)
    
    print(f"\n  Computing P(boundary at i | word) = exp(Œ±[i] + Œ≤[i] - Œ±[n])")
    print(f"\n  Position-by-position boundary probabilities:")
    
    log_priors = []
    priors = []
    for i in range(1, n):
        log_p_boundary = alpha[i] + beta[i]
        normalized_log_prior = log_p_boundary - log_total_prob
        prior = math.exp(normalized_log_prior)
        log_priors.append(log_p_boundary)
        priors.append(prior)
        
        char_before = word[i-1]
        char_after = word[i] if i < n else ""
        print(f"    Position {i} (after '{char_before}'):")
        print(f"      Œ±[{i}] + Œ≤[{i}] = {alpha[i]:.4f} + {beta[i]:.4f} = {log_p_boundary:.4f}")
        print(f"      Normalized: {log_p_boundary:.4f} - {log_total_prob:.4f} = {normalized_log_prior:.4f}")
        print(f"      P(boundary) = exp({normalized_log_prior:.4f}) = {prior:.4f}")
    
    # Final summary
    print(f"\n{'‚îÄ'*70}")
    print("FINAL OUTPUT")
    print(f"{'‚îÄ'*70}")
    print(f"Word: '{word}'")
    print(f"\nCharacter sequence: {' '.join(word)}")
    print(f"Boundary probabilities:")
    print(f"  {' '.join([f'{p:.3f}' for p in priors])}")
    print(f"\nVisualization:")
    print(f"  {' '.join(word)}")
    print(f"  {' '.join([' ' if p < 0.3 else '|' if p < 0.7 else '||' for p in priors])}")
    print(f"  {' '.join([f'{p:.2f}' for p in priors])}")
    
    return priors

# Monkey-patch the verbose method to SuffixHMMPrior
SuffixHMMPrior.get_boundary_priors_verbose = get_boundary_priors_verbose

# Try to use existing model if available, otherwise show instructions
try:
    # Check if 'out' exists in namespace
    if 'out' in globals() and 'hmm_prior' in out:
        hmm_prior = out['hmm_prior']
        print("‚úÖ Using existing HMM prior from loaded model")
        
        # Process words from dataset - only output for correct segmentations
        # Try words from acc_df first, then gold_df if needed
        max_words_to_show = 3  # Maximum number of correct segmentations to display
        correct_count = 0
        total_count = 0
        
        # Combine words from both dataframes (acc_df first, then gold_df)
        words_to_try = []
        if 'acc_df' in globals() and len(acc_df) > 0:
            words_to_try.extend(acc_df['Word'].tolist())
        if 'gold_df' in globals() and len(gold_df) > 0:
            # Add words from gold_df that aren't already in the list
            gold_words = gold_df['Word'].tolist()
            words_to_try.extend([w for w in gold_words if w not in words_to_try])
        
        if len(words_to_try) == 0:
            print("‚ö†Ô∏è  No words found in acc_df or gold_df")
        else:
            print(f"üîç Searching through {len(words_to_try)} words for correct segmentations...")
            
            for word in words_to_try:
                try:
                    total_count += 1
                    
                    # Get predicted segmentation
                    tokens = tokenize_with_vocab(word, vocab, max_token_len=4)
                    seg_string, probs = segment_tokens(model, vocab, tokens, hmm_prior=hmm_prior, thr=thr)
                    predicted_morphs = seg_string.split('-')
                    
                    # Normalize predicted morphs to lowercase
                    pred_normalized = [m.lower().strip() for m in predicted_morphs if m.strip()]
                    
                    # Get gold segmentation from acc_df (test data)
                    gold_row = acc_df[acc_df['Word'] == word] if 'acc_df' in globals() else pd.DataFrame()
                    if len(gold_row) == 0:
                        # Try gold_df as fallback
                        gold_row = gold_df[gold_df['Word'] == word] if 'gold_df' in globals() else pd.DataFrame()
                        if len(gold_row) == 0:
                            continue  # Skip silently if word not found
                        # Use Morph_split from gold_df
                        gold_morphs = gold_row['Morph_split'].iloc[0]
                        if not isinstance(gold_morphs, list):
                            gold_morphs = list(gold_morphs) if hasattr(gold_morphs, '__iter__') else [str(gold_morphs)]
                        gold_variants = [gold_morphs]
                    else:
                        # Use Gold column from acc_df (list of variants)
                        gold_variants_raw = gold_row['Gold'].iloc[0]
                        # Normalize gold_variants (handle numpy arrays, nested structures)
                        gold_variants = normalize_gold_variants(gold_variants_raw)
                        if not isinstance(gold_variants, list) or len(gold_variants) == 0:
                            continue  # Skip silently if no valid gold variants
                    
                    # Check if prediction matches any gold variant exactly
                    is_correct = False
                    matched_gold = None
                    for gold_variant in gold_variants:
                        if not isinstance(gold_variant, list):
                            gold_variant = list(gold_variant) if hasattr(gold_variant, '__iter__') else [str(gold_variant)]
                        gold_normalized = [m.lower().strip() for m in gold_variant if m.strip()]
                        if pred_normalized == gold_normalized:
                            is_correct = True
                            matched_gold = gold_variant
                            break
                    
                    if is_correct:
                        correct_count += 1
                        # Only output verbose information for correct segmentations
                        priors = hmm_prior.get_boundary_priors_verbose(word)
                        print(f"\n‚úÖ CORRECT SEGMENTATION: '{word}'")
                        print(f"   Predicted: {seg_string}")
                        print(f"   Gold:      {'-'.join(matched_gold)}")
                        print("\n" + "="*70 + "\n")
                        
                        # Stop after finding max_words_to_show correct segmentations
                        if correct_count >= max_words_to_show:
                            break
                    # Silently skip incorrect segmentations
                            
                except Exception as e:
                    # Silently skip errors, continue to next word
                    continue
            
            print(f"\nüìä Summary: Found {correct_count} correct segmentation(s) out of {total_count} words checked.")
            if correct_count == 0:
                print("   No correct segmentations found. Try checking more words or adjusting the threshold.")
    else:
        print("‚ö†Ô∏è  No model found in memory.")
        print("\nTo use this demonstration:")
        print("1. First run your model training/loading cell (e.g., run_segmentation_with_privK)")
        print("2. Then run this cell again")
        print("\nAlternatively, you can manually specify:")
        print("  hmm_prior = out['hmm_prior']")
        print("  priors = hmm_prior.get_boundary_priors_verbose('pikunas')")
except NameError as e:
    print(f"‚ùå {e}")
    print("\nTo use this demonstration:")
    print("1. First run your model training/loading cell (e.g., run_segmentation_with_privK)")
    print("2. Then run this cell again")
    print("\nAlternatively, you can manually specify:")
    print("  hmm_prior = out['hmm_prior']")
    print("  priors = hmm_prior.get_boundary_priors_verbose('pikunas')")


In [None]:
def build_vocab(samples, min_freq=1):
    from collections import Counter
    ctr = Counter()
    for s in samples: ctr.update(s["tokens"])
    vocab = {"<PAD>":0, "<UNK>":1}
    for t,c in sorted(ctr.items(), key=lambda x: (-x[1], x[0])):
        if c>=min_freq and t not in vocab:
            vocab[t] = len(vocab)
    return vocab

class SegDataset(Dataset):
    # --- MODIFIED to take hmm_prior instead of dt_clf/dt_vec ---
    def __init__(self, samples, vocab, hmm_prior=None, feat_dim=0):
        self.samples = samples
        self.vocab = vocab
        self.hmm_prior = hmm_prior # Changed from dt_clf, dt_vec
        self.feat_dim = feat_dim

    def __len__(self): return len(self.samples)
    
    def __getitem__(self, idx):
        s = self.samples[idx]
        tokens = s["tokens"]
        ids = [self.vocab.get(t, self.vocab["<UNK>"]) for t in tokens]
        y = s["y"]
        # --- THIS LINE IS THE KEY CHANGE ---
        prior = prior_probs_for_sample(self.hmm_prior, tokens)
        priv = s["priv"] if self.feat_dim>0 else []
        return {"ids": ids, "y": y, "prior": prior, "priv": priv, "tokens": tokens}

def collate(batch):
    maxT = max(len(b["ids"]) for b in batch)
    maxB = maxT-1
    B = len(batch)

    ids = torch.full((B, maxT), 0, dtype=torch.long)
    mask_tok = torch.zeros((B, maxT), dtype=torch.bool)
    y = torch.full((B, maxB), -100, dtype=torch.long)
    prior = torch.zeros((B, maxB), dtype=torch.float32)
    mask_b = torch.zeros((B, maxB), dtype=torch.bool)

    feat_dim = len(batch[0]["priv"]) if isinstance(batch[0]["priv"], list) else 0
    priv = torch.zeros((B, feat_dim), dtype=torch.float32) if feat_dim>0 else None

    for i, b in enumerate(batch):
        T = len(b["ids"])
        ids[i,:T] = torch.tensor(b["ids"], dtype=torch.long)
        mask_tok[i,:T] = True
        if T>1:
            L = T-1
            y[i,:L] = torch.tensor(b["y"], dtype=torch.long)
            p = b["prior"] if len(b["prior"])==L else [0.5]*L
            prior[i,:L] = torch.tensor(p, dtype=torch.float32)
            mask_b[i,:L] = True
        if feat_dim>0:
            priv[i] = torch.tensor(b["priv"], dtype=torch.float32)

    return {
        "ids": ids, "mask_tok": mask_tok,
        "y": y, "prior": prior, "mask_b": mask_b,
        "priv": priv  # (B, F) or None
    }

In [None]:
class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, emb_dim=16, hidden_size=64, num_layers=2,
                 use_prior=True, dropout=0.1, freeze_emb=False, fuse_mode="logit_add"):
        super().__init__()
        self.use_prior = use_prior
        self.fuse_mode = fuse_mode
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        if freeze_emb:
            for p in self.emb.parameters(): p.requires_grad = False
        lstm_dropout = dropout if num_layers > 1 else 0.0
        self.lstm = nn.LSTM(
            input_size=emb_dim, hidden_size=hidden_size//2,
            num_layers=num_layers, dropout=lstm_dropout,
            bidirectional=True, batch_first=True
        )
        in_mlp = hidden_size + (1 if (use_prior and fuse_mode=="concat") else 0)
        self.boundary_mlp = nn.Sequential(
            nn.Linear(in_mlp, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size // 2, 2)
        )
        if use_prior and fuse_mode == "logit_add":
            self.alpha = nn.Parameter(torch.tensor(1.0))

    def forward(self, ids, prior, mask_tok):
        emb = self.emb(ids)
        h, _ = self.lstm(emb)          # (B,T,H)
        left = h[:, :-1, :]            # (B,T-1,H)
        if self.use_prior and self.fuse_mode == "concat":
            feat = torch.cat([left, prior.unsqueeze(-1)], dim=-1)
            return self.boundary_mlp(feat)
        logits = self.boundary_mlp(left)
        if self.use_prior and self.fuse_mode == "logit_add":
            eps = 1e-6
            p = prior.clamp(eps, 1-eps)
            prior_logit = torch.log(p) - torch.log(1-p)
            logits[..., 1] = logits[..., 1] + self.alpha * prior_logit
        return logits


In [None]:
def boundary_metrics_from_lists(probs_list, gold_list, thr=0.5):
    if not probs_list: return 0.0,0.0,0.0
    p = torch.cat([t for t in probs_list if t.numel()>0], dim=0).numpy()
    g = torch.cat([t for t in gold_list if t.numel()>0], dim=0).numpy()
    pred = (p >= thr).astype(int)
    P,R,F1,_ = precision_recall_fscore_support(g, pred, average='binary', zero_division=0)
    return P,R,F1

def exact_match_rate_from_lists(probs_list, gold_list, thr=0.5):
    if not probs_list: return 0.0
    em=[]
    for p,g in zip(probs_list, gold_list):
        if g.numel()==0: em.append(1.0)
        else:
            pred = (p.numpy() >= thr).astype(int)
            em.append(float(np.array_equal(pred, g.numpy())))
    return float(np.mean(em))

@torch.no_grad()
def predict(model, loader):
    model.eval()
    probs_list, gold_list = [], []
    for batch in loader:
        logits = model(batch["ids"], batch["prior"], batch["mask_tok"])
        probs = torch.softmax(logits, dim=-1)[..., 1]      # (B,T-1)
        y = batch["y"]; mask = batch["mask_b"]
        B = probs.shape[0]
        for b in range(B):
            L = int(mask[b].sum().item())
            if L==0:
                probs_list.append(torch.empty(0))
                gold_list.append(torch.empty(0, dtype=torch.long))
            else:
                probs_list.append(probs[b,:L].cpu())
                gold_list.append(y[b,:L].cpu())
    return probs_list, gold_list

In [None]:
criterion_ce  = nn.CrossEntropyLoss()
criterion_bce = nn.BCEWithLogitsLoss(reduction="mean")
mse = nn.MSELoss(reduction="mean")

def train_epoch(model, loader, opt, lambda_prior=0.1, lambda_k=0.1, k_reg=None):
    model.train()
    tot=0; n=0
    for batch in loader:
        ids, prior, y, mask_b = batch["ids"], batch["prior"], batch["y"], batch["mask_b"]
        priv = batch["priv"]  # (B,F) or None

        logits = model(ids, prior, batch["mask_tok"])    # (B,T-1,2)
        logits_flat = logits[mask_b]                     # (N,2)
        y_true = y[mask_b]                               # (N,)

        # (1) CE on gold boundaries
        loss = criterion_ce(logits_flat, y_true)

        # (2) Optional: distill toward DT prior on cut-logit
        if lambda_prior > 0:
            cut_logit = logits[..., 1]                   # (B,T-1)
            prior_flat = prior[mask_b]                   # (N,)
            loss_pr = criterion_bce(cut_logit[mask_b], prior_flat)
            loss = loss + lambda_prior * loss_pr

        # (3) K-regularizer using privileged K-hat
        if (lambda_k > 0) and (k_reg is not None) and (priv is not None):
            with torch.no_grad():
                k_hat = predict_k_hat_priv(k_reg, priv)  # (B,)
            # expected number of cuts from model = sum(sigmoid(cut_logit))
            cut_logit = logits[..., 1]                   # (B,T-1)
            p_cut = torch.sigmoid(cut_logit)             # (B,T-1)
            exp_K = p_cut.sum(dim=1)                     # (B,)
            loss_k = mse(exp_K, k_hat)
            loss = loss + lambda_k * loss_k

        opt.zero_grad(); loss.backward(); opt.step()
        tot += loss.item(); n += 1
    return tot/max(n,1)

def split_train_test(samples, test_ratio=0.2):
    n = len(samples); idx = np.arange(n); np.random.shuffle(idx)
    cut = int(n*(1-test_ratio))
    tr = [samples[i] for i in idx[:cut]]
    te = [samples[i] for i in idx[cut:]]
    return tr, te

def best_threshold_for_exact(probs_list, gold_list, grid=None):
    if grid is None: grid = np.linspace(0.3, 0.9, 61)
    best_thr, best_em, best_f1 = 0.5, -1.0, 0.0
    p_all = np.concatenate([t.numpy() for t in probs_list if t.numel()>0], axis=0)
    g_all = np.concatenate([t.numpy() for t in gold_list  if t.numel()>0], axis=0)
    for thr in grid:
        ems=[]
        for p,g in zip(probs_list, gold_list):
            if g.numel()==0: ems.append(1.0); continue
            ems.append(float(np.array_equal((p.numpy()>=thr).astype(int), g.numpy())))
        em = float(np.mean(ems))
        pred_all = (p_all>=thr).astype(int)
        P,R,F1,_ = precision_recall_fscore_support(g_all, pred_all, average='binary', zero_division=0)
        if em>best_em or (np.isclose(em,best_em) and F1>best_f1):
            best_thr, best_em, best_f1 = thr, em, F1
    print(f"[Exact-opt threshold] thr={best_thr:.3f} | exact={best_em:.3f} | boundaryF1={best_f1:.3f}")
    return best_thr

In [None]:
# =========================
# MODEL SAVING AND LOADING FUNCTIONS
# =========================
# These functions handle saving and loading trained models to avoid retraining
# Models are saved to the models folder with a unique identifier based on parameters

def generate_model_id(df, provided_suffix_list, use_suffix_list, unk_penalty, epochs,
                     use_prior, fuse_mode, lambda_prior, lambda_k, batch_size, hparams, synthetic_choice,
                     augmentation_word_selection=None, augmentation_n_words=None):
    """
    Generate a unique identifier for a model based on its training parameters.
    This ensures that models with the same parameters can be reused.
    
    Args:
        All training parameters that affect the model
        augmentation_word_selection: How words were selected for augmentation ("all", "first", "random")
        augmentation_n_words: Number of words used when selection is "first" or "random"
    
    Returns:
        A string identifier (hash) for the model
    """
    # Get word selection parameters from globals if not provided
    if augmentation_word_selection is None:
        augmentation_word_selection = globals().get('AUGMENTATION_WORD_SELECTION', 'all')
    if augmentation_n_words is None:
        augmentation_n_words = globals().get('AUGMENTATION_N_WORDS', None)
    
    # Create a dictionary of all parameters
    params_dict = {
        'synthetic_choice': synthetic_choice,
        'use_suffix_list': use_suffix_list,
        'unk_penalty': unk_penalty,
        'epochs': epochs,
        'use_prior': use_prior,
        'fuse_mode': fuse_mode,
        'lambda_prior': lambda_prior,
        'lambda_k': lambda_k,
        'batch_size': batch_size,
        'hparams': hparams,
        'suffix_list_len': len(provided_suffix_list) if provided_suffix_list else 0,
        'df_shape': df.shape if df is not None else (0, 0),
        'augmentation_word_selection': augmentation_word_selection,
        'augmentation_n_words': augmentation_n_words
    }
    
    # Convert to JSON string and hash it
    params_str = json.dumps(params_dict, sort_keys=True)
    model_id = hashlib.md5(params_str.encode()).hexdigest()[:16]
    return model_id

def save_model(model, vocab, out, model_id, models_folder=MODELS_FOLDER, 
               synthetic_choice=None, augmentation_word_selection=None, augmentation_n_words=None):
    """
    Save a trained model and its associated artifacts.
    
    Args:
        model: The trained PyTorch model
        vocab: Vocabulary dictionary
        out: Dictionary containing hmm_prior, k_teacher, best_thr, etc.
        model_id: Unique identifier for this model
        models_folder: Folder to save models in
        synthetic_choice: Which synthetic data was used
        augmentation_word_selection: How words were selected for augmentation
        augmentation_n_words: Number of words used for augmentation
    """
    model_dir = os.path.join(models_folder, model_id)
    os.makedirs(model_dir, exist_ok=True)
    
    # Save model state
    model_path = os.path.join(model_dir, "model.pt")
    torch.save(model.state_dict(), model_path)
    
    # Save vocabulary
    vocab_path = os.path.join(model_dir, "vocab.pkl")
    with open(vocab_path, "wb") as f:
        pickle.dump(vocab, f)
    
    # Save other artifacts (hmm_prior, k_teacher, best_thr, etc.)
    artifacts_path = os.path.join(model_dir, "artifacts.pkl")
    with open(artifacts_path, "wb") as f:
        pickle.dump(out, f)
    
    # Get configuration values from globals if not provided
    if synthetic_choice is None:
        synthetic_choice = globals().get('SYNTHETIC_DATA_CHOICE', 'none')
    if augmentation_word_selection is None:
        augmentation_word_selection = globals().get('AUGMENTATION_WORD_SELECTION', 'all')
    if augmentation_n_words is None:
        augmentation_n_words = globals().get('AUGMENTATION_N_WORDS', None)
    
    # Save metadata (parameters used)
    metadata_path = os.path.join(model_dir, "metadata.json")
    with open(metadata_path, "w") as f:
        metadata = {
            'model_id': model_id,
            'vocab_size': len(vocab),
            'synthetic_choice': synthetic_choice,
            'augmentation_word_selection': augmentation_word_selection,
        }
        if augmentation_n_words is not None:
            metadata['augmentation_n_words'] = augmentation_n_words
        json.dump(metadata, f, indent=2)
    
    print(f"Model saved to {model_dir}")
    return model_dir

def load_model(model_id, models_folder=MODELS_FOLDER, vocab_size=None):
    """
    Load a trained model and its associated artifacts.
    
    Args:
        model_id: Unique identifier for the model
        models_folder: Folder where models are saved
        vocab_size: Vocabulary size (needed to reconstruct model architecture)
    
    Returns:
        Dictionary with 'vocab', 'out', 'model_state_path', 'model_dir' or None if not found
    """
    model_dir = os.path.join(models_folder, model_id)
    
    if not os.path.exists(model_dir):
        return None
    
    # Load vocabulary
    vocab_path = os.path.join(model_dir, "vocab.pkl")
    if not os.path.exists(vocab_path):
        return None
    
    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)
    
    # Load artifacts to get model architecture info
    artifacts_path = os.path.join(model_dir, "artifacts.pkl")
    if not os.path.exists(artifacts_path):
        return None
    
    with open(artifacts_path, "rb") as f:
        out = pickle.load(f)
    
    # Load model state
    model_path = os.path.join(model_dir, "model.pt")
    if not os.path.exists(model_path):
        return None
    
    print(f"Model artifacts loaded from {model_dir}")
    return {
        'vocab': vocab,
        'out': out,
        'model_state_path': model_path,
        'model_dir': model_dir
    }


In [None]:
# ===================================================================
# MAIN TRAINING FUNCTION WITH MODEL CHECKPOINTING
# ===================================================================
# This function trains a morphology parser model. It checks if a model with
# the same parameters already exists and loads it instead of retraining.

def run_segmentation_with_privK(
    df,
    provided_suffix_list,
    use_suffix_list=True,
    unk_penalty=-15.0,
    epochs=15,
    use_prior=True, # This now controls the HMM prior
    fuse_mode="logit_add",
    lambda_prior=0.1,
    lambda_k=0.2,
    batch_size=64,
    hparams=None,
    synthetic_choice=None  # Added to track which synthetic data was used
):
    """
    Train or load a morphology parser model.
    
    This function will:
    1. Check if a model with the same parameters already exists
    2. If found, load it and return it (skipping training)
    3. If not found, train a new model and save it
    
    Args:
        df: Training DataFrame
        provided_suffix_list: List of valid suffixes for HMM prior
        use_suffix_list: Whether to use the provided suffix list
        unk_penalty: Penalty for unknown suffixes in HMM
        epochs: Number of training epochs
        use_prior: Whether to use HMM prior
        fuse_mode: How to fuse prior with model predictions
        lambda_prior: Weight for prior distillation loss
        lambda_k: Weight for K-regularizer loss
        batch_size: Training batch size
        hparams: Model hyperparameters dictionary
        synthetic_choice: Which synthetic data was used ("none", "gpt4o", "gpt5mini")
    
    Returns:
        Tuple of (model, vocab, out_dict)
    """
    if hparams is None:
        hparams = dict(emb_dim=16, hidden_size=64, num_layers=2,
                       dropout=0.25, lr=1e-3, weight_decay=1e-4, freeze_emb=False)
    
    if synthetic_choice is None:
        synthetic_choice = SYNTHETIC_DATA_CHOICE if 'SYNTHETIC_DATA_CHOICE' in globals() else "none"
    
    # Get word selection parameters from globals
    augmentation_word_selection = globals().get('AUGMENTATION_WORD_SELECTION', 'all')
    augmentation_n_words = globals().get('AUGMENTATION_N_WORDS', None)
    
    # Generate model identifier based on parameters
    model_id = generate_model_id(
        df, provided_suffix_list, use_suffix_list, unk_penalty, epochs,
        use_prior, fuse_mode, lambda_prior, lambda_k, batch_size, hparams, synthetic_choice,
        augmentation_word_selection=augmentation_word_selection,
        augmentation_n_words=augmentation_n_words
    )
    
    # Try to load existing model
    print(f"Checking for existing model with ID: {model_id}")
    loaded = load_model(model_id, models_folder=MODELS_FOLDER)
    
    if loaded is not None:
        print(f"‚úÖ Found existing model! Loading from {loaded['model_dir']}")
        # Reconstruct model architecture
        vocab = loaded['vocab']
        out = loaded['out']
        model_state_path = loaded['model_state_path']
        
        model = BiLSTMTagger(
            vocab_size=len(vocab),
            emb_dim=hparams.get("emb_dim", 16),
            hidden_size=hparams.get("hidden_size", 64),
            num_layers=hparams.get("num_layers", 2),
            use_prior=(use_prior and fuse_mode!="none"),
            dropout=hparams.get("dropout", 0.25),
            freeze_emb=hparams.get("freeze_emb", False),
            fuse_mode=fuse_mode
        )
        
        # Load model weights
        model.load_state_dict(torch.load(model_state_path))
        model.eval()
        
        print("Model loaded successfully. Skipping training.")
        return model, vocab, out
    
    # Model doesn't exist, need to train
    print(f"No existing model found. Training new model...")
    
    # Rebuild samples to include gold morphemes for HMM training
    samples = build_samples_with_priv(df, feat_names=NEW_NUM_FEATS)
    train_s, test_s = split_train_test(samples, 0.2)

    # --- HMM prior (token-window) trained on TRAIN ONLY ---
    hmm_prior = None
    if use_prior and use_suffix_list:
        # THE KEY CHANGE IS HERE: Call the new function with your list
        hmm_prior = create_hmm_prior_from_list(provided_suffix_list, unk_penalty)
    
    # --- HMM prior (token-window) trained on TRAIN ONLY ---
    if use_prior and not use_suffix_list:
        hmm_prior = train_hmm_prior(train_s)

    # K-teacher (privileged) on TRAIN ONLY (this part remains the same)
    feat_dim = len(NEW_NUM_FEATS)
    k_reg = train_k_teacher_priv(train_s, feat_dim=feat_dim)

    vocab = build_vocab(train_s, min_freq=1)

    # --- datasets/loaders now use the HMM prior ---
    train_ds = SegDataset(train_s, vocab, hmm_prior=hmm_prior, feat_dim=feat_dim)
    test_ds  = SegDataset(test_s,  vocab, hmm_prior=hmm_prior, feat_dim=feat_dim)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate)

    model = BiLSTMTagger(
        vocab_size=len(vocab),
        emb_dim=hparams.get("emb_dim", 16),
        hidden_size=hparams.get("hidden_size", 64),
        num_layers=hparams.get("num_layers", 2),
        use_prior=(use_prior and fuse_mode!="none"),
        dropout=hparams.get("dropout", 0.25),
        freeze_emb=hparams.get("freeze_emb", False),
        fuse_mode=fuse_mode
    )

    opt = torch.optim.AdamW(model.parameters(), lr=hparams.get("lr", 1e-3), weight_decay=hparams.get("weight_decay", 1e-4))

    final_probs_list, final_gold_list = None, None
    for ep in range(1, epochs+1):
        loss = train_epoch(model, train_loader, opt, lambda_prior=lambda_prior, lambda_k=lambda_k, k_reg=k_reg)
        probs_list, gold_list = predict(model, test_loader)
        P,R,F1 = boundary_metrics_from_lists(probs_list, gold_list, thr=0.5)
        EM = exact_match_rate_from_lists(probs_list, gold_list, thr=0.5)
        print(f"Epoch {ep:02d} | loss={loss:.4f} | boundary P/R/F1={P:.3f}/{R:.3f}/{F1:.3f} | exact={EM:.3f}")
        final_probs_list, final_gold_list = probs_list, gold_list

    best_thr = best_threshold_for_exact(final_probs_list, final_gold_list)

    out = {
        "probs_list": final_probs_list,
        "gold_list": final_gold_list,
        # Return the hmm_prior instead of dt_clf/dt_vec
        "hmm_prior": hmm_prior,
        "k_teacher": k_reg,
        "best_thr": best_thr
    }
    
    # Save the trained model
    print(f"\nSaving trained model with ID: {model_id}")
    save_model(model, vocab, out, model_id, models_folder=MODELS_FOLDER,
               synthetic_choice=synthetic_choice,
               augmentation_word_selection=augmentation_word_selection,
               augmentation_n_words=augmentation_n_words)

    return model, vocab, out

In [None]:
# ===================================================================
# K-FOLD CROSS-VALIDATION FUNCTION
# ===================================================================
# This function performs k-fold cross-validation on the training data
# It splits the data into k folds and trains/evaluates on each fold

from sklearn.model_selection import KFold

def run_kfold_cross_validation(
    df,
    provided_suffix_list,
    n_folds=5,
    use_suffix_list=True,
    unk_penalty=-15.0,
    epochs=15,
    use_prior=True,
    fuse_mode="logit_add",
    lambda_prior=0.1,
    lambda_k=0.2,
    batch_size=64,
    hparams=None,
    synthetic_choice=None,
    random_state=42
):
    """
    Perform k-fold cross-validation on the training data.
    
    Args:
        df: Training DataFrame
        provided_suffix_list: List of valid suffixes for HMM prior
        n_folds: Number of folds for cross-validation (default: 5)
        use_suffix_list: Whether to use the provided suffix list
        unk_penalty: Penalty for unknown suffixes in HMM
        epochs: Number of training epochs per fold
        use_prior: Whether to use HMM prior
        fuse_mode: How to fuse prior with model predictions
        lambda_prior: Weight for prior distillation loss
        lambda_k: Weight for K-regularizer loss
        batch_size: Training batch size
        hparams: Model hyperparameters dictionary
        synthetic_choice: Which synthetic data was used ("none", "gpt4o", "gpt5mini")
        random_state: Random seed for reproducibility
    
    Returns:
        Dictionary containing:
        - fold_results: List of results for each fold
        - mean_metrics: Average metrics across all folds
        - std_metrics: Standard deviation of metrics across folds
        - best_fold_idx: Index of the fold with best exact match rate
    """
    if hparams is None:
        hparams = dict(emb_dim=16, hidden_size=64, num_layers=2,
                       dropout=0.25, lr=1e-3, weight_decay=1e-4, freeze_emb=False)
    
    if synthetic_choice is None:
        synthetic_choice = SYNTHETIC_DATA_CHOICE if 'SYNTHETIC_DATA_CHOICE' in globals() else "none"
    
    print(f"\n{'='*80}")
    print(f"K-FOLD CROSS-VALIDATION (k={n_folds})")
    print(f"{'='*80}")
    
    # Build samples from dataframe
    samples = build_samples_with_priv(df, feat_names=NEW_NUM_FEATS)
    n_samples = len(samples)
    
    # Create k-fold splitter
    kfold = KFold(n_splits=n_folds, shuffle=True, random_state=random_state)
    
    fold_results = []
    all_metrics = {
        'boundary_precision': [],
        'boundary_recall': [],
        'boundary_f1': [],
        'exact_match': [],
        'best_threshold': []
    }
    
    # Train and evaluate on each fold
    for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(samples), 1):
        print(f"\n{'‚îÄ'*80}")
        print(f"FOLD {fold_idx}/{n_folds}")
        print(f"{'‚îÄ'*80}")
        print(f"Train samples: {len(train_indices)}, Validation samples: {len(val_indices)}")
        
        # Split samples into train and validation
        train_samples = [samples[i] for i in train_indices]
        val_samples = [samples[i] for i in val_indices]
        
        # Train HMM prior on training fold only
        hmm_prior = None
        if use_prior and use_suffix_list:
            hmm_prior = create_hmm_prior_from_list(provided_suffix_list, unk_penalty)
        elif use_prior and not use_suffix_list:
            hmm_prior = train_hmm_prior(train_samples)
        
        # Train K-teacher on training fold only
        feat_dim = len(NEW_NUM_FEATS)
        k_reg = train_k_teacher_priv(train_samples, feat_dim=feat_dim)
        
        # Build vocabulary from training fold only
        vocab = build_vocab(train_samples, min_freq=1)
        
        # Create datasets and loaders
        train_ds = SegDataset(train_samples, vocab, hmm_prior=hmm_prior, feat_dim=feat_dim)
        val_ds = SegDataset(val_samples, vocab, hmm_prior=hmm_prior, feat_dim=feat_dim)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
        
        # Create model
        model = BiLSTMTagger(
            vocab_size=len(vocab),
            emb_dim=hparams.get("emb_dim", 16),
            hidden_size=hparams.get("hidden_size", 64),
            num_layers=hparams.get("num_layers", 2),
            use_prior=(use_prior and fuse_mode!="none"),
            dropout=hparams.get("dropout", 0.25),
            freeze_emb=hparams.get("freeze_emb", False),
            fuse_mode=fuse_mode
        )
        
        # Create optimizer
        opt = torch.optim.AdamW(
            model.parameters(),
            lr=hparams.get("lr", 1e-3),
            weight_decay=hparams.get("weight_decay", 1e-4)
        )
        
        # Training loop
        best_val_em = -1.0
        best_val_f1 = -1.0
        best_epoch = 0
        
        for ep in range(1, epochs+1):
            loss = train_epoch(
                model, train_loader, opt,
                lambda_prior=lambda_prior,
                lambda_k=lambda_k,
                k_reg=k_reg
            )
            
            # Evaluate on validation set
            probs_list, gold_list = predict(model, val_loader)
            P, R, F1 = boundary_metrics_from_lists(probs_list, gold_list, thr=0.5)
            EM = exact_match_rate_from_lists(probs_list, gold_list, thr=0.5)
            
            print(f"  Epoch {ep:02d} | loss={loss:.4f} | boundary P/R/F1={P:.3f}/{R:.3f}/{F1:.3f} | exact={EM:.3f}")
            
            # Track best validation performance
            if EM > best_val_em or (np.isclose(EM, best_val_em) and F1 > best_val_f1):
                best_val_em = EM
                best_val_f1 = F1
                best_epoch = ep
                best_probs_list = probs_list
                best_gold_list = gold_list
        
        # Find best threshold on validation set
        best_thr = best_threshold_for_exact(best_probs_list, best_gold_list)
        
        # Final evaluation with best threshold
        P_final, R_final, F1_final = boundary_metrics_from_lists(best_probs_list, best_gold_list, thr=best_thr)
        EM_final = exact_match_rate_from_lists(best_probs_list, best_gold_list, thr=best_thr)
        
        print(f"\n  Best epoch: {best_epoch}")
        print(f"  Final (thr={best_thr:.3f}): boundary P/R/F1={P_final:.3f}/{R_final:.3f}/{F1_final:.3f} | exact={EM_final:.3f}")
        
        # Store fold results
        fold_result = {
            'fold': fold_idx,
            'boundary_precision': P_final,
            'boundary_recall': R_final,
            'boundary_f1': F1_final,
            'exact_match': EM_final,
            'best_threshold': best_thr,
            'best_epoch': best_epoch
        }
        fold_results.append(fold_result)
        
        # Collect metrics for averaging
        all_metrics['boundary_precision'].append(P_final)
        all_metrics['boundary_recall'].append(R_final)
        all_metrics['boundary_f1'].append(F1_final)
        all_metrics['exact_match'].append(EM_final)
        all_metrics['best_threshold'].append(best_thr)
    
    # Calculate mean and std across folds
    mean_metrics = {
        'boundary_precision': np.mean(all_metrics['boundary_precision']),
        'boundary_recall': np.mean(all_metrics['boundary_recall']),
        'boundary_f1': np.mean(all_metrics['boundary_f1']),
        'exact_match': np.mean(all_metrics['exact_match']),
        'best_threshold': np.mean(all_metrics['best_threshold'])
    }
    
    std_metrics = {
        'boundary_precision': np.std(all_metrics['boundary_precision']),
        'boundary_recall': np.std(all_metrics['boundary_recall']),
        'boundary_f1': np.std(all_metrics['boundary_f1']),
        'exact_match': np.std(all_metrics['exact_match']),
        'best_threshold': np.std(all_metrics['best_threshold'])
    }
    
    # Find best fold (highest exact match rate)
    best_fold_idx = max(range(len(fold_results)), key=lambda i: fold_results[i]['exact_match'])
    
    # Print summary
    print(f"\n{'='*80}")
    print(f"K-FOLD CROSS-VALIDATION SUMMARY")
    print(f"{'='*80}")
    print(f"\nPer-fold results:")
    for result in fold_results:
        print(f"  Fold {result['fold']}: "
              f"P={result['boundary_precision']:.3f}, "
              f"R={result['boundary_recall']:.3f}, "
              f"F1={result['boundary_f1']:.3f}, "
              f"EM={result['exact_match']:.3f}")
    
    print(f"\nMean ¬± Std across {n_folds} folds:")
    print(f"  Boundary Precision: {mean_metrics['boundary_precision']:.3f} ¬± {std_metrics['boundary_precision']:.3f}")
    print(f"  Boundary Recall:    {mean_metrics['boundary_recall']:.3f} ¬± {std_metrics['boundary_recall']:.3f}")
    print(f"  Boundary F1:        {mean_metrics['boundary_f1']:.3f} ¬± {std_metrics['boundary_f1']:.3f}")
    print(f"  Exact Match:        {mean_metrics['exact_match']:.3f} ¬± {std_metrics['exact_match']:.3f}")
    print(f"  Best Threshold:    {mean_metrics['best_threshold']:.3f} ¬± {std_metrics['best_threshold']:.3f}")
    print(f"\nBest fold: Fold {fold_results[best_fold_idx]['fold']} "
          f"(Exact Match: {fold_results[best_fold_idx]['exact_match']:.3f})")
    print(f"{'='*80}\n")
    
    return {
        'fold_results': fold_results,
        'mean_metrics': mean_metrics,
        'std_metrics': std_metrics,
        'best_fold_idx': best_fold_idx,
        'all_metrics': all_metrics
    }


In [None]:
def tokenize_with_vocab(word: str, vocab: dict, max_token_len: int = 4):
    i, toks = 0, []
    while i < len(word):
        matched = None
        Lmax = min(max_token_len, len(word)-i)
        for L in range(Lmax, 0, -1):
            seg = word[i:i+L]
            if seg in vocab:
                matched = seg; break
        toks.append(matched if matched else word[i])
        i += len(toks[-1])
    return toks

# And your prediction function needs to be updated slightly
@torch.no_grad()
def segment_tokens(model, vocab, tokens, hmm_prior=None, thr=0.5): # changed arguments
    ids = torch.tensor([[vocab.get(t, vocab["<UNK>"]) for t in tokens]], dtype=torch.long)
    mask_tok = torch.ones_like(ids, dtype=torch.bool)
    T = len(tokens)
    if T<=1: return "".join(tokens), np.array([])
    
    # Use the new prior function
    prior_list = prior_probs_for_sample(hmm_prior, tokens)
    
    prior = torch.tensor([prior_list], dtype=torch.float32)
    logits = model(ids, prior, mask_tok)
    probs = torch.softmax(logits, dim=-1)[0, :, 1].cpu().numpy()
    cuts = (probs >= thr).astype(int)
    out=[]
    for i, tok in enumerate(tokens):
        out.append(tok)
        if i < T-1 and cuts[i]==1: out.append("-")
    return "".join(out), probs

In [None]:
# =========================
# LOAD SUFFIX LIST FOR HMM PRIOR
# =========================
# The suffix list is used to create the HMM prior that guides segmentation
# This file should be in the data folder

def read_suffixes(filename):
    """
    Read a list of suffixes from a file.
    Expected format: lines with "number suffix" (e.g., "1 -ta")
    """
    suffixes = []
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            # Split into number and suffix part
            parts = line.split(maxsplit=1)
            if len(parts) == 2:
                _, suffix = parts
                suffixes.append(suffix[1:])  # Remove the leading dash
    return suffixes

# Load suffix list from data folder
suffix_filename = os.path.join(DATA_FOLDER, "suffixesCQ-Anettte-Rios_LS.txt")
if not os.path.exists(suffix_filename):
    # Try in root directory as fallback
    suffix_filename = "suffixesCQ-Anettte-Rios_LS.txt"
    if not os.path.exists(suffix_filename):
        print(f"Warning: Suffix file not found at {os.path.join(DATA_FOLDER, 'suffixesCQ-Anettte-Rios_LS.txt')}")
        print("Please ensure the suffix file is in the data folder.")
        suffix_list = []
    else:
        suffix_list = read_suffixes(suffix_filename)
        print(f"Loaded {len(suffix_list)} suffixes from {suffix_filename}")
else:
    suffix_list = read_suffixes(suffix_filename)
    print(f"Loaded {len(suffix_list)} suffixes from {suffix_filename}")

In [None]:
import numpy as np
from typing import List, Set, Tuple

# ---------- helpers to turn segs into boundary sets (char offsets) ----------
def offsets_from_morphemes(morphs: List[str]) -> Set[int]:
    # boundaries after each morph except the last
    offs = []
    s = 0
    for i, m in enumerate(morphs):
        s += len(m)
        if i < len(morphs) - 1:
            offs.append(s)
    return set(offs)

def offsets_from_tokens_and_mask(tokens: List[str], mask01: np.ndarray) -> Set[int]:
    # boundaries after token i where mask01[i]==1, measured in character offsets
    offs = set()
    cum = 0
    for i, t in enumerate(tokens):
        cum += len(t)
        if i < len(tokens) - 1 and mask01[i] == 1:
            offs.add(cum)
    return offs

def f1_from_sets(pred: Set[int], gold: Set[int]) -> Tuple[float, float, float, int, int, int]:
    tp = len(pred & gold)
    fp = len(pred - gold)
    fn = len(gold - pred)
    P = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    R = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    F1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
    return P, R, F1, tp, fp, fn

def normalize_gold_variants(gold_variants):
    """
    Convert gold_variants to a list format, handling numpy arrays and nested structures.
    """
    if gold_variants is None:
        return []
    
    # If it's a numpy array, convert to list
    if isinstance(gold_variants, np.ndarray):
        gold_variants = gold_variants.tolist()
    
    # If it's already a list, ensure nested elements are also lists (not numpy arrays)
    if isinstance(gold_variants, list):
        normalized = []
        for variant in gold_variants:
            if isinstance(variant, np.ndarray):
                normalized.append(variant.tolist())
            elif isinstance(variant, list):
                # Recursively normalize nested lists
                normalized.append([item.tolist() if isinstance(item, np.ndarray) else item for item in variant])
            else:
                normalized.append(variant)
        return normalized
    
    return []

# ---------- main evaluation ----------
def evaluate_on_gold_df(df, model, vocab, out, max_token_len=4, use_tuned_thr=True, show_sample=5):
    hmm_prior = out["hmm_prior"]
    thr = float(out.get("best_thr", 0.5)) if use_tuned_thr else 0.5

    total_tp = total_fp = total_fn = 0
    exact_hits = 0
    n_eval = 0
    examples = []

    for _, row in df.iterrows():
        word = str(row["Word"])
        gold_variants = row["Gold"]  # e.g., [['pi','kuna','s'], ['pi','ku','nas']]

        # Normalize gold_variants (convert numpy arrays to lists)
        gold_variants = normalize_gold_variants(gold_variants)

        # skip if no gold
        if not isinstance(gold_variants, list) or len(gold_variants) == 0:
            continue

        # tokenize & predict
        toks = tokenize_with_vocab(word, vocab, max_token_len=max_token_len)
        seg_string, probs = segment_tokens(model, vocab, toks, hmm_prior=hmm_prior, thr=thr)
        mask01 = (probs >= thr).astype(int)
        pred_set = offsets_from_tokens_and_mask(toks, mask01)

        # build gold sets for all variants
        gold_sets = [offsets_from_morphemes(gv) for gv in gold_variants]

        # exact match if we match ANY gold variant
        if any(pred_set == gs for gs in gold_sets):
            exact_hits += 1

        # choose the gold variant that gives best F1 for this word
        best = max((f1_from_sets(pred_set, gs) + (gs,) for gs in gold_sets), key=lambda z: z[2])
        P, R, F1, tp, fp, fn, best_gs = best

        total_tp += tp; total_fp += fp; total_fn += fn
        n_eval += 1

        if len(examples) < show_sample:
            # reconstruct a nice gold string for the best variant
            best_morphs = None
            for gv in gold_variants:
                if offsets_from_morphemes(gv) == best_gs:
                    best_morphs = gv; break
            gold_str = "-".join(best_morphs) if best_morphs else "(ambig)"
            examples.append({
                "word": word,
                "tokens": toks,
                "pred_seg": seg_string,
                "gold_best": gold_str,
                "P": round(P,3), "R": round(R,3), "F1": round(F1,3)
            })

    # micro metrics
    micro_P = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_R = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    micro_F1 = 2 * micro_P * micro_R / (micro_P + micro_R) if (micro_P + micro_R) > 0 else 0.0
    exact_rate = exact_hits / n_eval if n_eval > 0 else 0.0

    print(f"Evaluated {n_eval} words")
    print(f"Boundary (micro)  P/R/F1 = {micro_P:.3f}/{micro_R:.3f}/{micro_F1:.3f}")
    print(f"Exact-match rate  = {exact_rate:.3f}")
    if examples:
        print("\nSample predictions:")
        for ex in examples:
            print(f"- {ex['word']}\n  tokens: {ex['tokens']}\n  pred  : {ex['pred_seg']}\n  gold  : {ex['gold_best']}\n  P/R/F1: {ex['P']}/{ex['R']}/{ex['F1']}\n")

    return {
        "n_eval": n_eval,
        "micro_precision": micro_P,
        "micro_recall": micro_R,
        "micro_f1": micro_F1,
        "exact_match_rate": exact_rate,
        "examples": examples
    }

In [None]:
import optuna

# ===================================================================
# Hyperparameter Tuning with Optuna
# ===================================================================

def objective(trial: optuna.Trial) -> float:
    """
    This function defines one trial of the hyperparameter optimization.
    Optuna will call this function multiple times with different hyperparameter values.
    """
    # 1. Suggest hyperparameters
    hparams = {
        "emb_dim": trial.suggest_categorical("emb_dim", [16, 32, 64]),
        "hidden_size": trial.suggest_categorical("hidden_size", [32, 64, 128]),
        "num_layers": trial.suggest_int("num_layers", 1, 3),
        "dropout": trial.suggest_float("dropout", 0.1, 0.5, step=0.05),
        "lr": trial.suggest_float("lr", 1e-4, 1e-2, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True),
        "freeze_emb": False,
    }
    # Also tune the lambda_prior weight
    lambda_prior_val = trial.suggest_float("lambda_prior", 0.0, 0.5)

    # 2. Train the model using the training data (str_df)
    print(f"\n--- Starting Trial {trial.number} with params: {hparams} | lambda_prior: {lambda_prior_val:.4f} ---")
    model, vocab, out = run_segmentation_with_privK(
        df=str_df,
        provided_suffix_list=suffix_list,
        use_suffix_list=True,
        unk_penalty=-15,
        epochs=15,  # Use fewer epochs during tuning for speed
        use_prior=True,
        lambda_prior=lambda_prior_val,
        lambda_k=0.2,
        hparams=hparams,
        synthetic_choice=SYNTHETIC_DATA_CHOICE  # Pass the synthetic data choice
    )

    # 3. Evaluate the trained model on the final hold-out test set (acc_df)
    # This score will be used by Optuna to find the best parameters.
    test_set_results = evaluate_on_gold_df(
        df=acc_df,
        model=model,
        vocab=vocab,
        out=out,
        max_token_len=4,
        use_tuned_thr=True,
        show_sample=0
    )
    test_exact_match = test_set_results["exact_match_rate"]
    
    print(f"--- Finished Trial {trial.number} | Test Set Exact Match: {test_exact_match:.4f} ---")
    
    # 4. Return the metric from the test set to be maximized
    return test_exact_match


# # Create a study object and specify the optimization direction.
# study = optuna.create_study(direction="maximize")

# # Start the optimization. Optuna will run `n_trials` of the `objective` function.
# # You can increase n_trials for a more thorough search.
# study.optimize(objective, n_trials=50)

# # Print the results of the best trial.
# print("\n\n==========================================================")
# print("             Hyperparameter Tuning Finished             ")
# print("==========================================================")
# print("Number of finished trials: ", len(study.trials))
# print("Best trial:")
# best_trial = study.best_trial

# print(f"  Value (Test Set Exact Match Rate): {best_trial.value:.4f}")
# print("  Params: ")
# for key, value in best_trial.params.items():
#     print(f"    {key}: {value}")

In [None]:
best = {
  "emb_dim": 32, "hidden_size": 128, "num_layers": 3,
  "dropout": 0.4, "lr": 0.009213045798657327, "weight_decay": 0.0001132283214088801, "freeze_emb": False,
}

In [None]:
# model, vocab, out = run_segmentation_with_privK(
#     df=str_df,
#     provided_suffix_list=suffix_list,
#     use_suffix_list=False,
#     unk_penalty=-15.0,
#     epochs=15,
#     use_prior=True,
#     lambda_prior=0.15289202508573396, # Weight for the HMM prior
#     lambda_k=0.2, 
#     hparams=best,
#     synthetic_choice=SYNTHETIC_DATA_CHOICE  # Pass the synthetic data choice
# )

In [None]:
# ===================================================================
# K-FOLD CROSS-VALIDATION DEMONSTRATION
# ===================================================================
# This cell demonstrates how to use k-fold cross-validation to evaluate
# model performance more robustly by training on multiple train/val splits

# Run 5-fold cross-validation on the training data
kfold_results = run_kfold_cross_validation(
    df=str_df,
    provided_suffix_list=suffix_list,
    n_folds=5,  # Number of folds
    use_suffix_list=False,
    unk_penalty=-15.0,
    epochs=15,
    use_prior=True,
    lambda_prior=0.15289202508573396,
    lambda_k=0.2,
    hparams=best,
    synthetic_choice=SYNTHETIC_DATA_CHOICE,
    random_state=RNG  # Use the same random seed for reproducibility
)

# The results dictionary contains:
# - fold_results: List of results for each fold
# - mean_metrics: Average metrics across all folds
# - std_metrics: Standard deviation of metrics across folds
# - best_fold_idx: Index of the fold with best exact match rate
# - all_metrics: Raw metrics from all folds

print("\nK-fold cross-validation completed!")
print(f"Average exact match rate: {kfold_results['mean_metrics']['exact_match']:.3f} ¬± {kfold_results['std_metrics']['exact_match']:.3f}")
print(f"Average boundary F1: {kfold_results['mean_metrics']['boundary_f1']:.3f} ¬± {kfold_results['std_metrics']['boundary_f1']:.3f}")


In [None]:
word = "pikunas"
tokens = tokenize_with_vocab(word, vocab, max_token_len=4)
thr = out.get("best_thr", 0.5)

# Note the change in arguments here
seg_string, boundary_probs = segment_tokens(
    model, vocab, tokens, hmm_prior=out["hmm_prior"], thr=thr
)

print("Tokens:", tokens)
print("Boundary probs:", np.round(boundary_probs, 3).tolist())
print(f"Segmentation (thr={thr:.3f}):", seg_string)

In [None]:
import numpy as np
from typing import List, Set, Tuple

# ---------- helpers to turn segs into boundary sets (char offsets) ----------
def offsets_from_morphemes(morphs: List[str]) -> Set[int]:
    # boundaries after each morph except the last
    offs = []
    s = 0
    for i, m in enumerate(morphs):
        s += len(m)
        if i < len(morphs) - 1:
            offs.append(s)
    return set(offs)

def offsets_from_tokens_and_mask(tokens: List[str], mask01: np.ndarray) -> Set[int]:
    # boundaries after token i where mask01[i]==1, measured in character offsets
    offs = set()
    cum = 0
    for i, t in enumerate(tokens):
        cum += len(t)
        if i < len(tokens) - 1 and mask01[i] == 1:
            offs.add(cum)
    return offs

def f1_from_sets(pred: Set[int], gold: Set[int]) -> Tuple[float, float, float, int, int, int]:
    tp = len(pred & gold)
    fp = len(pred - gold)
    fn = len(gold - pred)
    P = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    R = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    F1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
    return P, R, F1, tp, fp, fn

def normalize_gold_variants(gold_variants):
    """
    Convert gold_variants to a list format, handling numpy arrays and nested structures.
    """
    if gold_variants is None:
        return []
    
    # If it's a numpy array, convert to list
    if isinstance(gold_variants, np.ndarray):
        gold_variants = gold_variants.tolist()
    
    # If it's already a list, ensure nested elements are also lists (not numpy arrays)
    if isinstance(gold_variants, list):
        normalized = []
        for variant in gold_variants:
            if isinstance(variant, np.ndarray):
                normalized.append(variant.tolist())
            elif isinstance(variant, list):
                # Recursively normalize nested lists
                normalized.append([item.tolist() if isinstance(item, np.ndarray) else item for item in variant])
            else:
                normalized.append(variant)
        return normalized
    
    return []

# ---------- main evaluation ----------
def evaluate_on_gold_df(df, model, vocab, out, max_token_len=4, use_tuned_thr=True, show_sample=5):
    hmm_prior = out["hmm_prior"]
    thr = float(out.get("best_thr", 0.5)) if use_tuned_thr else 0.5

    total_tp = total_fp = total_fn = 0
    exact_hits = 0
    n_eval = 0
    examples = []

    for _, row in df.iterrows():
        word = str(row["Word"])
        gold_variants = row["Gold"]  # e.g., [['pi','kuna','s'], ['pi','ku','nas']]

        # Normalize gold_variants (convert numpy arrays to lists)
        gold_variants = normalize_gold_variants(gold_variants)

        # skip if no gold
        if not isinstance(gold_variants, list) or len(gold_variants) == 0:
            continue

        # tokenize & predict
        toks = tokenize_with_vocab(word, vocab, max_token_len=max_token_len)
        seg_string, probs = segment_tokens(model, vocab, toks, hmm_prior=hmm_prior, thr=thr)
        mask01 = (probs >= thr).astype(int)
        pred_set = offsets_from_tokens_and_mask(toks, mask01)

        # build gold sets for all variants
        gold_sets = [offsets_from_morphemes(gv) for gv in gold_variants]

        # exact match if we match ANY gold variant
        if any(pred_set == gs for gs in gold_sets):
            exact_hits += 1

        # choose the gold variant that gives best F1 for this word
        best = max((f1_from_sets(pred_set, gs) + (gs,) for gs in gold_sets), key=lambda z: z[2])
        P, R, F1, tp, fp, fn, best_gs = best

        total_tp += tp; total_fp += fp; total_fn += fn
        n_eval += 1

        if len(examples) < show_sample:
            # reconstruct a nice gold string for the best variant
            best_morphs = None
            for gv in gold_variants:
                if offsets_from_morphemes(gv) == best_gs:
                    best_morphs = gv; break
            gold_str = "-".join(best_morphs) if best_morphs else "(ambig)"
            examples.append({
                "word": word,
                "tokens": toks,
                "pred_seg": seg_string,
                "gold_best": gold_str,
                "P": round(P,3), "R": round(R,3), "F1": round(F1,3)
            })

    # micro metrics
    micro_P = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_R = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    micro_F1 = 2 * micro_P * micro_R / (micro_P + micro_R) if (micro_P + micro_R) > 0 else 0.0
    exact_rate = exact_hits / n_eval if n_eval > 0 else 0.0

    print(f"Evaluated {n_eval} words")
    print(f"Boundary (micro)  P/R/F1 = {micro_P:.3f}/{micro_R:.3f}/{micro_F1:.3f}")
    print(f"Exact-match rate  = {exact_rate:.3f}")
    if examples:
        print("\nSample predictions:")
        for ex in examples:
            print(f"- {ex['word']}\n  tokens: {ex['tokens']}\n  pred  : {ex['pred_seg']}\n  gold  : {ex['gold_best']}\n  P/R/F1: {ex['P']}/{ex['R']}/{ex['F1']}\n")

    return {
        "n_eval": n_eval,
        "micro_precision": micro_P,
        "micro_recall": micro_R,
        "micro_f1": micro_F1,
        "exact_match_rate": exact_rate,
        "examples": examples
    }

# ===================================================================
# NEW CODE: Suffix Validator Function
# ===================================================================

def is_segmentation_valid(
    segmentation: list[str],
    allowed_suffixes: set[str]
) -> bool:
    """
    Checks if a segmentation is valid based on a list of allowed suffixes.

    The first morpheme is assumed to be the root and is ignored. All subsequent
    morphemes must be in the `allowed_suffixes` set.

    Args:
        segmentation (list[str]): The predicted segmentation, e.g., ['pay', 'kunaq'].
        allowed_suffixes (set[str]): A set of valid suffix strings.

    Returns:
        bool: True if the segmentation is valid, False otherwise.
    """
    if len(segmentation) <= 1:
        # A word with no splits is always valid.
        return True

    # Check every morpheme starting from the second one.
    for morpheme in segmentation[1:]:
        if morpheme not in allowed_suffixes:
            return False  # Found a suffix that is not in the allowed list.

    return True

# ===================================================================
# REVISED CODE: Evaluation function that ignores rejected predictions
# ===================================================================

def evaluate_and_ignore_rejected(
    df, model, vocab, out,
    allowed_suffixes: list[str], # Required argument for the validator
    max_token_len=4,
    use_tuned_thr=True,
    show_sample=5
):
    hmm_prior = out["hmm_prior"]
    thr = float(out.get("best_thr", 0.5)) if use_tuned_thr else 0.5
    allowed_suffixes_set = set(allowed_suffixes)

    total_tp = total_fp = total_fn = 0
    exact_hits = 0
    
    n_total_words = 0      # Counts all words we attempt to evaluate
    n_evaluated_words = 0  # Counts only words with valid, scored predictions
    rejection_count = 0
    false_rejection_count = 0  # Count of CORRECT segmentations that were rejected
    correct_kept_count = 0     # Count of CORRECT segmentations that were kept
    examples = []

    for _, row in df.iterrows():
        word = str(row["Word"])
        gold_variants = row["Gold"]

        # Normalize gold_variants (convert numpy arrays to lists)
        gold_variants = normalize_gold_variants(gold_variants)

        if not isinstance(gold_variants, list) or len(gold_variants) == 0:
            continue
        
        n_total_words += 1

        # 1. Get the model's prediction
        toks = tokenize_with_vocab(word, vocab, max_token_len=max_token_len)
        seg_string, probs = segment_tokens(model, vocab, toks, hmm_prior=hmm_prior, thr=thr)
        predicted_morphs = seg_string.split('-')
        
        # 2. Check if prediction is correct BEFORE validating (for false rejection analysis)
        mask01 = (probs >= thr).astype(int)
        pred_set = offsets_from_tokens_and_mask(toks, mask01)
        gold_sets = [offsets_from_morphemes(gv) for gv in gold_variants]
        is_correct = any(pred_set == gs for gs in gold_sets)

        # 3. Validate the prediction. If invalid, check if it was correct (false rejection)
        if not is_segmentation_valid(predicted_morphs, allowed_suffixes_set):
            rejection_count += 1
            if is_correct:
                false_rejection_count += 1
            continue  # <-- KEY CHANGE: Skip the rest of the loop for this word

        # --- If we reach this point, the prediction is valid and will be scored ---
        n_evaluated_words += 1
        
        # 4. Track if kept prediction is correct
        if is_correct:
            correct_kept_count += 1
            exact_hits += 1

        best = max((f1_from_sets(pred_set, gs) + (gs,) for gs in gold_sets), key=lambda z: z[2])
        P, R, F1, tp, fp, fn, best_gs = best

        total_tp += tp
        total_fp += fp
        total_fn += fn

        if len(examples) < show_sample:
            # reconstruct a nice gold string for the best variant
            best_morphs = None
            for gv in gold_variants:
                if offsets_from_morphemes(gv) == best_gs:
                    best_morphs = gv; break
            gold_str = "-".join(best_morphs) if best_morphs else "(ambig)"
            examples.append({
                "word": word,
                "tokens": toks,
                "pred_seg": seg_string,
                "gold_best": gold_str,
                "P": round(P,3), "R": round(R,3), "F1": round(F1,3)
            })

    # --- Final Metrics ---
    # Note: Denominators now use n_evaluated_words, which is smaller than n_total_words
    micro_P = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_R = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    micro_F1 = 2 * micro_P * micro_R / (micro_P + micro_R) if (micro_P + micro_R) > 0 else 0.0
    exact_rate = exact_hits / n_evaluated_words if n_evaluated_words > 0 else 0.0
    
    # --- Rejection Filter Analysis ---
    # Filter precision: of the kept predictions, how many are correct?
    filter_precision = correct_kept_count / n_evaluated_words if n_evaluated_words > 0 else 0.0
    
    # False rejection rate: of all correct predictions, how many were rejected?
    total_correct = correct_kept_count + false_rejection_count
    false_rejection_rate = false_rejection_count / total_correct if total_correct > 0 else 0.0

    print(f"Attempted to evaluate {n_total_words} words")
    print(f"Predictions Rejected by Suffix Validator: {rejection_count} ({rejection_count/n_total_words:.2%})")
    print(f"Final scores are based on the remaining {n_evaluated_words} valid predictions.")
    print("\n--- Rejection Filter False Rejection Analysis ---")
    print(f"Filter achieves {filter_precision:.1%} precision but rejects {false_rejection_rate:.1%} valid segmentations")
    print(f"  - Correct predictions kept: {correct_kept_count}")
    print(f"  - Correct predictions rejected (false rejections): {false_rejection_count}")
    print(f"  - Total correct predictions: {total_correct}")
    print("\n--- Final Scores (on non-rejected predictions only) ---")
    print(f"Boundary (micro)  P/R/F1 = {micro_P:.3f}/{micro_R:.3f}/{micro_F1:.3f}")
    print(f"Exact-match rate  = {exact_rate:.3f}")

    if examples:
        print("\nSample predictions:")
        for ex in examples:
            print(f"- {ex['word']}\n  tokens: {ex['tokens']}\n  pred  : {ex['pred_seg']}\n  gold  : {ex['gold_best']}\n  P/R/F1: {ex['P']}/{ex['R']}/{ex['F1']}\n")
    return { 
        "micro_f1": micro_F1, 
        "exact_match_rate": exact_rate, 
        "rejection_count": rejection_count,
        "false_rejection_count": false_rejection_count,
        "filter_precision": filter_precision,
        "false_rejection_rate": false_rejection_rate
    }

In [None]:
# # Remove words with length > 16
# acc_df = acc_df[acc_df['Word'].str.len() <= 14].reset_index(drop=True)

# # Remove rows where all gold variants have only one morpheme
# acc_df = acc_df[acc_df['Gold'].apply(lambda variants: any(len(variant) > 1 for variant in variants))].reset_index(drop=True)

In [None]:
acc_df.head()

In [None]:
results = evaluate_on_gold_df(
    acc_df,                     # your concatenated DataFrame with Word + Gold (list of variants)
    model, vocab, out,      # from training
    max_token_len=4,        # must match your tokenize scheme
    use_tuned_thr=True,     # use the best threshold found on dev
    show_sample=8           # print a few qualitative examples
)

In [None]:
# 3. Call the NEW evaluation function with your suffix list
print("\n--- Evaluating with Post-Processing Rejection Filter ---")
results_with_rejection = evaluate_and_ignore_rejected(
    acc_df,              # The test dataframe
    model, vocab, out,   # The trained model and its artifacts
    allowed_suffixes=suffix_list, # Your list of rules!
    show_sample=8
)

In [None]:
# # ===================================================================
# # SYNTHETIC DATA AUGMENTATION COMPARISON TABLE
# # ===================================================================
# # This cell creates a comprehensive comparison table showing the impact
# # of different augmentation strategies on model performance

# import pandas as pd
# import numpy as np
# from IPython.display import display, HTML

# # Show the best hyperparameters being used
# print("="*80)
# print("BEST HYPERPARAMETERS")
# print("="*80)
# print(f"emb_dim: {best['emb_dim']}")
# print(f"hidden_size: {best['hidden_size']}")
# print(f"num_layers: {best['num_layers']}")
# print(f"dropout: {best['dropout']}")
# print(f"lr: {best['lr']}")
# print(f"weight_decay: {best['weight_decay']}")
# print(f"freeze_emb: {best['freeze_emb']}")
# print(f"lambda_prior: 0.15289202508573396")
# print(f"lambda_k: 0.2")
# print("="*80)
# print("\n")

# # Function to run evaluation for a specific augmentation configuration
# def evaluate_augmentation_config(
#     synthetic_choice,
#     word_selection,
#     n_words,
#     str_df_base,
#     acc_df,
#     suffix_list,
#     best_hparams,
#     lambda_prior=0.15289202508573396,
#     lambda_k=0.2
# ):
#     """
#     Evaluate model with a specific augmentation configuration.
    
#     Args:
#         synthetic_choice: "none", "gpt4o", or "gpt5mini"
#         word_selection: "first" or "random"
#         n_words: number of words (None for "none")
#         str_df_base: base training dataframe (gold data)
#         acc_df: test dataframe
#         suffix_list: list of allowed suffixes
#         best_hparams: hyperparameters dictionary
#         lambda_prior: weight for prior distillation
#         lambda_k: weight for K-regularizer
    
#     Returns:
#         Dictionary with results
#     """
#     print(f"\n{'='*80}")
#     print(f"Evaluating: {synthetic_choice} | {word_selection} | {n_words}")
#     print(f"{'='*80}")
    
#     # Prepare training data based on configuration
#     if synthetic_choice == "none":
#         # For "none", use the already processed str_df_base
#         train_str_df = str_df_base.copy()
#     else:
#         # Load synthetic data
#         synthetic_df = load_synthetic_data(synthetic_choice)
#         if synthetic_df is None:
#             print(f"Warning: Could not load {synthetic_choice} data")
#             return None
        
#         # Get common words
#         gpt_5_mini_words = set(gpt_5_mini_df['Word'])
#         gpt_4o_words = set(gpt_4o_df['Word'])
#         common_words = gpt_4o_words.intersection(gpt_5_mini_words)
        
#         # Select words based on strategy
#         if word_selection == "first":
#             sorted_words = sorted(common_words)
#             n = min(n_words, len(sorted_words))
#             selected_words = set(sorted_words[:n])
#         elif word_selection == "random":
#             import random
#             # Use same seed for reproducibility (42 for all random selections)
#             random.seed(42)
#             n = min(n_words, len(common_words))
#             selected_words = set(random.sample(list(common_words), n))
#         else:
#             selected_words = common_words
        
#         # Filter synthetic data
#         df_sampled = synthetic_df[synthetic_df['Word'].isin(selected_words)]
        
#         # Combine with gold data (need to use the base gold_df format for processing)
#         # Load base gold data
#         gold_df_temp = pd.read_parquet(os.path.join(DATA_FOLDER, "Sue_kalt.parquet"))
#         gold_df_temp['Word'] = gold_df_temp['word']
#         gold_df_temp['morph'] = gold_df_temp['morph'].str.replace('-', ' ')
#         gold_df_temp['Morph_split_str'] = gold_df_temp['morph']
#         gold_df_temp['Morph_split'] = gold_df_temp['morph'].str.split(' ')
#         gold_df_temp = gold_df_temp[['Word', 'Morph_split', 'Morph_split_str']]
#         gold_df_temp.drop_duplicates(subset='Word', keep='first', inplace=True)
#         gold_df_temp.dropna(subset=['Word'], inplace=True)
        
#         train_df = pd.concat([df_sampled, gold_df_temp], ignore_index=True)
        
#         # Process training data (same as in main notebook)
#         train_df["Char_split"] = train_df["Morph_split"].apply(tokenize_morphemes)
#         train_df["CV_split"] = train_df["Char_split"].apply(morphs_to_cv)
        
#         # Create str_df format
#         train_str_df = pd.DataFrame()
#         train_str_df["Full_chain"] = train_df["CV_split"].apply(cv_to_string)
#         train_str_df["Trimmed_chain"] = train_str_df["Full_chain"].apply(
#             lambda x: x.split("-", 1)[1] if "-" in x else np.nan
#         )
#         train_str_df["Word"] = train_df["Word"]
#         train_str_df["Char_split"] = train_df["Char_split"]
#         train_str_df["Morph_split"] = train_df["Morph_split"]
#         train_str_df = train_str_df.dropna(subset=["Trimmed_chain"]).reset_index(drop=True)
        
#         # Add features
#         train_str_df["Word_len"] = train_str_df["Word"].str.len()
#         train_str_df["Vowel_no"] = train_str_df["Full_chain"].str.count("V")
#         train_str_df["Cons_no"] = train_str_df["Full_chain"].str.count("C")
#         train_str_df["Tail_cons_no"] = train_str_df["Trimmed_chain"].str.count("C")
#         train_str_df["Tail_vowel_no"] = train_str_df["Trimmed_chain"].str.count("V")
#         train_str_df["No_splits"] = train_str_df["Morph_split"].str.len()
#         train_str_df["YW_count"] = train_str_df["Word"].str.count("[yw]")
#         train_str_df["Tail_YW_count"] = train_str_df["Morph_split"].apply(
#             lambda ms: sum(m.count("y") + m.count("w") for m in ms[1:])
#         )
    
#     # Train or load model
#     model, vocab, out = run_segmentation_with_privK(
#         df=train_str_df,
#         provided_suffix_list=suffix_list,
#         use_suffix_list=False,
#         unk_penalty=-15.0,
#         epochs=15,
#         use_prior=True,
#         lambda_prior=lambda_prior,
#         lambda_k=lambda_k,
#         hparams=best_hparams,
#         synthetic_choice=synthetic_choice
#     )
    
#     # Evaluate without rejection filter
#     results_base = evaluate_on_gold_df(
#         acc_df, model, vocab, out,
#         max_token_len=4,
#         use_tuned_thr=True,
#         show_sample=0
#     )
    
#     # Evaluate with rejection filter
#     results_filtered = evaluate_and_ignore_rejected(
#         acc_df, model, vocab, out,
#         allowed_suffixes=suffix_list,
#         max_token_len=4,
#         use_tuned_thr=True,
#         show_sample=0
#     )
    
#     return {
#         "base_exact_match": results_base["exact_match_rate"],
#         "base_f1": results_base["micro_f1"],
#         "filtered_exact_match": results_filtered["exact_match_rate"],
#         "filtered_f1": results_filtered["micro_f1"],
#         "filter_precision": results_filtered["filter_precision"],
#         "false_rejection_rate": results_filtered["false_rejection_rate"],
#         "rejection_count": results_filtered["rejection_count"],
#         "false_rejection_count": results_filtered["false_rejection_count"]
#     }

# # Load base gold data for training
# print("Loading base gold data...")
# gold_df_base = pd.read_parquet(os.path.join(DATA_FOLDER, "Sue_kalt.parquet"))
# gold_df_base['Word'] = gold_df_base['word']
# gold_df_base['morph'] = gold_df_base['morph'].str.replace('-', ' ')
# gold_df_base['Morph_split_str'] = gold_df_base['morph']
# gold_df_base['Morph_split'] = gold_df_base['morph'].str.split(' ')
# gold_df_base = gold_df_base[['Word', 'Morph_split', 'Morph_split_str']]
# gold_df_base.drop_duplicates(subset='Word', keep='first', inplace=True)
# gold_df_base.dropna(subset=['Word'], inplace=True)

# # Process base data
# gold_df_base["Char_split"] = gold_df_base["Morph_split"].apply(tokenize_morphemes)
# gold_df_base["CV_split"] = gold_df_base["Char_split"].apply(morphs_to_cv)

# str_df_base = pd.DataFrame()
# str_df_base["Full_chain"] = gold_df_base["CV_split"].apply(cv_to_string)
# str_df_base["Trimmed_chain"] = str_df_base["Full_chain"].apply(
#     lambda x: x.split("-", 1)[1] if "-" in x else np.nan
# )
# str_df_base["Word"] = gold_df_base["Word"]
# str_df_base["Char_split"] = gold_df_base["Char_split"]
# str_df_base["Morph_split"] = gold_df_base["Morph_split"]
# str_df_base = str_df_base.dropna(subset=["Trimmed_chain"]).reset_index(drop=True)

# str_df_base["Word_len"] = str_df_base["Word"].str.len()
# str_df_base["Vowel_no"] = str_df_base["Full_chain"].str.count("V")
# str_df_base["Cons_no"] = str_df_base["Full_chain"].str.count("C")
# str_df_base["Tail_cons_no"] = str_df_base["Trimmed_chain"].str.count("C")
# str_df_base["Tail_vowel_no"] = str_df_base["Trimmed_chain"].str.count("V")
# str_df_base["No_splits"] = str_df_base["Morph_split"].str.len()
# str_df_base["YW_count"] = str_df_base["Word"].str.count("[yw]")
# str_df_base["Tail_YW_count"] = str_df_base["Morph_split"].apply(
#     lambda ms: sum(m.count("y") + m.count("w") for m in ms[1:])
# )

# # Define all configurations to test
# configs = [
#     ("none", None, None),
#     ("gpt4o", "first", 100),
#     ("gpt5mini", "first", 100),
#     ("gpt4o", "first", 200),
#     ("gpt5mini", "first", 200),
#     ("gpt4o", "first", 300),
#     ("gpt5mini", "first", 300),
#     ("gpt4o", "random", 100),
#     ("gpt5mini", "random", 100),
#     ("gpt4o", "random", 200),
#     ("gpt5mini", "random", 200),
#     ("gpt4o", "random", 300),
#     ("gpt5mini", "random", 300),
# ]

# # Run evaluations for all configurations
# results_list = []
# for synthetic_choice, word_selection, n_words in configs:
#     try:
#         result = evaluate_augmentation_config(
#             synthetic_choice=synthetic_choice,
#             word_selection=word_selection,
#             n_words=n_words,
#             str_df_base=str_df_base,
#             acc_df=acc_df,
#             suffix_list=suffix_list,
#             best_hparams=best,
#             lambda_prior=0.15289202508573396,
#             lambda_k=0.2
#         )
        
#         if result is not None:
#             # Create config name
#             if synthetic_choice == "none":
#                 config_name = "No augmentation"
#             else:
#                 config_name = f"{synthetic_choice} {word_selection} {n_words}"
            
#             results_list.append({
#                 "Configuration": config_name,
#                 "Base Exact Match": f"{result['base_exact_match']:.3f}",
#                 "Base F1": f"{result['base_f1']:.3f}",
#                 "Filtered Exact Match": f"{result['filtered_exact_match']:.3f}",
#                 "Filtered F1": f"{result['filtered_f1']:.3f}",
#                 "Filter Precision": f"{result['filter_precision']:.1%}",
#                 "False Rejection Rate": f"{result['false_rejection_rate']:.1%}",
#                 "Rejections": result['rejection_count'],
#                 "False Rejections": result['false_rejection_count']
#             })
#     except Exception as e:
#         print(f"Error evaluating {synthetic_choice} {word_selection} {n_words}: {e}")
#         import traceback
#         traceback.print_exc()

# # Create and display the table
# if results_list:
#     results_df = pd.DataFrame(results_list)
#     print("\n" + "="*80)
#     print("SYNTHETIC DATA AUGMENTATION COMPARISON TABLE")
#     print("="*80)
#     display(results_df)
    
#     # Also save to CSV
#     output_file = os.path.join(DATA_FOLDER, "augmentation_comparison_table.csv")
#     results_df.to_csv(output_file, index=False)
#     print(f"\nTable saved to: {output_file}")
# else:
#     print("No results to display.")
