In [None]:
import pandas as pd

In [None]:
df=pd.read_csv("/content/josh_data_cleaned.csv")

In [None]:
df

Unnamed: 0,user_id,recording_id,language,duration,rec_url_gcp,transcription_url_gcp,metadata_url_gcp
0,245746,825780,hi,443,https://storage.googleapis.com/upload_goai/967...,https://storage.googleapis.com/upload_goai/967...,https://storage.googleapis.com/upload_goai/967...
1,291038,825727,hi,443,https://storage.googleapis.com/upload_goai/967...,https://storage.googleapis.com/upload_goai/967...,https://storage.googleapis.com/upload_goai/967...
2,246004,988596,hi,475,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...
3,93626,990175,hi,475,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...
4,286851,526266,hi,522,https://storage.googleapis.com/upload_goai/639...,https://storage.googleapis.com/upload_goai/639...,https://storage.googleapis.com/upload_goai/639...
...,...,...,...,...,...,...,...
99,278010,753435,hi,589,https://storage.googleapis.com/upload_goai/887...,https://storage.googleapis.com/upload_goai/887...,https://storage.googleapis.com/upload_goai/887...
100,413240,1021370,hi,1194,https://storage.googleapis.com/upload_goai/118...,https://storage.googleapis.com/upload_goai/118...,https://storage.googleapis.com/upload_goai/118...
101,11057,1020918,hi,1194,https://storage.googleapis.com/upload_goai/118...,https://storage.googleapis.com/upload_goai/118...,https://storage.googleapis.com/upload_goai/118...
102,93299,840793,hi,1146,https://storage.googleapis.com/upload_goai/983...,https://storage.googleapis.com/upload_goai/983...,https://storage.googleapis.com/upload_goai/983...


In [None]:
# Run once
!pip install -q soundfile librosa requests tqdm pandas scikit-learn

In [None]:
import os, time, math, csv
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import requests
from tqdm import tqdm
import pandas as pd

# === USER CONFIGURE ===
CSV_PATH = "josh_data_cleaned.csv"   # <--- change if your cleaned CSV has a different name
BASE_DIR = Path("data_josh")         # output base dir
AUDIO_DIR = BASE_DIR / "audio_orig"
TRANS_JSON_DIR = BASE_DIR / "transcript_json"
META_DIR = BASE_DIR / "metadata"
REPORT_CSV = BASE_DIR / "download_report.csv"

MAX_WORKERS = 8            # parallel downloads (8 is a good start on Colab)
TIMEOUT = 60               # seconds
MAX_RETRIES = 4
RETRY_BACKOFF = 2.0        # multiply per retry (exponential backoff)

# create dirs
for d in (AUDIO_DIR, TRANS_JSON_DIR, META_DIR):
    d.mkdir(parents=True, exist_ok=True)

# load CSV
if not Path(CSV_PATH).exists():
    raise FileNotFoundError(f"CSV not found: {CSV_PATH}. Put cleaned CSV in the working dir or change CSV_PATH.")

df = pd.read_csv(CSV_PATH, dtype=str).fillna("")
required_cols = {"recording_id", "rec_url_gcp", "transcription_url_gcp"}
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV missing required columns: {missing}. Expected at least {required_cols}")

# helper: safe filename for each URL type
def dst_for(rec_id, url, kind):
    # kind in {"audio","transcript","meta"}
    if kind == "audio":
        return AUDIO_DIR / f"{rec_id}.wav"
    if kind == "transcript":
        return TRANS_JSON_DIR / f"{rec_id}.json"
    if kind == "meta":
        return META_DIR / f"{rec_id}.json"
    raise ValueError(kind)

# download worker
def download_with_retries(url, dst_path, max_retries=MAX_RETRIES, timeout=TIMEOUT):
    dst_path = Path(dst_path)
    if dst_path.exists():
        return True, "exists"
    attempt = 0
    while attempt < max_retries:
        try:
            # stream to file
            with requests.get(url, stream=True, timeout=timeout) as r:
                r.raise_for_status()
                dst_path.parent.mkdir(parents=True, exist_ok=True)
                tmp = dst_path.with_suffix(".part")
                with open(tmp, "wb") as fh:
                    for chunk in r.iter_content(chunk_size=1<<20):
                        if chunk:
                            fh.write(chunk)
                tmp.replace(dst_path)
            return True, "ok"
        except Exception as e:
            attempt += 1
            time.sleep(RETRY_BACKOFF ** attempt)  # exponential backoff
            last_err = str(e)
    return False, f"failed:{last_err}"

# prepare list of tasks (triples)
tasks = []
for i, row in df.iterrows():
    rec = str(row["recording_id"]).strip()
    if not rec:
        continue
    audio_url = str(row.get("rec_url_gcp","")).strip()
    tr_url = str(row.get("transcription_url_gcp","")).strip()
    meta_url = str(row.get("metadata_url_gcp","")).strip()
    tasks.append( (rec, "audio", audio_url, dst_for(rec, audio_url, "audio")) )
    tasks.append( (rec, "transcript", tr_url, dst_for(rec, tr_url, "transcript")) )
    # metadata may be missing — still attempt if present
    if meta_url:
        tasks.append( (rec, "metadata", meta_url, dst_for(rec, meta_url, "meta")) )

print(f"Prepared {len(tasks)} download tasks for {len(df)} recordings.")

# run downloads in parallel with progress bar
results = []  # list of tuples (rec, kind, url, dst, ok, status)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
    future_to_task = {}
    for rec, kind, url, dst in tasks:
        if not url:
            results.append((rec, kind, url, str(dst), False, "no_url"))
            continue
        # schedule
        future = ex.submit(download_with_retries, url, dst)
        future_to_task[future] = (rec, kind, url, dst)
    # progress bar
    pbar = tqdm(total=len(future_to_task), desc="Downloading")
    for fut in as_completed(future_to_task):
        rec, kind, url, dst = future_to_task[fut]
        try:
            ok, status = fut.result()
        except Exception as e:
            ok, status = False, f"error:{str(e)}"
        results.append((rec, kind, url, str(dst), ok, status))
        pbar.update(1)
    pbar.close()

# write report CSV
BASE_DIR.mkdir(parents=True, exist_ok=True)
with open(REPORT_CSV, "w", newline="", encoding="utf-8") as fh:
    writer = csv.writer(fh)
    writer.writerow(["recording_id","kind","url","dst","ok","status"])
    for r in results:
        writer.writerow(r)

# summary
from collections import defaultdict
summary = defaultdict(int)
for rec, kind, url, dst, ok, status in results:
    summary["total"] += 1
    if ok:
        summary["ok"] += 1
    else:
        summary["failed"] += 1

print("Download summary:", dict(summary))
print("Report saved to:", REPORT_CSV)
failed_examples = [r for r in results if not r[4]]
if failed_examples:
    print("Sample failures (first 10):")
    for row in failed_examples[:10]:
        print(row)
else:
    print("All downloads completed or files already existed.")


Prepared 312 download tasks for 104 recordings.


Downloading: 100%|██████████| 312/312 [03:16<00:00,  1.59it/s]

Download summary: {'total': 312, 'ok': 312}
Report saved to: data_josh/download_report.csv
All downloads completed or files already existed.





## train test split

In [None]:
# Cell 1 — speaker-level split (train 80% / validation 20%)
# Produces: splits/recording_split_map.csv, splits/train_recordings.txt, splits/validation_recordings.txt

import pandas as pd
from pathlib import Path
from sklearn.model_selection import GroupShuffleSplit

CSV_PATH = "josh_data_cleaned.csv"   # change if needed
OUTDIR = Path("splits")
OUTDIR.mkdir(parents=True, exist_ok=True)

df = pd.read_csv(CSV_PATH, dtype=str).fillna("")

# normalize speaker id column
if "user_id" in df.columns:
    df["speaker_id"] = df["user_id"].astype(str)
elif "speaker_id" in df.columns:
    df["speaker_id"] = df["speaker_id"].astype(str)
else:
    raise ValueError("CSV must contain 'user_id' or 'speaker_id' column.")

df["recording_id"] = df["recording_id"].astype(str)

# group speakers -> recordings
speaker_to_recs = df.groupby("speaker_id")["recording_id"].unique().to_dict()
spk_df = pd.DataFrame([{"speaker_id": s, "n_recs": len(r)} for s, r in speaker_to_recs.items()])

# split speakers
gss = GroupShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
train_idx, val_idx = next(gss.split(spk_df, groups=spk_df["speaker_id"].values))
train_speakers = spk_df.loc[train_idx, "speaker_id"].tolist()
val_speakers = spk_df.loc[val_idx, "speaker_id"].tolist()

# map recordings to split
rows = []
for rec, spk in zip(df["recording_id"], df["speaker_id"]):
    split = "train" if spk in train_speakers else "validation"
    rows.append({"recording_id": rec, "speaker_id": spk, "split": split})
rec_split_df = pd.DataFrame(rows).drop_duplicates(subset=["recording_id"])

# save
rec_split_df.to_csv(OUTDIR / "recording_split_map.csv", index=False)
(OUTDIR / "train_recordings.txt").write_text("\n".join(rec_split_df.loc[rec_split_df["split"]=="train","recording_id"].unique().tolist()))
(OUTDIR / "validation_recordings.txt").write_text("\n".join(rec_split_df.loc[rec_split_df["split"]=="validation","recording_id"].unique().tolist()))

print("Saved recording_split_map.csv and train/validation recording lists to", OUTDIR.resolve())
print("Speakers:", len(speaker_to_recs), "Train speakers:", len(train_speakers), "Validation speakers:", len(val_speakers))


Saved recording_split_map.csv and train/validation recording lists to /content/splits
Speakers: 102 Train speakers: 81 Validation speakers: 21


## to check for duplicates and such

In [None]:
import pandas as pd

# Load your cleaned CSV (change filename if needed)
df = pd.read_csv("josh_data_cleaned.csv")

# How many total rows?
print("Total rows:", len(df))

# How many unique user IDs?
unique_users = df["user_id"].nunique()
print("Unique user_id values:", unique_users)

# Check if one user_id appears more than once
user_counts = df["user_id"].value_counts()

multi_record_users = user_counts[user_counts > 1]

if len(multi_record_users) == 0:
    print("\n✔ Every user_id appears exactly once (unique speakers).")
else:
    print("\n⚠ Some user_id values appear multiple times:")
    print(multi_record_users.head(20))
    print(f"\nTotal users with multiple recordings: {len(multi_record_users)}")


Total rows: 104
Unique user_id values: 102

⚠ Some user_id values appear multiple times:
user_id
93626     2
147403    2
Name: count, dtype: int64

Total users with multiple recordings: 2


In [None]:
df['user_id']=='93626'

Unnamed: 0,user_id
0,False
1,False
2,False
3,False
4,False
...,...
99,False
100,False
101,False
102,False


In [None]:
import pandas as pd

df = pd.read_csv("josh_data_cleaned.csv")

# Filter only the repeated speakers
dupes = df[df["user_id"].isin([93626, 147403])]

dupes


Unnamed: 0,user_id,recording_id,language,duration,rec_url_gcp,transcription_url_gcp,metadata_url_gcp
3,93626,990175,hi,475,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...,https://storage.googleapis.com/upload_goai/114...
21,147403,270150,hi,584,https://storage.googleapis.com/upload_goai/378...,https://storage.googleapis.com/upload_goai/378...,https://storage.googleapis.com/upload_goai/378...
25,147403,255381,hi,780,https://storage.googleapis.com/upload_goai/362...,https://storage.googleapis.com/upload_goai/362...,https://storage.googleapis.com/upload_goai/362...
73,93626,238079,hi,478,https://storage.googleapis.com/upload_goai/344...,https://storage.googleapis.com/upload_goai/344...,https://storage.googleapis.com/upload_goai/344...


In [None]:
# Check if any duplicated rows exist based on URL columns
url_cols = ["rec_url_gcp", "transcription_url_gcp", "metadata_url_gcp"]

# Compare rows group-by user_id
for uid, group in dupes.groupby("user_id"):
    print("\n=== USER:", uid, "===")
    print("Number of rows:", len(group))

    same_rec = group["rec_url_gcp"].nunique() == 1
    same_trans = group["transcription_url_gcp"].nunique() == 1
    same_meta = group["metadata_url_gcp"].nunique() == 1

    print("Same audio URL?        ", same_rec)
    print("Same transcription URL?", same_trans)
    print("Same metadata URL?     ", same_meta)

    if same_rec and same_trans and same_meta:
        print("➡ These are EXACT DUPLICATES.")
    else:
        print("➡ These are DIFFERENT recordings.")



=== USER: 93626 ===
Number of rows: 2
Same audio URL?         False
Same transcription URL? False
Same metadata URL?      False
➡ These are DIFFERENT recordings.

=== USER: 147403 ===
Number of rows: 2
Same audio URL?         False
Same transcription URL? False
Same metadata URL?      False
➡ These are DIFFERENT recordings.


## Sanity check

In [None]:
# Cell A — sanity: list counts and sample filenames
from pathlib import Path
base = Path("/content/drive/MyDrive/josh_talks/data_josh")
audio_dir = base / "audio_orig"
json_dir  = base / "transcript_json"
splits_path = Path("/content/drive/MyDrive/josh_talks/splits/recording_split_map.csv")

print("Base path:", base)
print("Audio dir exists:", audio_dir.exists())
print("JSON dir exists:", json_dir.exists())
print("Splits file exists:", splits_path.exists())

# show counts and a few filenames
if audio_dir.exists():
    auds = sorted(list(audio_dir.glob("*")))
    print("Audio files:", len(auds))
    print("Some audio files:", [p.name for p in auds[:5]])
else:
    print("No audio directory found at", audio_dir)

if json_dir.exists():
    js = sorted(list(json_dir.glob("*.json")))
    print("JSON transcript files:", len(js))
    print("Some json files:", [p.name for p in js[:5]])
else:
    print("No JSON directory found at", json_dir)

if splits_path.exists():
    import pandas as pd
    sdf = pd.read_csv(splits_path, dtype=str)
    print("Splits rows:", len(sdf))
    print("Sample splits:\n", sdf.head())
else:
    print("Split map not found at", splits_path)


Base path: /content/drive/MyDrive/josh_talks/data_josh
Audio dir exists: True
JSON dir exists: True
Splits file exists: True
Audio files: 104
Some audio files: ['1020918.wav', '1021370.wav', '238079.wav', '238123.wav', '239492.wav']
JSON transcript files: 104
Some json files: ['1020918.json', '1021370.json', '238079.json', '238123.json', '239492.json']
Splits rows: 104
Sample splits:
   recording_id speaker_id  split
0       825780     245746  train
1       825727     291038  train
2       988596     246004  train
3       990175      93626  train
4       526266     286851  train


## Segmentation + manifest

In [None]:
# Cell B — segmentation: slice by timestamps, save 16k mono segments + .txt and manifest
# Adjust paths if your structure differs.

!pip install -q librosa soundfile tqdm pandas


In [None]:

import json, os, librosa, soundfile as sf
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np

BASE = Path("/content/drive/MyDrive/josh_talks")
DATA_ROOT = BASE / "data_josh"                  # given by you
AUDIO_ORIG_DIR = DATA_ROOT / "audio_orig"
JSON_DIR = DATA_ROOT / "transcript_json"

OUTPUT_BASE = BASE / "processed"                 # all outputs will go here
SEGMENTS_DIR = OUTPUT_BASE / "segments"
TRANS_DIR = OUTPUT_BASE / "transcripts"
MANIFESTS_DIR = OUTPUT_BASE / "manifests"
SPLITS_MAP = BASE / "splits" / "recording_split_map.csv"

SEGMENTS_DIR.mkdir(parents=True, exist_ok=True)
TRANS_DIR.mkdir(parents=True, exist_ok=True)
MANIFESTS_DIR.mkdir(parents=True, exist_ok=True)

if not SPLITS_MAP.exists():
    raise FileNotFoundError(f"Split map not found at {SPLITS_MAP}. Please ensure the speaker split CSV is present.")

split_df = pd.read_csv(SPLITS_MAP, dtype=str).set_index("recording_id")

def write_wav_16k(y, sr, outpath):
    # convert to mono if needed
    if y.ndim > 1:
        y = np.mean(y, axis=1)
    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=16000)
        sr = 16000
    y = y.astype("float32")
    outpath.parent.mkdir(parents=True, exist_ok=True)
    sf.write(str(outpath), y, sr)
    return len(y) / sr

manifest_rows = []
json_files = sorted(JSON_DIR.glob("*.json"))
print("JSON files discovered:", len(json_files))

for jf in tqdm(json_files):
    rec_id = jf.stem
    # lookup split (default to train)
    split = "train"
    if rec_id in split_df.index:
        split = split_df.loc[rec_id, "split"]
    # find audio
    audio_path = AUDIO_ORIG_DIR / f"{rec_id}.wav"
    if not audio_path.exists():
        # try other extension
        if (AUDIO_ORIG_DIR / f"{rec_id}.mp3").exists():
            audio_path = AUDIO_ORIG_DIR / f"{rec_id}.mp3"
        else:
            print(f"WARNING: audio for {rec_id} not found in {AUDIO_ORIG_DIR}; skipping")
            continue
    # load audio once
    try:
        y_full, sr_full = librosa.load(str(audio_path), sr=None, mono=False)
    except Exception as e:
        print(f"Failed to load audio {audio_path}: {e}")
        continue
    # load json
    try:
        j = json.load(open(jf, encoding="utf-8"))
    except Exception as e:
        print(f"Failed to load JSON {jf}: {e}")
        continue

    # segments might be list or under key "segments"
    segments = None
    if isinstance(j, list):
        segments = j
    elif isinstance(j, dict) and "segments" in j and isinstance(j["segments"], list):
        segments = j["segments"]
    else:
        # try common keys
        for k in ("results","utterances","transcripts"):
            if k in j and isinstance(j[k], list):
                segments = j[k]
                break

    if not segments:
        print(f"No segments found in {jf}; skipping.")
        continue

    seg_idx = 0
    for seg in segments:
        if not isinstance(seg, dict):
            continue
        start = seg.get("start")
        end = seg.get("end")
        text = seg.get("text") or seg.get("transcript") or ""
        speaker = seg.get("speaker_id") or seg.get("speaker") or ""
        if start is None or end is None or end <= start:
            continue
        text = text.strip()
        if not text:
            continue
        # slice
        s_idx = int(round(start * sr_full))
        e_idx = int(round(end * sr_full))
        try:
            seg_audio = y_full[s_idx:e_idx]
        except Exception:
            # fallback: load slice via librosa (slower)
            seg_audio, _ = librosa.load(str(audio_path), sr=sr_full, mono=False, offset=start, duration=(end-start))
        if seg_audio is None or (isinstance(seg_audio, np.ndarray) and seg_audio.size == 0):
            continue
        seg_id = f"{rec_id}_seg{seg_idx}"
        wav_out = SEGMENTS_DIR / f"{seg_id}.wav"
        txt_out = TRANS_DIR / f"{seg_id}.txt"
        try:
            duration = write_wav_16k(seg_audio, sr_full, wav_out)
        except Exception as e:
            print(f"Failed to write segment {seg_id}: {e}")
            continue
        txt_out.parent.mkdir(parents=True, exist_ok=True)
        with open(txt_out, "w", encoding="utf-8") as fh:
            fh.write(text)
        manifest_rows.append({
            "id": seg_id,
            "audio_filepath": str(wav_out),
            "text": text,
            "duration": duration,
            "speaker_id": speaker if speaker else (split_df.loc[rec_id,"speaker_id"] if rec_id in split_df.index else ""),
            "recording_id": rec_id,
            "split": split
        })
        seg_idx += 1

manifest_df = pd.DataFrame(manifest_rows)
manifest_out = MANIFESTS_DIR / "manifest_segments_split.csv"
manifest_df.to_csv(manifest_out, index=False)
print("Segments created:", len(manifest_df))
print("Manifest saved to:", manifest_out)


JSON files discovered: 104


100%|██████████| 104/104 [05:34<00:00,  3.22s/it]

Segments created: 5941
Manifest saved to: /content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split.csv





In [None]:
import pandas as pd
from pathlib import Path
M = Path("/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split.csv")
df = pd.read_csv(M)
print("Total segments:", len(df))
print(df["split"].value_counts())
print("\nColumns:", df.columns.tolist())
print("\nSample rows:")
display(df.head(10))


Total segments: 5941
split
train         4735
validation    1206
Name: count, dtype: int64

Columns: ['id', 'audio_filepath', 'text', 'duration', 'speaker_id', 'recording_id', 'split']

Sample rows:


Unnamed: 0,id,audio_filepath,text,duration,speaker_id,recording_id,split
0,1020918_seg0,/content/drive/MyDrive/josh_talks/processed/se...,जो टाइम था वो गोल्डन टाइम होता है स्कूल के दिन...,14.31,11057,1020918,train
1,1020918_seg1,/content/drive/MyDrive/josh_talks/processed/se...,वहां तो जिम्मेदारियां भी नहीं होती हैं और कम क...,9.39,11057,1020918,train
2,1020918_seg2,/content/drive/MyDrive/josh_talks/processed/se...,की नहीं आपको ये समान लेके आना है कि नहीं आपको ...,14.4,11057,1020918,train
3,1020918_seg3,/content/drive/MyDrive/josh_talks/processed/se...,लेकिन मस्ती भरपूर होती है मस्ती भरपूर,2.43,11057,1020918,train
4,1020918_seg4,/content/drive/MyDrive/josh_talks/processed/se...,जी,1.41,11057,1020918,train
5,1020918_seg5,/content/drive/MyDrive/josh_talks/processed/se...,बोल सकते,5.76,11057,1020918,train
6,1020918_seg6,/content/drive/MyDrive/josh_talks/processed/se...,जी हां,8.82,11057,1020918,train
7,1020918_seg7,/content/drive/MyDrive/josh_talks/processed/se...,हा,6.36,11057,1020918,train
8,1020918_seg8,/content/drive/MyDrive/josh_talks/processed/se...,जी,0.72,11057,1020918,train
9,1020918_seg9,/content/drive/MyDrive/josh_talks/processed/se...,जी सर बिल्कुल बहुत ऐसी तो हमारी शरारतें तो बहु...,14.76,11057,1020918,train


In [None]:
import pandas as pd
from pathlib import Path

M = Path("/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split.csv")
df = pd.read_csv(M)

# Stats
print("Duration stats (sec):")
print(df["duration"].describe(percentiles=[0.25,0.5,0.75,0.9]).to_string())

# Filter
min_dur = 0.5
max_dur = 30.0
keep = df[(df["duration"] >= min_dur) & (df["duration"] <= max_dur)].copy()
print(f"Kept {len(keep)} / {len(df)} segments after filtering durations [{min_dur},{max_dur}]")

OUT = M.parent / "manifest_segments_split_cleaned.csv"
keep.to_csv(OUT, index=False)
print("Saved cleaned manifest to:", OUT)


Duration stats (sec):
count    5941.000000
mean        7.476182
std         5.419152
min         0.120000
25%         2.070000
50%         6.750000
75%        13.320000
90%        14.490000
max        15.000000
Kept 5578 / 5941 segments after filtering durations [0.5,30.0]
Saved cleaned manifest to: /content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv


In [None]:
# install datasets if not present
!pip install -q datasets

from datasets import Dataset, DatasetDict
import pandas as pd
from pathlib import Path

MAN = Path("/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv")
df = pd.read_csv(MAN)

train_df = df[df["split"] == "train"].reset_index(drop=True)
val_df   = df[df["split"] == "validation"].reset_index(drop=True)

train_ds = Dataset.from_pandas(train_df)
val_ds   = Dataset.from_pandas(val_df)

ds = DatasetDict({"train": train_ds, "validation": val_ds})
ds.save_to_disk("/content/drive/MyDrive/josh_talks/processed/hf_dataset_segments")
print("Saved HF dataset to /content/drive/MyDrive/josh_talks/processed/hf_dataset_segments")


Saving the dataset (0/1 shards):   0%|          | 0/4445 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1133 [00:00<?, ? examples/s]

Saved HF dataset to /content/drive/MyDrive/josh_talks/processed/hf_dataset_segments


## tests to make sure the run is good and correct

In [None]:
# Cell 1 — verify dataset exists and basic stats
from pathlib import Path
import pandas as pd
from datasets import load_from_disk

HF_DATASET_PATH = "/content/drive/MyDrive/josh_talks/processed/hf_dataset_segments"
MANIFEST_CLEAN = "/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv"

print("HF dataset path:", HF_DATASET_PATH)
print("Manifest path:", MANIFEST_CLEAN)

ds = load_from_disk(HF_DATASET_PATH)
print("Dataset keys:", list(ds.keys()))
print("Train examples:", len(ds["train"]))
print("Validation examples:", len(ds["validation"]))

# load manifest
m = pd.read_csv(MANIFEST_CLEAN)
print("Manifest rows (cleaned):", len(m))
print("Split counts in manifest:")
print(m["split"].value_counts())

# quick column check (we expect audio_filepath, text, duration, speaker_id, recording_id, split)
print("Manifest columns:", m.columns.tolist())


HF dataset path: /content/drive/MyDrive/josh_talks/processed/hf_dataset_segments
Manifest path: /content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv
Dataset keys: ['train', 'validation']
Train examples: 4445
Validation examples: 1133
Manifest rows (cleaned): 5578
Split counts in manifest:
split
train         4445
validation    1133
Name: count, dtype: int64
Manifest columns: ['id', 'audio_filepath', 'text', 'duration', 'speaker_id', 'recording_id', 'split']


In [None]:
# Cell 2 — assert no speaker appears in both splits and recording-split consistency
import pandas as pd
from pathlib import Path

MANIFEST = "/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv"
SPLIT_MAP = "/content/drive/MyDrive/josh_talks/splits/recording_split_map.csv"

m = pd.read_csv(MANIFEST, dtype=str)
split_map = pd.read_csv(SPLIT_MAP, dtype=str).set_index("recording_id")

# 1) every recording in manifest should be in split_map
missing_recs = set(m["recording_id"].unique()) - set(split_map.index)
print("Recordings in manifest not in split_map (should be 0):", len(missing_recs))
if missing_recs:
    print(list(missing_recs)[:20])

# 2) every recording has single split in manifest
rec_split_counts = m.groupby("recording_id")["split"].nunique()
bad = rec_split_counts[rec_split_counts != 1]
print("Recordings with inconsistent split in manifest (should be 0):", len(bad))
if len(bad): print(bad.head())

# 3) no speaker across splits
spk_splits = m.groupby("speaker_id")["split"].nunique()
multi_spk = spk_splits[spk_splits > 1]
print("Speakers appearing in >1 split (should be 0):", len(multi_spk))
if len(multi_spk): print(multi_spk.head())

# 4) quick assert (will raise if problem found)
assert len(missing_recs) == 0, "Some recordings missing in split_map!"
assert len(bad) == 0, "Some recordings have inconsistent split!"
assert len(multi_spk) == 0, "Some speakers appear in multiple splits!"
print("All split/speaker consistency checks PASSED ✅")


Recordings in manifest not in split_map (should be 0): 0
Recordings with inconsistent split in manifest (should be 0): 0
Speakers appearing in >1 split (should be 0): 0
All split/speaker consistency checks PASSED ✅


In [None]:
# Cell 3 — duration stats and play a few random samples (manual check)
import pandas as pd, random
from IPython.display import Audio, display

MANIFEST = "/content/drive/MyDrive/josh_talks/processed/manifests/manifest_segments_split_cleaned.csv"
df = pd.read_csv(MANIFEST)
print("Total segments:", len(df))
print(df["split"].value_counts())
print("\nDuration percentiles (s):")
print(df["duration"].describe(percentiles=[0.25,0.5,0.75,0.9]).to_string())

# Play up to 6 random segments (ensure to listen and verify transcript)
nplay = min(6, len(df))
sample = df.sample(nplay, random_state=42).reset_index(drop=True)
for r in sample.itertuples():
    print("ID:", r.id, "| rec:", r.recording_id, "| split:", r.split, "| dur:", round(r.duration,2))
    print("Text:", (r.text[:300] + "...") if len(r.text)>300 else r.text)
    try:
        display(Audio(filename=r.audio_filepath, autoplay=False))
    except Exception as e:
        print("Audio playback error:", e)
    print("-"*60)


Total segments: 5578
split
train         4445
validation    1133
Name: count, dtype: int64

Duration percentiles (s):
count    5578.000000
mean        7.936304
std         5.273785
min         0.510000
25%         2.700000
50%         7.770000
75%        13.470000
90%        14.550000
max        15.000000
ID: 494019_seg29 | rec: 494019 | split: train | dur: 3.0
Text: हां सही है


------------------------------------------------------------
ID: 305347_seg76 | rec: 305347 | split: train | dur: 6.63
Text: हम अच्छा ऐसा क्यों


------------------------------------------------------------
ID: 269794_seg17 | rec: 269794 | split: train | dur: 0.69
Text: हाँ


------------------------------------------------------------
ID: 526266_seg28 | rec: 526266 | split: train | dur: 6.66
Text: दुखद अंत और खुशाल अंत दुख और खुशी के बारे में बात करना है फिलींग के


------------------------------------------------------------
ID: 253253_seg25 | rec: 253253 | split: train | dur: 14.52
Text: जी हम भी गए थे जी एक वाराणसी गए हैं वाराणसी बहुत अच्छा स्थान है जी वाराणसी गए थे गंगा आरती उसका बहुत अच्छा वहा काशी विश्वनाथ मंदिर था वहाँ भी


------------------------------------------------------------
ID: 989901_seg79 | rec: 989901 | split: train | dur: 2.13
Text: काम नहीं करेंगे तो आप बड़ा नहीं कहलायेंगे.!


------------------------------------------------------------


In [None]:
# DIAGNOSTIC CELL — run this and paste the full output
import torch, numpy as np, soundfile as sf, librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_from_disk
from pathlib import Path

HF_DATASET_PATH = "/content/drive/MyDrive/josh_talks/processed/hf_dataset_segments"
ds = load_from_disk(HF_DATASET_PATH)

print("Loading processor + model (pretrained openai/whisper-small)...")
model_name = "openai/whisper-small"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)

# configure for Hindi transcription (new API)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="hi", task="transcribe")
model.config.suppress_tokens = []

def read_audio(path):
    audio, sr = sf.read(path, dtype="float32")
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    return audio, 16000

# pick up to 5 validation examples
n = min(5, len(ds["validation"]))
val = ds["validation"].select(range(n))

print(f"Running feature_extractor on {n} validation examples...\n")
feats = []
for i, ex in enumerate(val):
    path = ex["audio_filepath"]
    try:
        audio, sr = read_audio(path)
    except Exception as e:
        print(f"[{i}] ERROR reading audio {path}: {e}")
        continue
    # use processor.feature_extractor (numpy) for inspection
    out = processor.feature_extractor(audio, sampling_rate=16000, return_tensors="np")
    # out["input_features"] shape is (1, time, n_mels) normally
    arr = out["input_features"][0]
    print(f"[{i}] file: {Path(path).name}  audio_len: {audio.shape[0]}  feature shape (time,n_mels): {arr.shape}")
    print(f"     dtype: {arr.dtype}  time_frames: {arr.shape[0]}  n_mels: {arr.shape[1]}")
    # print a tiny excerpt of values for sanity
    print("     feature sample (first row, first 8 values):", np.array2string(arr[0,:8], precision=4, max_line_width=200))
    feats.append(arr)

# Quick global check
if len(feats) == 0:
    raise SystemExit("No features computed — audio reading failed.")

# Check that all feats share same n_mels
n_mels_set = set([f.shape[1] for f in feats])
print("\nUnique n_mels across examples:", n_mels_set)

# compute expected seq len from model conv stride product (informational)
conv1_stride = model.model.encoder.conv1.stride[0]
conv2_stride = model.model.encoder.conv2.stride[0]
expected_seq_len = int(model.config.max_source_positions * conv1_stride * conv2_stride)
print("Model expected_seq_len (time frames):", expected_seq_len)
print("Model expected channels (should be n_mels): model conv weight shape:",
      list(model.model.encoder.conv1.weight.shape)[:2], "(out_chan, in_chan)")

# If shapes look good (i.e. n_mels = model.conv1.in_channels), pad to expected length and generate
in_chan = model.model.encoder.conv1.weight.shape[1]
if len(n_mels_set) == 1 and list(n_mels_set)[0] == in_chan:
    print("\nn_mels matches model conv1.in_channels → attempting pad + generate (fast)")
    padded = []
    for f in feats:
        cur = f.shape[0]
        if cur > expected_seq_len:
            f2 = f[:expected_seq_len, :]
        elif cur < expected_seq_len:
            pad_amount = expected_seq_len - cur
            f2 = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0)
        else:
            f2 = f
        # transpose -> (n_mels, time)
        f2_t = f2.T
        padded.append(f2_t)
    input_features = torch.tensor(np.stack(padded), dtype=torch.float32).to(device)  # (batch, n_mels, time)
    print("Padded input_features.shape:", input_features.shape)
    try:
        gen = model.generate(input_features, max_length=256)
        preds = processor.batch_decode(gen, skip_special_tokens=True)
        print("\nGeneration SUCCESS. Sample predictions:")
        for p in preds:
            print("->", p[:200])
    except Exception as e:
        print("\nGeneration FAILED with exception:", e)
        raise
else:
    print("\n⚠️ MISMATCH: n_mels from feature_extractor does not equal model conv input channels.")
    print("    feature_extractor produced n_mels:", n_mels_set)
    print("    model.conv1 expected in_channels:", in_chan)
    print("Action: if n_mels != in_chan, we must either (A) ensure the processor.feature_extractor is the correct one,")
    print("or (B) compute mel spectrograms ourselves with n_mels=", in_chan, "so they match the model.")
    print("Run `print(n_mels_set)` output above and paste here; I will provide the exact fix (fast).")


Loading processor + model (pretrained openai/whisper-small)...
Running feature_extractor on 5 validation examples...

[0] file: 238123_seg0.wav  audio_len: 175680  feature shape (time,n_mels): (80, 3000)
     dtype: float32  time_frames: 80  n_mels: 3000
     feature sample (first row, first 8 values): [ 0.4665  0.1219 -0.1797 -0.1602  0.0384 -0.0817  0.0269 -0.037 ]
[1] file: 238123_seg1.wav  audio_len: 237600  feature shape (time,n_mels): (80, 3000)
     dtype: float32  time_frames: 80  n_mels: 3000
     feature sample (first row, first 8 values): [-0.3124 -0.0652 -0.0828 -0.0036  0.0845  0.0612  0.0265  0.0265]
[2] file: 238123_seg2.wav  audio_len: 99360  feature shape (time,n_mels): (80, 3000)
     dtype: float32  time_frames: 80  n_mels: 3000
     feature sample (first row, first 8 values): [ 0.5897  0.0235 -0.1093 -0.0912 -0.138  -0.1147 -0.0996 -0.1519]
[3] file: 238123_seg3.wav  audio_len: 83040  feature shape (time,n_mels): (80, 3000)
     dtype: float32  time_frames: 80  n_me

In [None]:
# Corrected generation cell — handles either (time,n_mels) or (n_mels,time) outputs
import torch, numpy as np, soundfile as sf, librosa
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_from_disk
from pathlib import Path

HF_DATASET_PATH = "/content/drive/MyDrive/josh_talks/processed/hf_dataset_segments"
ds = load_from_disk(HF_DATASET_PATH)

print("Loading processor + model (pretrained openai/whisper-small)...")
model_name = "openai/whisper-small"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)

# set Hindi transcription prompt (new API)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="hi", task="transcribe")
model.config.suppress_tokens = []

def read_audio(path):
    audio, sr = sf.read(path, dtype="float32")
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    return audio

# pick up to 5 validation examples
n = min(5, len(ds["validation"]))
val = ds["validation"].select(range(n))
examples = [val[i] for i in range(len(val))]

# extract features
feats = []
shapes = []
for ex in examples:
    audio = read_audio(ex["audio_filepath"])
    out = processor.feature_extractor(audio, sampling_rate=16000, return_tensors="np")
    arr = out["input_features"][0]  # shape either (time, n_mels) OR (n_mels, time)
    shapes.append(arr.shape)
    feats.append(arr)

print("Feature shapes from extractor (per example):", shapes)

# determine model expected time length and conv in_channels
conv1_in = model.model.encoder.conv1.weight.shape[1]
conv1_out = model.model.encoder.conv1.weight.shape[0]
conv1_stride = model.model.encoder.conv1.stride[0]
conv2_stride = model.model.encoder.conv2.stride[0]
expected_seq_len = int(model.config.max_source_positions * conv1_stride * conv2_stride)
print("Model expects in_channels (n_mels)=", conv1_in, "  expected time frames =", expected_seq_len)
print("Model conv1 weight shape (out_chan, in_chan):", [conv1_out, conv1_in])

# Normalize each feat to shape (n_mels, time)
normalized = []
for f in feats:
    if f.ndim != 2:
        raise ValueError("Unexpected feature dim: ", f.shape)
    # If shape is (time, n_mels) -> transpose to (n_mels, time)
    if f.shape[0] == expected_seq_len or f.shape[0] == conv1_in:
        # ambiguous case: check which dimension equals conv1_in (n_mels)
        if f.shape[0] == conv1_in:
            feat_nmels_time = f  # (n_mels, time)
        elif f.shape[1] == conv1_in:
            feat_nmels_time = f.T  # (n_mels, time)
        else:
            # fallback: pick the axis that matches conv1_in if any, otherwise assume returned is (n_mels, time)
            if f.shape[1] == conv1_in:
                feat_nmels_time = f.T
            else:
                feat_nmels_time = f
    else:
        # default logic: if one dimension equals conv1_in, that is n_mels
        if f.shape[0] == conv1_in:
            feat_nmels_time = f
        elif f.shape[1] == conv1_in:
            feat_nmels_time = f.T
        else:
            # Neither axis equals conv1_in; assume extractor uses (n_mels, time) if first dim small (like 80)
            feat_nmels_time = f if f.shape[0] < f.shape[1] else f.T

    # double-check
    if feat_nmels_time.shape[0] != conv1_in:
        print("WARNING: inferred n_mels does not match model conv1.in_channels!", feat_nmels_time.shape)
    normalized.append(feat_nmels_time)

# Pad/truncate along time-axis (axis=1) to expected_seq_len, then stack
padded = []
for f in normalized:
    n_mels, cur_time = f.shape
    if cur_time > expected_seq_len:
        f2 = f[:, :expected_seq_len]
    elif cur_time < expected_seq_len:
        pad_amount = expected_seq_len - cur_time
        f2 = np.pad(f, ((0,0),(0, pad_amount)), mode="constant", constant_values=0.0)
    else:
        f2 = f
    padded.append(f2)

# stack -> shape (batch, n_mels, time)
input_features = torch.tensor(np.stack(padded), dtype=torch.float32).to(device)
print("Final input_features.shape:", input_features.shape, "(batch, n_mels, time)")

# final sanity: n_mels should match conv1.in_channels
if input_features.shape[1] != conv1_in:
    raise RuntimeError(f"n_mels ({input_features.shape[1]}) != model.conv1.in_channels ({conv1_in})")

# generate
generated_ids = model.generate(input_features, max_length=256)
preds = processor.batch_decode(generated_ids, skip_special_tokens=True)

print("\nPredictions:")
for p in preds:
    print(p[:400])


Loading processor + model (pretrained openai/whisper-small)...
Feature shapes from extractor (per example): [(80, 3000), (80, 3000), (80, 3000), (80, 3000), (80, 3000)]
Model expects in_channels (n_mels)= 80   expected time frames = 3000
Model conv1 weight shape (out_chan, in_chan): [768, 80]
Final input_features.shape: torch.Size([5, 80, 3000]) (batch, n_mels, time)


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Predictions:
 तो मैंने पर्वार में उसके बारे आँ एक बड़ा है कि मैं ये नोकरी चोर रहा हों तुकि उस्वर में काफी समय से वर्क रहा था अग
 अगर बादिया देखाना तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया तो बादिया
 और शार्शात में अपने बोस वोते ते ते उनको भी इसके बारे में जानकारी दी
 तो आप अगले करीर के कदम में क्या योजना बनारही हैं?
 Ciao.


In [None]:
!pip install -U transformers==4.47.0 datasets accelerate peft==0.7.1 jiwer librosa soundfile




In [None]:
# WORKING WHISPER LORA TRAINING - Based on official PEFT examples
import os
import numpy as np
import torch
import soundfile as sf
import librosa
from datasets import load_from_disk
from transformers import (
    WhisperProcessor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from dataclasses import dataclass
from typing import Any, Dict, List, Union

HF_DATASET_PATH = "/content/drive/MyDrive/josh_talks/processed/hf_dataset_segments"
OUT_DIR = "/content/drive/MyDrive/josh_talks/whisper_lora_smoke"
os.makedirs(OUT_DIR, exist_ok=True)

# Load dataset (small smoke subsets)
ds_all = load_from_disk(HF_DATASET_PATH)
print("Loaded splits:", list(ds_all.keys()))
print("Sizes:", len(ds_all["train"]), len(ds_all["validation"]))
N_TRAIN = min(500, len(ds_all["train"]))
N_VAL = min(100, len(ds_all["validation"]))
train_small = ds_all["train"].select(range(N_TRAIN))
val_small = ds_all["validation"].select(range(N_VAL))

# Processor + model meta
model_name = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_name)
forced_ids = processor.get_decoder_prompt_ids(language="hi", task="transcribe")

def read_audio_16k(path):
    audio, sr = sf.read(path, dtype="float32")
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)
    if sr != 16000:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    return audio

# compute expected sizes
tmp = WhisperForConditionalGeneration.from_pretrained(model_name)
conv1_stride = tmp.model.encoder.conv1.stride[0]
conv2_stride = tmp.model.encoder.conv2.stride[0]
expected_seq_len = int(tmp.config.max_source_positions * conv1_stride * conv2_stride)
conv1_in = tmp.model.encoder.conv1.weight.shape[1]
del tmp
torch.cuda.empty_cache()
print("expected_seq_len:", expected_seq_len, "n_mels:", conv1_in)

# preprocess -> input_features, attention_mask, labels
def prepare_example_for_training(ex):
    audio = read_audio_16k(ex["audio_filepath"])
    out = processor.feature_extractor(audio, sampling_rate=16000, return_tensors="np")
    arr = out["input_features"][0]
    if arr.shape[0] == conv1_in:
        feat = arr
    elif arr.shape[1] == conv1_in:
        feat = arr.T
    else:
        feat = arr if arr.shape[0] < arr.shape[1] else arr.T

    cur_time = feat.shape[1]
    if cur_time > expected_seq_len:
        feat2 = feat[:, :expected_seq_len]
        mask = np.ones(expected_seq_len, dtype=np.int64)
    else:
        pad_amt = expected_seq_len - cur_time
        feat2 = np.pad(feat, ((0,0),(0,pad_amt)), mode="constant", constant_values=0.0)
        mask = np.concatenate([np.ones(cur_time, dtype=np.int64), np.zeros(pad_amt, dtype=np.int64)])

    lab = processor.tokenizer(ex["text"], return_tensors="np", add_special_tokens=True).input_ids[0]

    return {
        "input_features": feat2.astype(np.float32),
        "attention_mask": mask.astype(np.int64),
        "labels": lab.astype(np.int64),
    }

print("Mapping preprocess (train)...")
train_proc = train_small.map(prepare_example_for_training, remove_columns=train_small.column_names, num_proc=1)
print("Mapping preprocess (val)...")
val_proc = val_small.map(prepare_example_for_training, remove_columns=val_small.column_names, num_proc=1)
print("Processed sizes:", len(train_proc), len(val_proc))

# Data Collator - EXACTLY like the official PEFT example
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

# Initialize data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# model + LoRA - following official PEFT example
device = "cuda" if torch.cuda.is_available() else "cpu"
model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Prepare model for training (important for PEFT!)
model = prepare_model_for_kbit_training(model)

# Make input require grad for gradient checkpointing
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

# Apply LoRA
target_modules = ["q_proj", "v_proj"]  # Starting with just q_proj and v_proj like official example
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=target_modules,
    lora_dropout=0.05,
    bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1e-3,
    warmup_steps=50,
    num_train_epochs=1,
    evaluation_strategy="epoch",
    fp16=torch.cuda.is_available(),
    per_device_eval_batch_size=4,
    generation_max_length=128,
    logging_steps=25,
    remove_unused_columns=False,  # Important!
    label_names=["labels"],  # Important!
    save_strategy="epoch",
    report_to="none",
    dataloader_pin_memory=False,
    predict_with_generate=False,  # Disable generation during eval for smoke test
)

# CRITICAL: Use tokenizer=processor.feature_extractor, NOT processing_class
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_proc,
    eval_dataset=val_proc,
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,  # THIS IS THE KEY FIX!
)

print("Starting smoke-train (1 epoch)...")
trainer.train()
print("Saving smoke model...")
model.save_pretrained(os.path.join(OUT_DIR, "final"))
processor.save_pretrained(os.path.join(OUT_DIR, "final"))
print("Smoke-train done, saved to", os.path.join(OUT_DIR, "final"))

Loaded splits: ['train', 'validation']
Sizes: 4445 1133
expected_seq_len: 3000 n_mels: 80
Mapping preprocess (train)...


Map (num_proc=1):   0%|          | 0/500 [00:00<?, ? examples/s]

Mapping preprocess (val)...


Map (num_proc=1):   0%|          | 0/100 [00:00<?, ? examples/s]

Processed sizes: 500 100
trainable params: 3,538,944 || all params: 245,273,856 || trainable%: 1.442854145857274
Starting smoke-train (1 epoch)...


  trainer = Seq2SeqTrainer(


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
# Cell 6 — smoke-test LoRA training (very small, 1 epoch)
!pip install -q transformers==4.35.0 datasets accelerate peft==0.6.0 evaluate jiwer


In [None]:

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from peft import LoraConfig, get_peft_model
from datasets import load_from_disk
import numpy as np

# load small datasets
train_ds = load_from_disk("/content/train_small_ds")
val_ds = load_from_disk("/content/val_small_ds")

device = "cuda" if torch.cuda.is_available() else "cpu"
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
processor.tokenizer.set_target_language("hi")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
model.gradient_checkpointing_enable()
model.config.use_cache = False

# PEFT LoRA config
lora_config = LoraConfig(
    task_type="SEQ_2_SEQ_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)
model = get_peft_model(model, lora_config)

# simple collator for small run
def collate_fn(features):
    import torch
    input_feats = [f["input_features"] for f in features]
    labels = [f["labels"] for f in features]
    batch_inputs = processor.feature_extractor.pad({"input_features": input_feats}, return_tensors="pt")
    batch_labels = processor.tokenizer.pad({"input_ids": labels}, return_tensors="pt")
    batch = {"input_features": batch_inputs["input_features"].to(device), "labels": batch_labels["input_ids"].to(device)}
    return batch

training_args = Seq2SeqTrainingArguments(
    output_dir="/content/drive/MyDrive/josh_talks/whisper_lora_smoke",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    fp16=torch.cuda.is_available(),
    evaluation_strategy="epoch",
    num_train_epochs=1,
    logging_steps=10,
    save_strategy="epoch",
    predict_with_generate=False,
    remove_unused_columns=False,
    report_to="none",
)

from evaluate import load as load_metric
wer_metric = load_metric("wer")

def compute_metrics(eval_pred):
    # Not used in smoke (predict_with_generate False). Placeholder.
    return {}

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    tokenizer=processor.tokenizer,
    compute_metrics=compute_metrics,
)

print("Starting smoke training (should be quick)...")
trainer.train()
trainer.save_model("/content/drive/MyDrive/josh_talks/whisper_lora_smoke_final")
print("Smoke training finished and saved to /content/drive/MyDrive/josh_talks/whisper_lora_smoke_final")


In [None]:
import shutil

folder_path = "/content/drive/MyDrive/josh_talks/processed"
output_zip = "/content/processed.zip"

shutil.make_archive(base_name="/content/processed", format="zip", root_dir=folder_path)

print("Zipped successfully:", output_zip)


Zipped successfully: /content/processed.zip
