# Language Model Phoneme Evaluation

This notebook evaluates trained models using a phoneme language model for improved accuracy.

**What this does:**
- Loads a trained model (GRU-Opt, Final2/Mamba, or Transformer)
- Runs baseline greedy decoding
- Runs beam search with phoneme language model
- Compares results and shows improvement

**Expected improvement:** 2-8 percentage points lower PER/CER

In [None]:
import sys
sys.path.insert(0, 'src')

In [None]:
import os
import pickle
import torch
import numpy as np
from edit_distance import SequenceMatcher
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

## Configuration

Update these paths for your setup:

In [None]:
# ===== UPDATE THESE PATHS =====
MODEL_TYPE = "final"  # Options: "final", "final2", "advanced"
MODEL_WEIGHTS = os.path.expanduser("~/results/final_training/modelWeights.pt")
DATASET_PATH = os.path.expanduser("~/competitionData/ptDecoder_ctc")
LM_PATH = "phoneme_lm.arpa"  # In notebooks folder
PHONEME_MAP_PATH = "phoneme_map.txt"  # In notebooks folder
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# LM settings (can adjust)
LM_WEIGHT = 0.8  # 0.4-1.2 (higher = trust LM more)
BEAM_WIDTH = 10  # 5-50 (higher = slower but more accurate)

print(f"Model: {MODEL_TYPE}")
print(f"Weights: {MODEL_WEIGHTS}")
print(f"Device: {DEVICE}")

## Install KenLM

In [None]:
try:
    import kenlm
    print("✓ KenLM already installed")
except ImportError:
    print("Installing KenLM...")
    !pip install https://github.com/kpu/kenlm/archive/master.zip
    import kenlm
    print("✓ KenLM installed successfully")

## Load Dataset

In [None]:
from neural_decoder.dataset import SpeechDataset

with open(DATASET_PATH, "rb") as f:
    data = pickle.load(f)

test_ds = SpeechDataset(data["test"])  # Actually validation set

def collate_fn(batch):
    X, y, X_lens, y_lens, days = zip(*batch)
    X_padded = pad_sequence(X, batch_first=True, padding_value=0.0)
    y_padded = pad_sequence(y, batch_first=True, padding_value=0)
    return (
        X_padded,
        y_padded,
        torch.stack(X_lens),
        torch.stack(y_lens),
        torch.stack(days),
    )

test_loader = DataLoader(
    test_ds,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)

print(f"✓ Loaded {len(test_ds)} validation samples")

## Load Model

In [None]:
if MODEL_TYPE == "final":
    from neural_decoder.model import GRUDecoder
    model = GRUDecoder(
        neural_dim=256,
        n_classes=40,
        hidden_dim=1024,
        layer_dim=5,
        nDays=len(data["train"]),
        dropout=0.4,
        device=DEVICE,
        strideLen=4,
        kernelLen=32,
        gaussianSmoothWidth=2.0,
        bidirectional=True,
    ).to(DEVICE)
    stride_len = 4
    kernel_len = 32

elif MODEL_TYPE == "final2":
    from neural_decoder.final2_model import NeuralDecoder
    model = NeuralDecoder(
        conv_size=1024,
        hidden_size=1024,
        encoder_n_layer=8,
        decoder_n_layer=8,
        decoders=['ph'],
    ).to(DEVICE)
    stride_len = 4
    kernel_len = 8

elif MODEL_TYPE == "advanced":
    from neural_decoder.advanced_models import StreamingTransformerDecoder
    model = StreamingTransformerDecoder(
        neural_dim=256,
        n_phonemes=40,
        d_model=512,
        nhead=8,
        num_layers=6,
        dropout=0.2,
        stride_len=4,
        kernel_len=8,
        gaussian_smooth_width=2.0,
        intermediate_layer=3,
        day_count=len(data["train"]),
        diphone_context=40,
        device=DEVICE,
    ).to(DEVICE)
    stride_len = 8  # Transformer uses patch embedding
    kernel_len = 8

# Load weights
model.load_state_dict(torch.load(MODEL_WEIGHTS, map_location=DEVICE))
model.eval()

total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Loaded {MODEL_TYPE} model ({total_params/1e6:.1f}M parameters)")

## Load Language Model

In [None]:
class PhonemeLM:
    def __init__(self, lm_path: str, phoneme_map: dict = None):
        if not os.path.exists(lm_path):
            raise FileNotFoundError(f"LM file not found: {lm_path}")
        self.model = kenlm.Model(lm_path)
        self.phoneme_map = phoneme_map or {}
        self._score_cache = {}

    def id_sequence_to_tokens(self, id_seq):
        tokens = []
        for i in id_seq:
            t = self.phoneme_map.get(int(i), None)
            if t is None:
                t = f"PH{int(i)}"
            tokens.append(t)
        return tokens

    def score(self, id_seq):
        key = tuple(id_seq)
        if key in self._score_cache:
            return self._score_cache[key]
        tokens = self.id_sequence_to_tokens(id_seq)
        s = self.model.score(" ".join(tokens), bos=False, eos=False)
        self._score_cache[key] = float(s)
        return float(s)

    def clear_cache(self):
        self._score_cache.clear()

def load_phoneme_map(path, n_classes):
    if path is None or not os.path.exists(path):
        return {i: f"PH{i}" for i in range(1, n_classes + 1)}
    try:
        with open(path, "r", encoding="utf-8") as f:
            lines = [l.strip() for l in f if l.strip()]
            m = {}
            for idx, tok in enumerate(lines, start=1):
                m[idx] = tok
            for i in range(1, n_classes + 1):
                if i not in m:
                    m[i] = f"PH{i}"
            return m
    except Exception:
        return {i: f"PH{i}" for i in range(1, n_classes + 1)}

phoneme_map = load_phoneme_map(PHONEME_MAP_PATH, 40)
lm = PhonemeLM(LM_PATH, phoneme_map=phoneme_map)

print(f"✓ Language model loaded from: {LM_PATH}")
print(f"  LM Weight: {LM_WEIGHT}")
print(f"  Beam Width: {BEAM_WIDTH}")

## Beam Search Decoder

In [None]:
def beam_search_lm(log_probs, lm_wrapper=None, lm_weight=0.8, beam_width=10, blank_id=0, topk_acoustic=5):
    if isinstance(log_probs, torch.Tensor):
        lp = log_probs.detach().cpu().numpy()
    else:
        lp = np.array(log_probs)
    T, V = lp.shape

    beams = [([], 0.0, 0.0)]  # (sequence, acoustic_score, lm_score)
    for t in range(T):
        step = lp[t]
        topk_idx = np.argsort(step)[-topk_acoustic:][::-1]
        new_beams = {}
        for seq, a_score, l_score in beams:
            for idx in topk_idx:
                token_logp = float(step[idx])
                if idx == blank_id:
                    new_seq = tuple(seq)
                    new_a = a_score + token_logp
                    new_l = l_score
                else:
                    new_seq = tuple(list(seq) + [int(idx)])
                    new_a = a_score + token_logp
                    if lm_wrapper is not None:
                        new_l = lm_wrapper.score(new_seq)
                    else:
                        new_l = 0.0
                combined = new_a + (lm_weight * new_l)
                if new_seq not in new_beams or combined > new_beams[new_seq][0]:
                    new_beams[new_seq] = (combined, new_a, new_l)
        sorted_beams = sorted(new_beams.items(), key=lambda x: x[1][0], reverse=True)[:beam_width]
        beams = [(list(k), v[1], v[2]) for k, v in sorted_beams]

    best = max(beams, key=lambda b: b[1] + lm_weight * b[2])
    decoded = best[0]
    collapsed = []
    prev = None
    for tok in decoded:
        if tok == prev:
            prev = tok
            continue
        if tok != blank_id:
            collapsed.append(tok)
        prev = tok
    return collapsed

## Run Evaluation

In [None]:
print("=" * 70)
print("EVALUATING WITH LANGUAGE MODEL")
print("=" * 70)

all_predictions_baseline = []
all_predictions_lm = []
all_targets = []

with torch.no_grad():
    for batch_idx, (X, y, X_len, y_len, day_idx) in enumerate(test_loader):
        X = X.to(DEVICE)
        X_len = X_len.to(DEVICE)
        day_idx = day_idx.to(DEVICE)
        
        # Forward pass (model-specific)
        if MODEL_TYPE == "final":
            logits = model(X, day_idx)
            lengths = ((X_len - kernel_len) / stride_len).long()
            log_probs = torch.log_softmax(logits, dim=-1)
        elif MODEL_TYPE == "final2":
            logits = model(X, day_idx)
            lengths = model._last_output_lens if hasattr(model, '_last_output_lens') else ((X_len - kernel_len) / stride_len).long()
            log_probs = torch.log_softmax(logits, dim=-1)
        elif MODEL_TYPE == "advanced":
            output = model(X, X_len, day_idx)
            log_probs = output['log_probs']
            lengths = output['eff_lengths']
        
        # Decode each sample
        for i in range(len(y)):
            seq_len = int(lengths[i])
            lp = log_probs[i, :seq_len, :]  # [T, V]
            
            # Baseline: Greedy decoding
            greedy = torch.argmax(lp, dim=-1).cpu().numpy()
            decoded_baseline = []
            prev = None
            for tok in greedy:
                if tok == prev or tok == 0:
                    prev = tok
                    continue
                decoded_baseline.append(tok)
                prev = tok
            all_predictions_baseline.append(decoded_baseline)
            
            # With LM: Beam search
            decoded_lm = beam_search_lm(
                lp,
                lm_wrapper=lm,
                lm_weight=LM_WEIGHT,
                beam_width=BEAM_WIDTH,
                blank_id=0,
                topk_acoustic=5
            )
            all_predictions_lm.append(decoded_lm)
            
            # Get target
            target = y[i, :y_len[i]].cpu().numpy().tolist()
            all_targets.append(target)
        
        if (batch_idx + 1) % 5 == 0:
            print(f"  Processed {batch_idx + 1}/{len(test_loader)} batches")

print("✓ Evaluation complete!")

## Compute Results

In [None]:
def compute_error_rate(predictions, targets):
    total_edit = 0
    total_len = 0
    for pred, target in zip(predictions, targets):
        matcher = SequenceMatcher(a=target, b=pred)
        total_edit += matcher.distance()
        total_len += len(target)
    return total_edit / total_len if total_len > 0 else 0.0

cer_baseline = compute_error_rate(all_predictions_baseline, all_targets)
cer_lm = compute_error_rate(all_predictions_lm, all_targets)
improvement = (cer_baseline - cer_lm) / cer_baseline * 100

print("\n" + "=" * 70)
print("FINAL RESULTS")
print("=" * 70)
print(f"\nModel: {MODEL_TYPE}")
print(f"\nBaseline (Greedy Decoding):")
print(f"  PER/CER: {cer_baseline:.4f} ({cer_baseline*100:.2f}%)")
print(f"\nWith Language Model:")
print(f"  PER/CER: {cer_lm:.4f} ({cer_lm*100:.2f}%)")
print(f"  Improvement: {improvement:.2f}% relative")
print(f"  Absolute gain: {(cer_baseline - cer_lm)*100:.2f} percentage points")
print("\n" + "=" * 70)
print(f"Language Model improved accuracy by {improvement:.1f}%!")
print("=" * 70)