In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from jiwer import wer
import csv


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Define paths (relative to the script's directory)
script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
input_folder = os.path.join(script_dir, "Whisper turbo FT")  # Folder with input text files (transcriptions)
output_folder = os.path.join(script_dir, "output_transcripts")  # Folder to save model outputs (create if needed)
gold_folder = os.path.join(script_dir, "Gold Transcriptions")  # Folder with gold transcriptions
mapping_csv = os.path.join(script_dir, "mapping.csv")  # CSV with 'Name' and 'Gold_path' columns
csv_path = os.path.join(script_dir, "results.csv")  # Path to save the results CSV
os.makedirs(output_folder, exist_ok=True)


In [None]:
# Model details
model_name = "Qwen/Qwen3-14B"
HF_TOKEN = "HF_TOKEN"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Automatically places model on GPU if available
    token=HF_TOKEN
)
print("Model loaded on:", next(model.parameters()).device)


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00,  1.15it/s]


Model loaded on: cuda:0


In [None]:
prompt_template = """You are an Urdu ASR error correction expert. Your ONLY task is to replace incorrectly transcribed Urdu words with their correct forms.
CRITICAL RULES:
- ABSOLUTELY NO punctuation (no ۔ ، ؟ ! . , : ; " ' or any symbols)
- NO new words or phrases
- NO reordering
- NO grammar changes
- ONLY replace wrong words with correct ones
- If unsure about a word, leave it unchanged
Think of this as a word-by-word dictionary replacement, not a rewrite.

Fix ONLY the incorrectly transcribed words in this Urdu text. Replace wrong words with correct ones based on context. Add NO punctuation.

Urdu text: {text}

Corrected text:"""


In [None]:
# Read the mapping CSV
with open(mapping_csv, 'r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    mappings = list(reader)

# Process each file based on mapping
for mapping in mappings:
    filename = mapping['Name']
    if not filename.endswith(".txt"):
        print(f"Skipping non-txt file: {filename}")
        continue
    
    input_path = os.path.join(input_folder, filename)
    if not os.path.exists(input_path):
        print(f"Input file not found: {input_path}")
        continue
    
    with open(input_path, 'r', encoding='utf-8') as f:
        input_text = f.read()
    
    # Format the prompt with the input text
    messages = [{"role": "user", "content": prompt_template.format(text=input_text)}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True,
    )

    model_inputs = tokenizer([input_ids], return_tensors="pt").to(model.device)
    
    # Generate with beam search for more deterministic and accurate corrections
    # These settings promote minimal, conservative changes: no sampling, beam search for quality
    outputs = model.generate(
        **model_inputs,
        max_new_tokens=32768,  # Adjust to roughly match input length + margin if needed
    )
    output_ids = outputs[0][len(model_inputs.input_ids[0]):].tolist() 
    # parsing thinking content
    try:
        # rindex finding 151668 (</think>)
        index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError:
        index = 0

    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    generated_text = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    #generateed text is the actual content
    
    # Save to output folder with the same filename
    output_path = os.path.join(output_folder, filename)
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(generated_text.strip())
    
    print(f"Processed {filename}")


Processed 3Alleged Gold.txt
Processed 11Deaths Gold.txt
Processed AfghanCricketGPT.txt
Processed BuildingCollapse Gold.txt
Processed BullyingGPT Gold.txt
Processed ConstructionHalt_Gold.txt
Processed CTD_Gold.txt
Processed GasTheft_Gold.txt
Processed Hamid Mir Imran Khan Key.txt
Processed Inflation Gold.txt
Processed KarachiKings Gold.txt
Processed Kidney Gold.txt
Processed MobileTheft Gold.txt
Processed Murree Gold.txt
Processed PakvsInd2 Gold.txt
Processed PakVsInd Gold.txt
Processed PassengerGPT.txt
Processed Petrol Gold.txt
Processed PunjabGovt Gold.txt
Processed Quetta Gold.txt
Processed RamadanGas Gold.txt
Processed RamadanMoon Gold.txt
Processed RedLine_Gold.txt
Processed Sama Electricity Relief Key.txt
Processed Sama FM Egypt Visit Key.txt
Processed Sama PSL Multan Key.txt
Processed SindhTax_Gold.txt
Processed Traffic_Accident_Geo.txt
Processed VehicleCollisionGPT.txt
Processed Women Gold.txt


In [None]:
# Now compute WER for each pair (before and after SLM) based on mapping
results = []
for mapping in mappings:
    filename = mapping['Name']
    gold_filename = mapping['Gold_path']
    if not filename.endswith(".txt") or not gold_filename.endswith(".txt"):
        print(f"Skipping non-txt pair: {filename} / {gold_filename}")
        continue
    
    input_path = os.path.join(input_folder, filename)
    output_path = os.path.join(output_folder, filename)
    gold_path = os.path.join(gold_folder, gold_filename)
    
    if os.path.exists(gold_path) and os.path.exists(input_path) and os.path.exists(output_path):
        with open(input_path, 'r', encoding='utf-8') as f:
            input_text = f.read().strip()
        with open(output_path, 'r', encoding='utf-8') as f:
            output_text = f.read().strip()
        with open(gold_path, 'r', encoding='utf-8') as f:
            gold_text = f.read().strip()
        
        wer_before = wer(gold_text, input_text)  # WER between gold and input (before SLM)
        wer_after = wer(gold_text, output_text)   # WER between gold and output (after SLM)
        
        results.append({
            'filename': filename,
            'gold_filename': gold_filename,
            'wer_before': wer_before,
            'wer_after': wer_after
        })
        
        print(f"WER for {filename} (before SLM): {wer_before}")
        print(f"WER for {filename} (after SLM): {wer_after}")
    else:
        print(f"Files not found for pair: {filename} / {gold_filename}")

# Write to CSV
if results:
    with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = ['filename', 'gold_filename', 'wer_before', 'wer_after']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)
    
    # Average WER
    avg_wer_before = sum(r['wer_before'] for r in results) / len(results)
    avg_wer_after = sum(r['wer_after'] for r in results) / len(results)
    print(f"Average WER before SLM: {avg_wer_before}")
    print(f"Average WER after SLM: {avg_wer_after}")
else:
    print("No WER calculations performed.")


WER for 3Alleged Gold.txt (before SLM): 0.22145328719723184
WER for 3Alleged Gold.txt (after SLM): 0.2179930795847751
WER for 11Deaths Gold.txt (before SLM): 0.282798833819242
WER for 11Deaths Gold.txt (after SLM): 0.24198250728862974
WER for AfghanCricketGPT.txt (before SLM): 0.23371647509578544
WER for AfghanCricketGPT.txt (after SLM): 0.19157088122605365
WER for BuildingCollapse Gold.txt (before SLM): 0.20892018779342722
WER for BuildingCollapse Gold.txt (after SLM): 0.19953051643192488
WER for BullyingGPT Gold.txt (before SLM): 0.2542372881355932
WER for BullyingGPT Gold.txt (after SLM): 0.2245762711864407
WER for ConstructionHalt_Gold.txt (before SLM): 0.20085470085470086
WER for ConstructionHalt_Gold.txt (after SLM): 0.20085470085470086
WER for CTD_Gold.txt (before SLM): 0.33796296296296297
WER for CTD_Gold.txt (after SLM): 0.33796296296296297
WER for GasTheft_Gold.txt (before SLM): 0.2364217252396166
WER for GasTheft_Gold.txt (after SLM): 0.22044728434504793
WER for Hamid Mir Im