# EXP-001 — Reducing Overfitting in CHIME Text Classification

This notebook runs a controlled experiment to reduce overfitting in CHIME framework text classification using BERT-based models. It compares three runs: a baseline BERT, a regularized BERT with early stopping and weight decay, and a smaller DistilBERT model. Outputs include configuration files, metrics, confusion matrices, learning curves, and optional test predictions.


**Expected runtime:** depends on CPU/GPU. On CPU, BERT may take several minutes per run.

## Cell 1 — Setup & dependency check
This notebook expects `torch`, `transformers`, `datasets`, `pandas`, `scikit-learn`, `matplotlib`, `seaborn`.

**Expected output:** a quick confirmation of versions and whether CUDA is available.

In [5]:
import os
import json
import time
import random
from dataclasses import asdict, dataclass
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)

print('torch:', torch.__version__)
print('cuda available:', torch.cuda.is_available())
print('device:', 'cuda' if torch.cuda.is_available() else 'cpu')

## Cell 2 — Experiment constants
Adjust these if needed. Keeping them centralized makes runs consistent.

In [6]:
# Cell 5
# Dataset source (Option A)
HF_DATASET_NAME = 'ashh007/DREAMS-CHIME-dataset'

# Models to compare
RUNS = [
    {
        'run_name': 'A_bert_baseline',
        'model_ckpt': 'bert-base-uncased',
        'use_early_stopping': False,
        'weight_decay': 0.0,
        'dropout': None,
    },
    {
        'run_name': 'B_bert_regularized',
        'model_ckpt': 'bert-base-uncased',
        'use_early_stopping': True,
        'weight_decay': 0.01,
        # Optional: increase dropout slightly to fight overfitting
        'dropout': 0.2,
    },
    {
        'run_name': 'C_distilbert_regularized',
        'model_ckpt': 'distilbert-base-uncased',
        'use_early_stopping': True,
        'weight_decay': 0.01,
        'dropout': None,
    },
]

# Reproducibility controls (kept lightweight)
# NOTE: we separate split randomness from training randomness so you can
# re-run training with different seeds on the *same* train/val/test split.
SPLIT_SEED = 42
SEED = 42
MAX_LENGTH = 128
TEST_SIZE = 0.10
VAL_SIZE = 0.10

# Training knobs (CPU-friendly defaults)
EPOCHS = 5
BATCH_SIZE = 8
LR = 2e-5

# Output folder
BASE_OUTPUT_DIR = os.path.join('ml_experiments_anish', 'experiment1_chime_text_overfitting', 'runs')
os.makedirs(BASE_OUTPUT_DIR, exist_ok=True)

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)
print('Split seed:', SPLIT_SEED)
print('Training seed set to:', SEED)
print('Outputs will be saved under:', BASE_OUTPUT_DIR)

# Cell 3 - Filesystem diagonis
Confirms the runtime environment, especially usefull when using Colab VMs

In [7]:
import os
import sys
import platform
from pathlib import Path

print('=== Runtime filesystem diagnostic ===')
print('Platform:', platform.platform())
print('Python:', sys.version.split()[0])
print('CWD:', os.getcwd())
print('BASE_OUTPUT_DIR (as set):', BASE_OUTPUT_DIR)
print('BASE_OUTPUT_DIR (absolute):', str(Path(BASE_OUTPUT_DIR).resolve()))

# If you are using a Colab kernel, files are being written to the Colab VM filesystem,
# not your local VS Code workspace. This cell helps you confirm where artifacts live.
in_colab = 'google.colab' in sys.modules
print('Detected Colab kernel:', in_colab)

try:
    p = Path(BASE_OUTPUT_DIR)
    p.mkdir(parents=True, exist_ok=True)
    entries = sorted([x.name for x in p.iterdir()])
    print(f'Entries in BASE_OUTPUT_DIR ({len(entries)}):', entries[:50])
except Exception as e:
    print('Could not list BASE_OUTPUT_DIR due to:', repr(e))

# Write a tiny probe file so you can confirm persistence in the active kernel filesystem.
try:
    probe = Path(BASE_OUTPUT_DIR) / '_write_probe.txt'
    probe.write_text('ok\n', encoding='utf-8')
    print('Wrote probe file:', str(probe))
except Exception as e:
    print('Could not write probe file due to:', repr(e))

print('Tip: if you want artifacts to persist outside the Colab VM, set BASE_OUTPUT_DIR to a Google Drive path (after mounting Drive).')

# Cell 4 - Checks disk usage

In [4]:
import shutil
from pathlib import Path

print('=== Disk usage (Colab VM) ===')
for p in ['/', '/content', str(Path(BASE_OUTPUT_DIR).resolve())]:
    try:
        du = shutil.disk_usage(p)
        free_gb = du.free / (1024**3)
        total_gb = du.total / (1024**3)
        used_gb = du.used / (1024**3)
        print(f'{p:>30}  free={free_gb:6.2f} GB  used={used_gb:6.2f} GB  total={total_gb:6.2f} GB')
    except Exception as e:
        print(f'{p:>30}  (error: {e!r})')

# ---- Optional cleanup helpers ----
# When the Colab VM runs out of space, you can delete old run folders and/or HF caches.
# This does NOT affect your local Windows repo unless you explicitly copy things back.

DO_CLEANUP = True  # set True to actually delete files
KEEP_BASELINE_RUNS = False  # set False to also delete A/B/C runs
KEEP_LAST_N_RUN_FOLDERS = 0  # keep the newest N run folders (after keeping baseline)
DELETE_SWEEP_RUNS_FIRST = True  # delete only sweep runs (S_bert_*) before deleting other runs
DELETE_HF_CACHE = False  # last resort: redownloads models/datasets later

# Common HF cache locations on Colab
HF_CACHE_DIRS = [
    Path.home() / '.cache' / 'huggingface' / 'hub',
    Path.home() / '.cache' / 'huggingface' / 'datasets',
    Path('/root/.cache/huggingface/hub'),
    Path('/root/.cache/huggingface/datasets'),
    Path('/root/.cache/torch'),
    Path.home() / '.cache' / 'torch',
]

def _is_run_dir(p: Path) -> bool:
    if not p.is_dir():
        return False
    name = p.name
    return ('_seed' in name) and (name[0].isdigit())

def _is_sweep_run_dir(p: Path) -> bool:
    return _is_run_dir(p) and ('_S_' in p.name or 'S_bert_' in p.name)

def _is_baseline_run_dir(p: Path) -> bool:
    if not _is_run_dir(p):
        return False
    return (
        ('A_bert_baseline' in p.name)
        or ('B_bert_regularized' in p.name)
        or ('C_distilbert_regularized' in p.name)
    )

def _delete_path(p: Path):
    if p.is_dir():
        shutil.rmtree(p, ignore_errors=True)
    elif p.exists():
        try:
            p.unlink()
        except Exception:
            pass

def cleanup_disk():
    base = Path(BASE_OUTPUT_DIR)
    if not base.exists():
        print('BASE_OUTPUT_DIR does not exist:', str(base))
        return

    # Collect run folders and sort oldest->newest by name (timestamp prefix sorts lexicographically)
    run_dirs = sorted([p for p in base.iterdir() if _is_run_dir(p)], key=lambda p: p.name)
    if not run_dirs:
        print('No run folders found under:', str(base))
        return

    kept = set()
    if KEEP_BASELINE_RUNS:
        for p in run_dirs:
            if _is_baseline_run_dir(p):
                kept.add(p)

    # Optionally delete sweep runs first (usually the bulk)
    deleted = []
    if DELETE_SWEEP_RUNS_FIRST:
        for p in list(run_dirs):
            if p in kept:
                continue
            if _is_sweep_run_dir(p):
                deleted.append(p)
                _delete_path(p)

        # Refresh list after deletions
        run_dirs = sorted([p for p in base.iterdir() if _is_run_dir(p)], key=lambda p: p.name)

    # Keep newest N (plus kept baselines)
    remaining = [p for p in run_dirs if p not in kept]
    to_keep_newest = set(remaining[-KEEP_LAST_N_RUN_FOLDERS:]) if KEEP_LAST_N_RUN_FOLDERS else set()
    for p in remaining:
        if p in to_keep_newest:
            continue
        deleted.append(p)
        _delete_path(p)

    print(f'Deleted {len(deleted)} run folders.')
    if deleted:
        print('Examples deleted:', [d.name for d in deleted[:5]])
    print('Kept baseline runs:', [p.name for p in sorted(kept, key=lambda x: x.name)])
    print('Kept newest runs:', [p.name for p in sorted(to_keep_newest, key=lambda x: x.name)])

    if DELETE_HF_CACHE:
        cache_deleted = 0
        for d in HF_CACHE_DIRS:
            if d.exists():
                _delete_path(d)
                cache_deleted += 1
        print(f'Deleted {cache_deleted} HF/torch cache dirs (if existed).')

    # Print disk usage after cleanup
    du2 = shutil.disk_usage('/')
    print(f"Free space after cleanup: {du2.free / (1024**3):.2f} GB")

if DO_CLEANUP:
    cleanup_disk()
else:
    print('Cleanup is disabled. Set DO_CLEANUP=True to delete old artifacts.')

## Cell 4 — Load dataset from Hugging Face

**Expected output:** dataset columns, row count, and label distribution.

In [8]:
ds = load_dataset(HF_DATASET_NAME)
print(ds)

# Convert to a single pandas DataFrame for consistent splitting
# Handles cases where dataset has a single split or multiple splits.
if hasattr(ds, 'keys') and len(ds.keys()) > 0:
    # Prefer a 'train' split if it exists, otherwise take the first split
    split_name = 'train' if 'train' in ds else list(ds.keys())[0]
    base = ds[split_name]
else:
    base = ds

df = base.to_pandas()
print('Rows:', len(df))
print('Columns:', list(df.columns))
display(df.head(3))

## Cell 5 — Normalize columns (text + label)
The legacy EXP-001 notebook used `CAPTIONS` and `labels`. Here we detect likely column names and standardize them to:
- `text`
- `label`

**Expected output:** confirmation of chosen columns + label distribution.

In [9]:
def pick_first_existing(columns, candidates):
    for c in candidates:
        if c in columns:
            return c
    return None

text_col = pick_first_existing(df.columns, ['CAPTIONS', 'caption', 'captions', 'text', 'sentence', 'content'])
label_col = pick_first_existing(df.columns, ['labels', 'label', 'category', 'class', 'target'])

if text_col is None or label_col is None:
    raise ValueError(
        f'Could not infer text/label columns. Found columns={list(df.columns)}. '
        'Please update the candidates list.'
    )

df = df[[text_col, label_col]].rename(columns={text_col: 'text', label_col: 'label'})
df['text'] = df['text'].astype(str).fillna('')
df['label'] = df['label'].astype(str)

# Drop empty texts (if any)
df = df[df['text'].str.strip().astype(bool)].reset_index(drop=True)

print('Using text column:', text_col)
print('Using label column:', label_col)
print('Rows after cleanup:', len(df))

label_counts = df['label'].value_counts()
display(label_counts)

plt.figure(figsize=(8, 4))
sns.barplot(x=label_counts.index, y=label_counts.values)
plt.title('Label distribution')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## Dataset duplicate / near-duplicate audit
This checks for **exact duplicates** and **near-duplicates** in:
- the in-memory `df` loaded from Hugging Face, and
- the local CSV `ml_experiments_anish/DREAMS_CHIME_dataset.csv` (if present).

High duplicate rates can inflate train/val/test metrics because similar texts can land in different splits.

In [None]:
from pathlib import Path

def _normalize_text_simple(t: str) -> str:
    import re
    t = str(t).lower().strip()
    t = re.sub(r"\s+", " ", t)
    t = re.sub(r"[^a-z0-9\s]", "", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

def audit_duplicates_in_df(df_in: pd.DataFrame, text_col: str = 'text', label_col: str | None = 'label', thr: float = 0.92):
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.neighbors import NearestNeighbors

    tmp = df_in.copy()
    tmp[text_col] = tmp[text_col].astype(str).fillna('')
    tmp['__norm__'] = tmp[text_col].map(_normalize_text_simple)

    raw_dup_rows = int(tmp[text_col].duplicated(keep=False).sum())
    norm_dup_rows = int(tmp['__norm__'].duplicated(keep=False).sum())
    print('Rows:', len(tmp))
    print('Exact duplicate rows (raw):', raw_dup_rows)
    print('Exact duplicate rows (normalized):', norm_dup_rows)

    # show biggest normalized dup groups
    if norm_dup_rows > 0:
        g = (
            tmp.groupby('__norm__')
            .size()
            .sort_values(ascending=False)
            .reset_index(name='count')
        )
        g = g[g['count'] >= 2]
        print('\nTop normalized duplicate groups:')
        display(g.head(10))

    # near-duplicates
    nonempty = tmp['__norm__'].str.len() > 0
    texts = tmp.loc[nonempty, '__norm__'].tolist()
    idx_map = np.flatnonzero(nonempty.to_numpy())
    near_pairs = []
    if len(texts) >= 2:
        vec = TfidfVectorizer(min_df=2, ngram_range=(3, 5), analyzer='char_wb')
        X = vec.fit_transform(texts)
        n_neighbors = min(6, len(texts))
        nn = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine', algorithm='brute').fit(X)
        dists, idxs = nn.kneighbors(X)
        seen = set()
        for i in range(len(texts)):
            for jpos in range(1, idxs.shape[1]):
                j = int(idxs[i, jpos])
                a = int(idx_map[i]); b = int(idx_map[j])
                lo, hi = (a, b) if a < b else (b, a)
                if (lo, hi) in seen:
                    continue
                sim = 1.0 - float(dists[i, jpos])
                if sim < thr:
                    continue
                if tmp.loc[lo, '__norm__'] == tmp.loc[hi, '__norm__']:
                    continue
                seen.add((lo, hi))
                item = {
                    'row_a': lo,
                    'row_b': hi,
                    'similarity': sim,
                    'text_a': tmp.loc[lo, text_col][:200],
                    'text_b': tmp.loc[hi, text_col][:200],
                }
                if label_col and label_col in tmp.columns:
                    item['label_a'] = tmp.loc[lo, label_col]
                    item['label_b'] = tmp.loc[hi, label_col]
                near_pairs.append(item)
        near_pairs.sort(key=lambda x: -x['similarity'])

    print(f"\nNear-duplicate pairs (similarity >= {thr}):", len(near_pairs))
    if near_pairs:
        display(pd.DataFrame(near_pairs[:10]))
    return {'raw_dup_rows': raw_dup_rows, 'norm_dup_rows': norm_dup_rows, 'near_pairs': near_pairs}

print('=== Audit: HF-loaded df (after column normalization) ===')
_ = audit_duplicates_in_df(df, text_col='text', label_col='label', thr=0.92)

local_csv = Path('ml_experiments_anish') / 'DREAMS_CHIME_dataset.csv'
if local_csv.exists():
    print('\n=== Audit: local CSV (ml_experiments_anish/DREAMS_CHIME_dataset.csv) ===')
    local_df = pd.read_csv(local_csv)
    # map expected columns
    tcol = 'CAPTIONS' if 'CAPTIONS' in local_df.columns else ('text' if 'text' in local_df.columns else local_df.columns[0])
    lcol = 'labels' if 'labels' in local_df.columns else (None)
    local_df = local_df.rename(columns={tcol: 'text', **({lcol: 'label'} if lcol else {})})
    _ = audit_duplicates_in_df(local_df, text_col='text', label_col=('label' if 'label' in local_df.columns else None), thr=0.92)
else:
    print('\nLocal CSV not found at:', str(local_csv.resolve()))

## Build duplicate/near-duplicate groups (`group_id`)
This assigns each row a `group_id` so that **near-duplicate text variants are treated as one group**.
We will use `group_id` in the split step so groups can’t leak across train/val/test.

In [10]:
import re
import numpy as np
import pandas as pd

USE_GROUP_SPLIT = True
GROUP_SIM_THRESHOLD = 0.92  # same threshold you used in the audit
GROUP_N_NEIGHBORS = 6

def _norm_for_grouping(t: str) -> str:
    t = str(t).lower().strip()
    t = re.sub(r"\s+", " ", t)
    t = re.sub(r"[^a-z0-9\s]", "", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

class _UnionFind:
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
    def find(self, x: int) -> int:
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x
    def union(self, a: int, b: int) -> None:
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        if self.rank[ra] < self.rank[rb]:
            self.parent[ra] = rb
        elif self.rank[ra] > self.rank[rb]:
            self.parent[rb] = ra
        else:
            self.parent[rb] = ra
            self.rank[ra] += 1

def build_group_ids(texts: pd.Series, thr: float = 0.92, n_neighbors: int = 6) -> np.ndarray:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.neighbors import NearestNeighbors

    raw = texts.astype(str).fillna('')
    norm = raw.map(_norm_for_grouping)
    n = len(norm)
    uf = _UnionFind(n)

    # 1) Union exact duplicates by normalized text
    buckets = {}
    for i, key in enumerate(norm.tolist()):
        if not key:
            continue
        if key in buckets:
            uf.union(i, buckets[key])
        else:
            buckets[key] = i

    # 2) Union near-duplicates using char-ngrams + cosine similarity
    nonempty_mask = norm.str.len() > 0
    texts_nonempty = norm.loc[nonempty_mask].tolist()
    idx_map = np.flatnonzero(nonempty_mask.to_numpy())
    if len(texts_nonempty) >= 2:
        vec = TfidfVectorizer(min_df=2, ngram_range=(3, 5), analyzer='char_wb')
        X = vec.fit_transform(texts_nonempty)
        k = min(int(n_neighbors), len(texts_nonempty))
        nn = NearestNeighbors(n_neighbors=k, metric='cosine', algorithm='brute').fit(X)
        dists, idxs = nn.kneighbors(X)
        for i in range(len(texts_nonempty)):
            for jpos in range(1, idxs.shape[1]):
                j = int(idxs[i, jpos])
                sim = 1.0 - float(dists[i, jpos])
                if sim < thr:
                    continue
                a = int(idx_map[i])
                b = int(idx_map[j])
                uf.union(a, b)

    roots = np.array([uf.find(i) for i in range(n)], dtype=int)
    # compress roots into 0..G-1 ids for readability
    unique_roots, group_ids = np.unique(roots, return_inverse=True)
    return group_ids

# Attach group ids to df (used by split)
df = df.copy()
df['group_id'] = build_group_ids(df['text'], thr=GROUP_SIM_THRESHOLD, n_neighbors=GROUP_N_NEIGHBORS)

group_sizes = df['group_id'].value_counts()
print('USE_GROUP_SPLIT:', USE_GROUP_SPLIT)
print('Groups:', int(group_sizes.shape[0]), 'Rows:', len(df))
print('Largest group size:', int(group_sizes.max()))
print('Groups with size>=2:', int((group_sizes >= 2).sum()))
print('Rows in groups size>=2:', int(group_sizes[group_sizes >= 2].sum()))
display(group_sizes.head(10))

## Cell 6 — Train/Val/Test split
Creates consistent splits with stratification.

**Expected output:** split sizes and per-split label distribution sanity checks.

In [11]:
# Group-aware split (recommended): prevents duplicate/near-duplicate leakage
if bool(globals().get('USE_GROUP_SPLIT', False)) and ('group_id' in df.columns):
    try:
        from sklearn.model_selection import StratifiedGroupKFold
        y = df['label']
        groups = df['group_id']
        n_splits_test = max(2, int(round(1.0 / float(TEST_SIZE))))
        sgkf_test = StratifiedGroupKFold(n_splits=n_splits_test, shuffle=True, random_state=SPLIT_SEED)
        train_val_idx, test_idx = next(sgkf_test.split(df, y=y, groups=groups))
        train_val_df = df.iloc[train_val_idx].reset_index(drop=True)
        test_df = df.iloc[test_idx].reset_index(drop=True)

        val_relative = VAL_SIZE / (1.0 - TEST_SIZE)
        n_splits_val = max(2, int(round(1.0 / float(val_relative))))
        sgkf_val = StratifiedGroupKFold(n_splits=n_splits_val, shuffle=True, random_state=SPLIT_SEED)
        y_tv = train_val_df['label']
        g_tv = train_val_df['group_id']
        train_idx, val_idx = next(sgkf_val.split(train_val_df, y=y_tv, groups=g_tv))
        train_df = train_val_df.iloc[train_idx].reset_index(drop=True)
        val_df = train_val_df.iloc[val_idx].reset_index(drop=True)

        # Safety: ensure no group leaks across splits
        tr_g = set(train_df['group_id'].tolist())
        va_g = set(val_df['group_id'].tolist())
        te_g = set(test_df['group_id'].tolist())
        assert tr_g.isdisjoint(va_g) and tr_g.isdisjoint(te_g) and va_g.isdisjoint(te_g)
        print('Used StratifiedGroupKFold (leakage-safe).')
    except Exception as e:
        print('Group-aware split requested, but failed; falling back to random stratified split. Error:', repr(e))
        USE_GROUP_SPLIT = False

# Fallback: classic stratified random split
if not bool(globals().get('USE_GROUP_SPLIT', False)) or ('group_id' not in df.columns):
    train_val_df, test_df = train_test_split(
        df,
        test_size=TEST_SIZE,
        random_state=SPLIT_SEED,
        stratify=df['label'],
    )

    val_relative = VAL_SIZE / (1.0 - TEST_SIZE)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_relative,
        random_state=SPLIT_SEED,
        stratify=train_val_df['label'],
    )

print('Train:', len(train_df), 'Val:', len(val_df), 'Test:', len(test_df))

def show_split_stats(name, split_df):
    vc = split_df['label'].value_counts(normalize=True).sort_index()
    print(f'-- {name} label distribution (fraction) --')
    print(vc)

show_split_stats('train', train_df)
show_split_stats('val', val_df)
show_split_stats('test', test_df)

## Cell 7 — Label mapping
Creates a stable `label2id` and `id2label`.

**Expected output:** the mapping table.

In [12]:
labels_sorted = sorted(df['label'].unique().tolist())
label2id = {lbl: i for i, lbl in enumerate(labels_sorted)}
id2label = {i: lbl for lbl, i in label2id.items()}

print('Labels:', labels_sorted)
print('label2id:', label2id)

def encode_labels(split_df):
    out = split_df.copy()
    out['label_id'] = out['label'].map(label2id)
    if out['label_id'].isna().any():
        raise ValueError('Found unknown labels during encoding.')
    out['label_id'] = out['label_id'].astype(int)
    return out

train_df = encode_labels(train_df)
val_df = encode_labels(val_df)
test_df = encode_labels(test_df)

display(pd.DataFrame({'label': labels_sorted, 'id': [label2id[x] for x in labels_sorted]}))

## Cell 8 — Metrics + plotting helpers
**Expected output:** none (helpers only).

In [13]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    acc = accuracy_score(labels, preds)
    macro_f1 = f1_score(labels, preds, average='macro')
    weighted_f1 = f1_score(labels, preds, average='weighted')
    return {
        'accuracy': acc,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
    }

def save_json(path, obj):
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)

def plot_and_save_confusion_matrix(y_true, y_pred, labels, out_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=labels, yticklabels=labels, cmap='Blues')
    plt.title('Confusion Matrix (Test)')
    plt.ylabel('True')
    plt.xlabel('Predicted')
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.show()

def plot_and_save_learning_curves(trainer_state_log_history, out_path):
    steps = []
    train_loss = []
    eval_steps = []
    eval_loss = []

    for row in trainer_state_log_history:
        if 'loss' in row and 'eval_loss' not in row:
            if 'step' in row:
                steps.append(row['step'])
                train_loss.append(row['loss'])
        if 'eval_loss' in row:
            if 'step' in row:
                eval_steps.append(row['step'])
                eval_loss.append(row['eval_loss'])

    plt.figure(figsize=(8, 5))
    if steps:
        plt.plot(steps, train_loss, label='train_loss')
    if eval_steps:
        plt.plot(eval_steps, eval_loss, label='eval_loss')
    plt.title('Learning Curves')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.show()

## Cell 9 — Tokenization + dataset preparation helper
We tokenize separately per model (because tokenizers differ across model families).

**Expected output:** none (helpers only).

In [14]:
def df_to_hf_dataset(split_df):
    from datasets import Dataset
    return Dataset.from_pandas(
        split_df[['text', 'label_id']].rename(columns={'label_id': 'labels'}).reset_index(drop=True)
    )

train_hf = df_to_hf_dataset(train_df)
val_hf = df_to_hf_dataset(val_df)
test_hf = df_to_hf_dataset(test_df)

def make_tokenized_datasets(model_ckpt: str):
    tokenizer = AutoTokenizer.from_pretrained(model_ckpt, use_fast=True)

    def tokenize_batch(batch):
        return tokenizer(batch['text'], truncation=True, max_length=MAX_LENGTH)

    train_tok = train_hf.map(tokenize_batch, batched=True)
    val_tok = val_hf.map(tokenize_batch, batched=True)
    test_tok = test_hf.map(tokenize_batch, batched=True)

    cols_to_keep = ['input_ids', 'attention_mask', 'labels']
    train_tok = train_tok.remove_columns([c for c in train_tok.column_names if c not in cols_to_keep])
    val_tok = val_tok.remove_columns([c for c in val_tok.column_names if c not in cols_to_keep])
    test_tok = test_tok.remove_columns([c for c in test_tok.column_names if c not in cols_to_keep])

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    return tokenizer, data_collator, train_tok, val_tok, test_tok

## Cell 10 — Training helper (Trainer)
This function runs a training configuration and saves artifacts into a per-run folder.

**Expected output:** training logs, evaluation metrics, and saved plots/files under `runs/`.

In [15]:
import inspect
import shutil
from dataclasses import dataclass
from datetime import datetime

@dataclass
class RunResult:
    run_name: str
    model_ckpt: str
    output_dir: str
    eval_metrics: dict
    test_metrics: dict

def run_training(run_cfg: dict) -> RunResult:
    run_name = run_cfg["run_name"]
    model_ckpt = run_cfg["model_ckpt"]

    timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
    out_dir = os.path.join(BASE_OUTPUT_DIR, f"{timestamp}_{run_name}_seed{SEED}")
    os.makedirs(out_dir, exist_ok=True)

    tokenizer, data_collator, train_tok, val_tok, test_tok = make_tokenized_datasets(model_ckpt)

    config = AutoConfig.from_pretrained(
        model_ckpt,
        num_labels=len(labels_sorted),
        id2label=id2label,
        label2id=label2id,
    )

    if run_cfg.get("dropout") is not None:
        d = float(run_cfg["dropout"])
        for attr in (
            "hidden_dropout_prob",
            "attention_probs_dropout_prob",
            "dropout",
            "attention_dropout",
            "classifier_dropout",
        ):
            if hasattr(config, attr):
                setattr(config, attr, d)

    model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, config=config)

    trainer_dir = os.path.join(out_dir, "trainer")
    os.makedirs(trainer_dir, exist_ok=True)

    save_strategy = str(run_cfg.get("save_strategy", "epoch"))

    ta_kwargs = dict(
        output_dir=trainer_dir,
        learning_rate=LR,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS,
        weight_decay=float(run_cfg.get("weight_decay", 0.0)),
        save_strategy=save_strategy,
        load_best_model_at_end=False,
        metric_for_best_model="macro_f1",
        greater_is_better=True,
        logging_strategy="steps",
        logging_steps=10,
        save_total_limit=1,
        report_to=[],
        seed=SEED,
    )

    load_best = run_cfg.get("load_best_model_at_end", None)
    if load_best is None:
        load_best = bool(run_cfg.get("use_early_stopping", False)) and save_strategy != "no"
    ta_kwargs["load_best_model_at_end"] = bool(load_best)

    ta_sig = inspect.signature(TrainingArguments.__init__)
    if "eval_strategy" in ta_sig.parameters:
        ta_kwargs["eval_strategy"] = "epoch"
    else:
        ta_kwargs["evaluation_strategy"] = "epoch"

    if "save_only_model" in ta_sig.parameters:
        ta_kwargs["save_only_model"] = bool(run_cfg.get("save_only_model", True))
    if "save_safetensors" in ta_sig.parameters:
        ta_kwargs["save_safetensors"] = True

    args = TrainingArguments(**ta_kwargs)

    callbacks = []
    if run_cfg.get("use_early_stopping", False):
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=2))

    trainer_kwargs = dict(
        model=model,
        args=args,
        train_dataset=train_tok,
        eval_dataset=val_tok,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=callbacks,
    )
    trainer_sig = inspect.signature(Trainer.__init__)
    if "tokenizer" in trainer_sig.parameters:
        trainer_kwargs["tokenizer"] = tokenizer

    trainer = Trainer(**trainer_kwargs)

    t0 = time.time()
    _train_output = trainer.train()
    train_time = time.time() - t0

    eval_metrics = trainer.evaluate()

    test_pred = trainer.predict(test_tok)
    test_logits = test_pred.predictions
    test_labels = test_pred.label_ids
    test_preds = np.argmax(test_logits, axis=1)

    test_metrics = {
        "accuracy": float(accuracy_score(test_labels, test_preds)),
        "macro_f1": float(f1_score(test_labels, test_preds, average="macro")),
        "weighted_f1": float(f1_score(test_labels, test_preds, average="weighted")),
        "classification_report": classification_report(
            test_labels,
            test_preds,
            target_names=labels_sorted,
            output_dict=True,
        ),
    }

    run_info = {
        "run_name": run_name,
        "model_ckpt": model_ckpt,
        "seed": SEED,
        "max_length": MAX_LENGTH,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "learning_rate": LR,
        "weight_decay": float(run_cfg.get("weight_decay", 0.0)),
        "use_early_stopping": bool(run_cfg.get("use_early_stopping", False)),
        "dropout": run_cfg.get("dropout", None),
        "train_time_sec": float(train_time),
        "label2id": label2id,
    }
    save_json(os.path.join(out_dir, "config.json"), run_info)
    save_json(os.path.join(out_dir, "eval_metrics.json"), eval_metrics)
    save_json(os.path.join(out_dir, "test_metrics.json"), test_metrics)

    if bool(run_cfg.get("save_predictions", True)):
        pred_df = test_df[["text", "label", "label_id"]].copy()
        pred_df["pred_id"] = test_preds
        pred_df["pred_label"] = pred_df["pred_id"].map(id2label)
        pred_df.to_csv(os.path.join(out_dir, "test_predictions.csv"), index=False, encoding="utf-8")

    if bool(run_cfg.get("save_plots", True)):
        plot_and_save_confusion_matrix(
            test_labels, test_preds, labels_sorted, os.path.join(out_dir, "confusion_matrix.png")
        )
        plot_and_save_learning_curves(
            trainer.state.log_history, os.path.join(out_dir, "learning_curves.png")
        )

    if bool(run_cfg.get("save_final_model", True)):
        model_dir = os.path.join(out_dir, "model")
        trainer.save_model(model_dir)
        tokenizer.save_pretrained(model_dir)

    if bool(run_cfg.get("cleanup_trainer_dir", False)):
        try:
            shutil.rmtree(trainer_dir, ignore_errors=True)
        except Exception:
            pass

    cleaned_run_dir = False
    if bool(run_cfg.get("cleanup_run_dir", False)):
        try:
            shutil.rmtree(out_dir, ignore_errors=True)
            cleaned_run_dir = True
        except Exception:
            cleaned_run_dir = False

    if cleaned_run_dir:
        print("Sweep mode: deleted run folder:", out_dir)
    else:
        print("Saved run artifacts to:", out_dir)

    return RunResult(
        run_name=run_name,
        model_ckpt=model_ckpt,
        output_dir=out_dir,
        eval_metrics=eval_metrics,
        test_metrics=test_metrics,
    )

## Cell 11 — Run all experiments
This will train and evaluate each run in `RUNS`.

**Expected output:** training logs per run + saved artifacts under `runs/`.

In [None]:
results = []
for cfg in RUNS:
    print('\n' + '='*90)
    print('Starting:', cfg['run_name'], 'Model:', cfg['model_ckpt'])
    print('='*90)
    res = run_training(cfg)
    results.append(res)

print('\nDone. Completed runs:', [r.run_name for r in results])

## Cell 12 — Summarize results
We summarize eval/test metrics across runs in a table.

**Expected output:** a table comparing `accuracy`, `macro_f1`, and `weighted_f1`.

In [None]:
summary_rows = []
for r in results:
    row = {
        'run_name': r.run_name,
        'model': r.model_ckpt,
        'eval_accuracy': float(r.eval_metrics.get('eval_accuracy', np.nan)),
        'eval_macro_f1': float(r.eval_metrics.get('eval_macro_f1', np.nan)),
        'eval_weighted_f1': float(r.eval_metrics.get('eval_weighted_f1', np.nan)),
        'test_accuracy': float(r.test_metrics.get('accuracy', np.nan)),
        'test_macro_f1': float(r.test_metrics.get('macro_f1', np.nan)),
        'test_weighted_f1': float(r.test_metrics.get('weighted_f1', np.nan)),
        'output_dir': r.output_dir,
    }
    summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows).sort_values(by='test_macro_f1', ascending=False)
display(summary_df)

summary_path = os.path.join(BASE_OUTPUT_DIR, 'summary_latest.json')
save_json(summary_path, summary_rows)
print('Saved summary to:', summary_path)

## Cell 13 — Hyperparameter sweep runner
This runs a **small grid search** over a few values of `(dropout, weight_decay, learning_rate)` and repeats each setting across multiple seeds.

**Selection uses VALIDATION macro-F1** (not test), so we keep test as an unbiased final evaluation.

**Tip:** keep this small. Each trial is a full fine-tune run.

In [None]:
import itertools
import math
import shutil

# ----- Sweep hyperparameter grid -----
# Total trials = len(SWEEP_DROPOUT) * len(SWEEP_WEIGHT_DECAY) * len(SWEEP_LR) * len(SWEEP_SEEDS)
SWEEP_DROPOUT = [0.0, 0.1]
SWEEP_WEIGHT_DECAY = [0.0, 0.001]
SWEEP_LR = [1e-5, 2e-5]
SWEEP_SEEDS = [13, 42, 123]  # 3 seeds minimum for stability

MAX_TRIALS = 50  # safety guard

# Disk safety
SWEEP_DELETE_OLD_SWEEP_RUNS = True
SWEEP_MIN_FREE_GB = 1.0

# For sweeps, skip heavy artifacts per trial
SWEEP_SAVE_FINAL_MODEL = False
SWEEP_SAVE_PREDICTIONS = False
SWEEP_SAVE_PLOTS = False
SWEEP_CLEANUP_TRAINER_DIR = True
SWEEP_CLEANUP_RUN_DIR = True

def _disk_free_gb(path: str) -> float:
    du = shutil.disk_usage(path)
    return float(du.free) / (1024 ** 3)

def _delete_old_sweep_runs(base_dir: str) -> int:
    if not os.path.isdir(base_dir):
        return 0
    deleted = 0
    for name in os.listdir(base_dir):
        full = os.path.join(base_dir, name)
        if os.path.isdir(full) and '_S_' in name:
            shutil.rmtree(full, ignore_errors=True)
            deleted += 1
    return deleted

def _fmt_float(x) -> str:
    if x is None:
        return 'none'
    if isinstance(x, (int, np.integer)):
        return str(int(x))
    if math.isclose(float(x), 0.0):
        return '0'
    return f'{float(x):g}'.replace('.', 'p')

def _make_sweep_run_name(dropout, weight_decay, lr) -> str:
    return f"S_bert_do{_fmt_float(dropout)}_wd{_fmt_float(weight_decay)}_lr{_fmt_float(lr)}"

base_cfg = {
    'run_name': 'SWEEP',
    'model_ckpt': 'bert-base-uncased',
    'use_early_stopping': True,
    'weight_decay': 0.0,
    'dropout': None,
    'save_strategy': 'no',
    'load_best_model_at_end': False,
    'save_only_model': True,
    'save_final_model': SWEEP_SAVE_FINAL_MODEL,
    'save_predictions': SWEEP_SAVE_PREDICTIONS,
    'save_plots': SWEEP_SAVE_PLOTS,
    'cleanup_trainer_dir': SWEEP_CLEANUP_TRAINER_DIR,
    'cleanup_run_dir': SWEEP_CLEANUP_RUN_DIR,
}

grid = list(itertools.product(SWEEP_DROPOUT, SWEEP_WEIGHT_DECAY, SWEEP_LR, SWEEP_SEEDS))

try:
    free_gb = _disk_free_gb(BASE_OUTPUT_DIR)
    print(f'Free disk space: {free_gb:.2f} GB')
    if SWEEP_DELETE_OLD_SWEEP_RUNS and free_gb < SWEEP_MIN_FREE_GB:
        n_del = _delete_old_sweep_runs(BASE_OUTPUT_DIR)
        print(f'Deleted {n_del} old sweep folders. Free now: {_disk_free_gb(BASE_OUTPUT_DIR):.2f} GB')
except Exception as e:
    print('Disk check skipped:', repr(e))

print(f'Planned trials: {len(grid)}')
if len(grid) > MAX_TRIALS:
    raise ValueError(f'Sweep too large ({len(grid)} > {MAX_TRIALS}). Reduce grid or raise MAX_TRIALS.')

_orig_seed = SEED
_orig_lr = LR

sweep_rows = []
sweep_results = []

for dropout, weight_decay, lr, seed in grid:
    SEED = int(seed)
    LR = float(lr)
    set_seed(SEED)

    trial_cfg = dict(base_cfg)
    trial_cfg['run_name'] = _make_sweep_run_name(dropout, weight_decay, lr)
    trial_cfg['dropout'] = float(dropout)
    trial_cfg['weight_decay'] = float(weight_decay)

    print('\n' + '-'*90)
    print('Trial:', trial_cfg['run_name'], 'seed=', SEED)
    print('Params:', {'dropout': dropout, 'weight_decay': weight_decay, 'lr': lr})
    print('-'*90)

    rr = run_training(trial_cfg)
    sweep_results.append(rr)

    # --- CHANGED: record VALIDATION metrics for selection (not test) ---
    sweep_rows.append({
        'run_name': rr.run_name,
        'model': rr.model_ckpt,
        'dropout': float(dropout),
        'weight_decay': float(weight_decay),
        'lr': float(lr),
        'seed': int(seed),
        # Selection metrics (validation)
        'val_accuracy': float(rr.eval_metrics.get('eval_accuracy', np.nan)),
        'val_macro_f1': float(rr.eval_metrics.get('eval_macro_f1', np.nan)),
        'val_weighted_f1': float(rr.eval_metrics.get('eval_weighted_f1', np.nan)),
        # Also record test for reporting (but NOT used to pick best)
        'test_accuracy': float(rr.test_metrics.get('accuracy', np.nan)),
        'test_macro_f1': float(rr.test_metrics.get('macro_f1', np.nan)),
        'test_weighted_f1': float(rr.test_metrics.get('weighted_f1', np.nan)),
        'output_dir': None if SWEEP_CLEANUP_RUN_DIR else rr.output_dir,
    })

SEED = _orig_seed
LR = _orig_lr
set_seed(SEED)

sweep_df = pd.DataFrame(sweep_rows)
display(sweep_df.sort_values(by='val_macro_f1', ascending=False))

# --- Aggregate across seeds using VALIDATION macro-F1 ---
group_cols = ['dropout', 'weight_decay', 'lr']
agg_df = (
    sweep_df.groupby(group_cols)
    .agg(
        n_trials=('val_macro_f1', 'count'),
        mean_val_macro_f1=('val_macro_f1', 'mean'),
        std_val_macro_f1=('val_macro_f1', 'std'),
        mean_val_accuracy=('val_accuracy', 'mean'),
        # Also aggregate test for later reference
        mean_test_macro_f1=('test_macro_f1', 'mean'),
        std_test_macro_f1=('test_macro_f1', 'std'),
    )
    .sort_values(by=['mean_val_macro_f1', 'mean_val_accuracy'], ascending=False)
    .reset_index()
)
display(agg_df)

best = agg_df.iloc[0].to_dict() if len(agg_df) else None
print('Best setting by mean VALIDATION macro-F1:')
print(best)

## Final best run 

In [16]:
# Requires: run Cell 27 first so `best` is available.

final_cfg = {
    'run_name': 'FINAL_best_from_sweep',
    'model_ckpt': 'bert-base-uncased',
    'use_early_stopping': True,
    'dropout': float(0.1),
    'weight_decay': float(0.001),
    'save_predictions': True,
    'save_plots': True,
    'save_final_model': True,
    'save_strategy': 'epoch',
    'load_best_model_at_end': True,
    'save_only_model': True,
    'cleanup_trainer_dir': True,
    'cleanup_run_dir': False,
}

_orig_seed = SEED
_orig_lr = LR
SEED = 42
LR = float(2e-5)
set_seed(SEED)

print('Final run config:', {'dropout': final_cfg['dropout'], 'weight_decay': final_cfg['weight_decay']}, 'LR=', LR, 'SEED=', SEED)
final_result = run_training(final_cfg)

# Restore globals
SEED = _orig_seed
LR = _orig_lr
set_seed(SEED)

print('\n=== Final Test Metrics (real-world estimate) ===')
print('test_accuracy:', final_result.test_metrics.get('accuracy'))
print('test_macro_f1:', final_result.test_metrics.get('macro_f1'))
print('test_weighted_f1:', final_result.test_metrics.get('weighted_f1'))
print('\nFinal run saved at:', final_result.output_dir)

## Re-run FINAL config across 3 seeds (report mean/std)
This runs the *same* final hyperparameters across multiple **training seeds** on the **same fixed split** (controlled by `SPLIT_SEED`).

Report the stability as mean/std of **test macro-F1** across seeds.

In [14]:
# New — run final config across multiple training seeds
import math
from datetime import datetime

FINAL_SEEDS = [13, 42, 123]  # adjust if you want more

# Make sure final_cfg exists (Cell 28 defines it)
if 'final_cfg' not in globals():
    raise RuntimeError('final_cfg not found. Run the Final config cell first.')

_orig_seed = SEED
_orig_lr = LR

rows = []
results_multi = []
for s in FINAL_SEEDS:
    SEED = int(s)
    LR = float(2e-5)  # keep fixed for fair seed comparison
    set_seed(SEED)

    cfg = dict(final_cfg)
    cfg['run_name'] = f"FINAL_seed{SEED}"
    print('\n' + '='*90)
    print('Running:', cfg['run_name'], 'SPLIT_SEED=', SPLIT_SEED, 'TRAIN_SEED=', SEED, 'LR=', LR)
    print('='*90)
    rr = run_training(cfg)
    results_multi.append(rr)
    rows.append({
        'seed': SEED,
        'test_accuracy': float(rr.test_metrics.get('accuracy', math.nan)),
        'test_macro_f1': float(rr.test_metrics.get('macro_f1', math.nan)),
        'test_weighted_f1': float(rr.test_metrics.get('weighted_f1', math.nan)),
        'val_macro_f1': float(rr.eval_metrics.get('eval_macro_f1', math.nan)),
        'output_dir': rr.output_dir,
    })

# Restore globals
SEED = _orig_seed
LR = _orig_lr
set_seed(SEED)

multi_df = pd.DataFrame(rows).sort_values('seed')
display(multi_df)

print('\n=== Stability summary (across seeds) ===')
print('macro-F1 mean:', float(multi_df['test_macro_f1'].mean()))
print('macro-F1 std :', float(multi_df['test_macro_f1'].std(ddof=1)))
print('acc mean    :', float(multi_df['test_accuracy'].mean()))
print('acc std     :', float(multi_df['test_accuracy'].std(ddof=1)))

out_path = Path(BASE_OUTPUT_DIR) / f"final_multiseed_{datetime.now().strftime('%Y-%m-%d_%H%M%S')}.csv"
multi_df.to_csv(out_path, index=False, encoding='utf-8')
print('Saved multiseed summary to:', str(out_path))

##  Real life test: run predictions on your own texts/CSVs
Use this when you have new DREAMS-like inputs (captions or dream narratives) that are **not** from the training dataset.

It generates a `predictions.csv` with predicted CHIME label + confidence.

In [None]:
import math
from pathlib import Path

import pandas as pd
import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# A) External data source
USE_HF_EXTERNAL_DATASET = True
HF_EXTERNAL_DATASET_NAME = 'ashh007/CHIME_external_evaluation'
HF_EXTERNAL_SPLIT = 'train'  # change if your dataset uses a different split
HF_TEXT_COLUMN_CANDIDATES = ['text', 'CAPTIONS', 'captions', 'caption', 'sentence', 'content']
HF_LABEL_COLUMN_CANDIDATES = ['label', 'labels', 'LABEL', 'Labels', 'category', 'class', 'target']
HF_EXTERNAL_FILENAME = 'external_eval.csv'  # used by fallback loader (raw file download)

# If you want local CSV instead, set USE_HF_EXTERNAL_DATASET=False
EXTERNAL_CSV = Path('ml_experiments_anish') / 'external_eval.csv'
CSV_TEXT_COLUMN = 'text'
CSV_LABEL_COLUMN = 'labels' 

# B) Model source
# IMPORTANT: you must have a trained model accessible either locally OR on Hugging Face Hub.
USE_HF_MODEL = False
HF_MODEL_ID = 'PUT_YOUR_USERNAME/PUT_YOUR_MODEL'  # if USE_HF_MODEL=True, set this to your fine-tuned model repo
LOCAL_MODEL_DIR = None  # set to a specific .../model folder if you have one locally

MAX_LEN = int(globals().get('MAX_LENGTH', 128))

def _pick_first_existing(columns: list[str], candidates: list[str]) -> str | None:
    colset = set(columns)
    for c in candidates:
        if c in colset:
            return c
    return None

def _load_hf_external_fallback(repo_id: str) -> pd.DataFrame:
    """Fallback for HF dataset repos that contain a CSV with unquoted commas in the text field.

    The HF `datasets` CSV loader uses pandas' C-engine which expects proper quoting. If the
    file is `text,label` but text contains commas and isn't quoted, it will fail. This loader
    downloads the raw file and splits each row on the *last* comma, assuming the label field
    never contains commas.
    """
    from huggingface_hub import hf_hub_download

    csv_path = hf_hub_download(
        repo_id=repo_id, repo_type='dataset', filename=HF_EXTERNAL_FILENAME
    )
    rows: list[dict[str, str]] = []
    bad = 0
    with open(csv_path, 'r', encoding='utf-8') as f:
        header = f.readline().strip()
        # Expect 2 columns; use last header token as label name when possible
        header_parts = [h.strip() for h in header.split(',') if h.strip()]
        label_name = header_parts[-1] if len(header_parts) >= 2 else 'label'
        for line in f:
            line = line.rstrip('\n')
            if not line.strip():
                continue
            if ',' not in line:
                bad += 1
                continue
            text_part, label_part = line.rsplit(',', 1)
            text_part = text_part.strip()
            label_part = label_part.strip()
            # Strip surrounding quotes if present (common when partially cleaned)
            if len(text_part) >= 2 and text_part[0] == '"' and text_part[-1] == '"':
                text_part = text_part[1:-1].replace('""', '"')
            rows.append({'text': text_part, 'label': label_part})
    if bad:
        print(f'Fallback parser skipped {bad} malformed line(s).')
    df = pd.DataFrame(rows)
    if df.empty:
        raise RuntimeError(f'Fallback parser produced 0 rows from {repo_id}/{HF_EXTERNAL_FILENAME}')
    # Keep a copy of original label column name if you want it
    if label_name != 'label':
        df = df.rename(columns={'label': label_name})
        # also provide normalized 'label' used by downstream code
        df['label'] = df[label_name].astype(str)
    else:
        df['label'] = df['label'].astype(str)
    df['text'] = df['text'].astype(str).fillna('')
    return df

def _load_external_df() -> pd.DataFrame:
    if USE_HF_EXTERNAL_DATASET:
        if HF_EXTERNAL_DATASET_NAME.startswith('PUT_YOUR_'):
            raise ValueError('Set HF_EXTERNAL_DATASET_NAME to your dataset id, e.g. "yourname/external-eval"')
        try:
            ds = load_dataset(HF_EXTERNAL_DATASET_NAME)
            split = HF_EXTERNAL_SPLIT if HF_EXTERNAL_SPLIT in ds else list(ds.keys())[0]
            base = ds[split]
            df_ext = base.to_pandas()
            tcol = _pick_first_existing(list(df_ext.columns), HF_TEXT_COLUMN_CANDIDATES)
            lcol = _pick_first_existing(list(df_ext.columns), HF_LABEL_COLUMN_CANDIDATES)
            if tcol is None:
                raise ValueError(f'Could not find a text column in HF dataset. Found: {list(df_ext.columns)}')
            out = df_ext.copy()
            out = out.rename(columns={tcol: 'text', **({lcol: 'label'} if lcol else {})})
            out['text'] = out['text'].astype(str).fillna('')
            if 'label' in out.columns:
                out['label'] = out['label'].astype(str)
            return out
        except Exception as e:
            # This is commonly caused by unquoted commas in the text field of a CSV file
            msg = str(e)
            print('HF external dataset load failed; using fallback raw CSV parser.')
            print('Original error (truncated):', msg[:300])
            return _load_hf_external_fallback(HF_EXTERNAL_DATASET_NAME)
    else:
        if not EXTERNAL_CSV.exists():
            raise FileNotFoundError(f'External CSV not found at: {EXTERNAL_CSV.resolve()}')
        # Use python engine to be tolerant of commas in text if the file isn't perfectly quoted
        df_ext = pd.read_csv(EXTERNAL_CSV, engine='python')
        if CSV_TEXT_COLUMN not in df_ext.columns:
            raise ValueError(f"External CSV must contain column '{CSV_TEXT_COLUMN}'. Found: {list(df_ext.columns)}")
        out = df_ext.copy()
        out = out.rename(columns={CSV_TEXT_COLUMN: 'text'})
        out['text'] = out['text'].astype(str).fillna('')
        if CSV_LABEL_COLUMN and (CSV_LABEL_COLUMN in out.columns):
            out = out.rename(columns={CSV_LABEL_COLUMN: 'label'})
            out['label'] = out['label'].astype(str)
        return out

def _resolve_model_dir() -> Path:
    # 1) Hugging Face model (best if you trained on Colab and pushed the model)
    if USE_HF_MODEL:
        if HF_MODEL_ID.startswith('PUT_YOUR_'):
            raise ValueError('Set HF_MODEL_ID to your fine-tuned model id, e.g. "yourname/chime-bert"')
        return Path(HF_MODEL_ID)  # transformers can load from hub id as a string

    # 2) Explicit local model dir (if you copied from Colab or trained locally)
    if LOCAL_MODEL_DIR is not None:
        p = Path(LOCAL_MODEL_DIR)
        if p.exists():
            return p
        raise FileNotFoundError(f'LOCAL_MODEL_DIR does not exist: {p}')

    # 3) Try to use final_result if present in this kernel
    if 'final_result' in globals() and getattr(final_result, 'output_dir', None):
        cand = Path(final_result.output_dir) / 'model'
        if cand.exists():
            return cand

    # 4) Auto-search common local runs folder
    default_runs = Path('ml_experiments_anish') / 'experiment1_chime_text_overfitting' / 'runs'
    if default_runs.exists():
        model_dirs = sorted(default_runs.glob('**/model'), key=lambda p: p.parent.name)
        if model_dirs:
            return model_dirs[-1]

    raise FileNotFoundError(
        'No trained model found. If you trained on Colab, either (a) push your model to Hugging Face Hub and set USE_HF_MODEL=True, '
        'or (b) copy the saved run folder locally and set LOCAL_MODEL_DIR to that folder.'
    )

def predict_texts(texts: list[str], model_ref, batch_size: int = 16):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(str(model_ref), use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(str(model_ref)).to(device)
    model.eval()

    id2label = {int(k): v for k, v in getattr(model.config, 'id2label', {}).items()}
    label2id = {k: int(v) for k, v in getattr(model.config, 'label2id', {}).items()}

    all_pred_ids: list[int] = []
    all_conf: list[float] = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            enc = tokenizer(batch, truncation=True, max_length=MAX_LEN, padding=True, return_tensors='pt')
            enc = {k: v.to(device) for k, v in enc.items()}
            logits = model(**enc).logits
            probs = torch.softmax(logits, dim=-1)
            conf, pred = torch.max(probs, dim=-1)
            all_pred_ids.extend([int(x) for x in pred.cpu().tolist()])
            all_conf.extend([float(x) for x in conf.cpu().tolist()])
    return all_pred_ids, all_conf, id2label, label2id

# Run external eval
ext = _load_external_df()
model_ref = _resolve_model_dir()
print('External rows:', len(ext))
print('Model source:', ('HF Hub' if USE_HF_MODEL else 'local/search'))
print('Model ref:', str(model_ref))

pred_ids, conf, id2label, label2id = predict_texts(ext['text'].tolist(), model_ref=model_ref, batch_size=16)
ext['pred_id'] = pred_ids
ext['pred_label'] = [id2label.get(int(i), str(i)) for i in ext['pred_id'].tolist()]
ext['pred_conf'] = conf

# Optional scoring if labels available
if 'label' in ext.columns:
    label_ids = ext['label'].astype(str).map(label2id)
    ok = label_ids.notna()
    if ok.any():
        from sklearn.metrics import accuracy_score, f1_score
        y_true = label_ids[ok].astype(int).to_numpy()
        y_pred = ext.loc[ok, 'pred_id'].astype(int).to_numpy()
        print('External accuracy:', float(accuracy_score(y_true, y_pred)))
        print('External macro_f1:', float(f1_score(y_true, y_pred, average='macro')))
        print('External weighted_f1:', float(f1_score(y_true, y_pred, average='weighted')))
    else:
        print('External labels present, but they do not match model labels:', list(label2id.keys()))

# Save a local copy for inspection
out_path = Path('ml_experiments_anish') / 'external_eval_predictions.csv'
out_path.parent.mkdir(parents=True, exist_ok=True)
ext.to_csv(out_path, index=False, encoding='utf-8')
print('Wrote predictions to:', str(out_path.resolve()))
display(ext.head(10))