In [None]:
import transformers, sys, importlib.util
print("Transformers version:", transformers.__version__)
print("Python version:", sys.version)
print("Transformers file:", importlib.util.find_spec("transformers").origin)


Transformers version: 4.44.0
Python version: 3.12.8 (tags/v3.12.8:2dc476b, Dec  3 2024, 19:30:04) [MSC v.1942 64 bit (AMD64)]
Transformers file: c:\Users\ashaikh\AppData\Local\Programs\Python\Python312\Lib\site-packages\transformers\__init__.py


In [None]:
# ================================
# IMPORTS
# ================================

import os, csv, tarfile, glob, time, datetime, random, torch, re, evaluate, unicodedata, json
import numpy as np
import pandas as pd
import papermill as pm

from tqdm import tqdm
from peft import PeftModel
from datasets import load_dataset, Dataset, Audio, DatasetDict, concatenate_datasets, Features, Value
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from peft import get_peft_model, LoraConfig
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW
from jiwer import wer as jiwer_wer
from huggingface_hub import login
from typing import List, Callable
from itertools import product


In [None]:
# ================================
# CONFIGURATION
# ================================

# Experiment settings
EXPERIMENT_NAME = "finetuning-29"
RANDOM_SEED = 42

# Model and LoRA config
BASE_MODEL_NAME = "openai/whisper-large-v3-turbo" # "openai/whisper-large-v3" # "openai/whisper-large-v2" 
LORA_R = 16 #32
LORA_ALPHA = 32 #64
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = ["q_proj", "v_proj"] #, "k_proj", "out_proj"]

# Training config
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
NUM_EPOCHS = 2 #12
FP16 = True
MAX_LABEL_LENGTH = 128

# Dataset config
TARGET_SR = 16000
AUDIO_COL = "audio"
TEXT_COL = "transcription"
TRAIN_NUM_SAMPLES = 8000  # None = full set
TEST_NUM_SAMPLES = None   # None = full set
EVAL_FROM_TRAIN_PCT = 0  # 0.05 = 5% validation from train

# Output files
PREDICTIONS_CSV = f"{EXPERIMENT_NAME}_predictions.csv"
SUMMARY_CSV = f"{EXPERIMENT_NAME}_summary.csv"

# Set random seeds
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)


In [None]:
# Pick from: "commonvoice", "fleurs", "csalt" or None
# At minimum, train_1 and test_1 must be non-None
train_1 = "commonvoice"
train_2 = "fleurs"
train_3 = None

test_1  = "csalt"
test_2  = None
test_3  = None


# helper functions

In [None]:
# ================================
# TEXT NORMALIZATION
# ================================

# -----------------------------
# Core normalization utilities
# -----------------------------

_ARABIC_DIACRITICS = re.compile(
    "["                             # Arabic diacritics range
    "\u0610-\u061A"                 # honorifics, small high
    "\u064B-\u065F"                 # tanwin/harakat
    "\u0670"                        # superscript alef
    "\u06D6-\u06ED"                 # Quranic marks
    "]"
)

# Zero-width & elongation
_ZW_CHARS = re.compile("[\u200B-\u200F\u202A-\u202E\u2066-\u2069]")
_KASHIDA  = re.compile("\u0640")  # tatweel

# Arabic presentation forms (NFKC will canonicalize most)
def _compat_normalize(s: str) -> str:
    # Normalize compatibility forms and spacing
    s = unicodedata.normalize("NFKC", s)
    # Remove bidi/zero-width and kashida
    s = _ZW_CHARS.sub("", s)
    s = _KASHIDA.sub("", s)
    # Remove diacritics
    s = _ARABIC_DIACRITICS.sub("", s)
    return s

# Map Arabic/Urdu codepoints to a single canonical set often used in Urdu
# (Farsi Yeh, Heh goal, etc.)
def _canonical_codepoints(s: str) -> str:
    # Unify Yeh forms: U+064A (Arabic Yeh), U+06CC (Farsi Yeh) -> choose U+06CC
    s = s.replace("\u064A", "\u06CC")
    # Unify Alef Maksura (rare in Urdu) to Farsi Yeh as well (defensive)
    s = s.replace("\u0649", "\u06CC")
    # Unify Heh goal variants: ÿ©/Ÿá/€Å/€Ç ‚Üí €Å (U+06C1) when appropriate
    # Keep it simple/robust for scoring:
    s = s.replace("\u06C0", "\u06C1")  # heh with hamza above ‚Üí heh goal
    # Don't over-aggressively rewrite 'Ÿá' to '€Å' (Arabic heh to Urdu heh goal),
    # but we can do a light pass:
    s = re.sub(r"(?<=\S)\u0647(?=\b)", "\u06C1", s)  # word-final Arabic heh ‚Üí Urdu heh goal
    return s

# Digits: normalize both Latin and Arabic-Indic to Arabic-Indic (or remove)
_ARABIC_INDIC_DIGITS = str.maketrans(
    "0123456789"
    "Ÿ†Ÿ°Ÿ¢Ÿ£Ÿ§Ÿ•Ÿ¶ŸßŸ®Ÿ©"
    "€∞€±€≤€≥€¥€µ€∂€∑€∏€π",
    "€∞€±€≤€≥€¥€µ€∂€∑€∏€π" * 3  # map Latin + Arabic-Indic + Extended to Extended Arabic-Indic
)
def _normalize_digits(s: str) -> str:
    return s.translate(_ARABIC_INDIC_DIGITS)

# Remove punctuation & special markers (keep intra-word apostrophes if you want)
_PUNCT = re.compile(r"[^\w\s\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]")  # drop non-Arabic/word chars
# Seamless-style disfluencies: remove tokens like #um #uh #laugh
_SEAMLESS_DISFL = re.compile(r"(?<!\w)#\w+")

def _strip_punct_and_disfluencies(s: str) -> str:
    s = _SEAMLESS_DISFL.sub(" ", s)
    # Convert underscores/odd joins to space first (defensive)
    s = s.replace("_", " ")
    s = _PUNCT.sub(" ", s)
    return s

def _squash_spaces(s: str) -> str:
    return re.sub(r"\s+", " ", s).strip()

# -------------------------------------------
# Orthographic + token-segmentation variants
# -------------------------------------------

# Frequent variants noted in paper: "⁄Üÿß€Å€å€í" spellings; "€ÅŸà ⁄Øÿß/€ÅŸà⁄Øÿß" etc.
_VARIANT_CANON = [
    # --- ⁄Üÿß€Å€å€í (imperative/necessity) canonicalization ---
    # Variants: ⁄Üÿß€Å€åÿ¶€í / ⁄Üÿß⁄æ€å€í / ⁄Üÿß€Åÿ¶€í / ⁄Üÿß€Å€å€ì, etc ‚Üí ⁄Üÿß€Å€å€í
    (re.compile(r"\b⁄Üÿß€Å€å[ÿ¶€í€ì]\b"), "⁄Üÿß€Å€å€í"),
    (re.compile(r"\b⁄Üÿß⁄æ€å[ÿ¶€í€ì]\b"), "⁄Üÿß€Å€å€í"),
    (re.compile(r"\b⁄Üÿß€Å[ÿ¶€í€ì]\b"), "⁄Üÿß€Å€å€í"),
    # common stem "chahie" unvoweled variants
    (re.compile(r"\b⁄Üÿß€Å€å?€í\b"), "⁄Üÿß€Å€å€í"),

    # --- €ÅŸà⁄Øÿß family: space-insensitive joining ---
    (re.compile(r"\b€ÅŸà\s+⁄Øÿß\b"), "€ÅŸà⁄Øÿß"),
    (re.compile(r"\b€ÅŸà\s+⁄Ø€å\b"), "€ÅŸà⁄Ø€å"),
    (re.compile(r"\b€ÅŸà\s+⁄Ø€í\b"), "€ÅŸà⁄Ø€í"),
    # The reverse (split) hardly needed if we canonicalize to joined forms

    # Misc. common merges/splits seen in practice (add as you observe)
    (re.compile(r"\b⁄©Ÿà ÿ¶€å\b"), "⁄©Ÿàÿ¶€å"),
    (re.compile(r"\b⁄©€Å\b"), "⁄©€Å"),  # noop example; placeholders for future
]

def _apply_variant_canon(s: str) -> str:
    for pat, rep in _VARIANT_CANON:
        s = pat.sub(rep, s)
    return s

# -----------------------------
# Public normalizer
# -----------------------------
def normalize_urdu_text(text: str) -> str:
    """
    Robust normalizer for Urdu ASR scoring:
    - Unicode compatibility & diacritics removal
    - Canonical Urdu codepoints (Yeh/Heh goal)
    - Remove Seamless-style '#um' disfluencies
    - Remove punctuation
    - Normalize digits (Latin/Arabic to Eastern Arabic-Indic)
    - Canonicalize frequent orthographic variants (⁄Üÿß€Å€å€í, €ÅŸà⁄Øÿß~€ÅŸà ⁄Øÿß)
    - Space squashing
    """
    if not text:
        return ""

    s = text

    # 1) Unicode & presentation forms ‚Üí canonical, drop tatweel/ZW & diacritics
    s = _compat_normalize(s)

    # 2) Canonical Urdu codepoints
    s = _canonical_codepoints(s)

    # 3) Disfluencies + punctuation
    s = _strip_punct_and_disfluencies(s)

    # 4) Digits (optional; or drop all digits if your refs omit numbers)
    s = _normalize_digits(s)

    # 5) Orthographic canonicalizations & token segmentation fixes
    s = _apply_variant_canon(s)

    # 6) Collapse spaces
    s = _squash_spaces(s)

    return s

# ---------------------------------------------------------
# Optional: "lenient" comparison for WER with variants
# ---------------------------------------------------------

# Define lightweight variant generators for lattice expansion on very frequent cases.
# Keep these sets tight to avoid combinatorial blow-up.
_VARIANT_RULES = {
    "⁄Üÿß€Å€å€í": {"⁄Üÿß€Å€å€í", "⁄Üÿß€Åÿ¶€í", "⁄Üÿß€Å€åÿ¶€í", "⁄Üÿß⁄æ€å€í", "⁄Üÿß€Å€å€ì"},
    "€ÅŸà⁄Øÿß": {"€ÅŸà⁄Øÿß", "€ÅŸà ⁄Øÿß"},
    "€ÅŸà⁄Ø€å": {"€ÅŸà⁄Ø€å", "€ÅŸà ⁄Ø€å"},
    "€ÅŸà⁄Ø€í": {"€ÅŸà⁄Ø€í", "€ÅŸà ⁄Ø€í"},
}

def _expand_variants(tokens: List[str]) -> List[List[str]]:
    expanded_per_token = []
    for tok in tokens:
        expanded_per_token.append(list(_VARIANT_RULES.get(tok, {tok})))
    # Cartesian product over tokens to build candidate sequences
    return [list(prod) for prod in product(*expanded_per_token)]

def generate_lenient_variants(s: str) -> List[str]:
    """
    Given a normalized string, produce a small set of alternative strings
    accounting for the most common spelling/spacing variants.
    """
    toks = s.split()
    seqs = _expand_variants(toks)
    return [" ".join(seq) for seq in seqs]

# Example of usage with jiwer:
# def lenient_min_wer(ref: str, hyp: str, normalizer: Callable[[str], str] = normalize_urdu_text) -> float:
#     r = normalizer(ref)
#     h = normalizer(hyp)
#     r_cands = generate_lenient_variants(r)
#     h_cands = generate_lenient_variants(h)
#     # Compute min WER across small lattice of variants
#     scores = []
#     for rc in r_cands:
#         for hc in h_cands:
#             scores.append(jiwer_wer(rc, hc))
#     return min(scores) if scores else jiwer_wer(r, h)

print("‚úÖ Text normalization function loaded")


‚úÖ Text normalization function loaded


In [None]:
# ================================
# SETUP
# ================================

overall_start_time = time.time()
print(f"üïê Experiment started: {datetime.datetime.fromtimestamp(overall_start_time).strftime('%Y-%m-%d %H:%M:%S')}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Using device: {device}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Login to HuggingFace
login(token="HF_TOKEN")


üïê Experiment started: 2025-11-15 21:06:35
‚úÖ Using device: cuda
‚úÖ GPU: NVIDIA A40
‚úÖ Available GPU memory: 48.31 GB


# data laoding

In [20]:
# ================================
# DATA LOADING HELPERS
# ================================

def ensure_audio_and_text(ds, text_keys=("transcription", "sentence", "text", "label")):
    """Standardize column names to 'audio' and 'transcription'"""
    # Ensure TEXT_COL
    if TEXT_COL not in ds.column_names:
        for k in text_keys:
            if k in ds.column_names:
                ds = ds.rename_column(k, TEXT_COL)
                break
    if TEXT_COL not in ds.column_names:
        raise ValueError("Could not find transcript column")

    # Ensure AUDIO_COL and cast
    if AUDIO_COL not in ds.column_names:
        cand = next((c for c in ds.column_names if c.lower() in ("audio", "path", "file")), None)
        if cand:
            ds = ds.rename_column(cand, AUDIO_COL)
    
    # Always cast audio to ensure consistent sampling rate and format
    ds = ds.cast_column(AUDIO_COL, Audio(sampling_rate=TARGET_SR, mono=True, decode=True))
    
    return ds

def subsample_after_shuffle(ds, n, seed=RANDOM_SEED):
    """Shuffle and subsample dataset"""
    if n is None or n <= 0 or n >= len(ds):
        return ds
    return ds.shuffle(seed=seed).select(range(n))

def load_csalt_raw():
    ds_all = load_dataset("urdu-asr/csalt-voice", token=False)
    train_like = ensure_audio_and_text(ds_all["validation"])
    return DatasetDict({"train": train_like})

def load_fleurs_raw():
    """Load FLEURS Urdu (ur_pk + ur_in + ur) and merge all splits"""
    all_langs = []
    for lang_code in ["ur_pk", "ur_in", "ur"]:
        try:
            dataset = load_dataset("google/fleurs", lang_code, trust_remote_code=True)
            all_langs.append(dataset)
            print(f"‚úÖ Loaded FLEURS split for {lang_code} with splits: {list(dataset.keys())}")
        except Exception:
            print(f"‚ö†Ô∏è Could not load FLEURS language code: {lang_code}")
            continue

    if not all_langs:
        raise ValueError("Could not load any FLEURS Urdu variants")

    # Merge all language variants together
    merged = {}
    for split in ["train", "validation", "test"]:
        merged_splits = [
            ensure_audio_and_text(ds[split]) for ds in all_langs if split in ds
        ]
        if merged_splits:
            merged[split] = concatenate_datasets(merged_splits)

    print(f"‚úÖ Combined FLEURS Urdu splits: {', '.join(merged.keys())}")
    return DatasetDict(merged)

def load_commonvoice_v23_local(src_path):
    r"""
    src_path: path to mcv-scripted-ur-v23.0.tar.gz OR to an extracted folder OR directly to ...\ur
    Returns DatasetDict with {train, validation, test} (or train-only via validated.tsv) with AUDIO_COL/TEXT_COL ready.
    """
    # 0) Decide root_dir
    if src_path.lower().endswith(".tar.gz"):
        base_dir = os.path.splitext(os.path.splitext(src_path)[0])[0]  # strip .tar.gz
        if not os.path.isdir(base_dir):
            print(f"üì¶ Extracting {os.path.basename(src_path)} ...")
            with tarfile.open(src_path, "r:gz") as tf:
                tf.extractall(os.path.dirname(src_path))
        root_dir = os.path.dirname(src_path)  # search under Downloads after extraction
    else:
        root_dir = src_path

    # 1) Find the folder that has clips/ and tsvs (search any depth)
    def has_cv_files(d):
        clips_ok = os.path.isdir(os.path.join(d, "clips"))
        files = {f.lower() for f in os.listdir(d) if os.path.isfile(os.path.join(d, f))}
        tsv_ok = (
            ("train.tsv" in files)
            and (("dev.tsv" in files) or ("validation.tsv" in files))
            and ("test.tsv" in files)
        ) or ("validated.tsv" in files)  # fallback
        return clips_ok and tsv_ok

    cv_dir = None
    if os.path.isdir(root_dir) and has_cv_files(root_dir):
        cv_dir = root_dir
    else:
        for d, dirs, files in os.walk(root_dir):
            if has_cv_files(d):
                cv_dir = d
                break

    if cv_dir is None:
        raise FileNotFoundError(
            "Could not locate a folder that contains clips/ and train/dev(test)/validated TSVs."
        )

    data_dir  = cv_dir
    clips_dir = os.path.join(cv_dir, "clips")
    print(f"üìÅ Using data_dir: {data_dir}")
    print(f"üéß Using clips_dir: {clips_dir}")

    def has_cv_files(d):
        return (
            os.path.isdir(os.path.join(d, "clips")) and
            (
                (os.path.exists(os.path.join(d, "train.tsv")) and
                 (os.path.exists(os.path.join(d, "dev.tsv")) or os.path.exists(os.path.join(d, "validation.tsv"))) and
                 os.path.exists(os.path.join(d, "test.tsv")))
                or os.path.exists(os.path.join(d, "validated.tsv"))
            )
        )

    cv_dir = None
    if os.path.isdir(root_dir) and has_cv_files(root_dir):
        cv_dir = root_dir
    else:
        for d, _, _ in os.walk(root_dir):
            if has_cv_files(d):
                cv_dir = d
                break
    if cv_dir is None:
        raise FileNotFoundError("Could not locate Common Voice 'ur' folder with clips/ and TSVs.")

    print(f"üìÅ Using data_dir: {cv_dir}")
    clips_dir = os.path.join(cv_dir, "clips")
    print(f"üéß Using clips_dir: {clips_dir}")

    # --- build data_files map (dev vs validation) ---
    train_tsv = os.path.join(cv_dir, "train.tsv")
    dev_tsv   = os.path.join(cv_dir, "dev.tsv")
    val_tsv   = os.path.join(cv_dir, "validation.tsv")
    test_tsv  = os.path.join(cv_dir, "test.tsv")
    validated = os.path.join(cv_dir, "validated.tsv")

    if os.path.exists(train_tsv) and (os.path.exists(dev_tsv) or os.path.exists(val_tsv)) and os.path.exists(test_tsv):
        data_files = {
            "train": train_tsv,
            "validation": dev_tsv if os.path.exists(dev_tsv) else val_tsv,
            "test": test_tsv,
        }
    elif os.path.exists(validated):
        data_files = {"train": validated}
    else:
        raise FileNotFoundError("Expected train/dev(or validation)/test TSVs or validated.tsv not found.")

    # --- NEW: force all TSV columns to string to avoid Arrow casting issues ---
    # read header columns from a representative TSV (train preferred)
    header_probe = data_files.get("train") or next(iter(data_files.values()))
    with open(header_probe, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        header = next(reader)
    string_features = Features({col: Value("string") for col in header})

    # load TSVs with forced string schema
    ds_all = load_dataset(
        "csv",
        data_files=data_files,
        delimiter="\t",
        features=string_features,   # <- key line
    )

    # work out text column
    first_split = next(iter(ds_all.keys()))
    cols = ds_all[first_split].column_names
    text_col_name = "sentence" if "sentence" in cols else ("text" if "text" in cols else None)
    if text_col_name is None:
        raise ValueError(f"No text column found. Columns: {cols}")

    # expand audio paths
    audio_col_name = "path"
    def add_full_audio_path(batch):
        batch[audio_col_name] = [os.path.join(clips_dir, fn) for fn in batch[audio_col_name]]
        return batch
    ds_all = ds_all.map(add_full_audio_path, batched=True)

    # filter rows with missing audio
    def file_exists(batch):
        return [os.path.exists(p) for p in batch[audio_col_name]]
    for split in list(ds_all.keys()):
        before = len(ds_all[split])
        ds_all[split] = ds_all[split].filter(file_exists, batched=True)
        after = len(ds_all[split])
        print(f"‚úÖ {split}: kept {after}/{before} rows (dropped {before - after} missing files)")

    # standardize + cast audio
    ds_all = ds_all.rename_column(audio_col_name, AUDIO_COL)
    if text_col_name != TEXT_COL:
        ds_all = ds_all.rename_column(text_col_name, TEXT_COL)
    ds_all = ds_all.cast_column(AUDIO_COL, Audio(sampling_rate=TARGET_SR, mono=True, decode=True))

    # final dict
    dd = {}
    for split in ["train", "validation", "test"]:
        if split in ds_all:
            dd[split] = ensure_audio_and_text(ds_all[split])
    return DatasetDict(dd)

# def load_commonvoice_raw():
#     ds_all = load_dataset("mozilla-foundation/common_voice_23_0", "ur", revision="9d10386a731ff6e6ed4ec973a4dc204a9820e8c842fbe388bdba0dd205ed5016", trust_remote_code=True, token=True)
#     dd = {}
#     for split in ["train", "validation", "test"]:
#         if split in ds_all:
#             ds = ds_all[split]
#             if "sentence" in ds.column_names:
#                 ds = ds.rename_column("sentence", TEXT_COL)
#             dd[split] = ensure_audio_and_text(ds)
#     return DatasetDict(dd)

# def load_commonvoice_raw():
#     # Root folder with TSVs and clips/
#     data_dir = r"C:\Users\shaider\Downloads\ur-20251104T134315Z-1-001\ur"
#     clips_dir = os.path.join(data_dir, "clips")

#     # Load local Common Voice TSVs
#     ds_all = load_dataset(
#         "csv",
#         data_files={
#             "train": os.path.join(data_dir, "train.tsv"),
#             "validation": os.path.join(data_dir, "dev.tsv"),
#             "test": os.path.join(data_dir, "test.tsv"),
#         },
#         delimiter="\t",
#     )

#     # Common Voice local TSV columns
#     audio_col_name = "path"       # filename only in TSV
#     text_col_name = "sentence"    # text column in TSV

#     # 1) Expand to absolute file paths
#     def add_full_audio_path(batch):
#         batch[audio_col_name] = [os.path.join(clips_dir, fname) for fname in batch[audio_col_name]]
#         return batch

#     ds_all = ds_all.map(add_full_audio_path, batched=True)

#     # 2) Filter out rows where the audio file does NOT exist
#     def file_exists(batch):
#         return [os.path.exists(p) for p in batch[audio_col_name]]

#     for split in list(ds_all.keys()):
#         before = len(ds_all[split])
#         ds_all[split] = ds_all[split].filter(file_exists, batched=True)
#         after = len(ds_all[split])
#         print(f"‚úÖ {split}: kept {after}/{before} rows (dropped {before - after} missing files)")

#     # 3) Standardize column names
#     ds_all = ds_all.rename_column(audio_col_name, AUDIO_COL)
#     ds_all = ds_all.rename_column(text_col_name, TEXT_COL)

#     # 4) Cast audio AFTER filtering
#     ds_all = ds_all.cast_column(AUDIO_COL, Audio(sampling_rate=TARGET_SR, mono=True, decode=True))

#     # 5) Ensure column names for safety (idempotent)
#     dd = {}
#     for split in ["train", "validation", "test"]:
#         if split in ds_all:
#             dd[split] = ensure_audio_and_text(ds_all[split])

#     return DatasetDict(dd)


In [None]:
# ================================
# LOAD AND PREPARE DATASETS (DYNAMIC)
# ================================

print("\n" + "="*50)
print("üìä LOADING DATASETS (dynamic)")
print("="*50)

# 1) Load raw DatasetDicts (unchanged)
print("Loading CommonVoice...")
commonvoice = load_commonvoice_v23_local(r"C:\Users\shaider\Downloads\cv-corpus-23.0-2025-09-05\ur")

print("Loading FLEURS...")
fleurs = load_fleurs_raw()

print("Loading CSaLT...")
csalt = load_csalt_raw()

def merge_all_splits(ds_dict):
    """
    Concatenate all available splits from a DatasetDict.
    This mirrors your previous logic (train+validation+test).
    """
    available = [ds_dict[s] for s in ["train", "validation", "test"] if s in ds_dict]
    if not available:
        raise ValueError("No splits found to merge in provided DatasetDict.")
    return concatenate_datasets(available)

def safe_select_columns(ds, wanted_cols):
    """
    Select only the columns that actually exist to avoid KeyError
    if a source is missing one. (Typically both AUDIO_COL and TEXT_COL exist.)
    """
    keep = [c for c in wanted_cols if c in ds.column_names]
    if not keep:
        raise ValueError(
            f"None of the requested columns {wanted_cols} are present in {ds.column_names}"
        )
    return ds.select_columns(keep)

# 2) Build a prepared (merged + column-selected) registry for each dataset name
prepared_registry = {
    "commonvoice": safe_select_columns(merge_all_splits(commonvoice), [AUDIO_COL, TEXT_COL]),
    "fleurs":      safe_select_columns(merge_all_splits(fleurs),      [AUDIO_COL, TEXT_COL]),
    "csalt":       safe_select_columns(merge_all_splits(csalt),       [AUDIO_COL, TEXT_COL]),
}

# 3) Helpers to resolve user choices into a list of prepared datasets
def resolve_choice(name: str | None):
    if name is None:
        return None
    key = name.strip().lower()
    if key not in prepared_registry:
        valid = ", ".join(sorted(prepared_registry.keys()))
        raise ValueError(f"Unknown dataset '{name}'. Valid options: {valid} or None.")
    return prepared_registry[key]

def build_pool(*names):
    """
    Given up to three names/None, return a concatenated dataset
    of all non-None selections. Requires at least one non-None.
    """
    selected = [resolve_choice(n) for n in names if n is not None]
    if not selected:
        raise ValueError("At least one dataset must be selected to build a pool.")
    if len(selected) == 1:
        return selected[0]
    return concatenate_datasets(selected)

# 4) Resolve TRAIN and TEST pools from the six choices
#    (Shuffle + optional subsample mirrors your original behavior)
print("\n" + "-"*50)
print("üß© Building TRAIN pool from user choices...")
train_pool = build_pool(train_1, train_2, train_3).shuffle(seed=RANDOM_SEED)

# Optional subsampling (disabled if TRAIN_NUM_SAMPLES=None)
train_ds = subsample_after_shuffle(train_pool, TRAIN_NUM_SAMPLES, seed=RANDOM_SEED)

# Optional: carve validation from train (unchanged)
validation_ds = None
if EVAL_FROM_TRAIN_PCT > 0.0:
    n_eval = int(len(train_ds) * EVAL_FROM_TRAIN_PCT)
    if n_eval > 0:
        validation_ds = train_ds.select(range(n_eval))
        train_ds = train_ds.select(range(n_eval, len(train_ds)))
        print(f"‚úÖ Validation carved from train: {len(validation_ds)}")

print("\n" + "-"*50)
print("üß™ Building TEST pool from user choices...")
test_pool = build_pool(test_1, test_2, test_3).shuffle(seed=RANDOM_SEED)

# Optional subsampling for test (same helper you already have)
test_ds = subsample_after_shuffle(test_pool, TEST_NUM_SAMPLES, seed=RANDOM_SEED)

# 5) Summaries
def _fmt(x): return x if x is not None else "-"
print("\n" + "="*50)
print("‚úÖ FINAL DATASET SIZES")
print("="*50)
print(f"Train set: {len(train_ds)} samples")
if validation_ds is not None:
    print(f"Validation set: {len(validation_ds)} samples")
print(f"Test set:  {len(test_ds)} samples")

print("\n" + "="*50)
print("üìù DATASET SOURCES (for this run)")
print("="*50)
print(f"train_1: {_fmt(train_1)} | train_2: {_fmt(train_2)} | train_3: {_fmt(train_3)}")
print(f"test_1:  {_fmt(test_1)}  | test_2:  {_fmt(test_2)}  | test_3:  {_fmt(test_3)}")



üìä LOADING DATASETS (dynamic)
Loading CommonVoice...
üìÅ Using data_dir: C:\Users\shaider\Downloads\cv-corpus-23.0-2025-09-05\ur
üéß Using clips_dir: C:\Users\shaider\Downloads\cv-corpus-23.0-2025-09-05\ur\clips
üìÅ Using data_dir: C:\Users\shaider\Downloads\cv-corpus-23.0-2025-09-05\ur
üéß Using clips_dir: C:\Users\shaider\Downloads\cv-corpus-23.0-2025-09-05\ur\clips
‚úÖ train: kept 7336/7336 rows (dropped 0 missing files)
‚úÖ validation: kept 5045/5045 rows (dropped 0 missing files)
‚úÖ test: kept 5088/5088 rows (dropped 0 missing files)
Loading FLEURS...
‚úÖ Loaded FLEURS split for ur_pk with splits: ['train', 'validation', 'test']
‚ö†Ô∏è Could not load FLEURS language code: ur_in
‚ö†Ô∏è Could not load FLEURS language code: ur
‚úÖ Combined FLEURS Urdu splits: train, validation, test
Loading CSaLT...

--------------------------------------------------
üß© Building TRAIN pool from user choices...

--------------------------------------------------
üß™ Building TEST pool from 

# model

In [None]:
# ================================
# MODEL SETUP
# ================================

print("\n" + "="*50)
print("üîß MODEL SETUP")
print("="*50)

# Load processor
processor = WhisperProcessor.from_pretrained(BASE_MODEL_NAME)
tokenizer = processor.tokenizer
feature_extractor = processor.feature_extractor
tokenizer.pad_token = tokenizer.eos_token

print(f"‚úÖ Loaded processor from {BASE_MODEL_NAME}")

# Load base model
print(f"Loading model in {'FP16' if FP16 else 'FP32'} precision...")
model = WhisperForConditionalGeneration.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.float16 if FP16 else torch.float32
)

# ‚úÖ Force Urdu-only transcription mode (no English translation)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.language = "ur"
model.config.task = "transcribe"
model.generation_config.language = "ur"
model.generation_config.task = "transcribe"

print("‚úÖ Configured model for Urdu transcription only (no English translation)")

# Apply LoRA
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="SEQ_2_SEQ_LM",
    target_modules=LORA_TARGET_MODULES
)

model = get_peft_model(model, lora_config)
model.forward = model.base_model.forward

print("\nüìä Trainable Parameters:")
model.print_trainable_parameters()

model = model.to(device)



üîß MODEL SETUP
‚úÖ Loaded processor from openai/whisper-large-v3-turbo
Loading model in FP16 precision...
‚úÖ Configured model for Urdu transcription only (no English translation)

üìä Trainable Parameters:
trainable params: 3,276,800 || all params: 812,154,880 || trainable%: 0.4035


In [None]:
# ================================
# DATA PREPROCESSING
# ================================

print("\n" + "="*50)
print("üîÑ PREPROCESSING DATA")
print("="*50)

def prepare_dataset(batch):
    """Preprocess audio and text for Whisper"""
    audio = batch[AUDIO_COL]
    
    # Process audio
    inputs = processor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_tensors="pt"
    )
    batch["input_features"] = inputs.input_features[0]
    
    # Process text
    tokenized = tokenizer(
        batch[TEXT_COL],
        padding="max_length",
        max_length=MAX_LABEL_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    batch["labels"] = tokenized.input_ids[0]
    
    return batch

# Preprocess datasets
train_ds = train_ds.map(
    prepare_dataset,
    remove_columns=train_ds.column_names,
    desc="Preprocessing train set"
)

if validation_ds:
    validation_ds = validation_ds.map(
        prepare_dataset,
        remove_columns=validation_ds.column_names,
        desc="Preprocessing validation set"
    )

test_ds = test_ds.map(
    prepare_dataset,
    remove_columns=test_ds.column_names,
    desc="Preprocessing test set"
)

print(f"‚úÖ Preprocessing complete")



üîÑ PREPROCESSING DATA
‚úÖ Preprocessing complete


In [None]:
# ================================
# PRE-TRAINING EVALUATION
# ================================

print("\n" + "="*50)
print("üîç PRE-TRAINING WER EVALUATION")
print("="*50)

def evaluate_model(model, test_dataset, device, desc="Evaluating"):
    """Evaluate model and return WER metrics"""
    model.eval()
    
    predictions = []
    references = []
    predictions_raw = []  # Store raw predictions for debugging
    references_raw = []   # Store raw references for debugging
    
    with torch.no_grad():
        for sample in tqdm(test_dataset, desc=desc):
            input_features = torch.tensor(sample["input_features"]).unsqueeze(0).to(device)
            
            if FP16:
                input_features = input_features.half()
            
            pred_ids = model.generate(input_features=input_features)
            pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()
            
            # Get reference from the preprocessed labels
            label_ids = sample["labels"]
            # Remove padding tokens
            label_ids = [id for id in label_ids if id != tokenizer.pad_token_id]
            label_str = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
            
            # Store raw versions
            predictions_raw.append(pred_str)
            references_raw.append(label_str)
            
            # *** APPLY TEXT NORMALIZATION HERE ***
            pred_str_normalized = normalize_urdu_text(pred_str)
            label_str_normalized = normalize_urdu_text(label_str)
            
            predictions.append(pred_str_normalized)
            references.append(label_str_normalized)
    
    # Calculate WER on normalized text
    sample_wers = [jiwer_wer(ref, pred) for ref, pred in zip(references, predictions)]
    overall_wer = np.mean(sample_wers)
    
    return {
        "predictions": predictions,
        "references": references,
        "predictions_raw": predictions_raw,  # Include raw for debugging
        "references_raw": references_raw,
        "sample_wers": sample_wers,
        "overall_wer": overall_wer
    }

# Evaluate before fine-tuning
pre_results = evaluate_model(model, test_ds, device, desc="Pre-training evaluation")
pre_training_wer = pre_results["overall_wer"]

print(f"\nüìä PRE-TRAINING WER: {pre_training_wer:.4f} ({pre_training_wer*100:.2f}%)")

# Optional: Show some examples to verify normalization is working
print("\nüîç Sample Normalization Examples:")
for i in range(min(3, len(pre_results["predictions"]))):
    print(f"\nExample {i+1}:")
    print(f"  Raw Reference:  {pre_results['references_raw'][i][:100]}")
    print(f"  Norm Reference: {pre_results['references'][i][:100]}")
    print(f"  Raw Prediction: {pre_results['predictions_raw'][i][:100]}")
    print(f"  Norm Prediction: {pre_results['predictions'][i][:100]}")



üîç PRE-TRAINING WER EVALUATION


Pre-training evaluation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 471/471 [14:35<00:00,  1.86s/it]


üìä PRE-TRAINING WER: 0.3101 (31.01%)

üîç Sample Normalization Examples:

Example 1:
  Raw Reference:  ⁄©ÿ±€å⁄∫ ÿßÿ≥ÿ™ÿπŸÖÿßŸÑ ÿßŸàÿ± €å€Å ÿ®ÿß€Åÿ± ⁄©ÿß ⁄©⁄æÿßŸÜÿß ÿ®⁄æ€å ⁄Ü⁄æŸà⁄ë ÿØ€å⁄∫ €ÅŸÖ ÿßÿµŸÑ ŸÖ€å⁄∫ €ÅŸÖ ŸÑŸà⁄Ø ÿ±€Åÿ™€í ÿ®⁄æ€å ÿ®ÿß€Åÿ± €Å€å⁄∫ ÿ™Ÿà Ÿæ⁄æÿ± €ÅŸÖÿßÿ±ÿß ÿßÿ≥ÿ™ÿπ
  Norm Reference: ⁄©ÿ±€å⁄∫ ÿßÿ≥ÿ™ÿπŸÖÿßŸÑ ÿßŸàÿ± €å€Å ÿ®ÿß€Åÿ± ⁄©ÿß ⁄©⁄æÿßŸÜÿß ÿ®⁄æ€å ⁄Ü⁄æŸà⁄ë ÿØ€å⁄∫ €ÅŸÖ ÿßÿµŸÑ ŸÖ€å⁄∫ €ÅŸÖ ŸÑŸà⁄Ø ÿ±€Åÿ™€í ÿ®⁄æ€å ÿ®ÿß€Åÿ± €Å€å⁄∫ ÿ™Ÿà Ÿæ⁄æÿ± €ÅŸÖÿßÿ±ÿß ÿßÿ≥ÿ™ÿπ
  Raw Prediction: ⁄©ÿ±€å⁄∫ ÿßÿ≥ÿ™ÿπŸÖÿßŸÑ ÿßŸàÿ± €å€Å ÿ®ÿß€Åÿ± ⁄©ÿß ÿÆÿßŸÜ€Å ÿ®⁄æ€å ⁄Ü⁄æŸà⁄ëÿ™€í €Å€å⁄∫ €ÅŸÖ ŸÖÿ´ŸÑ ŸÖ€å⁄∫ €ÅŸÖ ŸÑŸà⁄Ø ÿ±€Åÿ™€í €Å€å⁄∫ ÿ®⁄æ€å ÿ®ÿß€Åÿ± €Å€å⁄∫ ÿ™Ÿà Ÿæ⁄æÿ± €ÅŸÖÿßÿ±ÿß
  Norm Prediction: ⁄©ÿ±€å⁄∫ ÿßÿ≥ÿ™ÿπŸÖÿßŸÑ ÿßŸàÿ± €å€Å ÿ®ÿß€Åÿ± ⁄©ÿß ÿÆÿßŸÜ€Å ÿ®⁄æ€å ⁄Ü⁄æŸà⁄ëÿ™€í €Å€å⁄∫ €ÅŸÖ ŸÖÿ´ŸÑ ŸÖ€å⁄∫ €ÅŸÖ ŸÑŸà⁄Ø ÿ±€Åÿ™€í €Å€å⁄∫ ÿ®⁄æ€å ÿ®ÿß€Åÿ± €Å€å⁄∫ ÿ™Ÿà Ÿæ⁄æÿ± €ÅŸÖÿßÿ±ÿß

Example 2:
  Raw Reference:  ÿßÿ≥ŸÑÿßŸÖ ÿπŸÑ€å⁄©ŸÖ ÿπŸÑ€åŸÜÿß
  Norm Reference: ÿßÿ≥ŸÑÿßŸÖ ÿπŸÑ€å⁄©ŸÖ ÿπŸÑ€åŸÜÿß
  Raw Pr




In [None]:
# ================================
# TRAINING SETUP
# ================================

print("\n" + "="*50)
print("üèãÔ∏è TRAINING SETUP")
print("="*50)

def collate_fn(batch):
    """Collate function for DataLoader"""
    input_feats = torch.stack([
        torch.tensor(item["input_features"], dtype=torch.float32)
        for item in batch
    ])
    
    label_tensors = pad_sequence(
        [torch.tensor(item["labels"], dtype=torch.long) for item in batch],
        batch_first=True,
        padding_value=tokenizer.pad_token_id
    )

    return {
        "input_features": input_feats,
        "labels": label_tensors
    }

# Create DataLoader
train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=0,
    collate_fn=collate_fn
)

# Setup optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

# Setup gradient scaler for FP16
scaler = torch.cuda.amp.GradScaler() if FP16 and torch.cuda.is_available() else None

print(f"‚úÖ Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"‚úÖ Batch size: {BATCH_SIZE}")
print(f"‚úÖ Total batches per epoch: {len(train_loader)}")
print(f"‚úÖ Mixed precision (FP16): {FP16}")



üèãÔ∏è TRAINING SETUP
‚úÖ Optimizer: AdamW (lr=0.0001)
‚úÖ Batch size: 8
‚úÖ Total batches per epoch: 1000
‚úÖ Mixed precision (FP16): True


  scaler = torch.cuda.amp.GradScaler() if FP16 and torch.cuda.is_available() else None


## training

In [None]:
# ================================
# TRAINING
# ================================

print("\n" + "="*50)
print("üöÄ STARTING TRAINING")
print("="*50)

train_start_time = time.time()
model.train()
validation_wers = []  # <---- ADD THIS

for epoch in range(NUM_EPOCHS):
    total_loss = 0.0
    print(f"\nüéØ Epoch {epoch+1}/{NUM_EPOCHS}")
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        try:
            input_feats = batch["input_features"].to(device)
            labels = batch["labels"].to(device)
            
            # Handle FP16 training
            if FP16 and scaler is not None:
                with torch.cuda.amp.autocast():
                    outputs = model(input_features=input_feats, labels=labels)
                    loss = outputs.loss
                
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                if FP16:
                    input_feats = input_feats.half()
                
                outputs = model(input_features=input_feats, labels=labels)
                loss = outputs.loss
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
            total_loss += loss.item()
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error processing batch: {e}")
            continue
    
    avg_loss = total_loss / len(train_loader)
    print(f"‚úÖ Epoch {epoch+1} complete ‚Äî Avg Loss: {avg_loss:.4f}")
    
    # Validation if available
    if validation_ds:
        val_results = evaluate_model(model, validation_ds, device, desc="Validation")
        val_wer = round(val_results["overall_wer"], 4)
        print(f"üîé Validation WER: {val_wer:.4f}")
        validation_wers.append(val_wer)  # <---- ADD THIS
        model.train()  # Back to training mode

train_end_time = time.time()
train_duration_secs = int(train_end_time - train_start_time)
train_duration_hms = str(datetime.timedelta(seconds=train_duration_secs))

print(f"\n‚úÖ Training complete! Duration: {train_duration_hms}")



üöÄ STARTING TRAINING

üéØ Epoch 1/2


  with torch.cuda.amp.autocast():
Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [56:33<00:00,  3.39s/it]


‚úÖ Epoch 1 complete ‚Äî Avg Loss: 0.1421

üéØ Epoch 2/2


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [55:53<00:00,  3.35s/it]

‚úÖ Epoch 2 complete ‚Äî Avg Loss: 0.0965

‚úÖ Training complete! Duration: 1:52:27





## eval

In [None]:
# ================================
# POST-TRAINING EVALUATION
# ================================

print("\n" + "="*50)
print("üìè POST-TRAINING WER EVALUATION")
print("="*50)

# Evaluate after fine-tuning
post_results = evaluate_model(model, test_ds, device, desc="Post-training evaluation")
post_training_wer = post_results["overall_wer"]

print(f"\nüìä POST-TRAINING WER: {post_training_wer:.4f} ({post_training_wer*100:.2f}%)")

# Calculate improvement
wer_improvement = pre_training_wer - post_training_wer
wer_improvement_pct = (wer_improvement / pre_training_wer) * 100

print(f"\nüéâ WER IMPROVEMENT: {wer_improvement:.4f} ({wer_improvement_pct:.2f}%)")
print(f"   Pre-training:  {pre_training_wer:.4f}")
print(f"   Post-training: {post_training_wer:.4f}")



üìè POST-TRAINING WER EVALUATION


Post-training evaluation: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 471/471 [14:39<00:00,  1.87s/it]


üìä POST-TRAINING WER: 0.3085 (30.85%)

üéâ WER IMPROVEMENT: 0.0016 (0.50%)
   Pre-training:  0.3101
   Post-training: 0.3085





In [28]:
# ================================
# DECODER SWEEP
# ================================
print("\n" + "="*50)
print("üî¨ DECODER SWEEP: Evaluating generation settings")
print("="*50)

# Subsample test set for faster sweep
SWEEP_SAMPLE_SIZE = min(500, len(test_ds))
sweep_test_ds = test_ds.shuffle(seed=RANDOM_SEED).select(range(SWEEP_SAMPLE_SIZE))
print(f"üìä Using {SWEEP_SAMPLE_SIZE} samples for decoder sweep")

from itertools import product

# --- keep this grid small / meaningful; remove params not accepted by generate()
param_grid = {
    "num_beams": [1, 2, 3],                   # keep small for quick test (increase on A40)
    "length_penalty": [None, 0.9, 1.0, 1.1],
    "no_repeat_ngram_size": [None, 2, 3],
    "repetition_penalty": [None, 1.05, 1.1],
    "do_sample": [False],                     # deterministic for ASR
    "max_new_tokens": [225],
}

# Build configs, normalize and dedupe
decoder_configs = []
seen = set()
idx = 0

print("\n‚öôÔ∏è Building smart decoder grid...")

for combo in product(*param_grid.values()):
    cfg = dict(zip(param_grid.keys(), combo))

    # Normalize greedy: if num_beams==1, strip beam-only params
    if cfg["num_beams"] == 1:
        cfg["length_penalty"] = None
        cfg["no_repeat_ngram_size"] = None
        cfg["repetition_penalty"] = None

    # Skip logically invalid combos (no-repeat only makes sense with beams>1)
    if cfg["num_beams"] == 1 and cfg["no_repeat_ngram_size"] is not None:
        continue
    if cfg["do_sample"] and cfg["num_beams"] > 1:
        continue

    # Create deterministic signature for deduplication (ignore name)
    sig_items = tuple(sorted((k, v) for k, v in cfg.items()))
    if sig_items in seen:
        continue
    seen.add(sig_items)

    cfg["name"] = f"cfg_{idx}"
    decoder_configs.append(cfg)
    idx += 1

print(f"üß™ Total decoder configs generated: {len(decoder_configs)}")

# Helper: sanitize gen kwargs (remove None / pandas NA and convert numpy types)
def sanitize_gen_kwargs(cfg):
    gen = {}
    for k, v in cfg.items():
        if k == "name":
            continue
        # drop None/NaN
        if v is None:
            continue
        if (isinstance(v, float) and np.isnan(v)) or pd.isna(v):
            continue
        # convert numpy scalar types to native python
        if isinstance(v, (np.integer, np.int64)):
            v = int(v)
        elif isinstance(v, (np.floating, np.float64)):
            v = float(v)
        elif isinstance(v, (np.bool_,)):
            v = bool(v)
        gen[k] = v
    return gen



üî¨ DECODER SWEEP: Evaluating generation settings
üìä Using 471 samples for decoder sweep

‚öôÔ∏è Building smart decoder grid...
üß™ Total decoder configs generated: 73


In [None]:
# ===============================
# EVALUATION CODE (per-config)
# ===============================
def evaluate_decoder_config(model, test_dataset, device, config, desc=""):
    model.eval()
    predictions, references, predictions_raw, references_raw = [], [], [], []

    gen_kwargs = sanitize_gen_kwargs(config)

    # default
    if "num_beams" not in gen_kwargs:
        gen_kwargs["num_beams"] = 1

    with torch.no_grad():
        for sample in tqdm(test_dataset, desc=desc, leave=False):
            # ensure tensor on CPU->GPU
            input_features = torch.tensor(sample["input_features"]).unsqueeze(0).to(device)
            if FP16:
                input_features = input_features.half()

            pred_ids = model.generate(input_features=input_features, **gen_kwargs)
            pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()

            label_ids = [id for id in sample["labels"] if id != tokenizer.pad_token_id]
            label_str = tokenizer.decode(label_ids, skip_special_tokens=True).strip()

            # store raw
            predictions_raw.append(pred_str)
            references_raw.append(label_str)

            # normalize for scoring
            pred_str_norm = normalize_urdu_text(pred_str)
            label_str_norm = normalize_urdu_text(label_str)

            predictions.append(pred_str_norm)
            references.append(label_str_norm)

    sample_wers = [jiwer_wer(r, p) for r, p in zip(references, predictions)]
    mean_wer = float(np.mean(sample_wers))
    return {
        "overall_wer": mean_wer,
        "predictions": predictions,
        "references": references,
        "predictions_raw": predictions_raw,
        "references_raw": references_raw,
        "sample_wers": sample_wers,
    }


In [None]:
# ===============================
# RUN SWEEP
# ===============================
decoder_results = []
decoder_details = {}  # keep per-config detailed outputs if you want

print("\nüîÑ Running decoder sweep...")
for i, config in enumerate(decoder_configs, 1):
    config_name = config["name"]
    desc = f"[{i}/{len(decoder_configs)}] {config_name}"
    print(f"\n‚öôÔ∏è Testing: {config_name}")

    try:
        out = evaluate_decoder_config(model, sweep_test_ds, device, config, desc=desc)
        wer_score = out["overall_wer"]

        result = {
            "config_name": config_name,
            "wer": round(wer_score, 4),
            **{k: v for k, v in config.items() if k != "name"}
        }
        decoder_results.append(result)
        decoder_details[config_name] = out  # store details

        print(f"   ‚úÖ WER: {wer_score:.4f}")

    except Exception as e:
        print(f"   ‚ùå Failed: {e}")
        continue

# Process results
if not decoder_results:
    raise RuntimeError("Decoder sweep yielded no successful results.")

decoder_df = pd.DataFrame(decoder_results).sort_values("wer").reset_index(drop=True)
print("\n" + "="*50)
print("üìä DECODER SWEEP RESULTS (sorted by WER)")
print("="*50)
print(decoder_df.to_string(index=False))

best_row = decoder_df.iloc[0]
best_cfg = best_row  # pandas Series

# Build best_gen_kwargs (cleaned)
best_gen_kwargs = sanitize_gen_kwargs(best_cfg.to_dict())

# Get detailed optimized_results from decoder_details
best_name = best_cfg["config_name"]
optimized_results = decoder_details.get(best_name, None)
if optimized_results is None:
    # fallback: re-evaluate and capture full outputs
    optimized_results = evaluate_decoder_config(model, test_ds, device, best_cfg.to_dict(), desc="Recompute best config")

optimized_wer = optimized_results["overall_wer"]

print("\n" + "="*50)
print("üèÜ BEST DECODER CONFIGURATION")
print("="*50)
print(f"Config: {best_name}")
print(f"WER: {optimized_wer:.4f}")
print("Parameters:")
for k, v in best_gen_kwargs.items():
    print(f"  - {k}: {v}")

# comparison to greedy (robust)
greedy_row = decoder_df[decoder_df.get("num_beams") == 1]
if not greedy_row.empty:
    greedy_wer = greedy_row.iloc[0]["wer"]
    sweep_improvement = greedy_wer - optimized_wer
    sweep_improvement_pct = (sweep_improvement / greedy_wer) * 100 if greedy_wer != 0 else 0.0
    print("\nüìà Improvement over greedy decoding:")
    print(f"   Greedy WER: {greedy_wer:.4f}")
    print(f"   Best WER: {optimized_wer:.4f}")
    print(f"   Improvement: {sweep_improvement:.4f} ({sweep_improvement_pct:.2f}%)")

# Build decoder_sweep_data for saving/summary
decoder_improvement = post_training_wer - optimized_wer
decoder_improvement_pct = (decoder_improvement / post_training_wer) * 100 if post_training_wer != 0 else 0.0

decoder_sweep_data = {
    "best_decoder_config": best_name,
    "best_decoder_config_params": best_gen_kwargs,
    "decoder_sweep_wer": round(optimized_wer, 4),
    "decoder_improvement": round(decoder_improvement, 4),
    "decoder_improvement_percent": round(decoder_improvement_pct, 2),
    "decoder_sweep_df": decoder_df,
    "optimized_results": optimized_results
}

print("\nüéâ Decoder sweep complete!")



üîÑ Running decoder sweep...

‚öôÔ∏è Testing: cfg_0


                                                               

   ‚úÖ WER: 0.2851

‚öôÔ∏è Testing: cfg_1


                                                               

   ‚úÖ WER: 0.2694

‚öôÔ∏è Testing: cfg_2


                                                               

   ‚úÖ WER: 0.2560

‚öôÔ∏è Testing: cfg_3


                                                               

   ‚úÖ WER: 0.2560

‚öôÔ∏è Testing: cfg_4


                                                               

   ‚úÖ WER: 0.3627

‚öôÔ∏è Testing: cfg_5


                                                               

   ‚úÖ WER: 0.3602

‚öôÔ∏è Testing: cfg_6


                                                               

   ‚úÖ WER: 0.3583

‚öôÔ∏è Testing: cfg_7


                                                               

   ‚úÖ WER: 0.3019

‚öôÔ∏è Testing: cfg_8


                                                               

   ‚úÖ WER: 0.2994

‚öôÔ∏è Testing: cfg_9


                                                                

   ‚úÖ WER: 0.2980

‚öôÔ∏è Testing: cfg_10


                                                                 

   ‚úÖ WER: 0.2684

‚öôÔ∏è Testing: cfg_11


                                                                 

   ‚úÖ WER: 0.2558

‚öôÔ∏è Testing: cfg_12


                                                                 

   ‚úÖ WER: 0.2556

‚öôÔ∏è Testing: cfg_13


                                                                 

   ‚úÖ WER: 0.3624

‚öôÔ∏è Testing: cfg_14


                                                                 

   ‚úÖ WER: 0.3591

‚öôÔ∏è Testing: cfg_15


                                                                 

   ‚úÖ WER: 0.3579

‚öôÔ∏è Testing: cfg_16


                                                                 

   ‚úÖ WER: 0.3003

‚öôÔ∏è Testing: cfg_17


                                                                 

   ‚úÖ WER: 0.2983

‚öôÔ∏è Testing: cfg_18


                                                                 

   ‚úÖ WER: 0.2976

‚öôÔ∏è Testing: cfg_19


                                                                 

   ‚úÖ WER: 0.2694

‚öôÔ∏è Testing: cfg_20


                                                                 

   ‚úÖ WER: 0.2560

‚öôÔ∏è Testing: cfg_21


                                                                 

   ‚úÖ WER: 0.2560

‚öôÔ∏è Testing: cfg_22


                                                                 

   ‚úÖ WER: 0.3627

‚öôÔ∏è Testing: cfg_23


                                                                 

   ‚úÖ WER: 0.3602

‚öôÔ∏è Testing: cfg_24


                                                                 

   ‚úÖ WER: 0.3583

‚öôÔ∏è Testing: cfg_25


                                                                 

   ‚úÖ WER: 0.3019

‚öôÔ∏è Testing: cfg_26


                                                                 

   ‚úÖ WER: 0.2994

‚öôÔ∏è Testing: cfg_27


                                                                 

   ‚úÖ WER: 0.2980

‚öôÔ∏è Testing: cfg_28


                                                                 

   ‚úÖ WER: 0.2702

‚öôÔ∏è Testing: cfg_29


                                                                 

   ‚úÖ WER: 0.2567

‚öôÔ∏è Testing: cfg_30


                                                                 

   ‚úÖ WER: 0.2570

‚öôÔ∏è Testing: cfg_31


                                                                 

   ‚úÖ WER: 0.3635

‚öôÔ∏è Testing: cfg_32


                                                                 

   ‚úÖ WER: 0.3604

‚öôÔ∏è Testing: cfg_33


                                                                 

   ‚úÖ WER: 0.3587

‚öôÔ∏è Testing: cfg_34


                                                                 

   ‚úÖ WER: 0.3034

‚öôÔ∏è Testing: cfg_35


                                                                 

   ‚úÖ WER: 0.3020

‚öôÔ∏è Testing: cfg_36


                                                                 

   ‚úÖ WER: 0.2998

‚öôÔ∏è Testing: cfg_37


                                                                 

   ‚úÖ WER: 0.2661

‚öôÔ∏è Testing: cfg_38


                                                                 

   ‚úÖ WER: 0.2541

‚öôÔ∏è Testing: cfg_39


                                                                 

   ‚úÖ WER: 0.2551

‚öôÔ∏è Testing: cfg_40


                                                                 

   ‚úÖ WER: 0.3592

‚öôÔ∏è Testing: cfg_41


                                                                 

   ‚úÖ WER: 0.3569

‚öôÔ∏è Testing: cfg_42


                                                                 

   ‚úÖ WER: 0.3569

‚öôÔ∏è Testing: cfg_43


                                                                 

   ‚úÖ WER: 0.2984

‚öôÔ∏è Testing: cfg_44


                                                                 

   ‚úÖ WER: 0.2999

‚öôÔ∏è Testing: cfg_45


                                                                 

   ‚úÖ WER: 0.2989

‚öôÔ∏è Testing: cfg_46


                                                                 

   ‚úÖ WER: 0.2652

‚öôÔ∏è Testing: cfg_47


                                                                 

   ‚úÖ WER: 0.2540

‚öôÔ∏è Testing: cfg_48


                                                                 

   ‚úÖ WER: 0.2539

‚öôÔ∏è Testing: cfg_49


                                                                 

   ‚úÖ WER: 0.3585

‚öôÔ∏è Testing: cfg_50


                                                                 

   ‚úÖ WER: 0.3562

‚öôÔ∏è Testing: cfg_51


                                                                 

   ‚úÖ WER: 0.3562

‚öôÔ∏è Testing: cfg_52


                                                                 

   ‚úÖ WER: 0.2981

‚öôÔ∏è Testing: cfg_53


                                                                 

   ‚úÖ WER: 0.2997

‚öôÔ∏è Testing: cfg_54


                                                                 

   ‚úÖ WER: 0.2963

‚öôÔ∏è Testing: cfg_55


                                                                 

   ‚úÖ WER: 0.2661

‚öôÔ∏è Testing: cfg_56


                                                                 

   ‚úÖ WER: 0.2541

‚öôÔ∏è Testing: cfg_57


                                                                 

   ‚úÖ WER: 0.2551

‚öôÔ∏è Testing: cfg_58


                                                                 

   ‚úÖ WER: 0.3592

‚öôÔ∏è Testing: cfg_59


                                                                 

   ‚úÖ WER: 0.3569

‚öôÔ∏è Testing: cfg_60


                                                                 

   ‚úÖ WER: 0.3569

‚öôÔ∏è Testing: cfg_61


                                                                 

   ‚úÖ WER: 0.2984

‚öôÔ∏è Testing: cfg_62


                                                                 

   ‚úÖ WER: 0.2999

‚öôÔ∏è Testing: cfg_63


                                                                 

   ‚úÖ WER: 0.2989

‚öôÔ∏è Testing: cfg_64


                                                                 

   ‚úÖ WER: 0.2663

‚öôÔ∏è Testing: cfg_65


                                                                 

   ‚úÖ WER: 0.2544

‚öôÔ∏è Testing: cfg_66


                                                                 

   ‚úÖ WER: 0.2556

‚öôÔ∏è Testing: cfg_67


                                                                 

   ‚úÖ WER: 0.3604

‚öôÔ∏è Testing: cfg_68


                                                                 

   ‚úÖ WER: 0.3569

‚öôÔ∏è Testing: cfg_69


                                                                 

   ‚úÖ WER: 0.3570

‚öôÔ∏è Testing: cfg_70


                                                                 

   ‚úÖ WER: 0.2993

‚öôÔ∏è Testing: cfg_71


                                                                 

   ‚úÖ WER: 0.3014

‚öôÔ∏è Testing: cfg_72


                                                                 

   ‚úÖ WER: 0.2997

üìä DECODER SWEEP RESULTS (sorted by WER)
config_name    wer  num_beams  length_penalty  no_repeat_ngram_size  repetition_penalty  do_sample  max_new_tokens
     cfg_48 0.2539          3             0.9                   NaN                1.10      False             225
     cfg_47 0.2540          3             0.9                   NaN                1.05      False             225
     cfg_38 0.2541          3             NaN                   NaN                1.05      False             225
     cfg_56 0.2541          3             1.0                   NaN                1.05      False             225
     cfg_65 0.2544          3             1.1                   NaN                1.05      False             225
     cfg_57 0.2551          3             1.0                   NaN                1.10      False             225
     cfg_39 0.2551          3             NaN                   NaN                1.10      False             225
     cfg_12 0.255



In [None]:
# ================================
# SAVE RESULTS
# ================================

print("\\n" + "="*50)
print("üíæ SAVING RESULTS")
print("="*50)

overall_end_time = time.time()
overall_duration_secs = int(overall_end_time - overall_start_time)
overall_duration_hms = str(datetime.timedelta(seconds=overall_duration_secs))

# Create output directory
CSV_OUTPUT_DIR = f"./experiments/{EXPERIMENT_NAME}"
os.makedirs(CSV_OUTPUT_DIR, exist_ok=True)
print(f"üìÅ Created directory: {CSV_OUTPUT_DIR}")

# Save original post-training predictions
samplewise_data = []
for i in range(len(post_results["predictions"])):
    samplewise_data.append({
        "reference_raw": post_results["references_raw"][i],
        "reference_normalized": post_results["references"][i],
        "prediction_raw": post_results["predictions_raw"][i],
        "prediction_normalized": post_results["predictions"][i],
        "wer": round(post_results["sample_wers"][i], 4)
    })

NEW_PREDICTIONS_CSV = f"{CSV_OUTPUT_DIR}/{PREDICTIONS_CSV}"
pd.DataFrame(samplewise_data).to_csv(NEW_PREDICTIONS_CSV, index=False)
print(f"üìÑ Saved predictions: {NEW_PREDICTIONS_CSV}")

# Save optimized predictions (from decoder sweep)
if 'decoder_sweep_data' in locals():
    optimized_predictions_csv = f"{CSV_OUTPUT_DIR}/{EXPERIMENT_NAME}_predictions_optimized.csv"
    optimized_data = [{
        "reference_raw": decoder_sweep_data["optimized_results"]["references_raw"][i],
        "reference_normalized": decoder_sweep_data["optimized_results"]["references"][i],
        "prediction_raw": decoder_sweep_data["optimized_results"]["predictions_raw"][i],
        "prediction_normalized": decoder_sweep_data["optimized_results"]["predictions"][i],
        "wer": round(decoder_sweep_data["optimized_results"]["sample_wers"][i], 4)
    } for i in range(len(decoder_sweep_data["optimized_results"]["predictions"]))]
    
    pd.DataFrame(optimized_data).to_csv(optimized_predictions_csv, index=False)
    print(f"üìÑ Saved optimized predictions: {optimized_predictions_csv}")
    
    # Save decoder sweep results
    SWEEP_CSV = f"{CSV_OUTPUT_DIR}/decoder_sweep_results.csv"
    decoder_sweep_data["decoder_sweep_df"].to_csv(SWEEP_CSV, index=False)
    print(f"üìÑ Saved decoder sweep results: {SWEEP_CSV}")

# Save run summary
summary_data = {
    "experiment_name": EXPERIMENT_NAME,
    "base_model": BASE_MODEL_NAME,
    "lora_r": LORA_R,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": LORA_DROPOUT,
    "target_modules": str(LORA_TARGET_MODULES),
    "learning_rate": LEARNING_RATE,
    "batch_size": BATCH_SIZE,
    "num_train_epochs": NUM_EPOCHS,
    "train_1": train_1 if train_1 is not None else "-",
    "train_2": train_2 if train_2 is not None else "-",
    "train_3": train_3 if train_3 is not None else "-",
    "test_1":  test_1 if test_1 is not None else "-",
    "test_2":  test_2 if test_2 is not None else "-",
    "test_3":  test_3 if test_3 is not None else "-",
    "train_num_samples_cap": TRAIN_NUM_SAMPLES if TRAIN_NUM_SAMPLES else "full",
    "test_num_samples_cap": TEST_NUM_SAMPLES if TEST_NUM_SAMPLES else "full",
    "eval_from_train_pct": EVAL_FROM_TRAIN_PCT,
    "train_set_size": len(train_ds),
    "validation_set_size": len(validation_ds) if validation_ds else 0,
    "test_set_size": len(test_ds),
    "total_start_time": datetime.datetime.fromtimestamp(overall_start_time).strftime("%Y-%m-%d %H:%M:%S"),
    "total_end_time": datetime.datetime.fromtimestamp(overall_end_time).strftime("%Y-%m-%d %H:%M:%S"),
    "total_duration": overall_duration_hms,
    "train_start_time": datetime.datetime.fromtimestamp(train_start_time).strftime("%Y-%m-%d %H:%M:%S"),
    "train_end_time": datetime.datetime.fromtimestamp(train_end_time).strftime("%Y-%m-%d %H:%M:%S"),
    "train_duration": train_duration_hms,
    "fp16_enabled": FP16,
    "pre_training_wer": round(pre_training_wer, 4),
    "post_training_wer": round(post_training_wer, 4),
    "wer_improvement": round(wer_improvement, 4),
    "wer_improvement_percent": round(wer_improvement_pct, 2) if BASE_MODEL_NAME == "openai/whisper-large-v2-turbo" or BASE_MODEL_NAME == "openai/whisper-large-v3-turbo" else '-',
    "wer_improvement_large": round(wer_improvement_pct, 2) if BASE_MODEL_NAME == "openai/whisper-large-v2" or BASE_MODEL_NAME == "openai/whisper-large-v3" else '-'
}

# Add validation WERs
for i, wer in enumerate(validation_wers, start=1):
    summary_data[f"validation_{i}"] = wer

# Add decoder sweep data if available
if 'decoder_sweep_data' in locals():
    summary_data.update({
        "best_decoder_config": decoder_sweep_data["best_decoder_config"],
        "best_decoder_params": str(decoder_sweep_data["best_decoder_config_params"]),
        "decoder_optimized_wer": decoder_sweep_data["decoder_sweep_wer"],
        "decoder_improvement": decoder_sweep_data["decoder_improvement"],
        "decoder_improvement_percent": decoder_sweep_data["decoder_improvement_percent"]
    })

NEW_SUMMARY_CSV = f"{CSV_OUTPUT_DIR}/{SUMMARY_CSV}"
pd.DataFrame([summary_data]).to_csv(NEW_SUMMARY_CSV, index=False)
print(f"üìÑ Saved summary: {NEW_SUMMARY_CSV}")


üíæ SAVING RESULTS
üìÅ Created directory: ./experiments/finetuning-29
üìÑ Saved predictions: ./experiments/finetuning-29/finetuning-29_predictions.csv
üìÑ Saved optimized predictions: ./experiments/finetuning-29/finetuning-29_predictions_optimized.csv
üìÑ Saved decoder sweep results: ./experiments/finetuning-29/decoder_sweep_results.csv
üìÑ Saved summary: ./experiments/finetuning-29/finetuning-29_summary.csv


In [None]:
# ================================
# FINAL SUMMARY
# ================================

print("\\n" + "="*50)
print("üéâ EXPERIMENT COMPLETE")
print("="*50)
print(f"Total duration: {overall_duration_hms}")
print(f"\\nüìä Results:")
print(f"   Pre-training WER:  {pre_training_wer:.4f} ({pre_training_wer*100:.2f}%)")
print(f"   Post-training WER: {post_training_wer:.4f} ({post_training_wer*100:.2f}%)")
print(f"   Fine-tuning Improvement: {wer_improvement:.4f} ({wer_improvement_pct:.2f}%)")

# Add decoder sweep results if available
if 'decoder_sweep_data' in locals():
    print(f"\\nüî¨ Decoder Optimization:")
    print(f"   Best Config: {decoder_sweep_data['best_decoder_config']}")
    print(f"   Optimized WER: {decoder_sweep_data['decoder_sweep_wer']:.4f} ({decoder_sweep_data['decoder_sweep_wer']*100:.2f}%)")
    print(f"   Decoder Improvement: {decoder_sweep_data['decoder_improvement']:.4f} ({decoder_sweep_data['decoder_improvement_percent']:.2f}%)")
    
    # Total improvement from baseline
    total_improvement = pre_training_wer - decoder_sweep_data['decoder_sweep_wer']
    total_improvement_pct = (total_improvement / pre_training_wer) * 100
    print(f"\\nüöÄ TOTAL IMPROVEMENT (Fine-tuning + Decoder):")
    print(f"   Baseline WER: {pre_training_wer:.4f}")
    print(f"   Final WER: {decoder_sweep_data['decoder_sweep_wer']:.4f}")
    print(f"   Total Improvement: {total_improvement:.4f} ({total_improvement_pct:.2f}%)")

print(f"\\nüìÅ Output files:")
print(f"   - {PREDICTIONS_CSV}")
if 'decoder_sweep_data' in locals():
    print(f"   - {EXPERIMENT_NAME}_predictions_optimized.csv")
    print(f"   - decoder_sweep_results.csv")
print(f"   - {SUMMARY_CSV}")
print("="*50)


üéâ EXPERIMENT COMPLETE
Total duration: 1 day, 1:45:46
\nüìä Results:
   Pre-training WER:  0.3101 (31.01%)
   Post-training WER: 0.3085 (30.85%)
   Fine-tuning Improvement: 0.0016 (0.50%)
\nüî¨ Decoder Optimization:
   Best Config: cfg_48
   Optimized WER: 0.2539 (25.39%)
   Decoder Improvement: 0.0546 (17.70%)
\nüöÄ TOTAL IMPROVEMENT (Fine-tuning + Decoder):
   Baseline WER: 0.3101
   Final WER: 0.2539
   Total Improvement: 0.0562 (18.12%)
\nüìÅ Output files:
   - finetuning-29_predictions.csv
   - finetuning-29_predictions_optimized.csv
   - decoder_sweep_results.csv
   - finetuning-29_summary.csv


In [None]:
# ================================
# SAVE FINE-TUNED MODEL
# ================================

print("\\n" + "="*50)
print("üíæ SAVING FINE-TUNED MODEL")
print("="*50)

# Define output directory
OUTPUT_DIR = f"./saved_models/{EXPERIMENT_NAME}"
LORA_ADAPTER_DIR = f"{OUTPUT_DIR}/lora_adapter"
MERGED_MODEL_DIR = f"{OUTPUT_DIR}/merged_model"

# Create directories
os.makedirs(LORA_ADAPTER_DIR, exist_ok=True)
print(f"üìÅ Created directory: {LORA_ADAPTER_DIR}")

# 1. Save LoRA adapter weights (lightweight, recommended)
print("\\nüîß Saving LoRA adapter weights...")
model.save_pretrained(LORA_ADAPTER_DIR)
processor.save_pretrained(LORA_ADAPTER_DIR)
print(f"‚úÖ LoRA adapter saved to: {LORA_ADAPTER_DIR}")

# 2. Save configuration info
config_info = {
    "base_model": BASE_MODEL_NAME,
    "lora_r": LORA_R,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": LORA_DROPOUT,
    "target_modules": LORA_TARGET_MODULES,
    "training_epochs": NUM_EPOCHS,
    "learning_rate": LEARNING_RATE,
    "batch_size": BATCH_SIZE,
    "train_datasets": [train_1, train_2, train_3],
    "test_datasets": [test_1, test_2, test_3],
    "train_samples": len(train_ds),
    "test_samples": len(test_ds),
    "pre_training_wer": round(pre_training_wer, 4),
    "post_training_wer": round(post_training_wer, 4),
    "wer_improvement": round(wer_improvement, 4)
}

# Add decoder sweep info if available
if 'decoder_sweep_data' in locals():
    # Convert numpy/pandas types to native Python types for JSON serialization
    best_decoder_params_clean = {}
    for k, v in decoder_sweep_data["best_decoder_config_params"].items():
        if isinstance(v, (np.integer, np.int64)):
            best_decoder_params_clean[k] = int(v)
        elif isinstance(v, (np.floating, np.float64)):
            best_decoder_params_clean[k] = float(v)
        elif isinstance(v, np.bool_):
            best_decoder_params_clean[k] = bool(v)
        else:
            best_decoder_params_clean[k] = v
    
    config_info.update({
        "decoder_optimization": True,
        "best_decoder_config": decoder_sweep_data["best_decoder_config"],
        "best_decoder_params": best_decoder_params_clean,  # Use cleaned version
        "optimized_wer": decoder_sweep_data["decoder_sweep_wer"],
        "total_wer_improvement": round(pre_training_wer - decoder_sweep_data["decoder_sweep_wer"], 4)
    })
    
    # Save best generation config for inference (with cleaned types)
    generation_config_path = f"{LORA_ADAPTER_DIR}/best_generation_config.json"
    with open(generation_config_path, "w") as f:
        json.dump(best_decoder_params_clean, f, indent=2)  # Use cleaned version
    print(f"‚úÖ Best generation config saved to: {generation_config_path}")
else:
    config_info["decoder_optimization"] = False

with open(f"{LORA_ADAPTER_DIR}/training_config.json", "w") as f:
    json.dump(config_info, f, indent=2)
print(f"‚úÖ Training config saved to: {LORA_ADAPTER_DIR}/training_config.json")

print("\\n" + "="*50)
print("üéâ MODEL SAVING COMPLETE")
print("="*50)
print(f"\\nüì¶ Saved files:")
print(f"   LoRA Adapter: {LORA_ADAPTER_DIR}")
print(f"   - adapter_model.safetensors (LoRA weights)")
print(f"   - adapter_config.json (LoRA configuration)")
print(f"   - preprocessor_config.json & tokenizer files")
print(f"   - training_config.json (your training settings)")
if 'decoder_sweep_data' in locals():
    print(f"   - best_generation_config.json (optimized decoder params)")

print(f"\\nüìù To load the model later, use:")
print(f"""

# Load base model
base_model = WhisperForConditionalGeneration.from_pretrained("{BASE_MODEL_NAME}")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "{LORA_ADAPTER_DIR}")

# Load processor
processor = WhisperProcessor.from_pretrained("{LORA_ADAPTER_DIR}")
""")

if 'decoder_sweep_data' in locals():
    print(f"""
# Load best generation config
with open("{LORA_ADAPTER_DIR}/best_generation_config.json", "r") as f:
    best_gen_config = json.load(f)

# Use for inference
pred_ids = model.generate(input_features=input_features, **best_gen_config)
""")


üíæ SAVING FINE-TUNED MODEL
üìÅ Created directory: ./saved_models/finetuning-29/lora_adapter
\nüîß Saving LoRA adapter weights...
‚úÖ LoRA adapter saved to: ./saved_models/finetuning-29/lora_adapter
‚úÖ Best generation config saved to: ./saved_models/finetuning-29/lora_adapter/best_generation_config.json
‚úÖ Training config saved to: ./saved_models/finetuning-29/lora_adapter/training_config.json
üéâ MODEL SAVING COMPLETE
\nüì¶ Saved files:
   LoRA Adapter: ./saved_models/finetuning-29/lora_adapter
   - adapter_model.safetensors (LoRA weights)
   - adapter_config.json (LoRA configuration)
   - preprocessor_config.json & tokenizer files
   - training_config.json (your training settings)
   - best_generation_config.json (optimized decoder params)
\nüìù To load the model later, use:


# Load base model
base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3-turbo")

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, "./saved_models/finetun