In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import csv


In [None]:
# Define paths (relative to the script's directory)
script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
input_folder = os.path.join(script_dir, "Whisper turbo FT")  # Folder with input text files (transcriptions)
output_folder = os.path.join(script_dir, "output_transcripts")  # Folder to save model outputs (create if needed)
gold_folder = os.path.join(script_dir, "Gold Transcriptions")  # Folder with gold transcriptions
mapping_csv = os.path.join(script_dir, "mapping.csv")  # CSV with 'Name' and 'Gold_path' columns
csv_path = os.path.join(script_dir, "results.csv")  # Path to save the results CSV
os.makedirs(output_folder, exist_ok=True)


In [None]:
# Model details
model_name = "google/gemma-2-9b"
HF_TOKEN = "HF_TOKEN"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Automatically places model on GPU if available
    token=HF_TOKEN
)
print("Model loaded on:", next(model.parameters()).device)


Loading checkpoint shards: 100%|██████████| 8/8 [00:07<00:00,  1.01it/s]


Model loaded on: cuda:0


In [None]:
prompt_template = """You are an Urdu ASR error correction expert. Your ONLY task is to replace incorrectly transcribed Urdu words with their correct forms.
CRITICAL RULES:
- ABSOLUTELY NO punctuation (no ۔ ، ؟ ! . , : ; " ' or any symbols)
- NO new words or phrases
- NO reordering
- NO grammar changes
- ONLY replace wrong words with correct ones
- If unsure about a word, leave it unchanged
Think of this as a word-by-word dictionary replacement, not a rewrite.

Fix ONLY the incorrectly transcribed words in this Urdu text. Replace wrong words with correct ones based on context. Add NO punctuation.

Urdu text: {text}

Corrected text:"""


In [None]:
# Read the mapping CSV
with open(mapping_csv, 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    mappings = list(reader)

# Process each file based on mapping
for mapping in mappings:
    filename = mapping['Name']
    if not filename.endswith(".txt"):
        print(f"Skipping non-txt file: {filename}")
        continue
    
    input_path = os.path.join(input_folder, filename)
    if not os.path.exists(input_path):
        print(f"Input file not found: {input_path}")
        continue
    
    with open(input_path, 'r', encoding='utf-8') as f:
        input_text = f.read()
    
    # Format the prompt with the input text
    prompt = prompt_template.format(text=input_text)

    model_inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=4096
    ).to(model.device)
    
    # Generate with beam search for more deterministic and accurate corrections
    # These settings promote minimal, conservative changes: no sampling, beam search for quality
    outputs = model.generate(
        **model_inputs,
        max_new_tokens=int(model_inputs.input_ids.shape[1] * 1.2),
        do_sample=False,
        temperature=0.0,
        repetition_penalty=1.2,
        num_beams=1
    )
    output_ids = outputs[0][len(model_inputs.input_ids[0]):].tolist() 
    generated_text = tokenizer.decode(
        output_ids,
        skip_special_tokens=True
    ).strip()
    
    generated_text = generated_text.replace("\n", " ").strip()
    
    # Save to output folder with the same filename
    output_path = os.path.join(output_folder, filename)
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(generated_text.strip())
    
    print(f"Processed {filename}")


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 3Alleged Gold.txt
Processed 11Deaths Gold.txt
Processed AfghanCricketGPT.txt
Processed BuildingCollapse Gold.txt
Processed BullyingGPT Gold.txt
Processed ConstructionHalt_Gold.txt
Processed CTD_Gold.txt
Processed GasTheft_Gold.txt
Processed Hamid Mir Imran Khan Key.txt
Processed Inflation Gold.txt
Processed KarachiKings Gold.txt
Processed Kidney Gold.txt
Processed MobileTheft Gold.txt
Processed Murree Gold.txt
Processed PakvsInd2 Gold.txt
Processed PakVsInd Gold.txt
Processed PassengerGPT.txt
Processed Petrol Gold.txt
Processed PunjabGovt Gold.txt
Processed Quetta Gold.txt
Processed RamadanGas Gold.txt
Processed RamadanMoon Gold.txt
Processed RedLine_Gold.txt
Processed Sama Electricity Relief Key.txt
Processed Sama FM Egypt Visit Key.txt
Processed Sama PSL Multan Key.txt
Processed SindhTax_Gold.txt
Processed Traffic_Accident_Geo.txt
Processed VehicleCollisionGPT.txt
Processed Women Gold.txt


In [None]:
from jiwer import wer, cer, mer, wil

# Now compute WER, CER, MER, WIL for each pair (before and after SLM) based on mapping
results = []

for mapping in mappings:
    filename = (mapping.get('Name') or '').strip()
    gold_filename = (mapping.get('Gold_path') or '').strip()

    if not filename.endswith(".txt") or not gold_filename.endswith(".txt"):
        print(f"Skipping non-txt pair: {filename} / {gold_filename}")
        continue

    input_path = os.path.join(input_folder, filename)
    output_path = os.path.join(output_folder, filename)
    gold_path = os.path.join(gold_folder, gold_filename)

    if os.path.exists(gold_path) and os.path.exists(input_path) and os.path.exists(output_path):
        with open(input_path, 'r', encoding='utf-8') as f:
            input_text = f.read().strip()
        with open(output_path, 'r', encoding='utf-8') as f:
            output_text = f.read().strip()
        with open(gold_path, 'r', encoding='utf-8') as f:
            gold_text = f.read().strip()

        # --- WER ---
        wer_before = wer(gold_text, input_text)
        wer_after  = wer(gold_text, output_text)

        # --- CER ---
        cer_before = cer(gold_text, input_text)
        cer_after  = cer(gold_text, output_text)

        # --- MER (Match Error Rate) ---
        mer_before = mer(gold_text, input_text)
        mer_after  = mer(gold_text, output_text)

        # --- WIL (Word Information Lost) ---
        wil_before = wil(gold_text, input_text)
        wil_after  = wil(gold_text, output_text)

        results.append({
            'filename': filename,
            'gold_filename': gold_filename,
            'wer_before': wer_before,
            'wer_after': wer_after,
            'cer_before': cer_before,
            'cer_after': cer_after,
            'mer_before': mer_before,
            'mer_after': mer_after,
            'wil_before': wil_before,
            'wil_after': wil_after
        })

        print(f"\nResults for {filename}")
        print(f"WER before SLM: {wer_before}")
        print(f"WER after  SLM: {wer_after}")
        print(f"CER before SLM: {cer_before}")
        print(f"CER after  SLM: {cer_after}")
        print(f"MER before SLM: {mer_before}")
        print(f"MER after  SLM: {mer_after}")
        print(f"WIL before SLM: {wil_before}")
        print(f"WIL after  SLM: {wil_after}")

    else:
        print(f"Files not found for pair: {filename} / {gold_filename}")

# Write to CSV
if results:
    with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = [
            'filename',
            'gold_filename',
            'wer_before', 'wer_after',
            'cer_before', 'cer_after',
            'mer_before', 'mer_after',
            'wil_before', 'wil_after'
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)

    # --- Averages ---
    n = len(results)
    avg_wer_before = sum(r['wer_before'] for r in results) / n
    avg_wer_after  = sum(r['wer_after']  for r in results) / n

    avg_cer_before = sum(r['cer_before'] for r in results) / n
    avg_cer_after  = sum(r['cer_after']  for r in results) / n

    avg_mer_before = sum(r['mer_before'] for r in results) / n
    avg_mer_after  = sum(r['mer_after']  for r in results) / n

    avg_wil_before = sum(r['wil_before'] for r in results) / n
    avg_wil_after  = sum(r['wil_after']  for r in results) / n

    print("\n===== AVERAGE RESULTS =====")
    print(f"Average WER before SLM: {avg_wer_before}")
    print(f"Average WER after  SLM: {avg_wer_after}")
    print(f"Average CER before SLM: {avg_cer_before}")
    print(f"Average CER after  SLM: {avg_cer_after}")
    print(f"Average MER before SLM: {avg_mer_before}")
    print(f"Average MER after  SLM: {avg_mer_after}")
    print(f"Average WIL before SLM: {avg_wil_before}")
    print(f"Average WIL after  SLM: {avg_wil_after}")
else:
    print("No WER/CER/MER/WIL calculations performed.")



Results for 3Alleged Gold.txt
WER before SLM: 0.22145328719723184
WER after  SLM: 0.8304498269896193
CER before SLM: 0.07362885048835462
CER after  SLM: 0.6070623591284748
MER before SLM: 0.21548821548821548
MER after  SLM: 0.8304498269896193
WIL before SLM: 0.34546616350988024
WIL after  SLM: 0.9696789836587275

Results for 11Deaths Gold.txt
WER before SLM: 0.282798833819242
WER after  SLM: 1.0
CER before SLM: 0.0865992414664981
CER after  SLM: 0.9367888748419722
MER before SLM: 0.27019498607242337
MER after  SLM: 1.0
WIL before SLM: 0.4362583665256847
WIL after  SLM: 1.0

Results for AfghanCricketGPT.txt
WER before SLM: 0.23371647509578544
WER after  SLM: 1.0
CER before SLM: 0.07410636442894507
CER after  SLM: 0.8866608544027899
MER before SLM: 0.23106060606060605
MER after  SLM: 1.0
WIL before SLM: 0.39737065309584396
WIL after  SLM: 1.0

Results for BuildingCollapse Gold.txt
WER before SLM: 0.20892018779342722
WER after  SLM: 0.8708920187793427
CER before SLM: 0.07146695325094035
