In [None]:
# ============================================================
# FULL ASR + LM EVAL SCRIPT
# - Fine-tuned Whisper (HF) -> RAW
# - 3-gram Kneser–Ney LM -> FIXED
# ============================================================

!pip -q install -U transformers datasets librosa soundfile accelerate
!apt-get -y install ffmpeg

import os, re, math, random
import pandas as pd
import numpy as np
import librosa
import torch
from google.colab import drive
import pickle

from transformers import WhisperProcessor, WhisperForConditionalGeneration

# -------------------------
# 0) Mount Drive
# -------------------------
drive.mount("/content/drive")

# -------------------------
# 1) Directory Paths
# -------------------------
AUDIO_DIR = "/content/drive/MyDrive/twi audio"
MANIFEST_PATH = "/content/drive/MyDrive/twi_audios_manifest.csv"
LM_PATH = "/content/drive/MyDrive/twi_kneser_ney_3gram.pkl"

OUT_PATH = "/content/drive/MyDrive/twi_asr_results_10.csv"

# Choose the finetuned model
# Try: "zirri23/whisper-akan-finetuned"
MODEL_ID = "zirri23/whisper-akan-finetuned"

# -------------------------
# 2) Load LM
# -------------------------
with open(LM_PATH, "rb") as f:
    lm = pickle.load(f)

print("LM loaded:", type(lm))
print("Order:", lm.order)
print("Vocab size:", len(lm.vocab))

# -------------------------
# 3) Load finetuned Whisper model (HF)
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()

print("Loaded model:", MODEL_ID)
print("Device:", device)

# -------------------------
# 4) Helpers: normalization/tokenization + metrics
# -------------------------
def normalize_text(text: str) -> str:
    text = str(text).lower()
    text = text.replace("’", "'").replace("“", '"').replace("”", '"')
    # keep Twi chars ɛ ɔ and basic punctuation removal
    text = re.sub(r"[^a-z0-9ɛɔ'\s\-]", " ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

def tokenize(text: str):
    text = normalize_text(text)
    toks = []
    for tok in text.split():
        toks.extend(tok.split("-"))
    return [t for t in toks if t]

def wer(ref: str, hyp: str) -> float:
    r = tokenize(ref)
    h = tokenize(hyp)
    dp = [[0]*(len(h)+1) for _ in range(len(r)+1)]
    for i in range(len(r)+1): dp[i][0] = i
    for j in range(len(h)+1): dp[0][j] = j
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            cost = 0 if r[i-1] == h[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,      # deletion
                dp[i][j-1] + 1,      # insertion
                dp[i-1][j-1] + cost  # substitution
            )
    return dp[-1][-1] / max(1, len(r))

def cer(ref: str, hyp: str) -> float:
    r = normalize_text(ref)
    h = normalize_text(hyp)
    dp = [[0]*(len(h)+1) for _ in range(len(r)+1)]
    for i in range(len(r)+1): dp[i][0] = i
    for j in range(len(h)+1): dp[0][j] = j
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            cost = 0 if r[i-1] == h[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[-1][-1] / max(1, len(r))

# -------------------------
# 5) LM correction helpers
# -------------------------
def vocab_lookup(token: str) -> str:
    return lm.vocab.lookup([token])[0]  # OOV -> <unk>

def twi_variants(word: str):
    w = word.lower()
    variants = set([w])
    swaps = [("e","ɛ"), ("ɛ","e"), ("o","ɔ"), ("ɔ","o")]

    for a, b in swaps:
        if a in w:
            variants.add(w.replace(a, b))

    expanded = set(variants)
    for v in list(variants):
        for a, b in swaps:
            if a in v:
                expanded.add(v.replace(a, b))
    variants = expanded

    filtered = []
    for v in variants:
        if v == w or (v in lm.vocab):
            filtered.append(v)
    if w not in filtered:
        filtered.insert(0, w)

    out, seen = [], set()
    for x in filtered:
        if x not in seen:
            out.append(x)
            seen.add(x)
    return out

def correct_transcript_beam(asr_text: str, beam_width: int = 10, max_variants_per_word: int = 6):
    tokens = tokenize(asr_text)
    if not tokens:
        return "", 0.0

    order = lm.order
    start_ctx = ["<s>"] * (order - 1)
    beams = [(0.0, [], start_ctx[:])]  # (logp, out_tokens, ctx_tokens)

    for w in tokens:
        cands = twi_variants(w)[:max_variants_per_word]
        new_beams = []
        for logp, out_tokens, ctx in beams:
            context = tuple(ctx[-(order - 1):])
            for cand in cands:
                cand2 = vocab_lookup(cand)
                p = lm.score(cand2, context)
                add = -50.0 if p <= 0.0 else math.log(p)
                new_out = out_tokens + [cand]
                new_ctx = (ctx + [cand2])[-(order - 1):]
                new_beams.append((logp + add, new_out, new_ctx))

        new_beams.sort(key=lambda x: x[0], reverse=True)
        beams = new_beams[:beam_width]

    best_logp, best_tokens, best_ctx = beams[0]
    end_p = lm.score("</s>", tuple(best_ctx[-(order - 1):]))
    best_logp += (-50.0 if end_p <= 0.0 else math.log(end_p))
    return " ".join(best_tokens), best_logp

# -------------------------
# 6) Transcription (HF Whisper) — stable version (no num_frames issues)
# -------------------------
def transcribe_twi(audio_path: str) -> str:
    y, sr = librosa.load(audio_path, sr=16000)
    inputs = processor(y, sampling_rate=16000, return_tensors="pt")

    input_features = inputs.input_features.to(device)

    gen_kwargs = {"task": "transcribe"}
    if hasattr(inputs, "attention_mask") and inputs.attention_mask is not None:
        gen_kwargs["attention_mask"] = inputs.attention_mask.to(device)

    with torch.no_grad():
        predicted_ids = model.generate(input_features, **gen_kwargs)

    text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return text.strip()

# -------------------------
# 7) Load audios + manifest -> ref_map
# -------------------------
audio_exts = (".wav", ".mp3", ".m4a", ".flac", ".ogg")
audio_files = sorted([
    os.path.join(AUDIO_DIR, f)
    for f in os.listdir(AUDIO_DIR)
    if f.lower().endswith(audio_exts)
])

print("Found", len(audio_files), "audio files in", AUDIO_DIR)
print("Example:", audio_files[:3])

manifest = pd.read_csv(MANIFEST_PATH)
print("Manifest columns:", list(manifest.columns))
print("Manifest rows:", len(manifest))

# Expect sentence text in one of these columns
text_col = "sentence" if "sentence" in manifest.columns else None
path_col = "audio_path" if "audio_path" in manifest.columns else None

if text_col is None:
    raise ValueError("Manifest must have a 'sentence' column (or update text_col logic).")
if path_col is None:
    # sometimes they store just 'path' then you build audio_path yourself
    if "path" in manifest.columns:
        # if it's a filename, map it to AUDIO_DIR
        manifest["audio_path"] = manifest["path"].apply(lambda p: p if str(p).startswith("/") else os.path.join(AUDIO_DIR, str(p)))
        path_col = "audio_path"
    else:
        raise ValueError("Manifest must have 'audio_path' or 'path' column.")

manifest["audio_name"] = manifest[path_col].apply(lambda p: os.path.basename(str(p)))
ref_map = dict(zip(manifest["audio_name"].astype(str), manifest[text_col].astype(str)))

print("✅ ref_map built with", len(ref_map), "entries")
print("Example audio files:", [os.path.basename(x) for x in audio_files[:3]])
for k in list(ref_map.keys())[:3]:
    print("Audio:", k)
    print("Ref  :", ref_map[k])
    print("---")

# -------------------------
# 8) QUICK sanity check (first 3)
# -------------------------
print("\nSanity check on first 3 files:")
for idx, path in enumerate(audio_files[:3]):
    name = os.path.basename(path)
    raw = transcribe_twi(path)
    print(f"\n--- {idx} ---")
    print("AUDIO:", name)
    print("REF  :", ref_map.get(name, "<missing ref>"))
    print("RAW  :", raw)

# -------------------------
# 9) Full evaluation: RAW vs LM-FIXED + print 2 qualitative examples
# -------------------------
rows = []
for i, path in enumerate(audio_files, start=1):
    audio_name = os.path.basename(path)
    if audio_name not in ref_map:
        print(f"Skipping (no ref found): {audio_name}")
        continue

    ref_text = str(ref_map[audio_name])

    print(f"\n[{i}/{len(audio_files)}] Transcribing: {audio_name}")
    raw = transcribe_twi(path)

    fixed, lm_score = correct_transcript_beam(raw, beam_width=10)

    rows.append({
        "audio": audio_name,
        "audio_path": path,
        "ref": ref_text,
        "whisper_raw": raw,
        "lm_fixed": fixed,
        "wer_raw": wer(ref_text, raw),
        "wer_fixed": wer(ref_text, fixed),
        "cer_raw": cer(ref_text, raw),
        "cer_fixed": cer(ref_text, fixed),
        "lm_logscore": lm_score
    })

df = pd.DataFrame(rows)
display(df)

# Summary
avg_wer_raw = df["wer_raw"].mean() if len(df) else None
avg_wer_fix = df["wer_fixed"].mean() if len(df) else None
delta = (avg_wer_raw - avg_wer_fix) if (avg_wer_raw is not None and avg_wer_fix is not None) else None

print("\nAverage WER (raw)  :", avg_wer_raw)
print("Average WER (fixed):", avg_wer_fix)
print("ΔWER (raw-fixed)   :", delta)

print("\nAverage CER (raw)  :", df['cer_raw'].mean() if len(df) else None)
print("Average CER (fixed):", df['cer_fixed'].mean() if len(df) else None)

# Qualitative: print 2 examples (before/after)
print("\nQualitative examples (2 files):")
sample_idx = list(range(len(df)))
random.shuffle(sample_idx)
for j in sample_idx[:2]:
    r = df.iloc[j]
    print("\n==============================")
    print("AUDIO:", r["audio"])
    print("REF  :", r["ref"])
    print("RAW  :", r["whisper_raw"])
    print("FIX  :", r["lm_fixed"])
    print("WER raw/fix:", r["wer_raw"], "/", r["wer_fixed"])

# Save results
df.to_csv(OUT_PATH, index=False)
print("\n✅ Saved:", OUT_PATH)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m87.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m515.2/515.2 kB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.7/536.7 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.6/47.6 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.
Mounted at /content/drive
LM loaded: <class 'nltk.lm.models.KneserNeyInterpolated'>
Order: 3
Vocab size: 2986


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


processor_config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/315 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]



adapter_config.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/479 [00:00<?, ?it/s]

generation_config.json: 0.00B [00:00, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/104M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/384 [00:00<?, ?it/s]

Loaded model: zirri23/whisper-akan-finetuned
Device: cuda
Found 10 audio files in /content/drive/MyDrive/twi audio
Example: ['/content/drive/MyDrive/twi audio/common_voice_tw_34745954.mp3', '/content/drive/MyDrive/twi audio/common_voice_tw_34997393.mp3', '/content/drive/MyDrive/twi audio/common_voice_tw_34997394.mp3']
Manifest columns: ['path', 'sentence', 'audio_path']
Manifest rows: 10
✅ ref_map built with 10 entries
Example audio files: ['common_voice_tw_34745954.mp3', 'common_voice_tw_34997393.mp3', 'common_voice_tw_34997394.mp3']
Audio: common_voice_tw_34745954.mp3
Ref  : Dabi, ɛnte saa
---
Audio: common_voice_tw_34997393.mp3
Ref  : • Ma wo yere nhu sɛ wopene nufuma so.
---
Audio: common_voice_tw_34997394.mp3
Ref  : Dɛn na ɛno bɛkyerɛ?
---

Sanity check on first 3 files:


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.
A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensA


--- 0 ---
AUDIO: common_voice_tw_34745954.mp3
REF  : Dabi, ɛnte saa
RAW  : dɛɛbi nte saa

--- 1 ---
AUDIO: common_voice_tw_34997393.mp3
REF  : • Ma wo yere nhu sɛ wopene nufuma so.
RAW  : maoyerɛ nnusɛ ɔpene ne famu aso

--- 2 ---
AUDIO: common_voice_tw_34997394.mp3
REF  : Dɛn na ɛno bɛkyerɛ?
RAW  : den na ɛnobɛkyerɛ

[1/10] Transcribing: common_voice_tw_34745954.mp3

[2/10] Transcribing: common_voice_tw_34997393.mp3

[3/10] Transcribing: common_voice_tw_34997394.mp3

[4/10] Transcribing: common_voice_tw_34997398.mp3

[5/10] Transcribing: common_voice_tw_34997400.mp3

[6/10] Transcribing: common_voice_tw_34997402.mp3

[7/10] Transcribing: common_voice_tw_35280404.mp3

[8/10] Transcribing: common_voice_tw_35280405.mp3

[9/10] Transcribing: common_voice_tw_35280406.mp3

[10/10] Transcribing: common_voice_tw_35280407.mp3


Unnamed: 0,audio,audio_path,ref,whisper_raw,lm_fixed,wer_raw,wer_fixed,cer_raw,cer_fixed,lm_logscore
0,common_voice_tw_34745954.mp3,/content/drive/MyDrive/twi audio/common_voice_...,"Dabi, ɛnte saa",dɛɛbi nte saa,dɛɛbi nte saa,0.666667,0.666667,0.230769,0.230769,-61.332504
1,common_voice_tw_34997393.mp3,/content/drive/MyDrive/twi audio/common_voice_...,• Ma wo yere nhu sɛ wopene nufuma so.,maoyerɛ nnusɛ ɔpene ne famu aso,maoyerɛ nnusɛ ɔpene ne famu aso,1.0,1.0,0.382353,0.382353,-214.713225
2,common_voice_tw_34997394.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Dɛn na ɛno bɛkyerɛ?,den na ɛnobɛkyerɛ,dɛn na ɛnobɛkyerɛ,0.75,0.5,0.111111,0.055556,-60.31426
3,common_voice_tw_34997398.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Ɛkaa abɔde ho nsɛm pii a ɛma yenya Yehowa dɔ h...,ɛka abɔde ho nsɛm pii a mma yɛnyɛ yi ho wɔ adɔ...,ɛka abɔde ho nsɛm pii a mma yɛnyɛ yi ho wɔ adɔ...,0.692308,0.769231,0.224138,0.241379,-107.068759
4,common_voice_tw_34997400.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Mekae saa asɛmfua yi.,mikaisa asɛ mfuo yi,mikaisa ase mfuo yi,0.75,0.75,0.3,0.35,-115.566256
5,common_voice_tw_34997402.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Dɛn na wopɛ sɛ wunya?,dan na ɔpɛ sɛ wunya,dan na ɔpɛ sɛ wunya,0.4,0.4,0.15,0.15,-42.929132
6,common_voice_tw_35280404.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Ɛyɛɛ no nwonwa ma obisaa me sɛ,ɛyɛ no hɔn hɔn ma obisa msɛ,ɛyɛ no hɔn hɔn ma obisa msɛ,0.857143,0.857143,0.366667,0.366667,-223.013923
7,common_voice_tw_35280405.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Sɛnea Asamoah gye di sɛ biribiara wɔ ne bere n...,sɛn a asɛmwa gye de sɛ beebiara wɔ ne berɛ ne ...,sen a asɛmwa gye dɛ sɛ beebiara wɔ ne bere ne ...,0.545455,0.454545,0.169811,0.169811,-159.604399
8,common_voice_tw_35280406.mp3,/content/drive/MyDrive/twi audio/common_voice_...,Mekyerɛkyerɛɛ mu sɛ Yehowa Adansefo nso hyɛ ey...,m'akyerakyeramu si ɛho w'adansafo nso hyɛ yi h...,m'akyerakyeramu si ɛho w'adansafo nso hyɛ yi h...,0.7,0.7,0.263158,0.263158,-201.019729
9,common_voice_tw_35280407.mp3,/content/drive/MyDrive/twi audio/common_voice_...,"Bere a midii awia aduan wiei no, mesan kɔɔ adw...",brɛ a ɔdi yɛ wiaduan wiɛ no ɛsan kɔ adwumayɛbe...,brɛ a odi yɛ wiaduan wie no ɛsan kɔ adwumayɛbe...,0.636364,0.636364,0.267857,0.25,-126.210014



Average WER (raw)  : 0.6997935397935398
Average WER (fixed): 0.6733949383949385
ΔWER (raw-fixed)   : 0.026398601398601285

Average CER (raw)  : 0.24658642391066637
Average CER (fixed): 0.24596929200043105

Qualitative examples (2 files):

AUDIO: common_voice_tw_35280407.mp3
REF  : Bere a midii awia aduan wiei no, mesan kɔɔ adwumayɛbea hɔ.
RAW  : brɛ a ɔdi yɛ wiaduan wiɛ no ɛsan kɔ adwumayɛbea hɔ
FIX  : brɛ a odi yɛ wiaduan wie no ɛsan kɔ adwumayɛbea hɔ
WER raw/fix: 0.6363636363636364 / 0.6363636363636364

AUDIO: common_voice_tw_34997400.mp3
REF  : Mekae saa asɛmfua yi.
RAW  : mikaisa asɛ mfuo yi
FIX  : mikaisa ase mfuo yi
WER raw/fix: 0.75 / 0.75

✅ Saved: /content/drive/MyDrive/twi_asr_results_10.csv
