# üéôÔ∏è Hindi Disfluency Restoration Pipeline

This notebook restores disfluencies (filler words like "‡§π‡§Æ‡•ç‡§Æ", "‡§π‡§æ‡§Ç", "‡§â‡§Æ‡•ç‡§Æ") to clean Hindi transcripts using:

1. **Whisper ASR** - Transcribes audio to detect spoken disfluencies
2. **Sequence Alignment** - Compares clean text with ASR output to find insertions
3. **N-gram Language Model** - Validates that insertions sound natural
4. **Position Prior** - Disfluencies usually occur at the start of utterances

---

## 1Ô∏è‚É£ Setup & Configuration

In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================
# These paths point to the competition data on Kaggle

AUDIO_DIR = "/kaggle/input/nppe-2-automatic-disfluency-restoration/downloaded_audios"  # Audio files (.wav)
TEST_CSV = "/kaggle/input/nppe-2-automatic-disfluency-restoration/test.csv"            # Test set to process
TRAIN_CSV = "/kaggle/input/nppe-2-automatic-disfluency-restoration/train.csv"          # Training data for LM
DISF_CSV_PATH = "/kaggle/input/nppe-2-automatic-disfluency-restoration/unique_disfluencies.csv"  # Disfluency list
OUTPUT_DIR = "/kaggle/working/"  # Where to save results

# Verbose logging - set to True for detailed output
VERBOSE = True

def log(msg):
    """Print message only if VERBOSE is True"""
    if VERBOSE:
        print(f"[LOG] {msg}")

print("‚úÖ Configuration loaded")
print(f"   Audio directory: {AUDIO_DIR}")
print(f"   Output directory: {OUTPUT_DIR}")

In [None]:
# =============================================================================
# IMPORTS
# =============================================================================
# Standard library
import os              # File path operations
import re              # Regular expressions for text cleaning
import gc              # Garbage collection to free memory
import pickle          # Save/load cache to disk
import unicodedata     # Normalize Hindi text (NFKC form)
from difflib import SequenceMatcher  # Find differences between two sequences
from collections import Counter      # Count n-gram frequencies

# Data science
import pandas as pd    # DataFrames for CSV handling
import numpy as np     # Numerical operations

# Deep learning
import torch           # PyTorch for GPU acceleration
from transformers import WhisperProcessor, WhisperForConditionalGeneration  # Whisper ASR model
import librosa         # Audio loading and processing

# Evaluation
from jiwer import wer  # Word Error Rate metric

print("‚úÖ All imports successful")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")

## 2Ô∏è‚É£ Disfluency Set & Thresholds

Disfluencies are filler words people say when thinking ("umm", "uh", etc.).
In Hindi, common ones include:
- **‡§π‡§Æ‡•ç‡§Æ** (hmm)
- **‡§π‡§æ‡§Ç/‡§π‡§æ‡§Å** (yes, often used as filler)
- **‡§â‡§Æ‡•ç‡§Æ** (umm)
- **‡§§‡•ã/‡§µ‡•ã** (so/that - can be filler or real word)

Each disfluency has its own **confidence threshold** - words that could be real words (like "‡§π‡§æ‡§Ç" = yes) need higher ASR confidence to be inserted.

In [None]:
# =============================================================================
# DISFLUENCY SET
# =============================================================================

def norm(x):
    """
    Normalize text for consistent comparison.
    - NFKC normalization handles different Unicode representations
    - Lowercase for case-insensitive matching
    """
    return unicodedata.normalize('NFKC', str(x).strip().lower())

# Load disfluencies from CSV file
try:
    disf_df = pd.read_csv(DISF_CSV_PATH)
    DISFLUENCY_SET = set(norm(x) for x in disf_df['disfluency'].astype(str).tolist() if x and str(x).strip())
    log(f"Loaded {len(DISFLUENCY_SET)} disfluencies from CSV")
except Exception as e:
    print(f"‚ö†Ô∏è Could not load disfluency CSV: {e}")
    DISFLUENCY_SET = set()

# Add common Hindi fillers that might not be in the CSV
COMMON_FILLERS = {
    '‡§Ö‡§Ç', '‡§â‡§Ç', '‡§ä‡§Ç', '‡§Ü‡§Ç', '‡§è‡§Ç', '‡§ì‡§Ç',  # Short vowel sounds
    '‡§π‡§Æ‡•ç‡§Æ', '‡§π‡§æ‡§Ç', '‡§π‡§æ‡§Å',                 # Hmm, yes
    '‡§â‡§Æ‡•ç‡§Æ', '‡§Ö‡§Æ‡•ç‡§Æ',                       # Umm
    '‡§π', '‡§Ö', '‡§è',                        # Single-letter fillers
    '‡§§‡•ã', '‡§µ‡•ã', '‡§ú‡•ã',                     # Conjunctions often used as fillers
    '‡§Æ‡§§‡§≤‡§¨', '‡§¨‡§∏', '‡§Ö‡§ö‡•ç‡§õ‡§æ'                 # "I mean", "just", "okay"
}
DISFLUENCY_SET |= set(norm(x) for x in COMMON_FILLERS)

# Build regex pattern for removing disfluencies (longest match first)
# This prevents partial matches (e.g., "‡§π‡§æ‡§Ç" matching inside a longer word)
pattern = r'\b(?:' + '|'.join(re.escape(x) for x in sorted(DISFLUENCY_SET, key=len, reverse=True)) + r')\b'
RE_DISF = re.compile(pattern, flags=re.IGNORECASE)

print(f"‚úÖ Loaded {len(DISFLUENCY_SET)} total disfluencies")
print(f"   Sample: {list(DISFLUENCY_SET)[:8]}...")

In [None]:
# =============================================================================
# PER-DISFLUENCY CONFIDENCE THRESHOLDS
# =============================================================================
# Each disfluency has its own threshold based on how often it's used as a real word.
# More negative = more lenient (insert even with low confidence)
# Less negative = stricter (only insert if ASR is very confident)

DISFLUENCY_THRESHOLDS = {
    # Very common fillers - be lenient
    '‡§π‡§Æ‡•ç‡§Æ': -8.0,   # "Hmm" - almost always a filler
    '‡§â‡§Æ‡•ç‡§Æ': -7.0,   # "Umm" - almost always a filler
    '‡§Ö‡§Ç': -7.0,     # Short sound - usually a filler
    '‡§π': -7.0,      # Single letter filler
    '‡§Ö‡§π': -7.0,     # "Ah"
    '‡§â‡§π': -7.0,     # "Uh"
    '‡§ì': -6.0,      # Can be interjection
    
    # Words that can be real - be stricter
    '‡§π‡§æ‡§Ç': -5.0,    # "Yes" - could be real acknowledgment
    '‡§π‡§æ‡§Å': -5.0,    # Same as above (different Unicode)
    '‡§§‡•ã': -4.0,     # "So" - often a real conjunction
    '‡§µ‡•ã': -4.0,     # "That" - often a real pronoun
    '‡§î‡§∞': -3.0,     # "And" - almost always a real word
    
    'default': -6.0  # Default for unknown disfluencies
}

def get_disfluency_threshold(token):
    """Get the confidence threshold for a specific disfluency."""
    norm_token = norm(token)
    return DISFLUENCY_THRESHOLDS.get(norm_token, DISFLUENCY_THRESHOLDS['default'])

print("‚úÖ Disfluency thresholds configured")
print("   Lenient: ‡§π‡§Æ‡•ç‡§Æ (-8.0), ‡§â‡§Æ‡•ç‡§Æ (-7.0)")
print("   Strict: ‡§π‡§æ‡§Ç (-5.0), ‡§§‡•ã (-4.0), ‡§î‡§∞ (-3.0)")

## 3Ô∏è‚É£ Whisper ASR Model

We use **Whisper Large V3** fine-tuned on Hindi data by ARTPARK-IISc.
The model:
1. Takes audio as input
2. Outputs Hindi text transcription
3. Provides **confidence scores** for each token (log probabilities)

In [None]:
# =============================================================================
# LOAD WHISPER MODEL
# =============================================================================

print("üîÑ Loading Whisper model (this may take 1-2 minutes)...")

# Detect if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"   Using device: {device}")

# Model ID - ARTPARK's Hindi-tuned Whisper
model_id = "ARTPARK-IISc/whisper-large-v3-vaani-hindi"

# Load processor (handles audio preprocessing and text decoding)
processor = WhisperProcessor.from_pretrained(model_id)
log("Processor loaded")

# Load model with memory optimizations
model = WhisperForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,  # FP16 on GPU
    low_cpu_mem_usage=True  # Load weights incrementally
)
model.to(device)  # Move to GPU
model.eval()      # Set to inference mode (no dropout)
torch.set_grad_enabled(False)  # Disable gradient computation (saves memory)

# Force Hindi language output
forced_decoder_ids = processor.get_decoder_prompt_ids(language="hi", task="transcribe")

# Clean up memory
gc.collect()
if device == "cuda":
    torch.cuda.empty_cache()

print(f"‚úÖ Model loaded successfully on {device}")
if device == "cuda":
    print(f"   GPU memory used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

## 4Ô∏è‚É£ Text Processing Utilities

In [None]:
# =============================================================================
# TEXT PROCESSING FUNCTIONS
# =============================================================================

def normalize_text(text):
    """
    Normalize text for alignment comparison.
    - Removes punctuation (‡•§ ‡•• , . ! ? etc.)
    - Collapses multiple spaces
    - Lowercases everything
    """
    text = unicodedata.normalize('NFKC', str(text))
    text = re.sub(r'[‡•§‡••,.!?;:\'"()\[\]-]+', ' ', text)  # Remove punctuation
    return re.sub(r'\s+', ' ', text).strip().lower()

def tokenize(text):
    """Split text into words (tokens) for alignment."""
    return [t for t in normalize_text(text).split() if t]

def make_clean(text):
    """
    Remove all disfluencies from text to create a 'clean' version.
    Used when we need to compare clean vs. ASR output.
    """
    if pd.isna(text) or not isinstance(text, str):
        return ""
    t = unicodedata.normalize('NFKC', text)
    t = RE_DISF.sub(' ', t)  # Remove disfluencies
    return re.sub(r'\s+', ' ', t).strip()

def is_disfluency(token):
    """Check if a token is a known disfluency."""
    return norm(token) in DISFLUENCY_SET

def is_repetition(token, context_tokens, position):
    """
    Check if token is a repetition of adjacent word.
    E.g., "‡§Æ‡•à‡§Ç ‡§Æ‡•à‡§Ç" (I I) - the second "‡§Æ‡•à‡§Ç" is a disfluent repetition.
    """
    # Check if same as word at current position
    if position < len(context_tokens) and token == context_tokens[position]:
        return True
    # Check if same as previous word
    if position > 0 and position <= len(context_tokens) and token == context_tokens[position - 1]:
        return True
    return False

print("‚úÖ Text processing utilities ready")

# Demo
demo_text = "‡§π‡§Æ‡•ç‡§Æ, ‡§Æ‡•à‡§Ç ‡§∏‡•ã‡§ö‡§§‡§æ ‡§π‡•Ç‡§Ç ‡§ï‡§ø ‡§Ø‡§π ‡§Ö‡§ö‡•ç‡§õ‡§æ ‡§π‡•à‡•§"
print(f"\n   Demo input: {demo_text}")
print(f"   Normalized: {normalize_text(demo_text)}")
print(f"   Clean (no disfluencies): {make_clean(demo_text)}")

## 5Ô∏è‚É£ Audio Transcription with Confidence Scores

In [None]:
# =============================================================================
# TOKEN-TO-WORD MAPPING
# =============================================================================

def map_tokens_to_words(token_ids, token_logprobs, tokenizer, decoded_text):
    """
    Map sub-word token log-probabilities to word-level confidence.
    
    Whisper uses sub-word tokenization (e.g., "‡§ï‡§π‡§æ‡§®‡•Ä" might be split into
    "‡§ï‡§π‡§æ" + "‡§®‡•Ä"). This function groups tokens back into words and averages
    their log-probabilities.
    
    Returns: List of {'word': str, 'avg_logprob': float}
    """
    words = decoded_text.strip().split()
    if not words or not token_ids:
        return [{'word': w, 'avg_logprob': None} for w in words]
    
    # Decode each token ID to its text representation
    token_texts = []
    for tid in token_ids:
        try:
            txt = tokenizer.decode([tid], skip_special_tokens=True)
            token_texts.append(txt)
        except:
            token_texts.append("")
    
    # Match tokens to words
    word_infos = []
    token_idx = 0
    current_text = ""
    current_logprobs = []
    
    for word in words:
        word_norm = normalize_text(word)
        
        # Consume tokens until we've matched this word
        while token_idx < len(token_texts):
            current_text += token_texts[token_idx].strip()
            if token_idx < len(token_logprobs):
                current_logprobs.append(token_logprobs[token_idx])
            token_idx += 1
            
            current_norm = normalize_text(current_text)
            if word_norm in current_norm or current_norm == word_norm:
                break
            if len(current_norm) >= len(word_norm) * 2:
                break  # Mismatch, move on
        
        # Calculate average log-probability for this word
        avg_lp = float(np.mean(current_logprobs)) if current_logprobs else None
        word_infos.append({'word': word, 'avg_logprob': avg_lp})
        current_text = ""
        current_logprobs = []
    
    return word_infos

print("‚úÖ Token-to-word mapping function ready")

In [None]:
# =============================================================================
# AUDIO TRANSCRIPTION
# =============================================================================

def transcribe_audio(audio_path, max_length=448, chunk_s=30):
    """
    Transcribe audio file using Whisper with per-word confidence scores.
    
    Args:
        audio_path: Path to .wav file
        max_length: Maximum output tokens per chunk
        chunk_s: Audio chunk size in seconds (30s is Whisper's native size)
    
    Returns:
        text: Full transcription
        tokens_info: List of {'word': str, 'avg_logprob': float}
    """
    try:
        # Load audio at 16kHz (Whisper's expected sample rate)
        audio, sr = librosa.load(audio_path, sr=16000)
        chunk_len = chunk_s * sr  # Samples per chunk
        
        all_texts = []
        all_tokens = []
        
        # Process audio in chunks (handles long recordings)
        for i in range(0, len(audio), chunk_len):
            chunk = audio[i:i+chunk_len]
            
            # Convert audio to model input features
            inputs = processor(chunk, sampling_rate=16000, return_tensors="pt")
            input_features = inputs.input_features.to(device)
            if device == "cuda":
                input_features = input_features.half()  # FP16 for GPU
            
            # Generate transcription with scores
            with torch.no_grad():
                out = model.generate(
                    input_features,
                    forced_decoder_ids=forced_decoder_ids,
                    max_length=max_length,
                    return_dict_in_generate=True,
                    output_scores=True  # Get token log-probabilities
                )
            
            # Decode output tokens to text
            seq = out.sequences[0]
            decoded = processor.batch_decode(seq.unsqueeze(0), skip_special_tokens=True)[0]
            all_texts.append(decoded.strip())
            
            # Extract per-token log-probabilities
            scores = out.scores
            if scores:
                token_ids = seq.tolist()
                token_logprobs = []
                for idx, step_logits in enumerate(scores):
                    logp = torch.log_softmax(step_logits, dim=-1)
                    if idx + 1 < len(token_ids):
                        chosen_id = token_ids[idx + 1]
                        token_logprobs.append(float(logp[0, chosen_id].cpu().numpy()))
                
                word_infos = map_tokens_to_words(token_ids, token_logprobs, processor.tokenizer, decoded)
                all_tokens.extend(word_infos)
            else:
                # Fallback: no confidence scores available
                for w in decoded.strip().split():
                    all_tokens.append({'word': w, 'avg_logprob': None})
            
            del input_features, out
        
        # Clean up GPU memory
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()
        
        return " ".join(all_texts).strip(), all_tokens
        
    except Exception as e:
        print(f"‚ö†Ô∏è Transcription error: {e}")
        gc.collect()
        return "", []

print("‚úÖ Transcription function ready")

## 6Ô∏è‚É£ ASR Cache (Avoid Re-transcribing)

In [None]:
# =============================================================================
# ASR CACHE
# =============================================================================
# Transcription is slow (~5-10s per audio). Cache results to avoid redoing work.

CACHE_PATH = os.path.join(OUTPUT_DIR, "asr_cache.pkl")
ASR_CACHE = {}

def load_cache():
    """Load previously cached ASR results."""
    global ASR_CACHE
    if os.path.exists(CACHE_PATH):
        with open(CACHE_PATH, 'rb') as f:
            ASR_CACHE = pickle.load(f)
        print(f"üìÇ Loaded {len(ASR_CACHE)} cached ASR results")
    else:
        print("üìÇ No cache found, starting fresh")

def save_cache():
    """Save ASR cache to disk."""
    with open(CACHE_PATH, 'wb') as f:
        pickle.dump(ASR_CACHE, f)
    log(f"Cache saved ({len(ASR_CACHE)} entries)")

def get_asr_transcript(audio_path):
    """
    Get ASR transcription, using cache if available.
    Returns: (text, tokens_info)
    """
    audio_id = os.path.splitext(os.path.basename(audio_path))[0]
    
    # Check cache
    if audio_id in ASR_CACHE:
        cached = ASR_CACHE[audio_id]
        if isinstance(cached, dict):
            return cached.get('text', ''), cached.get('tokens', [])
        return cached, []
    
    # Transcribe and cache
    log(f"Transcribing {audio_id}...")
    text, tokens = transcribe_audio(audio_path)
    ASR_CACHE[audio_id] = {'text': text, 'tokens': tokens}
    return text, tokens

print("‚úÖ Cache system ready")

## 7Ô∏è‚É£ N-Gram Language Model

The language model checks if an insertion sounds natural.
- Built from training transcripts
- Uses trigrams (3-word sequences)
- Assigns probability to word sequences

In [None]:
# =============================================================================
# N-GRAM LANGUAGE MODEL
# =============================================================================

NGRAM_COUNTS = None   # Stores (w1, w2, w3) -> count
PREFIX_COUNTS = None  # Stores (w1, w2) -> count
NGRAM_VOCAB_SIZE = 0

def build_ngram_model(n=3):
    """
    Build n-gram language model from training transcripts.
    Uses add-alpha smoothing for unseen n-grams.
    """
    global NGRAM_COUNTS, PREFIX_COUNTS, NGRAM_VOCAB_SIZE
    
    print("üîÑ Building n-gram language model from training data...")
    
    try:
        train_df = pd.read_csv(TRAIN_CSV)
    except Exception as e:
        print(f"‚ö†Ô∏è Could not load training data: {e}")
        return
    
    ngram_counts = Counter()
    prefix_counts = Counter()
    
    # Count n-grams in all transcripts
    for text in train_df['transcript'].dropna():
        tokens = ['<s>'] + tokenize(str(text)) + ['</s>']  # Add start/end markers
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i+n])
            prefix = ngram[:-1]
            ngram_counts[ngram] += 1
            prefix_counts[prefix] += 1
    
    NGRAM_COUNTS = ngram_counts
    PREFIX_COUNTS = prefix_counts
    NGRAM_VOCAB_SIZE = len(set(t for ng in ngram_counts for t in ng))
    
    print(f"‚úÖ Built LM: {len(ngram_counts):,} n-grams, vocab size {NGRAM_VOCAB_SIZE:,}")

def sentence_logprob(tokens, n=3, alpha=0.1):
    """
    Compute log-probability of a sentence under the n-gram model.
    Uses add-alpha (Laplace) smoothing for unseen n-grams.
    """
    if NGRAM_COUNTS is None:
        return 0.0
    
    tokens = ['<s>'] + list(tokens) + ['</s>']
    logprob = 0.0
    V = max(1, NGRAM_VOCAB_SIZE)
    
    for i in range(len(tokens) - n + 1):
        ngram = tuple(tokens[i:i+n])
        prefix = ngram[:-1]
        count = NGRAM_COUNTS.get(ngram, 0)
        prefix_count = PREFIX_COUNTS.get(prefix, 0)
        prob = (count + alpha) / (prefix_count + alpha * V + 1e-10)
        logprob += np.log(prob + 1e-10)
    
    return logprob

def check_insertion_plausibility(clean_tokens, position, token, lm_threshold=-2.0):
    """
    Check if inserting a token maintains sentence plausibility.
    Returns True if log-prob doesn't drop too much after insertion.
    """
    if NGRAM_COUNTS is None:
        return True
    
    logprob_before = sentence_logprob(clean_tokens)
    tokens_with = list(clean_tokens)
    tokens_with.insert(min(position, len(tokens_with)), token)
    logprob_after = sentence_logprob(tokens_with)
    
    delta = logprob_after - logprob_before
    return delta > lm_threshold

print("‚úÖ Language model functions ready")

## 8Ô∏è‚É£ Alignment & Insertion Logic

The core algorithm:
1. Compare clean text with ASR output using `SequenceMatcher`
2. Find words that ASR detected but aren't in clean text (insertions)
3. For each candidate insertion, check:
   - Is it a known disfluency?
   - Does the ASR confidence exceed the threshold?
   - Does the LM approve the insertion?

In [None]:
# =============================================================================
# POSITION PRIOR
# =============================================================================

def position_prior(token_index, n_tokens, exponent=1.5):
    """
    Bias insertions toward earlier positions in the sentence.
    Disfluencies typically occur at the start when speakers are thinking.
    
    Returns: Score between 0 and 1 (higher = more likely insertion point)
    """
    if n_tokens <= 1:
        return 1.0
    frac = token_index / float(max(1, n_tokens - 1))
    return 1.0 - frac ** exponent

print("‚úÖ Position prior function ready")
print(f"   Position 0 (start): {position_prior(0, 10):.2f}")
print(f"   Position 5 (middle): {position_prior(5, 10):.2f}")
print(f"   Position 9 (end): {position_prior(9, 10):.2f}")

In [None]:
# =============================================================================
# FIND INSERTIONS (CORE ALGORITHM)
# =============================================================================

def find_insertions(clean_tokens, asr_tokens, asr_tokens_info=None,
                    pos_exponent=1.5, use_lm=True, lm_threshold=-2.0):
    """
    Find disfluencies to insert from ASR output into clean text.
    
    Algorithm:
    1. Align clean and ASR tokens using SequenceMatcher
    2. For 'insert' operations (tokens in ASR but not clean):
       - Check if it's a known disfluency
       - Check if ASR confidence + position prior exceeds threshold
       - Check if LM approves the insertion
    3. For 'replace' operations, only insert disfluencies
    
    Returns: List of (position, token) tuples
    """
    if not asr_tokens:
        return []
    
    # Build map: normalized word -> list of token infos
    info_map = {}
    if asr_tokens_info:
        for info in asr_tokens_info:
            if info and 'word' in info:
                norm_word = normalize_text(info['word'])
                info_map.setdefault(norm_word, []).append(info)
    
    # Find differences between clean and ASR
    sm = SequenceMatcher(a=clean_tokens, b=asr_tokens, autojunk=False)
    insertions = []
    n_clean = len(clean_tokens)
    
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == 'insert':
            # ASR has tokens that clean text doesn't
            for j in range(j1, j2):
                token = asr_tokens[j]
                norm_token = normalize_text(token)
                is_disf = is_disfluency(token)
                
                # Get ASR confidence for this token
                avg_lp = None
                if norm_token in info_map and info_map[norm_token]:
                    token_info = info_map[norm_token].pop(0)
                    avg_lp = token_info.get('avg_logprob')
                
                # Calculate position score
                pos_score = position_prior(i1, n_clean, pos_exponent)
                
                # Get threshold for this disfluency
                conf_threshold = get_disfluency_threshold(token)
                
                # Decision logic
                should_insert = False
                
                if is_disf:
                    # Known disfluency - use per-disfluency threshold
                    if avg_lp is not None:
                        score = avg_lp + np.log(pos_score + 1e-6)
                        should_insert = score > conf_threshold
                    else:
                        should_insert = True  # No confidence, trust disfluency set
                elif is_repetition(token, clean_tokens, i1):
                    should_insert = True  # Repetitions are disfluent
                elif avg_lp is not None:
                    # Non-disfluency: need very high confidence
                    score = avg_lp + np.log(pos_score + 1e-6)
                    should_insert = score > -4.0
                
                # LM check
                if should_insert and use_lm and NGRAM_COUNTS is not None:
                    if not check_insertion_plausibility(clean_tokens, i1, token, lm_threshold):
                        should_insert = False
                
                if should_insert:
                    insertions.append((i1, token))
        
        elif tag == 'replace':
            # ASR has different tokens - only insert disfluencies
            for j in range(j1, j2):
                token = asr_tokens[j]
                if is_disfluency(token):
                    if use_lm and NGRAM_COUNTS is not None:
                        if check_insertion_plausibility(clean_tokens, i1, token, lm_threshold):
                            insertions.append((i1, token))
                    else:
                        insertions.append((i1, token))
    
    return insertions

print("‚úÖ Insertion detection function ready")

In [None]:
# =============================================================================
# APPLY INSERTIONS
# =============================================================================

def apply_insertions(original_words, insertions, max_consecutive=4):
    """
    Insert disfluencies into the original text.
    
    Args:
        original_words: List of original words
        insertions: List of (position, token) tuples
        max_consecutive: Limit consecutive same tokens (prevents "‡§π‡§Æ‡•ç‡§Æ ‡§π‡§Æ‡•ç‡§Æ ‡§π‡§Æ‡•ç‡§Æ ‡§π‡§Æ‡•ç‡§Æ ‡§π‡§Æ‡•ç‡§Æ")
    
    Returns: List of words with insertions
    """
    result = list(original_words)
    
    # Sort by position descending (so insertions don't shift later positions)
    sorted_insertions = sorted(insertions, key=lambda x: (-x[0], insertions.index(x)))
    
    for pos, token in sorted_insertions:
        pos = min(pos, len(result))
        
        # Count consecutive same tokens
        consec = 1
        i = pos - 1
        while i >= 0 and result[i] == token:
            consec += 1
            i -= 1
        i = pos
        while i < len(result) and result[i] == token:
            consec += 1
            i += 1
        
        # Only insert if under limit
        if consec <= max_consecutive:
            result.insert(pos, token)
    
    return result

print("‚úÖ Insertion application function ready")

In [None]:
# =============================================================================
# MAIN RESTORATION FUNCTION
# =============================================================================

def restore_disfluencies(clean_text, audio_path, pos_exponent=1.5, use_lm=True, lm_threshold=-2.0):
    """
    Main function: Restore disfluencies to clean text using audio.
    
    Args:
        clean_text: Text without disfluencies
        audio_path: Path to audio file
        pos_exponent: Position prior exponent (higher = stronger bias to start)
        use_lm: Whether to use language model filtering
        lm_threshold: LM threshold (more negative = more lenient)
    
    Returns: Text with disfluencies restored
    """
    if pd.isna(clean_text) or not isinstance(clean_text, str):
        clean_text = ""
    
    # Get ASR transcription
    asr_text, asr_tokens_info = get_asr_transcript(audio_path)
    
    if not asr_text:
        return clean_text if clean_text else ""
    
    if not clean_text:
        return asr_text
    
    # Tokenize for alignment
    clean_tokens = tokenize(clean_text)
    asr_tokens = tokenize(asr_text)
    
    # Find insertions
    insertions = find_insertions(
        clean_tokens, asr_tokens,
        asr_tokens_info=asr_tokens_info,
        pos_exponent=pos_exponent,
        use_lm=use_lm,
        lm_threshold=lm_threshold
    )
    
    if not insertions:
        return clean_text
    
    # Apply insertions
    original_words = clean_text.split()
    restored_words = apply_insertions(original_words, insertions)
    
    return ' '.join(restored_words)

print("‚úÖ Main restoration function ready")

## 9Ô∏è‚É£ Main Inference Loop

In [None]:
# =============================================================================
# INFERENCE FUNCTION
# =============================================================================

def run_inference(pos_exponent=1.5, use_lm=True, lm_threshold=-2.0):
    """
    Process all test samples and generate submission.
    """
    print("\n" + "=" * 60)
    print("üöÄ DISFLUENCY RESTORATION PIPELINE")
    print("=" * 60)
    print(f"   Position exponent: {pos_exponent}")
    print(f"   Use LM: {use_lm}")
    print(f"   LM threshold: {lm_threshold}")
    print("=" * 60)
    
    # Load cache
    load_cache()
    
    # Build LM if needed
    if use_lm and NGRAM_COUNTS is None:
        build_ngram_model()
    
    # Load test data
    test_df = pd.read_csv(TEST_CSV)
    print(f"\nüìä Processing {len(test_df)} test samples...")
    
    results = []
    for i, row in test_df.iterrows():
        audio_path = f"{AUDIO_DIR}/{row['id']}.wav"
        
        if os.path.exists(audio_path):
            restored = restore_disfluencies(
                row['transcript'], audio_path,
                pos_exponent=pos_exponent,
                use_lm=use_lm,
                lm_threshold=lm_threshold
            )
        else:
            restored = row['transcript']
            log(f"‚ö†Ô∏è Audio not found: {audio_path}")
        
        results.append({'id': row['id'], 'transcript': restored})
        
        # Progress update every 10 samples
        if (i + 1) % 10 == 0:
            save_cache()
            print(f"   ‚úì Processed {i + 1}/{len(test_df)}")
    
    # Save final results
    save_cache()
    output_path = os.path.join(OUTPUT_DIR, "submission.csv")
    pd.DataFrame(results).to_csv(output_path, index=False)
    
    print("\n" + "=" * 60)
    print(f"‚úÖ COMPLETE! Saved to: {output_path}")
    print(f"   Total samples: {len(results)}")
    print(f"   Output files: {os.listdir(OUTPUT_DIR)}")
    print("=" * 60)
    
    return results

print("‚úÖ Inference function ready")

## üèÉ Execute Pipeline

In [None]:
# =============================================================================
# RUN THE PIPELINE
# =============================================================================

results = run_inference(
    pos_exponent=1.5,    # Bias toward sentence start
    use_lm=True,         # Use language model
    lm_threshold=-2.0    # LM threshold
)