In [2]:
# importing the zipfile module 
from zipfile import ZipFile 
  
# loading the temp.zip and creating a zip object 
with ZipFile('train_sample.zip', 'r') as zObject: 
  
    # Extracting all the members of the zip  
    # into a specific location. 
    zObject.extractall( 
        path="./") 

In [10]:
!pip install transformers datasets jiwer torchaudio sentencepiece

Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests (from transformers)
  Downloading requests-2.32.4-py3-none-any.whl.metadata (4.9 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494

In [3]:
pip install soundfile

Defaulting to user installation because normal site-packages is not writeable
Collecting soundfile
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: soundfile
Successfully installed soundfile-0.13.1

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


# Full hybrid pipeline v1 (Whisper encoder + MoE + CTC) → evaluate → LLM few-shot refine (Mistral-7B)
- Utilities for Kaldi-style inputs (wav.scp, segments, text) and robust audio loading (torchaudio → soundfile fallback).
- Dataset + collate: packs Whisper processor features and label tensors.
- Model: WhisperMoEModel (Whisper encoder → small Mixture-of-Experts block → projection to vocab) trained with CTC.
- Training: Adam, prints per-epoch train/val loss; saves best.pth on improvement.
- Greedy CTC decode → WER/CER calculation.
- LLM refinement: loads mistralai/Mistral-7B-Instruct-v0.3 with a few-shot prompt; batches generation; cleans output; re-scores.

In [2]:
# %%
# Cell 1: Install dependencies
import os
# Ensure sentencepiece is available for tokenizer
#os.system("pip install --quiet sentencepiece")

# %%
# Cell 2: Imports & Setup
import torchaudio
import soundfile as sf  # fallback for wav loading
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    WhisperProcessor,
    WhisperModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
import numpy as np
import re

# Hugging Face token for gated LLM models
HF_TOKEN = "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"

# %%
# Cell 3: Utility functions to read Kaldi files
def read_wav_scp(path):
    wav_dict = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rec_id, wav_path = line.strip().split(maxsplit=1)
            wav_dict[rec_id] = wav_path
    return wav_dict


def read_segments(path):
    segments = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            utt_id, rec_id, start, end = parts[0], parts[1], parts[2], parts[3]
            segments[utt_id] = (rec_id, float(start), float(end))
    return segments


def read_text(path):
    texts = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            utt_id, transcript = line.strip().split(maxsplit=1)
            texts[utt_id] = transcript
    return texts

# %%
# Cell 4: KaldiDataset and collate_fn
class KaldiDataset(Dataset):
    def __init__(self, data_dir, processor, sample_rate=16000):
        self.data_dir = data_dir
        self.processor = processor
        self.sample_rate = sample_rate
        md = f"{data_dir}/transcripts"
        self.wav_dict = read_wav_scp(f"{md}/wav.scp")
        self.segments = read_segments(f"{md}/segments")
        self.texts = read_text(f"{md}/text")
        self.utt_ids = list(self.texts.keys())

    def __len__(self):
        return len(self.utt_ids)

    def __getitem__(self, idx):
        utt_id = self.utt_ids[idx]
        rec_id, start, end = self.segments[utt_id]
        rel_path = self.wav_dict[rec_id]
        # Try locating wav under data_dir, else metadata_dir
        wav_path1 = os.path.join(self.data_dir, rel_path)
        wav_path2 = os.path.join(os.path.dirname(wav_path1), rel_path)
        if os.path.isfile(wav_path1):
            wav_path = wav_path1
        elif os.path.isfile(wav_path2):
            wav_path = wav_path2
        else:
            raise FileNotFoundError(f"Audio file not found: {wav_path1} or {wav_path2}")
        try:
            waveform, sr = torchaudio.load(wav_path)
        except Exception:
            # Fallback to soundfile
            waveform_np, sr = sf.read(wav_path)
            # Convert numpy to torch tensor with shape [channels, time]
            if waveform_np.ndim > 1:
                waveform = torch.from_numpy(wavfile_np.T)
            else:
                waveform = torch.from_numpy(wavfile_np).unsqueeze(0)
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
        segment = waveform[0, int(start*self.sample_rate):int(end*self.sample_rate)]
        feat = self.processor.feature_extractor(
            segment.numpy(), sampling_rate=self.sample_rate, return_tensors="pt"
        ).input_features[0]
        labels = self.processor.tokenizer(
            self.texts[utt_id], return_tensors="pt", add_special_tokens=False
        ).input_ids[0]
        return {"utt_id": utt_id, "input_features": feat, "labels": labels}


def collate_fn(batch):
    feats = [b["input_features"] for b in batch]
    labs = [b["labels"] for b in batch]
    ids = [b["utt_id"] for b in batch]
    feats_p = nn.utils.rnn.pad_sequence(feats, batch_first=True)
    labs_p = nn.utils.rnn.pad_sequence(labs, batch_first=True, padding_value=-100)
    return {"utt_ids": ids, "input_features": feats_p, "labels": labs_p}

# %%
# Cell 5: Define WhisperMoEModel
class WhisperMoEModel(nn.Module):
    def __init__(self, whisper_encoder, d_model, num_classes):
        super().__init__()
        self.whisper_encoder = whisper_encoder
        self.expert_m = nn.Linear(d_model, d_model)
        self.expert_e = nn.Linear(d_model, d_model)
        self.gate = nn.Linear(2*d_model, 2)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, input_features):
        out = self.whisper_encoder(input_features=input_features)
        shared = out.last_hidden_state
        m = self.expert_m(shared)
        e = self.expert_e(shared)
        cat = torch.cat([m, e], dim=-1)
        g = self.gate(cat)
        w = torch.softmax(g, dim=-1)
        wm, we = w[..., 0:1], w[..., 1:2]
        mix = wm*m + we*e
        logits = self.classifier(mix)
        return logits, w

# %%
# Cell 6: CTC decode and metrics
def ctc_greedy_decode(logits, blank_id):
    ids = logits.argmax(dim=-1).cpu().tolist()
    prev, out = None, []
    for i in ids:
        if i != prev and i != blank_id:
            out.append(i)
        prev = i
    return out


def edit_distance(r, h):
    m, n = len(r), len(h)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1): dp[i][0] = i
    for j in range(n+1): dp[0][j] = j
    for i in range(1, m+1):
        for j in range(1, n+1):
            dp[i][j] = dp[i-1][j-1] if r[i-1]==h[j-1] else 1+min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    return dp[m][n]


def compute_wer(preds, refs):
    te, tw = 0, 0
    for r, h in zip(refs, preds):
        rw, hw = r.split(), h.split()
        te += edit_distance(rw, hw); tw += len(rw)
    return te/tw*100 if tw>0 else 0


def compute_cer(preds, refs):
    te, tc = 0, 0
    for r, h in zip(refs, preds):
        rc, hc = list(r.replace(" ","")), list(h.replace(" ",""))
        te += edit_distance(rc, hc); tc += len(rc)
    return te/tc*100 if tc>0 else 0

# %%
# Cell 7: LLM refine function with batching, filtering & progress
def refine_with_llm(raw_texts, tokenizer_llm, model_llm, gen_conf, few_shot=None, batch_size=16, min_length=6):
    """
    Only refine utterances longer than min_length tokens.
    Short utterances are left unchanged.
    """
    device = next(model_llm.parameters()).device
    refined = list(raw_texts)
    to_refine = [i for i, txt in enumerate(raw_texts) if len(txt.split())>=min_length]
    N = len(to_refine)
    for start in range(0, N, batch_size):
        batch_idxs = to_refine[start:start+batch_size]
        prompts = []
        for idx in batch_idxs:
            raw = raw_texts[idx]
            lines = [
                "You are an ASR post-processor. Correct recognition errors in a code-switched transcript.",
                "Output only the corrected transcript without any tags or symbols."
            ]
            if few_shot:
                for gold, example in few_shot:
                    lines.append(f"Example Raw: {example}")
                    lines.append(f"Example Corrected: {gold}")
            lines.append(f"Raw: {raw}")
            prompts.append("\n".join(lines))
        print(f"Refining utterances {start+1}-{start+len(batch_idxs)}/{N}")
        if tokenizer_llm.pad_token_id is None:
            tokenizer_llm.pad_token = tokenizer_llm.eos_token
        inputs = tokenizer_llm(prompts, return_tensors="pt", padding=True).to(device)
        outputs = model_llm.generate(
            **inputs,
            generation_config=gen_conf,
            pad_token_id=tokenizer_llm.eos_token_id
        )
        for i_out, output in enumerate(outputs):
            text = tokenizer_llm.decode(output, skip_special_tokens=True).strip()
            refined[batch_idxs[i_out]] = text
    return refined

# %%
# Cell 8: Train/validate/evaluate loops
def train_epoch(model, loader, optimizer, criterion, device, blank_id):
    model.train(); total=0
    for b in loader:
        feats=b["input_features"].to(device)
        labs=b["labels"].to(device)
        logits,_=model(feats)
        logp=F.log_softmax(logits,dim=-1).transpose(0,1)
        B=logp.size(1)
        in_l=torch.full((B,), logp.size(0), dtype=torch.long).to(device)
        tgt_l=(labs!=-100).sum(dim=1).to(device)
        labs_ctc=labs.masked_fill(labs==-100, blank_id)
        loss=criterion(logp, labs_ctc, in_l, tgt_l)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total+=loss.item()
    return total/len(loader)

def validate_epoch(model, loader, criterion, device, blank_id):
    model.eval(); total=0
    with torch.no_grad():
        for b in loader:
            feats=b["input_features"].to(device)
            labs=b["labels"].to(device)
            logits,_=model(feats)
            logp=F.log_softmax(logits,dim=-1).transpose(0,1)
            B=logp.size(1)
            in_l=torch.full((B,), logp.size(0), dtype=torch.long).to(device)
            tgt_l=(labs!=-100).sum(dim=1).to(device)
            labs_ctc=labs.masked_fill(labs==-100, blank_id)
            total+=criterion(logp, labs_ctc, in_l, tgt_l).item()
    return total/len(loader)

def evaluate(model, loader, processor, device, blank_id):
    model.eval(); preds, refs = [], []
    with torch.no_grad():
        for b in loader:
            feats=b["input_features"].to(device)
            labs=b["labels"]
            logits,_=model(feats)
            for i in range(logits.size(0)):
                ids=ctc_greedy_decode(logits[i], blank_id)
                preds.append(processor.tokenizer.decode(ids, skip_special_tokens=True))
            for l in labs:
                l=l.clone().masked_fill(l==-100, blank_id)
                refs.append(processor.tokenizer.decode(l.cpu().tolist(), skip_special_tokens=True))
    return preds, refs

# %%
# Cell 9: Main + LLM post-processing
def main():
    train_dir, test_dir = "train_split", "test_split"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    processor = WhisperProcessor.from_pretrained("openai/whisper-base")
    blank_id = processor.tokenizer.pad_token_id
    whisper = WhisperModel.from_pretrained("openai/whisper-base")
    d_model = whisper.config.d_model
    num_classes = processor.tokenizer.vocab_size

    tr_ds = KaldiDataset(train_dir, processor)
    te_ds = KaldiDataset(test_dir, processor)
    tr_ld = DataLoader(tr_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
    te_ld = DataLoader(te_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)

    model = WhisperMoEModel(whisper.encoder, d_model, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CTCLoss(blank=blank_id, zero_infinity=True)

    # Train Whisper+MoE model
    best_val = float('inf')
    for ep in range(1, 76):
        # Freeze encoder first 3 epochs
        for p in model.whisper_encoder.parameters():
            p.requires_grad = (ep > 3)
        tr_loss = train_epoch(model, tr_ld, optimizer, criterion, device, blank_id)
        val_loss = validate_epoch(model, te_ld, criterion, device, blank_id)
        print(f"Epoch {ep}/75  Train Loss: {tr_loss:.3f}  Val Loss: {val_loss:.3f}")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best.pth")
            print("Saved new best model")

    # Load the best trained model
    model.load_state_dict(torch.load("best.pth"))

    # Evaluate base
    preds, refs = evaluate(model, te_ld, processor, device, blank_id)
    print("Base WER,CER:", compute_wer(preds, refs), compute_cer(preds, refs))

    # LLM few-shot refinement only
    llm_nm = "mistralai/Mistral-7B-Instruct-v0.3"
    tok_llm = AutoTokenizer.from_pretrained(llm_nm, trust_remote_code=True, token=HF_TOKEN)
    mlm = AutoModelForCausalLM.from_pretrained(
        llm_nm, device_map="auto", torch_dtype=torch.float16,
        trust_remote_code=True, token=HF_TOKEN
    )
    gen_cfg = GenerationConfig(max_new_tokens=512, do_sample=True, temperature=0.7)
    examples = [("I EN | am EN | going EN | घर HI | आज HI", "I am going घर आज")]

    print("Starting few-shot LLM refinement...")
    fs = refine_with_llm(preds, tok_llm, mlm, gen_cfg, few_shot=examples)
    print("Few-shot done.")

    # Strip tags & compute final metrics
    def strip_llm_tags(text):
        cleaned = re.sub(r"(\bEN\b|\bHI\b|\||,)", "", text)
        return " ".join(cleaned.split())
    clean_fs = [strip_llm_tags(t) for t in fs]
    print("Clean Few-shot WER,CER:", compute_wer(clean_fs, refs), compute_cer(clean_fs, refs))

    # Show sample outputs
    for i in range(min(5, len(preds))):
        print(f"\nSample {i+1}:")
        print("ASR      :", preds[i])
        print("Few-shot :", clean_fs[i])
        print("Reference:", refs[i])

if __name__ == "__main__":
    main()


Epoch 1/75  Train Loss: 299.723  Val Loss: 221.473
Saved new best model
Epoch 2/75  Train Loss: 112.520  Val Loss: 77.578
Saved new best model
Epoch 3/75  Train Loss: 54.269  Val Loss: 57.761
Saved new best model
Epoch 4/75  Train Loss: 13.481  Val Loss: 5.073
Saved new best model
Epoch 5/75  Train Loss: 4.806  Val Loss: 4.987
Saved new best model
Epoch 6/75  Train Loss: 4.559  Val Loss: 4.612
Saved new best model
Epoch 7/75  Train Loss: 4.456  Val Loss: 4.524
Saved new best model
Epoch 8/75  Train Loss: 4.363  Val Loss: 4.417
Saved new best model
Epoch 9/75  Train Loss: 4.287  Val Loss: 4.419
Epoch 10/75  Train Loss: 4.214  Val Loss: 4.263
Saved new best model
Epoch 11/75  Train Loss: 4.175  Val Loss: 4.304
Epoch 12/75  Train Loss: 4.119  Val Loss: 4.093
Saved new best model
Epoch 13/75  Train Loss: 4.026  Val Loss: 4.007
Saved new best model
Epoch 14/75  Train Loss: 3.935  Val Loss: 3.912
Saved new best model
Epoch 15/75  Train Loss: 3.866  Val Loss: 3.812
Saved new best model
Epoch 

  model.load_state_dict(torch.load("best.pth"))


Base WER,CER: 15.096668037844507 6.0834029906194855


tokenizer_config.json:   0%|          | 0.00/141k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Starting few-shot LLM refinement...
Refining utterances 1-16/215
Refining utterances 17-32/215
Refining utterances 33-48/215
Refining utterances 49-64/215
Refining utterances 65-80/215
Refining utterances 81-96/215
Refining utterances 97-112/215
Refining utterances 113-128/215
Refining utterances 129-144/215
Refining utterances 145-160/215
Refining utterances 161-176/215
Refining utterances 177-192/215
Refining utterances 193-208/215
Refining utterances 209-215/215
Few-shot done.
Clean Few-shot WER,CER: 1390.826820238585 1328.0765301383858

Sample 1:
ASR      : दोस्तों bashें nested और multvel if statementे spoken tutorial में आपक स्वागत है
Few-shot : You are an ASR post-processor. Correct recognition errors in a code-switched transcript. Output only the corrected transcript without any tags or symbols. Example Raw: I am going घर आज Example Corrected: I am going घर आज Raw: दोस्तों bashें nested और multvel if statementे spoken tutorial में आपक स्वागत है। Corrected: FRIENDS bash nested A

# Full hybrid pipeline v2 (tighter, 50 epochs) → start LLM few-shot

In [5]:
# %%
# Cell 1: Install dependencies
import os
# Ensure sentencepiece is available for tokenizer
os.system("pip install --quiet sentencepiece")

# %%
# Cell 2: Imports & Setup
import torchaudio
import soundfile as sf  # fallback for wav loading
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    WhisperProcessor,
    WhisperModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
import numpy as np
import re

# Hugging Face token for gated LLM models
HF_TOKEN = "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"

# %%
# Cell 3: Utility functions to read Kaldi files
def read_wav_scp(path):
    wav_dict = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rec_id, wav_path = line.strip().split(maxsplit=1)
            wav_dict[rec_id] = wav_path
    return wav_dict


def read_segments(path):
    segments = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            utt_id, rec_id, start, end = parts[0], parts[1], parts[2], parts[3]
            segments[utt_id] = (rec_id, float(start), float(end))
    return segments


def read_text(path):
    texts = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            utt_id, transcript = line.strip().split(maxsplit=1)
            texts[utt_id] = transcript
    return texts

# %%
# Cell 4: KaldiDataset and collate_fn
class KaldiDataset(Dataset):
    def __init__(self, data_dir, processor, sample_rate=16000):
        self.data_dir = data_dir
        self.processor = processor
        self.sample_rate = sample_rate
        md = f"{data_dir}/transcripts"
        self.wav_dict = read_wav_scp(f"{md}/wav.scp")
        self.segments = read_segments(f"{md}/segments")
        self.texts = read_text(f"{md}/text")
        self.utt_ids = list(self.texts.keys())

    def __len__(self):
        return len(self.utt_ids)

    def __getitem__(self, idx):
        utt_id = self.utt_ids[idx]
        rec_id, start, end = self.segments[utt_id]
        rel = self.wav_dict[rec_id]

        # Try primary and fallback paths
        p1 = os.path.join(self.data_dir, rel)
        p2 = os.path.join(self.data_dir, "transcripts", rel)
        if os.path.isfile(p1):
            wav_path = p1
        elif os.path.isfile(p2):
            wav_path = p2
        else:
            raise FileNotFoundError(f"Missing audio: {p1} or {p2}")

        # Load with torchaudio or fallback to soundfile
        try:
            waveform, sr = torchaudio.load(wav_path)
        except:
            arr, sr = sf.read(wav_path)
            arr = np.asarray(arr, dtype=np.float32)
            if arr.ndim == 1:
                waveform = torch.from_numpy(arr).unsqueeze(0)
            else:
                waveform = torch.from_numpy(arr.T)

        # Resample if needed
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)

        # Slice segment, extract features, tokenize
        segment = waveform[0, int(start*self.sample_rate):int(end*self.sample_rate)]
        feat = self.processor.feature_extractor(
            segment.numpy(),
            sampling_rate=self.sample_rate,
            return_tensors="pt"
        ).input_features[0]
        labels = self.processor.tokenizer(
            self.texts[utt_id],
            return_tensors="pt",
            add_special_tokens=False
        ).input_ids[0]

        return {
            "utt_id": utt_id,
            "input_features": feat,
            "labels": labels
        }

def collate_fn(batch):
    feats = [b["input_features"] for b in batch]
    labs = [b["labels"] for b in batch]
    ids = [b["utt_id"] for b in batch]
    feats_p = nn.utils.rnn.pad_sequence(feats, batch_first=True)
    labs_p = nn.utils.rnn.pad_sequence(labs, batch_first=True, padding_value=-100)
    return {"utt_ids": ids, "input_features": feats_p, "labels": labs_p}

# %%
# Cell 5: Define WhisperMoEModel
class WhisperMoEModel(nn.Module):
    def __init__(self, whisper_encoder, d_model, num_classes):
        super().__init__()
        self.whisper_encoder = whisper_encoder
        self.expert_m = nn.Linear(d_model, d_model)
        self.expert_e = nn.Linear(d_model, d_model)
        self.gate = nn.Linear(2*d_model, 2)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, input_features):
        out = self.whisper_encoder(input_features=input_features)
        shared = out.last_hidden_state
        m = self.expert_m(shared)
        e = self.expert_e(shared)
        cat = torch.cat([m, e], dim=-1)
        g = self.gate(cat)
        w = torch.softmax(g, dim=-1)
        wm, we = w[..., 0:1], w[..., 1:2]
        mix = wm*m + we*e
        logits = self.classifier(mix)
        return logits, w

# %%
# Cell 6: CTC decode and metrics
def ctc_greedy_decode(logits, blank_id):
    ids = logits.argmax(dim=-1).cpu().tolist()
    prev, out = None, []
    for i in ids:
        if i != prev and i != blank_id:
            out.append(i)
        prev = i
    return out


def edit_distance(r, h):
    m, n = len(r), len(h)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1): dp[i][0] = i
    for j in range(n+1): dp[0][j] = j
    for i in range(1, m+1):
        for j in range(1, n+1):
            dp[i][j] = dp[i-1][j-1] if r[i-1]==h[j-1] else 1+min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
    return dp[m][n]


def compute_wer(preds, refs):
    te, tw = 0, 0
    for r, h in zip(refs, preds):
        rw, hw = r.split(), h.split()
        te += edit_distance(rw, hw); tw += len(rw)
    return te/tw*100 if tw>0 else 0


def compute_cer(preds, refs):
    te, tc = 0, 0
    for r, h in zip(refs, preds):
        rc, hc = list(r.replace(" ","")), list(h.replace(" ",""))
        te += edit_distance(rc, hc); tc += len(rc)
    return te/tc*100 if tc>0 else 0

# %%
# Cell 7: LLM refine function with batching, filtering & progress
def refine_with_llm(raw_texts, tokenizer_llm, model_llm, gen_conf,
                    few_shot=None, batch_size=16, min_length=6):
    device = next(model_llm.parameters()).device
    refined = list(raw_texts)
    to_refine = [i for i, t in enumerate(raw_texts) if len(t.split()) >= min_length]

    for start in range(0, len(to_refine), batch_size):
        batch_idxs = to_refine[start:start + batch_size]
        prompts = []
        for idx in batch_idxs:
            lines = [
                "You are an ASR post-processor. Correct recognition errors and output just the corrected sentence.",
            ]
            if few_shot:
                for gold, example in few_shot:
                    lines.append(f"Example Raw: {example}")
                    lines.append(f"Example Corrected: {gold}")
            lines.append(f"Raw: {raw_texts[idx]}")
            lines.append("Corrected:")
            prompts.append("\n".join(lines))

        print(f"Refining utterances {start+1}-{start+len(batch_idxs)}")
        if tokenizer_llm.pad_token_id is None:
            tokenizer_llm.pad_token = tokenizer_llm.eos_token

        inputs = tokenizer_llm(prompts, return_tensors="pt", padding=True).to(device)
        outputs = model_llm.generate(
            **inputs,
            generation_config=gen_conf,
            pad_token_id=tokenizer_llm.eos_token_id,
        )
        for i, out in enumerate(outputs):
            text = tokenizer_llm.decode(out, skip_special_tokens=True).strip()
            refined[batch_idxs[i]] = text

    return refined

# %%
# Cell 8: Train/validate/evaluate loops
def train_epoch(model, loader, optimizer, criterion, device, blank_id):
    model.train(); total=0
    for b in loader:
        feats=b["input_features"].to(device)
        labs=b["labels"].to(device)
        logits,_=model(feats)
        logp=F.log_softmax(logits,dim=-1).transpose(0,1)
        B=logp.size(1)
        in_l=torch.full((B,), logp.size(0), dtype=torch.long).to(device)
        tgt_l=(labs!=-100).sum(dim=1).to(device)
        labs_ctc=labs.masked_fill(labs==-100, blank_id)
        loss=criterion(logp, labs_ctc, in_l, tgt_l)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total+=loss.item()
    return total/len(loader)

def validate_epoch(model, loader, criterion, device, blank_id):
    model.eval(); total=0
    with torch.no_grad():
        for b in loader:
            feats=b["input_features"].to(device)
            labs=b["labels"].to(device)
            logits,_=model(feats)
            logp=F.log_softmax(logits,dim=-1).transpose(0,1)
            B=logp.size(1)
            in_l=torch.full((B,), logp.size(0), dtype=torch.long).to(device)
            tgt_l=(labs!=-100).sum(dim=1).to(device)
            labs_ctc=labs.masked_fill(labs==-100, blank_id)
            total+=criterion(logp, labs_ctc, in_l, tgt_l).item()
    return total/len(loader)

def evaluate(model, loader, processor, device, blank_id):
    model.eval(); preds, refs = [], []
    with torch.no_grad():
        for b in loader:
            feats=b["input_features"].to(device)
            labs=b["labels"]
            logits,_=model(feats)
            for i in range(logits.size(0)):
                ids=ctc_greedy_decode(logits[i], blank_id)
                preds.append(processor.tokenizer.decode(ids, skip_special_tokens=True))
            for l in labs:
                l=l.clone().masked_fill(l==-100, blank_id)
                refs.append(processor.tokenizer.decode(l.cpu().tolist(), skip_special_tokens=True))
    return preds, refs

# %%
# Cell 9: Main + LLM post-processing
def main():
    train_dir, test_dir = "train_split", "test_split"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    processor = WhisperProcessor.from_pretrained("openai/whisper-base")
    blank_id = processor.tokenizer.pad_token_id
    whisper = WhisperModel.from_pretrained("openai/whisper-base")
    d_model = whisper.config.d_model
    num_classes = processor.tokenizer.vocab_size

    tr_ds = KaldiDataset(train_dir, processor)
    te_ds = KaldiDataset(test_dir, processor)
    tr_ld = DataLoader(tr_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
    te_ld = DataLoader(te_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)

    model = WhisperMoEModel(whisper.encoder, d_model, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CTCLoss(blank=blank_id, zero_infinity=True)

    # Train Whisper+MoE model
    best_val = float('inf')
    for ep in range(1, 51):
        # Freeze encoder first 3 epochs
        for p in model.whisper_encoder.parameters():
            p.requires_grad = (ep > 3)
        tr_loss = train_epoch(model, tr_ld, optimizer, criterion, device, blank_id)
        val_loss = validate_epoch(model, te_ld, criterion, device, blank_id)
        print(f"Epoch {ep}/50  Train Loss: {tr_loss:.3f}  Val Loss: {val_loss:.3f}")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best.pth")
            print("Saved new best model")

    # Load the best trained model
    model.load_state_dict(torch.load("best.pth"))

    # Evaluate base
    preds, refs = evaluate(model, te_ld, processor, device, blank_id)
    print("Base WER,CER:", compute_wer(preds, refs), compute_cer(preds, refs))

    # LLM few-shot refinement only
    llm_nm = "mistralai/Mistral-7B-Instruct-v0.3"
    tok_llm = AutoTokenizer.from_pretrained(llm_nm, trust_remote_code=True, token=HF_TOKEN)
    mlm = AutoModelForCausalLM.from_pretrained(
        llm_nm, device_map="auto", torch_dtype=torch.float16,
        trust_remote_code=True, token=HF_TOKEN
    )
    gen_cfg = GenerationConfig(max_new_tokens=512, do_sample=True, temperature=0.7)
    examples = [("I EN | am EN | going EN | घर HI | आज HI", "I am going घर आज")]

    print("Starting few-shot LLM refinement...")
    fs = refine_with_llm(preds, tok_llm, mlm, gen_cfg, few_shot=examples)
    print("Few-shot done.")

    # Strip tags & compute final metrics
    def strip_llm_tags(text):
        cleaned = re.sub(r"(\bEN\b|\bHI\b|\||,)", "", text)
        return " ".join(cleaned.split())
    clean_fs = [strip_llm_tags(t) for t in fs]
    print("Clean Few-shot WER,CER:", compute_wer(clean_fs, refs), compute_cer(clean_fs, refs))

    # Show sample outputs
    for i in range(min(5, len(preds))):
        print(f"\nSample {i+1}:")
        print("ASR      :", preds[i])
        print("Few-shot :", clean_fs[i])
        print("Reference:", refs[i])

if __name__ == "__main__":
    main()



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


Epoch 1/50  Train Loss: 292.826  Val Loss: 131.995
Saved new best model
Epoch 2/50  Train Loss: 78.031  Val Loss: 58.970
Saved new best model
Epoch 3/50  Train Loss: 37.485  Val Loss: 33.101
Saved new best model
Epoch 4/50  Train Loss: 10.436  Val Loss: 5.463
Saved new best model
Epoch 5/50  Train Loss: 4.972  Val Loss: 4.843
Saved new best model
Epoch 6/50  Train Loss: 4.621  Val Loss: 4.649
Saved new best model
Epoch 7/50  Train Loss: 4.454  Val Loss: 4.627
Saved new best model
Epoch 8/50  Train Loss: 4.346  Val Loss: 4.389
Saved new best model
Epoch 9/50  Train Loss: 4.232  Val Loss: 4.230
Saved new best model
Epoch 10/50  Train Loss: 4.104  Val Loss: 4.093
Saved new best model
Epoch 11/50  Train Loss: 3.962  Val Loss: 3.868
Saved new best model
Epoch 12/50  Train Loss: 3.645  Val Loss: 3.320
Saved new best model
Epoch 13/50  Train Loss: 3.039  Val Loss: 2.674
Saved new best model
Epoch 14/50  Train Loss: 2.473  Val Loss: 2.282
Saved new best model
Epoch 15/50  Train Loss: 2.009  Va

  model.load_state_dict(torch.load("best.pth"))


Base WER,CER: 0.20567667626491154 0.07430110522894028


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


Starting few-shot LLM refinement...
Refining utterances 1-16
Refining utterances 17-32
Refining utterances 33-48
Refining utterances 49-64
Refining utterances 65-80
Refining utterances 81-96
Refining utterances 97-112
Refining utterances 113-128
Refining utterances 129-144
Refining utterances 145-160
Refining utterances 161-176
Refining utterances 177-192
Refining utterances 193-208
Refining utterances 209-214
Few-shot done.
Clean Few-shot WER,CER: 1295.7630604689427 1282.9571839881119

Sample 1:
ASR      : दोस्तों bash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है
Few-shot : You are an ASR post-processor. Correct recognition errors and output just the corrected sentence. Example Raw: I am going घर आज Example Corrected: I am going घर आज Raw: दोस्तों bash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है Corrected: FRIENDS in BASH nested AND multilevel IF STATEMENT spoken TUTORIAL is your welcome Raw: सब कुछ कैसे हो रहा है Correcte

# LLM refine helper (variant)

- Defines a stricter refine_with_llm batched generator (min length filter, prompt that asks for compact outputs).

In [7]:
# Cell 7: LLM refine function with batching, filtering & progress
def refine_with_llm(raw_texts, tokenizer_llm, model_llm, gen_conf, few_shot=None, batch_size=16, min_length=6):
    """
    Only refine utterances longer than min_length tokens.
    Short utterances are left unchanged.
    Uses a tight prompt to only output tags and language labels.
    """
    device = next(model_llm.parameters()).device
    refined = list(raw_texts)
    to_refine = [i for i, t in enumerate(raw_texts) if len(t.split()) >= min_length]
    N = len(to_refine)
    for start in range(0, N, batch_size):
        batch_idxs = to_refine[start:start+batch_size]
        prompts = []
        for idx in batch_idxs:
            raw = raw_texts[idx]
            # Strong prompt: only tagging, no translation or paraphrase
            lines = [
                "You are an ASR post-processor for Hindi–English code-switched speech.",
                "Given the raw ASR output, preserve every original word and insert '|' at language boundaries.",
                "After each word, append 'EN' or 'HI' to label its language.",
                "Do NOT translate, paraphrase, or add any extra words.",
                f"Raw: {raw}",
                "Corrected:"
            ]
            prompts.append("\n".join(lines))
        print(f"Refining utterances {start+1}-{start+len(batch_idxs)} of {N}")
        # ensure pad token
        if tokenizer_llm.pad_token_id is None:
            tokenizer_llm.pad_token = tokenizer_llm.eos_token
        inputs = tokenizer_llm(prompts, return_tensors="pt", padding=True).to(device)
        outputs = model_llm.generate(
            **inputs,
            generation_config=gen_conf,
            pad_token_id=tokenizer_llm.eos_token_id,
            max_new_tokens=128,
            eos_token_id=tokenizer_llm.eos_token_id
        )
        for i_out, out in enumerate(outputs):
            # slice off prompt tokens
            new_tokens = out[inputs.input_ids.shape[-1]:]
            text = tokenizer_llm.decode(new_tokens, skip_special_tokens=True).strip()
            refined[batch_idxs[i_out]] = text
    return refined


# Evaluate few-shot run from Cell 4 + strip tags

What it does:

- Reloads best.pth, recomputes Base WER/CER for reference.
- Runs the few-shot Mistral refiner (batched) and applies a post-cleaner (strip_llm_tags) before scoring.
- Prints a few ASR vs Few-shot vs Reference samples.

In [8]:
# Cell 9: Main + LLM post-processing
def main():
    train_dir, test_dir = "train_split", "test_split"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    processor = WhisperProcessor.from_pretrained("openai/whisper-base")
    blank_id = processor.tokenizer.pad_token_id
    whisper = WhisperModel.from_pretrained("openai/whisper-base")
    d_model = whisper.config.d_model
    num_classes = processor.tokenizer.vocab_size

    tr_ds = KaldiDataset(train_dir, processor)
    te_ds = KaldiDataset(test_dir, processor)
    tr_ld = DataLoader(tr_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
    te_ld = DataLoader(te_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)

    model = WhisperMoEModel(whisper.encoder, d_model, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CTCLoss(blank=blank_id, zero_infinity=True)

    # Train Whisper+MoE model
    best_val = float('inf')
    '''for ep in range(1, 51):
        # Freeze encoder first 3 epochs
        for p in model.whisper_encoder.parameters():
            p.requires_grad = (ep > 3)
        tr_loss = train_epoch(model, tr_ld, optimizer, criterion, device, blank_id)
        val_loss = validate_epoch(model, te_ld, criterion, device, blank_id)
        print(f"Epoch {ep}/50  Train Loss: {tr_loss:.3f}  Val Loss: {val_loss:.3f}")
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best.pth")
            print("Saved new best model")
    '''
    # Load the best trained model
    model.load_state_dict(torch.load("best.pth"))

    # Evaluate base
    preds, refs = evaluate(model, te_ld, processor, device, blank_id)
    print("Base WER,CER:", compute_wer(preds, refs), compute_cer(preds, refs))

    # LLM few-shot refinement only
    llm_nm = "mistralai/Mistral-7B-Instruct-v0.3"
    tok_llm = AutoTokenizer.from_pretrained(llm_nm, trust_remote_code=True, token=HF_TOKEN)
    mlm = AutoModelForCausalLM.from_pretrained(
        llm_nm, device_map="auto", torch_dtype=torch.float16,
        trust_remote_code=True, token=HF_TOKEN
    )
    gen_cfg = GenerationConfig(max_new_tokens=512, do_sample=True, temperature=0.7)
    examples = [("I EN | am EN | going EN | घर HI | आज HI", "I am going घर आज")]

    print("Starting few-shot LLM refinement...")
    fs = refine_with_llm(preds, tok_llm, mlm, gen_cfg, few_shot=examples)
    print("Few-shot done.")

    # Strip tags & compute final metrics
    def strip_llm_tags(text):
        cleaned = re.sub(r"(\bEN\b|\bHI\b|\||,)", "", text)
        return " ".join(cleaned.split())
    clean_fs = [strip_llm_tags(t) for t in fs]
    print("Clean Few-shot WER,CER:", compute_wer(clean_fs, refs), compute_cer(clean_fs, refs))

    # Show sample outputs
    for i in range(min(5, len(preds))):
        print(f"\nSample {i+1}:")
        print("ASR      :", preds[i])
        print("Few-shot :", clean_fs[i])
        print("Reference:", refs[i])

if __name__ == "__main__":
    main()


  model.load_state_dict(torch.load("best.pth"))


Base WER,CER: 0.20567667626491154 0.07430110522894028


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Starting few-shot LLM refinement...
Refining utterances 1-16 of 214
Refining utterances 17-32 of 214
Refining utterances 33-48 of 214
Refining utterances 49-64 of 214
Refining utterances 65-80 of 214
Refining utterances 81-96 of 214
Refining utterances 97-112 of 214
Refining utterances 113-128 of 214
Refining utterances 129-144 of 214
Refining utterances 145-160 of 214
Refining utterances 161-176 of 214
Refining utterances 177-192 of 214
Refining utterances 193-208 of 214
Refining utterances 209-214 of 214
Few-shot done.
Clean Few-shot WER,CER: 188.52324146441794 186.43076065756478

Sample 1:
ASR      : दोस्तों bash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है
Few-shot : दोस्तों बash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है In this example we have a Hindi–English code-switched speech and the task is to post-process the raw ASR output to preserve every original word and insert '' at language boundaries. After each word
Re

# Compact hybrid pipeline v3 (same idea, tighter code) → evaluate → LLM refine

In [15]:
# %%
# Cell 1: Install dependencies
import os
# Ensure required packages
os.system("pip install --quiet sentencepiece soundfile transformers torchaudio")

# %%
# Cell 2: Imports & Setup
import soundfile as sf
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import (
    WhisperProcessor,
    WhisperModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
import numpy as np
import re

# HF token
HF_TOKEN = "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"

# %%
# Cell 3: Kaldi utilities
def read_wav_scp(path):
    wav_dict = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            rec, p = line.strip().split(maxsplit=1)
            wav_dict[rec] = p
    return wav_dict

def read_segments(path):
    seg = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            u, r, s, e = line.strip().split()[:4]
            seg[u] = (r, float(s), float(e))
    return seg

def read_text(path):
    txt = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            u, t = line.strip().split(maxsplit=1)
            txt[u] = t
    return txt

# %%
# Cell 4: Dataset
class KaldiDataset(Dataset):
    def __init__(self, data_dir, processor, sr=16000):
        self.dir = data_dir
        self.proc = processor
        self.sr = sr
        m = os.path.join(data_dir, "transcripts")
        self.wav = read_wav_scp(os.path.join(m, "wav.scp"))
        self.seg = read_segments(os.path.join(m, "segments"))
        self.txt = read_text(os.path.join(m, "text"))
        self.ids = list(self.txt.keys())

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i):
        u = self.ids[i]
        r, st, en = self.seg[u]
        # resolve audio path
        p1 = os.path.join(self.dir, self.wav[r])
        p2 = os.path.join(self.dir, "transcripts", self.wav[r])
        if os.path.isfile(p1):
            path = p1
        elif os.path.isfile(p2):
            path = p2
        else:
            raise FileNotFoundError(f"Missing audio file: {p1} or {p2}")
        # load waveform
        try:
            waveform, sr = torchaudio.load(path)
        except Exception:
            arr, sr = sf.read(path)
            arr = np.asarray(arr, dtype=np.float32)
            if arr.ndim > 1:
                waveform = torch.from_numpy(arr.T)
            else:
                waveform = torch.from_numpy(arr).unsqueeze(0)
        # resample
        if sr != self.sr:
            waveform = torchaudio.transforms.Resample(sr, self.sr)(waveform)
        # slice segment
        segment = waveform[0, int(st*self.sr):int(en*self.sr)]
        # feature extraction
        feat = self.proc.feature_extractor(
            segment.numpy(), sampling_rate=self.sr, return_tensors='pt'
        ).input_features[0]
        # labels
        lbl = self.proc.tokenizer(
            self.txt[u], return_tensors='pt', add_special_tokens=False
        ).input_ids[0]
        return {"id": u, "feat": feat, "lbl": lbl}

# Cell 5: Model definition
class WhisperMoE(nn.Module):
    def __init__(self, encoder, d_model, vocab_size):
        super().__init__()
        self.enc = encoder
        self.exp_m = nn.Linear(d_model, d_model)
        self.exp_e = nn.Linear(d_model, d_model)
        self.gate = nn.Linear(2*d_model,2)
        self.cls = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        h = self.enc(input_features=x).last_hidden_state
        m = self.exp_m(h);
        e = self.exp_e(h)
        g = torch.softmax(self.gate(torch.cat([m,e],-1)), -1)
        mix = g[...,0:1]*m + g[...,1:2]*e
        return self.cls(mix), g

# %%
# Cell 6: CTC utilities
def decode_ctc(logits, blank):
    ids = logits.argmax(-1).tolist()
    out=[]; prev=None
    for i in ids:
        if i!=prev and i!=blank: out.append(i)
        prev=i
    return out

def edit_dist(a,b):
    m,n=len(a),len(b)
    dp=[[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1): dp[i][0]=i
    for j in range(n+1): dp[0][j]=j
    for i in range(1,m+1):
        for j in range(1,n+1):
            dp[i][j]=dp[i-1][j-1] if a[i-1]==b[j-1] else 1+min(dp[i-1][j],dp[i][j-1],dp[i-1][j-1])
    return dp[m][n]

def wer(p,r):
    tw=0;te=0
    for pr,rr in zip(p,r):
        pw, rw=pr.split(), rr.split()
        te+=edit_dist(rw,pw); tw+=len(rw)
    return te/tw*100

def cer(p,r):
    tc=0;te=0
    for pr,rr in zip(p,r):
        pc=list(pr.replace(' ','')); rc=list(rr.replace(' ',''))
        te+=edit_dist(rc,pc); tc+=len(rc)
    return te/tc*100

# %%
# Cell 7: Single-utterance LLM refine
def refine_one(raw, tokenizer_llm, model_llm, gen_conf, device):
    prompt = f"Raw: {raw}\nTag each word with EN or HI, separated by '|'. Do not translate.\nCorrected:"
    inp = tokenizer_llm(prompt, return_tensors='pt').to(device)
    out = model_llm.generate(
        **inp,
        max_new_tokens=64,
        eos_token_id=tokenizer_llm.eos_token_id,
        pad_token_id=tokenizer_llm.eos_token_id
    )
    new = out[0][inp.input_ids.size(1):]
    text = tokenizer_llm.decode(new, skip_special_tokens=True).split('\n')[0].strip()
    return text

# %%
# Cell 8: Training loops
def train_epoch(m,dl,opt,crit,dev,blank):
    m.train();loss=0
    for b in dl:
        x=b['feat'].to(dev); y=b['lbl'].to(dev)
        logit,_=m(x)
        lp=F.log_softmax(logit,-1).transpose(0,1)
        L=lp.size(0)
        inl=torch.full((lp.size(1),),L,dtype=torch.long,device=dev)
        tl=(y!=-100).sum(1).to(dev)
        y2=y.masked_fill(y==-100,blank)
        l=crit(lp,y2,inl,tl)
        opt.zero_grad();l.backward();opt.step();loss+=l.item()
    return loss/len(dl)

def val_epoch(m,dl,crit,dev,blank):
    m.eval();loss=0
    with torch.no_grad():
        for b in dl:
            x=b['feat'].to(dev); y=b['lbl'].to(dev)
            logit,_=m(x)
            lp=F.log_softmax(logit,-1).transpose(0,1)
            L=lp.size(0)
            inl=torch.full((lp.size(1),),L,dtype=torch.long,device=dev)
            tl=(y!=-100).sum(1).to(dev)
            y2=y.masked_fill(y==-100,blank)
            loss+=crit(lp,y2,inl,tl).item()
    return loss/len(dl)

def evaluate(m,dl,proc,dev,blank):
    m.eval();preds,refs=[],[]
    with torch.no_grad():
        for b in dl:
            x=b['feat'].to(dev); y=b['lbl']
            logit,_=m(x)
            for i in range(logit.size(0)):
                ids=decode_ctc(logit[i],blank)
                preds.append(proc.tokenizer.decode(ids,skip_special_tokens=True))
            for yy in y:
                t=yy.clone().masked_fill(yy==-100,blank).tolist()
                refs.append(proc.tokenizer.decode(t,skip_special_tokens=True))
    return preds,refs

# %%
# Cell 9: Main
def main():
    train_dir, test_dir="train_split","test_split"
    dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    proc=WhisperProcessor.from_pretrained('openai/whisper-base')
    blank=proc.tokenizer.pad_token_id
    w=WhisperModel.from_pretrained('openai/whisper-base')
    d=w.config.d_model; vocab=proc.tokenizer.vocab_size
    ds_tr=KaldiDataset(train_dir,proc);ds_te=KaldiDataset(test_dir,proc)
    dl_tr=DataLoader(ds_tr,batch_size=8,shuffle=True,collate_fn=collate_fn)
    dl_te=DataLoader(ds_te,batch_size=8,shuffle=False,collate_fn=collate_fn)

    model=WhisperMoE(w.encoder,d,vocab).to(dev)
    opt=optim.Adam(model.parameters(),1e-4)
    crit=nn.CTCLoss(blank=blank,zero_infinity=True)

    # Train
    best=1e9
    for ep in range(1,51):
        for p in model.enc.parameters(): p.requires_grad=(ep>3)
        tr=train_epoch(model,dl_tr,opt,crit,dev,blank)
        vl=val_epoch(model,dl_te,crit,dev,blank)
        print(f"Epoch {ep} TL:{tr:.3f} VL:{vl:.3f}")
        if vl<best: best=vl;torch.save(model.state_dict(),'best.pth')

    # Load & base eval
    model.load_state_dict(torch.load('best.pth'))
    preds,refs=evaluate(model,dl_te,proc,dev,blank)
    print('Base WER,CER:',wer(preds,refs),cer(preds,refs))

    # LLM refine one by one
    llm_nm='mistralai/Mistral-7B-Instruct-v0.3'
    tok=AutoTokenizer.from_pretrained(llm_nm,token=HF_TOKEN,trust_remote_code=True)
    mlm=AutoModelForCausalLM.from_pretrained(llm_nm,device_map='auto',torch_dtype=torch.float16,token=HF_TOKEN,trust_remote_code=True)
    gen=GenerationConfig(max_new_tokens=64,do_sample=False)

    print('Starting LLM refinement...')
    refined=[]
    for utt in preds:
        if len(utt.split())<6: refined.append(utt)
        else: refined.append(refine_one(utt,tok,mlm,gen,dev))

    # Clean tags & eval on actual text
    def clean(t): return ' '.join(w for w in t.split() if w not in {'EN','HI','|'})
    clean_ref=[clean(t) for t in refined]
    print('LLM WER,CER:',wer(clean_ref,refs),cer(clean_ref,refs))

    # Samples
    for i in range(5): print(f"Sample {i+1}: ASR={preds[i]}\nRefined={clean_ref[i]}\nRef={refs[i]}\n")

if __name__=='__main__': main()



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


Epoch 1 TL:340.697 VL:185.992
Epoch 2 TL:83.089 VL:42.096
Epoch 3 TL:28.839 VL:26.404
Epoch 4 TL:6.243 VL:5.460
Epoch 5 TL:4.823 VL:4.831
Epoch 6 TL:4.604 VL:4.734
Epoch 7 TL:4.495 VL:4.521
Epoch 8 TL:4.404 VL:4.469
Epoch 9 TL:4.321 VL:4.385
Epoch 10 TL:4.236 VL:4.374
Epoch 11 TL:4.221 VL:4.389
Epoch 12 TL:4.229 VL:4.179
Epoch 13 TL:4.092 VL:4.139
Epoch 14 TL:4.029 VL:3.957
Epoch 15 TL:3.911 VL:3.929
Epoch 16 TL:3.845 VL:4.091
Epoch 17 TL:3.866 VL:3.680
Epoch 18 TL:3.711 VL:3.610
Epoch 19 TL:3.682 VL:3.589
Epoch 20 TL:3.613 VL:3.399
Epoch 21 TL:3.555 VL:3.283
Epoch 22 TL:3.400 VL:3.241
Epoch 23 TL:3.349 VL:3.122
Epoch 24 TL:3.297 VL:3.119
Epoch 25 TL:3.174 VL:2.971
Epoch 26 TL:3.093 VL:2.901
Epoch 27 TL:3.021 VL:2.723
Epoch 28 TL:2.919 VL:2.689
Epoch 29 TL:2.805 VL:2.501
Epoch 30 TL:2.716 VL:2.483
Epoch 31 TL:2.684 VL:2.333
Epoch 32 TL:2.562 VL:2.220
Epoch 33 TL:2.512 VL:2.207
Epoch 34 TL:2.382 VL:2.101
Epoch 35 TL:2.301 VL:2.129
Epoch 36 TL:2.280 VL:1.975
Epoch 37 TL:2.175 VL:1.883
Ep

  model.load_state_dict(torch.load('best.pth'))


Base WER,CER: 65.61085972850678 47.673446642518805


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Starting LLM refinement...
LLM WER,CER: 87.08350473056356 73.87387387387388
Sample 1: ASR=दो्ंेilevelे प�ा्वागत है
Refined=दो्ंेilevelे प�ा्वागत है
Ref=दोस्तों bash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है

Sample 2: ASR=इस म ह नमने�रे मं सीखेंगे
Refined=Is M| this man| namane|ere| man| sikh|enge|g|e|
Ref=इस tutorial में हम निम्न के बारे में सीखेंगे

Sample 3: ASR=nested और
Refined=nested और
Ref=nested ifelse और

Sample 4: ASR=multileelse statement
Refined=multileelse statement
Ref=multilevel ifelse statement

Sample 5: ASR=हमु� ��ार ���यग करके करेंगे
Refined=हमु� ��ार ���यग करके करेंगे
Ref=हम यह कुछ उदाहरण उपयोग करके करेंगे



# Compact hybrid pipeline v4 (alt schedule) → evaluate → LLM refine

In [18]:
# %%
# Cell 1: Install dependencies
import os
# Ensure required packages
os.system("pip install --quiet sentencepiece soundfile transformers torchaudio")

# %%
# Cell 2: Imports & Setup
import soundfile as sf
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from transformers import (
    WhisperProcessor,
    WhisperModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
import numpy as np
import re

# HF token
HF_TOKEN = "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"

# %%
# Cell 3: Kaldi utilities
def read_wav_scp(path):
    wav_dict = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            rec, p = line.strip().split(maxsplit=1)
            wav_dict[rec] = p
    return wav_dict

def read_segments(path):
    seg = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            u, r, s, e = line.strip().split()[:4]
            seg[u] = (r, float(s), float(e))
    return seg

def read_text(path):
    txt = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            u, t = line.strip().split(maxsplit=1)
            txt[u] = t
    return txt

# %%
# Cell 4: Dataset
class KaldiDataset(Dataset):
    def __init__(self, data_dir, processor, sr=16000):
        self.dir = data_dir
        self.proc = processor
        self.sr = sr
        m = os.path.join(data_dir, "transcripts")
        self.wav = read_wav_scp(os.path.join(m, "wav.scp"))
        self.seg = read_segments(os.path.join(m, "segments"))
        self.txt = read_text(os.path.join(m, "text"))
        self.ids = list(self.txt.keys())

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i):
        u = self.ids[i]
        r, st, en = self.seg[u]
        p1 = os.path.join(self.dir, self.wav[r])
        p2 = os.path.join(self.dir, "transcripts", self.wav[r])
        if os.path.isfile(p1):
            path = p1
        elif os.path.isfile(p2):
            path = p2
        else:
            raise FileNotFoundError(f"Missing audio file: {p1} or {p2}")
        try:
            waveform, sr = torchaudio.load(path)
        except Exception:
            arr, sr = sf.read(path)
            arr = np.asarray(arr, dtype=np.float32)
            if arr.ndim > 1:
                waveform = torch.from_numpy(arr.T)
            else:
                waveform = torch.from_numpy(arr).unsqueeze(0)
        if sr != self.sr:
            waveform = torchaudio.transforms.Resample(sr, self.sr)(waveform)
        segment = waveform[0, int(st*self.sr):int(en*self.sr)]
        input_features = self.proc.feature_extractor(
            segment.numpy(), sampling_rate=self.sr, return_tensors='pt'
        ).input_features[0]
        labels = self.proc.tokenizer(
            self.txt[u], return_tensors='pt', add_special_tokens=False
        ).input_ids[0]
        return {"utt_id": u, "feat": input_features, "lbl": labels}

# Collate function for DataLoader
def collate_fn(batch):
    feats = [b['feat'] for b in batch]
    labs = [b['lbl'] for b in batch]
    utt_ids = [b['utt_id'] for b in batch]
    feats_p = nn.utils.rnn.pad_sequence(feats, batch_first=True)
    labs_p = nn.utils.rnn.pad_sequence(labs, batch_first=True, padding_value=-100)
    return {"utt_ids": utt_ids, "feat": feats_p, "lbl": labs_p}

# Cell 5: Model definition
class WhisperMoE(nn.Module):
    def __init__(self, encoder, d_model, vocab_size):
        super().__init__()
        self.enc = encoder
        self.exp_m = nn.Linear(d_model, d_model)
        self.exp_e = nn.Linear(d_model, d_model)
        self.gate = nn.Linear(2*d_model,2)
        self.cls = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        h = self.enc(input_features=x).last_hidden_state
        m = self.exp_m(h);
        e = self.exp_e(h)
        g = torch.softmax(self.gate(torch.cat([m,e],-1)), -1)
        mix = g[...,0:1]*m + g[...,1:2]*e
        return self.cls(mix), g

# %%
# Cell 6: CTC utilities
def decode_ctc(logits, blank):
    ids = logits.argmax(-1).tolist()
    out=[]; prev=None
    for i in ids:
        if i!=prev and i!=blank: out.append(i)
        prev=i
    return out

def edit_dist(a,b):
    m,n=len(a),len(b)
    dp=[[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1): dp[i][0]=i
    for j in range(n+1): dp[0][j]=j
    for i in range(1,m+1):
        for j in range(1,n+1):
            dp[i][j]=dp[i-1][j-1] if a[i-1]==b[j-1] else 1+min(dp[i-1][j],dp[i][j-1],dp[i-1][j-1])
    return dp[m][n]

def wer(p,r):
    tw=0;te=0
    for pr,rr in zip(p,r):
        pw, rw=pr.split(), rr.split()
        te+=edit_dist(rw,pw); tw+=len(rw)
    return te/tw*100

def cer(p,r):
    tc=0;te=0
    for pr,rr in zip(p,r):
        pc=list(pr.replace(' ','')); rc=list(rr.replace(' ',''))
        te+=edit_dist(rc,pc); tc+=len(rc)
    return te/tc*100

# %%
# Cell 7: Single-utterance LLM refine
def refine_one(raw, tokenizer_llm, model_llm, gen_conf, device):
    prompt = f"Raw: {raw}\nTag each word with EN or HI, separated by '|'. Do not translate.\nCorrected:"
    inp = tokenizer_llm(prompt, return_tensors='pt').to(device)
    out = model_llm.generate(
        **inp,
        max_new_tokens=64,
        eos_token_id=tokenizer_llm.eos_token_id,
        pad_token_id=tokenizer_llm.eos_token_id
    )
    new = out[0][inp.input_ids.size(1):]
    text = tokenizer_llm.decode(new, skip_special_tokens=True).split('\n')[0].strip()
    return text

# %%
# Cell 8: Training loops
def train_epoch(m,dl,opt,crit,dev,blank):
    m.train();loss=0
    for b in dl:
        x=b['feat'].to(dev); y=b['lbl'].to(dev)
        logit,_=m(x)
        lp=F.log_softmax(logit,-1).transpose(0,1)
        L=lp.size(0)
        inl=torch.full((lp.size(1),),L,dtype=torch.long,device=dev)
        tl=(y!=-100).sum(1).to(dev)
        y2=y.masked_fill(y==-100,blank)
        l=crit(lp,y2,inl,tl)
        opt.zero_grad();l.backward();opt.step();loss+=l.item()
    return loss/len(dl)

def val_epoch(m,dl,crit,dev,blank):
    m.eval();loss=0
    with torch.no_grad():
        for b in dl:
            x=b['feat'].to(dev); y=b['lbl'].to(dev)
            logit,_=m(x)
            lp=F.log_softmax(logit,-1).transpose(0,1)
            L=lp.size(0)
            inl=torch.full((lp.size(1),),L,dtype=torch.long,device=dev)
            tl=(y!=-100).sum(1).to(dev)
            y2=y.masked_fill(y==-100,blank)
            loss+=crit(lp,y2,inl,tl).item()
    return loss/len(dl)

def evaluate(m,dl,proc,dev,blank):
    m.eval();preds,refs=[],[]
    with torch.no_grad():
        for b in dl:
            x=b['feat'].to(dev); y=b['lbl']
            logit,_=m(x)
            for i in range(logit.size(0)):
                ids=decode_ctc(logit[i],blank)
                preds.append(proc.tokenizer.decode(ids,skip_special_tokens=True))
            for yy in y:
                t=yy.clone().masked_fill(yy==-100,blank).tolist()
                refs.append(proc.tokenizer.decode(t,skip_special_tokens=True))
    return preds,refs

# %%
# Cell 9: Main
def main():
    train_dir, test_dir="train_split","test_split"
    dev=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    proc=WhisperProcessor.from_pretrained('openai/whisper-base')
    blank=proc.tokenizer.pad_token_id
    w=WhisperModel.from_pretrained('openai/whisper-base')
    d=w.config.d_model; vocab=proc.tokenizer.vocab_size
    ds_tr=KaldiDataset(train_dir,proc);ds_te=KaldiDataset(test_dir,proc)
    dl_tr=DataLoader(ds_tr,batch_size=8,shuffle=True,collate_fn=collate_fn)
    dl_te=DataLoader(ds_te,batch_size=8,shuffle=False,collate_fn=collate_fn)

    model=WhisperMoE(w.encoder,d,vocab).to(dev)
    opt=optim.Adam(model.parameters(),1e-4)
    crit=nn.CTCLoss(blank=blank,zero_infinity=True)

    # Train
    best=1e9
    for ep in range(1,51):
        for p in model.enc.parameters(): p.requires_grad=(ep>3)
        tr=train_epoch(model,dl_tr,opt,crit,dev,blank)
        vl=val_epoch(model,dl_te,crit,dev,blank)
        print(f"Epoch {ep} TL:{tr:.3f} VL:{vl:.3f}")
        if vl<best: best=vl;torch.save(model.state_dict(),'best.pth')

    # Load & base eval
    model.load_state_dict(torch.load('best.pth'))
    preds,refs=evaluate(model,dl_te,proc,dev,blank)
    print('Base WER,CER:',wer(preds,refs),cer(preds,refs))

    # LLM refine one by one
    llm_nm='mistralai/Mistral-7B-Instruct-v0.3'
    tok=AutoTokenizer.from_pretrained(llm_nm,token=HF_TOKEN,trust_remote_code=True)
    mlm=AutoModelForCausalLM.from_pretrained(llm_nm,device_map='auto',torch_dtype=torch.float16,token=HF_TOKEN,trust_remote_code=True)
    gen=GenerationConfig(max_new_tokens=64,do_sample=False)

    print('Starting LLM refinement...')
    refined=[]
    for utt in preds:
        if len(utt.split())<6: refined.append(utt)
        else: refined.append(refine_one(utt,tok,mlm,gen,dev))

    # Clean tags & eval on actual text
    def clean(t): return ' '.join(w for w in t.split() if w not in {'EN','HI','|'})
    clean_ref=[clean(t) for t in refined]
    print('LLM WER,CER:',wer(clean_ref,refs),cer(clean_ref,refs))

    # Samples
    for i in range(5): print(f"Sample {i+1}: ASR={preds[i]}\nRefined={clean_ref[i]}\nRef={refs[i]}\n")

if __name__=='__main__': main()



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


Epoch 1 TL:286.275 VL:223.366
Epoch 2 TL:123.612 VL:117.304
Epoch 3 TL:69.171 VL:41.618
Epoch 4 TL:6.796 VL:5.301
Epoch 5 TL:4.878 VL:5.526
Epoch 6 TL:4.776 VL:4.720
Epoch 7 TL:4.545 VL:4.610
Epoch 8 TL:4.434 VL:4.503
Epoch 9 TL:4.349 VL:4.443
Epoch 10 TL:4.289 VL:4.365
Epoch 11 TL:4.257 VL:4.321
Epoch 12 TL:4.190 VL:4.212
Epoch 13 TL:4.129 VL:4.144
Epoch 14 TL:4.046 VL:4.068
Epoch 15 TL:3.998 VL:3.952
Epoch 16 TL:3.925 VL:3.886
Epoch 17 TL:3.855 VL:3.863
Epoch 18 TL:3.824 VL:3.714
Epoch 19 TL:3.718 VL:3.604
Epoch 20 TL:3.644 VL:3.552
Epoch 21 TL:3.547 VL:3.409
Epoch 22 TL:3.456 VL:3.280
Epoch 23 TL:3.372 VL:3.219
Epoch 24 TL:3.278 VL:3.160
Epoch 25 TL:3.261 VL:3.013
Epoch 26 TL:3.160 VL:2.981
Epoch 27 TL:3.086 VL:2.997
Epoch 28 TL:2.986 VL:2.711
Epoch 29 TL:2.875 VL:2.624
Epoch 30 TL:2.775 VL:2.526
Epoch 31 TL:2.742 VL:2.569
Epoch 32 TL:2.674 VL:2.407
Epoch 33 TL:2.607 VL:2.354
Epoch 34 TL:2.516 VL:2.266
Epoch 35 TL:2.437 VL:2.160
Epoch 36 TL:2.360 VL:2.105
Epoch 37 TL:2.307 VL:1.977


  model.load_state_dict(torch.load('best.pth'))


Base WER,CER: 77.33443027560675 65.2456580291632


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Starting LLM refinement...
LLM WER,CER: 79.5146030440148 67.70688213987182
Sample 1: ASR=� mult if statement� सवागत है
Refined=Welcome
Ref=दोस्तों bash में nested और multilevel if statement के spoken tutorial में आपका स्वागत है

Sample 2: ASR=इरे सीखेगे
Refined=इरे सीखेगे
Ref=इस tutorial में हम निम्न के बारे में सीखेंगे

Sample 3: ASR=ested ifelse और
Refined=ested ifelse और
Ref=nested ifelse और

Sample 4: ASR=mvel ifvelelse statement
Refined=mvel ifvelelse statement
Ref=multilevel ifelse statement

Sample 5: ASR=र�क करेंगे
Refined=र�क करेंगे
Ref=हम यह कुछ उदाहरण उपयोग करके करेंगे

