In [None]:
# ============================================================
# Jigsaw — SCRATCH Hyperparameter Tuning (Grid/Random, Resumable)
# * Windows/VS Code/Jupyter safe DataLoader (num_workers=0 on Windows)
# * tqdm console progress bars (no ipywidgets errors)
# * Early stopping
# * TensorBoard logging (train loss, val AUC/ACC, hparams snapshot)
# * Skips combos already logged in trial_results_scratch.csv
# * Optional per-trial checkpoints
# * Best trial of the session writes submission_scratch.(csv|xlsx)
# ============================================================

import os, re, json, time, random, hashlib, platform
from datetime import datetime, timezone
from itertools import product
from typing import Dict, Any, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# ------------------- Paths & switches -------------------
TRAIN_PATH = "train.csv"
TEST_PATH  = "test.csv"
SUB_PATH   = "sample_submission.csv"

RESULTS_CSV = "trial_results_scratch.csv"    # resumable log (append)
CHECKPOINT_DIR = "checkpoints_scratch"       # per-trial .pt files
SAVE_CHECKPOINTS = True

# TensorBoard
ENABLE_TENSORBOARD = True                    # turn on/off
TB_LOGDIR_BASE = "tb_scratch_tune"           # tensorboard --logdir tb_scratch_tune
TB_WRITE_HPARAMS = True

EARLY_STOP_PATIENCE = 3                      # epochs without AUC gain before stopping (0 to disable)
MAX_TRIALS_PER_RUN  = 10                     # safety cap per session
SAVE_BEST_SESSION_SUBMISSION = True
SAVE_EVERY_TRIAL_SUBMISSION = True  # NEW: Save submission after every trial
SUBMISSION_CSV  = "submission_scratch.csv"
SUBMISSION_XLSX = "submission_scratch.xlsx"

# --- IO / dataloader runtime safety (Windows/Jupyter safe) ---
IS_WINDOWS = (os.name == "nt")
NUM_WORKERS = 0 if IS_WINDOWS else 2         # KEY: avoid multiprocessing on Windows
PERSISTENT_WORKERS = False
PIN_MEMORY = torch.cuda.is_available()
LOG_EVERY_N = 50                              # fallback batch logging if tqdm unavailable

# Progress bars: force console (no ipywidgets)
FORCE_CONSOLE_TQDM = True
if FORCE_CONSOLE_TQDM:
    os.environ["TQDM_NOTEBOOK"] = "0"
    try:
        from tqdm import tqdm  # console bar
    except Exception:
        tqdm = None
else:
    tqdm = None

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

# ------------------- Load data -------------------
assert os.path.exists(TRAIN_PATH) and os.path.exists(TEST_PATH) and os.path.exists(SUB_PATH), \
    "Place train.csv, test.csv, sample_submission.csv in the working directory."

TEXT_COLS = ['body','rule','subreddit','positive_example_1','positive_example_2','negative_example_1','negative_example_2']
train_df = pd.read_csv(TRAIN_PATH)
test_df  = pd.read_csv(TEST_PATH)

for df in [train_df, test_df]:
    for c in TEXT_COLS:
        if c in df.columns:
            df[c] = df[c].fillna("").astype(str).str.strip()

def build_input_template(row):
    return " [SEP] ".join([
        f"[COMMENT] {row['body']}",
        f"[RULE] {row['rule']}",
        f"[POS_EX_1] {row['positive_example_1']}",
        f"[POS_EX_2] {row['positive_example_2']}",
        f"[NEG_EX_1] {row['negative_example_1']}",
        f"[NEG_EX_2] {row['negative_example_2']}",
        f"[SUBREDDIT] r/{row['subreddit']}"
    ])

if "input_text" not in train_df.columns:
    train_df["input_text"] = train_df.apply(build_input_template, axis=1)
    test_df["input_text"]  = test_df.apply(build_input_template, axis=1)

# ------------------- Utils -------------------
def set_seed(seed:int=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def now_iso(): return datetime.now(timezone.utc).isoformat()

def combo_key(params:Dict[str,Any])->str:
    s = json.dumps({k:params[k] for k in sorted(params)}, sort_keys=True)
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def load_done_keys(path:str)->set:
    if not os.path.exists(path): return set()
    try:
        df = pd.read_csv(path)
        return set(df["key"].astype(str).tolist()) if "key" in df.columns else set()
    except Exception:
        return set()

def append_result_row(row:Dict[str,Any], path=RESULTS_CSV):
    df = pd.DataFrame([row], columns=list(row.keys()))
    if os.path.exists(path): df.to_csv(path, mode="a", header=False, index=False)
    else:                    df.to_csv(path, index=False)

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(TB_LOGDIR_BASE, exist_ok=True)

# ------------------- Tokenizer/Vocab -------------------
TOKEN_RE = re.compile(r"[A-Za-z0-9_']+")
def tokenize(s): return TOKEN_RE.findall((s or "").lower())

VOCAB_CACHE: Dict[int, Dict[str,int]] = {}
def build_vocab(df:pd.DataFrame, vocab_size:int=30000)->Dict[str,int]:
    if vocab_size in VOCAB_CACHE: return VOCAB_CACHE[vocab_size]
    from collections import Counter
    cnt = Counter()
    for col in ["body","rule"]:
        for txt in df[col].tolist():
            cnt.update(tokenize(txt))
    vocab = {"<pad>":0, "<unk>":1}
    for i,(tok,_) in enumerate(cnt.most_common(vocab_size-2), start=2):
        vocab[tok] = i
    VOCAB_CACHE[vocab_size] = vocab
    return vocab

def encode_text(s, vocab, max_len):
    ids = [vocab.get(t,1) for t in tokenize(s)][:max_len]
    if len(ids) < max_len: ids += [0]*(max_len-len(ids))
    return np.array(ids, dtype=np.int64)

class ScratchDataset(Dataset):
    def __init__(self, df, vocab, seq_len, with_labels=True):
        self.df=df.reset_index(drop=True); self.vocab=vocab; self.seq_len=seq_len; self.with_labels=with_labels
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.loc[i]
        half = self.seq_len//2
        x = np.concatenate([encode_text(r["body"], self.vocab, half),
                            encode_text(r["rule"], self.vocab, half)])
        if self.with_labels:
            y = int(r["rule_violation"])
            return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.float32)
        return torch.tensor(x, dtype=torch.long)

def make_dataloader(ds: Dataset, batch_size: int, shuffle: bool) -> DataLoader:
    kwargs = dict(batch_size=batch_size, shuffle=shuffle, num_workers=NUM_WORKERS)
    if NUM_WORKERS > 0:
        kwargs["prefetch_factor"] = 2
        kwargs["persistent_workers"] = PERSISTENT_WORKERS
    if torch.cuda.is_available():
        kwargs["pin_memory"] = PIN_MEMORY
    return DataLoader(ds, **kwargs)

# ------------------- Model -------------------
def parse_kernel_sizes(spec:str):
    ks = []
    for k in str(spec).split("-"):
        k = k.strip()
        if k.isdigit(): ks.append(int(k))
    return ks or [3,5]

def channel_schedule(start:int, blocks:int, growth:str):
    chs = [start]
    for _ in range(1, blocks):
        if growth == "x1.5": chs.append(int(round(chs[-1]*1.5)))
        elif growth == "x2": chs.append(chs[-1]*2)
        else:                chs.append(chs[-1])
    return chs

class TextCNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, conv_blocks, channels_start,
                 channel_growth, kernel_sizes_spec, use_batchnorm=True,
                 pooling="max", dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        ks = parse_kernel_sizes(kernel_sizes_spec)
        chs = channel_schedule(channels_start, conv_blocks, channel_growth)
        self.blocks = nn.ModuleList()
        in_ch = emb_dim
        for bi in range(conv_blocks):
            k = ks[min(bi, len(ks)-1)]
            out_ch = chs[bi]
            conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, padding=k//2)
            bn   = nn.BatchNorm1d(out_ch) if use_batchnorm else nn.Identity()
            self.blocks.append(nn.Sequential(conv, bn, nn.ReLU()))
            in_ch = out_ch
        self.pooling = pooling
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(in_ch, 1)

    def forward(self, x):
        e = self.emb(x).transpose(1,2)   # [B,E,L]
        h = e
        for blk in self.blocks: h = blk(h)
        if self.pooling == "avg": h = F.adaptive_avg_pool1d(h,1).squeeze(-1)
        else:                     h = F.adaptive_max_pool1d(h,1).squeeze(-1)
        h = self.drop(h)
        return self.fc(h).squeeze(-1)

# ------------------- Loss/Optim/Val -------------------
class BCEWithLS(nn.Module):
    def __init__(self, smoothing=0.0): super().__init__(); self.s=smoothing
    def forward(self, logits, targets):
        if self.s>0: targets = targets*(1-self.s)+0.5*self.s
        return F.binary_cross_entropy_with_logits(logits, targets)

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, smoothing=0.0): super().__init__(); self.g=gamma; self.s=smoothing
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        if self.s>0: targets = targets*(1-self.s)+0.5*self.s
        loss_pos = -targets * ((1-p)**self.g) * torch.log(torch.clamp(p, 1e-8, 1.0))
        loss_neg = -(1-targets) * (p**self.g) * torch.log(torch.clamp(1-p, 1.0-1e-8))
        return (loss_pos+loss_neg).mean()

def get_loss(name, smoothing):
    return FocalLoss(2.0, smoothing) if name=="focal" else BCEWithLS(smoothing)

def make_optimizer(model, name, lr, weight_decay):
    if name == "adamw": return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif name == "sgd": return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    else: raise ValueError(f"Unknown optimizer: {name}")

def _epoch_validate(model, dl, device="cpu"):
    model.eval()
    preds, ys = [], []
    with torch.no_grad():
        for xb,yb in dl:
            xb,yb = xb.to(device), yb.to(device)
            p = torch.sigmoid(model(xb)).detach().cpu().numpy()
            preds.append(p); ys.append(yb.detach().cpu().numpy())
    preds = np.concatenate(preds); ys = np.concatenate(ys)
    auc = roc_auc_score(ys, preds)
    acc = accuracy_score(ys.astype(int), (preds >= 0.5).astype(int))
    return auc, acc, preds, ys

# ------------------- Train one combo -------------------
def train_eval_once_with_best(params:dict, enable_tb:bool=False):
    set_seed(int(params["seed"]))
    vocab = build_vocab(train_df, int(params["vocab_size"]))
    seq_len = int(params["seq_len"])
    tr, va = train_test_split(train_df, test_size=0.2, random_state=int(params["seed"]),
                              stratify=train_df["rule_violation"])
    ds_tr = ScratchDataset(tr, vocab, seq_len, True)
    ds_va = ScratchDataset(va, vocab, seq_len, True)
    dl_tr = make_dataloader(ds_tr, int(params["batch_size"]), True)
    dl_va = make_dataloader(ds_va, int(params["batch_size"]), False)

    model = TextCNN(
        vocab_size=len(vocab),
        emb_dim=int(params["emb_dim"]),
        conv_blocks=int(params["conv_blocks"]),
        channels_start=int(params["channels_start"]),
        channel_growth=str(params["channel_growth"]),
        kernel_sizes_spec=str(params["kernel_sizes"]),
        use_batchnorm=bool(params["use_batchnorm"]),
        pooling=str(params["pooling"]),
        dropout=float(params["dropout"])
    ).to(DEVICE)

    opt = make_optimizer(model, str(params["optimizer"]), float(params["learning_rate"]), float(params["weight_decay"]))
    loss_fn = get_loss(str(params["loss_fn"]), float(params["label_smoothing"]))
    grad_clip = float(params["grad_clip"])
    epochs = int(params["epochs"])

    pos_weight = None
    if str(params["class_weighting"])=="balanced":
        pos_weight = torch.tensor([(len(tr)-tr["rule_violation"].sum())/(tr["rule_violation"].sum()+1e-6)], device=DEVICE)

    tb = None
    tb_run_dir = None
    if enable_tb:
        try:
            from torch.utils.tensorboard import SummaryWriter
            tag = (
                f"emb{params['emb_dim']}_cb{params['conv_blocks']}_ch{params['channels_start']}"
                f"_lr{params['learning_rate']}_bs{params['batch_size']}"
            )
            tb_run_dir = os.path.join(TB_LOGDIR_BASE, f"{tag}_{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S')}")
            tb = SummaryWriter(log_dir=tb_run_dir)
            tb.add_text("hparams/json", json.dumps(params, indent=2))
        except Exception as e:
            print("TensorBoard unavailable:", e)
            tb = None

    best_auc, best_acc, best_state = -1.0, 0.0, None
    global_step = 0
    no_improve = 0

    for ep in range(epochs):
        model.train()
        iterator = dl_tr if tqdm is None else tqdm(dl_tr, leave=False, desc=f"Epoch {ep+1}/{epochs}")
        for i, (xb, yb) in enumerate(iterator):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            logits = model(xb)
            loss = (F.binary_cross_entropy_with_logits(logits, yb, pos_weight=pos_weight)
                    if pos_weight is not None else loss_fn(logits, yb))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()
            if tqdm is None and (i % LOG_EVERY_N == 0):
                print(f"  batch {i:>4}/{len(dl_tr)}  loss={float(loss.item()):.4f}")
            if tb:
                tb.add_scalar("train/loss", float(loss.item()), global_step)
            global_step += 1

        auc, acc, _, _ = _epoch_validate(model, dl_va, device=DEVICE)
        improved = auc > best_auc + 1e-5
        if improved:
            best_auc, best_acc = auc, acc
            best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1

        print(f"[SCRATCH] Epoch {ep+1}/{epochs} AUC={auc:.5f} ACC={acc:.4f} "
              f"(best {best_auc:.5f}, patience {no_improve}/{EARLY_STOP_PATIENCE})")
        if tb:
            tb.add_scalar("val/auc", float(auc), ep)
            tb.add_scalar("val/accuracy", float(acc), ep)

        if EARLY_STOP_PATIENCE and no_improve >= EARLY_STOP_PATIENCE:
            print("Early stopping: no improvement.")
            break

    if tb:
        # Snapshot final best metrics + (optional) hparams summary
        tb.add_scalar("val/best_auc", float(best_auc))
        tb.add_scalar("val/best_acc", float(best_acc))
        if TB_WRITE_HPARAMS:
            try:
                # TensorBoard HParams (writes to a separate event file inside this run)
                from torch.utils.tensorboard.summary import hparams
                metric_dict = {"hparam/best_auc": float(best_auc), "hparam/best_acc": float(best_acc)}
                tb.file_writer.add_summary(hparams(params, metric_dict))
            except Exception:
                pass
        tb.close()

    return best_auc, best_acc, best_state, vocab

# ------------------- Predict test with a state -------------------
def predict_test_with_state(best_state, params, vocab, out_csv="submission_scratch.csv"):
    seq_len = int(params["seq_len"])
    class TestDS(Dataset):
        def __init__(self, df, vocab, seq_len):
            self.df=df.reset_index(drop=True); self.vocab=vocab; self.seq_len=seq_len
        def __len__(self): return len(self.df)
        def __getitem__(self, i):
            r = self.df.loc[i]
            half = self.seq_len//2
            x = np.concatenate([encode_text(r["body"], self.vocab, half),
                                encode_text(r["rule"], self.vocab, half)])
            return torch.tensor(x, dtype=torch.long)

    test_ds = TestDS(test_df, vocab, seq_len)
    test_dl = make_dataloader(test_ds, int(params["batch_size"]), False)

    model = TextCNN(
        vocab_size=len(vocab),
        emb_dim=int(params["emb_dim"]),
        conv_blocks=int(params["conv_blocks"]),
        channels_start=int(params["channels_start"]),
        channel_growth=str(params["channel_growth"]),
        kernel_sizes_spec=str(params["kernel_sizes"]),
        use_batchnorm=bool(params["use_batchnorm"]),
        pooling=str(params["pooling"]),
        dropout=float(params["dropout"])
    ).to(DEVICE)
    model.load_state_dict({k: v.to(DEVICE) for k,v in best_state.items()})
    model.eval()

    preds = []
    with torch.no_grad():
        for xb in test_dl:
            xb = xb.to(DEVICE)
            p = torch.sigmoid(model(xb)).detach().cpu().numpy()
            preds.append(p)
    preds = np.concatenate(preds).reshape(-1)

    sub = pd.read_csv(SUB_PATH).copy()
    if "row_id" not in sub.columns:
        if "row_id" in test_df.columns:
            sub = test_df[["row_id"]].copy()
        else:
            sub["row_id"] = np.arange(len(preds))
    sub["rule_violation"] = np.clip(preds, 0, 1)
    sub.to_csv(out_csv, index=False)
    print(f"✅ Wrote {out_csv} (rows={len(sub)})")
    return out_csv

# ------------------- Param space handling -------------------
CONSTANTS_DEFAULT = {
    "vocab_size": 30000,
    "use_batchnorm": True,
    "pooling": "max",
    "optimizer": "adamw",
    "grad_clip": 1.0,
    "scheduler": "none",        # catalog only
    "class_weighting": "none",
    "seed": 42,
}
REQ = ['seq_len','emb_dim','conv_blocks','channels_start','channel_growth','kernel_sizes',
       'dropout','weight_decay','label_smoothing','learning_rate','batch_size','epochs','loss_fn']
INTS   = ["seq_len","emb_dim","conv_blocks","channels_start","batch_size","epochs","seed"]
FLOATS = ["dropout","weight_decay","label_smoothing","learning_rate","grad_clip"]
STRS   = ["channel_growth","kernel_sizes","pooling","optimizer","scheduler","loss_fn","class_weighting"]
BOOLS  = ["use_batchnorm"]

def coerce_one(p:Dict[str,Any])->Dict[str,Any]:
    x = {**CONSTANTS_DEFAULT, **p}
    missing = [k for k in REQ if k not in x]
    if missing: raise KeyError(f"Missing required param(s): {missing}")
    for k in INTS:   x[k] = int(x[k])
    for k in FLOATS: x[k] = float(x[k])
    for k in STRS:   x[k] = str(x[k])
    for k in BOOLS:
        v = x[k]; x[k] = (v.strip().lower() in ("true","1","yes","y")) if isinstance(v,str) else bool(v)
    return x

def expand_grid(space:Dict[str,List[Any]], shuffle=True, seed=42)->List[Dict[str,Any]]:
    from itertools import product
    keys = list(space.keys())
    vals = [space[k] if isinstance(space[k], (list, tuple)) else [space[k]] for k in keys]
    combos = []
    for tup in product(*vals):
        combos.append({k:v for k,v in zip(keys, tup)})
    if shuffle:
        rnd = random.Random(seed); rnd.shuffle(combos)
    return combos

# ------------------- Tuner (grid/random + resume) -------------------
def run_param_space(space:Dict[str,List[Any]],
                    constants:Dict[str,Any]=None,
                    mode:str="grid",        # "grid" or "random"
                    n_samples:int=None,     # only for mode="random"
                    max_trials:int=10,
                    enable_tb:bool=False,
                    save_best_submission:bool=True):
    constants = constants or {}
    grid = expand_grid(space, shuffle=True, seed=int(constants.get("seed", 42)))
    if mode == "random" and n_samples is not None:
        grid = grid[:n_samples]  # shuffled already

    done = load_done_keys(RESULTS_CSV)
    print(f"Total combos: {len(grid)} | Completed in CSV: {len(done)}")

    best_auc = -1.0
    best_payload = None
    ran = 0
    t0 = time.time()

    for idx, raw in enumerate(grid):
        params = coerce_one({**raw, **constants})
        key = combo_key(params)
        if key in done:
            continue

        print(f"\n=== Trial {ran+1}/{max_trials} | idx={idx} ===")
        print({k: params[k] for k in REQ})

        t1 = time.time()
        try:
            auc, acc, state, vocab = train_eval_once_with_best(params, enable_tb)
            status = "ok"
            if SAVE_CHECKPOINTS and state is not None:
                torch.save({"state_dict": state, "params": params},
                           os.path.join(CHECKPOINT_DIR, f"{key}.pt"))
        except Exception as e:
            auc, acc = float("nan"), float("nan")
            state, vocab = None, None
            status = f"error: {e}"
            print("❌", e)
        dur = time.time() - t1

        row_out = {
            "timestamp": now_iso(),
            "key": key,
            "mode": "scratch",
            "device": DEVICE,
            "python": platform.python_version(),
            "grid_idx": idx,
            "val_auc": auc,
            "val_acc": acc,
            "runtime_sec": round(dur,2),
            "status": status,
            **{f"hp/{k}": params[k] for k in sorted(params)}
        }
        append_result_row(row_out, RESULTS_CSV)
        ran += 1

        if status == "ok" and auc > best_auc:
            best_auc = auc
            best_payload = (state, params, vocab)
            
            # Save submission after every successful trial (if enabled)
            if SAVE_EVERY_TRIAL_SUBMISSION:
                trial_submission_csv = f"submission_trial_{ran:03d}_auc_{auc:.4f}.csv"
                trial_submission_xlsx = f"submission_trial_{ran:03d}_auc_{auc:.4f}.xlsx"
                
                predict_test_with_state(state, params, vocab, out_csv=trial_submission_csv)
                try:
                    sub_df = pd.read_csv(trial_submission_csv)
                    with pd.ExcelWriter(trial_submission_xlsx, engine="xlsxwriter") as w:
                        sub_df.to_excel(w, sheet_name="submission", index=False)
                    print(f"✅ Wrote trial submission: {trial_submission_csv} and {trial_submission_xlsx}")
                except Exception as e:
                    print(f"Note: could not write XLSX for trial {ran}:", e)

        if ran >= max_trials:
            break

    print(f"\nSession done. Ran {ran} trial(s) in {round(time.time()-t0,2)}s.")
    if best_payload and save_best_submission:
        state, params, vocab = best_payload
        predict_test_with_state(state, params, vocab, out_csv=SUBMISSION_CSV)
        try:
            sub_df = pd.read_csv(SUBMISSION_CSV)
            with pd.ExcelWriter(SUBMISSION_XLSX, engine="xlsxwriter") as w:
                sub_df.to_excel(w, sheet_name="submission", index=False)
            print(f"✅ Wrote {SUBMISSION_XLSX}")
        except Exception as e:
            print("Note: could not write XLSX submission:", e)
    else:
        print("No submission written this session.")

# ============================================================
# DEFINE YOUR PARAM SPACE HERE (laptop-safe; resume lets you add more)
# 
# HOW TO ADD MORE PARAMETERS:
# 1. Add new parameter to PARAM_SPACE with a list of values to try
# 2. Update your model/training code to use the new parameter
# 3. Add parameter to REQ list if it's required for model creation
# 4. The system will automatically generate all combinations
#
# EXAMPLES:
# - Add new optimizers: optimizer=["adam", "adamw", "sgd", "rmsprop"]
# - Add new architectures: model_type=["cnn", "transformer", "lstm"]
# - Add new data augmentation: augmentation=["none", "backtranslation", "paraphrase"]
# ============================================================
PARAM_SPACE = dict(
    # Capacity/structure
    seq_len=[200, 224, 256, 288, 320],  # Expanded sequence lengths
    emb_dim=[96, 128, 160, 192, 224],   # More embedding dimensions
    conv_blocks=[1, 2, 3],              # More convolution blocks
    channels_start=[96, 128, 160, 192], # More starting channels
    channel_growth=["x1.2", "x1.5", "x2.0"],  # Different growth rates
    kernel_sizes=["3-5-7", "3-5-7-9", "5-7-9"],  # More kernel size combinations

    # Optimization/regularization
    optimizer=["adam", "adamw", "sgd"],  # More optimizers
    learning_rate=[5e-4, 8e-4, 1e-3, 1.2e-3, 1.5e-3],  # More learning rates
    batch_size=[32, 64, 128, 256],       # More batch sizes
    epochs=[6, 8, 10, 12],               # More epoch options
    dropout=[0.1, 0.15, 0.2, 0.25, 0.3], # More dropout rates
    weight_decay=[0, 1e-5, 1e-4, 2e-4, 5e-4],  # More weight decay options
    label_smoothing=[0.0, 0.01, 0.03, 0.05],   # More label smoothing options
    loss_fn=["bce_logits", "focal_loss"],       # More loss functions
    
    # NEW: Additional hyperparameters
    warmup_epochs=[0, 1, 2],             # Learning rate warmup
    scheduler_type=["none", "cosine", "step", "plateau"],  # Learning rate schedulers
    grad_clip=[0.5, 1.0, 1.5, 2.0],     # Gradient clipping values
    activation=["relu", "gelu", "swish"], # Activation functions
    pooling_type=["max", "avg", "attention"],  # Pooling methods
    use_batchnorm=[True, False],         # Batch normalization toggle
    use_residual=[True, False],          # Residual connections
    attention_heads=[1, 2, 4],           # Multi-head attention (if using attention pooling)
)

# Constants applied to every combo (change here if needed)
CONSTANTS = dict(
    vocab_size=30000,
    class_weighting="none",
    seed=42,
    # Note: use_batchnorm, pooling, grad_clip, scheduler are now in PARAM_SPACE
)

# ============================================================
# GO: run grid (or random sample) with resume
# Start TensorBoard in a terminal:  tensorboard --logdir tb_scratch_tune
# ============================================================
if __name__ == "__main__":
    run_param_space(
        PARAM_SPACE,
        constants=CONSTANTS,
        mode="grid",             # or "random"
        n_samples=None,          # only used for mode="random"
        max_trials=200, #MAX_TRIALS_PER_RUN
        enable_tb=ENABLE_TENSORBOARD,
        save_best_submission=SAVE_BEST_SESSION_SUBMISSION
    )


Device: cpu


In [None]:
# ============================================================
# Jigsaw — SCRATCH Hyperparameter Tuning (Grid/Random, Resumable)
# * Preprocessing: nulls, normalization, outlier clipping
# * TextCNN (BN + Dropout) from scratch; activation/residual options
# * tqdm console bars (no ipywidgets), Windows-safe DataLoader
# * Early stopping + TensorBoard (train loss/acc, val AUC/ACC, hparams)
# * Saves per-trial artifacts + checkpoints; robust to interrupts
# * Resume/skip via trial_results_scratch.csv; ALL remaining when max_trials=0
# * Best trial writes submission_scratch.(csv|xlsx)
# ============================================================

import os, re, json, time, random, hashlib, platform, math, html
from datetime import datetime, timezone
from itertools import product
from typing import Dict, Any, List

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score

# ------------------- Paths & switches -------------------
TRAIN_PATH = "train.csv"
TEST_PATH  = "test.csv"
SUB_PATH   = "sample_submission.csv"

RESULTS_CSV = "trial_results_scratch.csv"    # resumable log (append)
CHECKPOINT_DIR = "checkpoints_scratch"       # per-trial .pt files
SAVE_CHECKPOINTS = True

# TensorBoard
ENABLE_TENSORBOARD = True                    # turn on/off
TB_LOGDIR_BASE = "tb_scratch_tune"           # tensorboard --logdir tb_scratch_tune
TB_WRITE_HPARAMS = True

EARLY_STOP_PATIENCE = 3                      # epochs without AUC gain before stopping (0 to disable)
MAX_TRIALS_PER_RUN  = 0                      # 0/None = run ALL remaining combos
SAVE_BEST_SESSION_SUBMISSION = True
SUBMISSION_CSV  = "submission_scratch.csv"
SUBMISSION_XLSX = "submission_scratch.xlsx"

# Save extra artifacts per trial
SAVE_TRIAL_ARTIFACTS = True
ARTIFACT_DIR = "artifacts_scratch"

# Preprocessing config
NORMALIZE_TEXT = True
OUTLIER_CHAR_MAX = 4000   # clip very long texts (per column)
REPORT_TOP_OUTLIERS = 3   # just to print a tiny summary

# --- IO / dataloader runtime safety (Windows/Jupyter safe) ---
IS_WINDOWS = (os.name == "nt")
NUM_WORKERS = 0 if IS_WINDOWS else 2         # avoid multiprocessing on Windows
PERSISTENT_WORKERS = False
PIN_MEMORY = torch.cuda.is_available()
LOG_EVERY_N = 50                              # fallback batch logging if tqdm unavailable

# Progress bars: force console (no ipywidgets)
FORCE_CONSOLE_TQDM = True
if FORCE_CONSOLE_TQDM:
    os.environ["TQDM_NOTEBOOK"] = "0"
    try:
        from tqdm import tqdm  # console bar
    except Exception:
        tqdm = None
else:
    tqdm = None

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(TB_LOGDIR_BASE, exist_ok=True)
os.makedirs(ARTIFACT_DIR, exist_ok=True)

# ------------------- Load data -------------------
assert os.path.exists(TRAIN_PATH) and os.path.exists(TEST_PATH) and os.path.exists(SUB_PATH), \
    "Place train.csv, test.csv, sample_submission.csv in the working directory."

TEXT_COLS = ['body','rule','subreddit','positive_example_1','positive_example_2','negative_example_1','negative_example_2']
train_df = pd.read_csv(TRAIN_PATH)
test_df  = pd.read_csv(TEST_PATH)

# --- Preprocessing: nulls -> empty, normalization, outlier clipping ---
URL_PATTERN = re.compile(r"(https?://\S+|www\.\S+)", re.IGNORECASE)
TAG_PATTERN = re.compile(r"<[^>]+>")
WS_PATTERN  = re.compile(r"\s+")

def clean_text(s: str) -> str:
    if not isinstance(s, str): return ""
    s = html.unescape(s)
    s = re.sub(URL_PATTERN, " URL ", s)
    s = re.sub(TAG_PATTERN, " ", s)
    s = s.replace("&amp;", "&").replace("&lt;", "<").replace("&gt;", ">")
    s = s.lower()
    s = re.sub(r"[\t\r\n]+", " ", s)
    s = re.sub(WS_PATTERN, " ", s).strip()
    return s

def normalize_and_clip_df(df: pd.DataFrame, label="train"):
    counts = {}
    for c in TEXT_COLS:
        if c in df.columns:
            # Nulls -> ""
            df[c] = df[c].fillna("").astype(str)
            if NORMALIZE_TEXT:
                df[c] = df[c].map(clean_text)
            # Outlier clipping by char length
            lens = df[c].str.len()
            over = (lens > OUTLIER_CHAR_MAX)
            n_over = int(over.sum())
            if n_over:
                df.loc[over, c] = df.loc[over, c].str.slice(0, OUTLIER_CHAR_MAX)
            counts[c] = n_over
    # brief report
    if any(counts.values()):
        print(f"[Preprocess] {label}: clipped long texts per column (>{OUTLIER_CHAR_MAX} chars):", counts)
    return df

train_df = normalize_and_clip_df(train_df, "train")
test_df  = normalize_and_clip_df(test_df,  "test")

def build_input_template(row):
    return " [SEP] ".join([
        f"[COMMENT] {row['body']}",
        f"[RULE] {row['rule']}",
        f"[POS_EX_1] {row['positive_example_1']}",
        f"[POS_EX_2] {row['positive_example_2']}",
        f"[NEG_EX_1] {row['negative_example_1']}",
        f"[NEG_EX_2] {row['negative_example_2']}",
        f"[SUBREDDIT] r/{row['subreddit']}"
    ])

if "input_text" not in train_df.columns:
    train_df["input_text"] = train_df.apply(build_input_template, axis=1)
    test_df["input_text"]  = test_df.apply(build_input_template, axis=1)

# ------------------- Utils -------------------
def set_seed(seed:int=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

def now_iso(): return datetime.now(timezone.utc).isoformat()

def combo_key(params:Dict[str,Any])->str:
    s = json.dumps({k:params[k] for k in sorted(params)}, sort_keys=True)
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def load_done_keys(path: str) -> set:
    if not os.path.exists(path):
        return set()
    try:
        df = pd.read_csv(path)
        if "key" not in df.columns:
            return set()
        # Only skip trials that completed successfully
        if "status" in df.columns:
            df = df[df["status"].astype(str).str.startswith("ok")]
        return set(df["key"].astype(str).tolist())
    except Exception:
        return set()

def append_result_row(row:Dict[str,Any], path=RESULTS_CSV):
    df = pd.DataFrame([row], columns=list(row.keys()))
    if os.path.exists(path): df.to_csv(path, mode="a", header=False, index=False)
    else:                    df.to_csv(path, index=False)

# Heartbeat + safe I/O helpers
def write_json(obj, path):
    tmp = path + ".tmp"
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)  # atomic

def save_latest_checkpoint(key, model_state, params):
    if not SAVE_CHECKPOINTS:
        return
    torch.save({"state_dict": model_state, "params": params},
               os.path.join(CHECKPOINT_DIR, f"{key}_latest.pt"))

# ------------------- Tokenizer/Vocab -------------------
TOKEN_RE = re.compile(r"[A-Za-z0-9_']+")
def tokenize(s): return TOKEN_RE.findall((s or "").lower())

VOCAB_CACHE: Dict[int, Dict[str,int]] = {}
def build_vocab(df:pd.DataFrame, vocab_size:int=30000)->Dict[str,int]:
    if vocab_size in VOCAB_CACHE: return VOCAB_CACHE[vocab_size]
    from collections import Counter
    cnt = Counter()
    for col in ["body","rule"]:
        for txt in df[col].tolist():
            cnt.update(tokenize(txt))
    vocab = {"<pad>":0, "<unk>":1}
    for i,(tok,_) in enumerate(cnt.most_common(vocab_size-2), start=2):
        vocab[tok] = i
    VOCAB_CACHE[vocab_size] = vocab
    return vocab

def encode_text(s, vocab, max_len):
    ids = [vocab.get(t,1) for t in tokenize(s)][:max_len]
    if len(ids) < max_len: ids += [0]*(max_len-len(ids))
    return np.array(ids, dtype=np.int64)

class ScratchDataset(Dataset):
    def __init__(self, df, vocab, seq_len, with_labels=True):
        self.df=df.reset_index(drop=True); self.vocab=vocab; self.seq_len=seq_len; self.with_labels=with_labels
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.loc[i]
        half = self.seq_len//2
        x = np.concatenate([encode_text(r["body"], self.vocab, half),
                            encode_text(r["rule"], self.vocab, half)])
        if self.with_labels:
            y = int(r["rule_violation"])
            return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.float32)
        return torch.tensor(x, dtype=torch.long)

def make_dataloader(ds: Dataset, batch_size: int, shuffle: bool) -> DataLoader:
    kwargs = dict(batch_size=batch_size, shuffle=shuffle, num_workers=NUM_WORKERS)
    if NUM_WORKERS > 0:
        kwargs["prefetch_factor"] = 2
        kwargs["persistent_workers"] = PERSISTENT_WORKERS
    if torch.cuda.is_available():
        kwargs["pin_memory"] = PIN_MEMORY
    return DataLoader(ds, **kwargs)

# ------------------- Model -------------------
def parse_kernel_sizes(spec:str):
    ks = []
    for k in str(spec).split("-"):
        k = k.strip()
        if k.isdigit(): ks.append(int(k))
    return ks or [3,5]

def channel_schedule(start:int, blocks:int, growth:str):
    chs = [start]
    for _ in range(1, blocks):
        if growth == "x1.5": chs.append(int(round(chs[-1]*1.5)))
        elif growth == "x2": chs.append(chs[-1]*2)
        else:                chs.append(chs[-1])
    return chs

class TextCNN(nn.Module):
    def __init__(self, vocab_size, emb_dim, conv_blocks, channels_start,
                 channel_growth, kernel_sizes_spec, use_batchnorm=True,
                 pooling="max", dropout=0.2, activation="relu", residual=False):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        Act = nn.ReLU if str(activation).lower()=="relu" else nn.GELU
        self.act = Act()
        self.residual = bool(residual)
        ks = parse_kernel_sizes(kernel_sizes_spec)
        chs = channel_schedule(channels_start, conv_blocks, channel_growth)
        self.blocks = nn.ModuleList()
        in_ch = emb_dim
        for bi in range(conv_blocks):
            k = ks[min(bi, len(ks)-1)]
            out_ch = chs[bi]
            conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, padding=k//2)
            bn   = nn.BatchNorm1d(out_ch) if use_batchnorm else nn.Identity()
            self.blocks.append(nn.ModuleDict({"conv": conv, "bn": bn}))
            in_ch = out_ch
        self.pooling = pooling
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(in_ch, 1)

    def forward(self, x):
        h = self.emb(x).transpose(1,2)   # [B,E,L]
        for blk in self.blocks:
            z = blk["conv"](h)
            z = blk["bn"](z)
            z = self.act(z)
            if self.residual and z.shape == h.shape:
                h = z + h
            else:
                h = z
        if self.pooling == "avg":
            h = F.adaptive_avg_pool1d(h,1).squeeze(-1)
        else:
            h = F.adaptive_max_pool1d(h,1).squeeze(-1)
        h = self.drop(h)
        return self.fc(h).squeeze(-1)

# ------------------- Loss/Optim/Val -------------------
class BCEWithLS(nn.Module):
    def __init__(self, smoothing=0.0): super().__init__(); self.s=smoothing
    def forward(self, logits, targets):
        if self.s>0: targets = targets*(1-self.s)+0.5*self.s
        return F.binary_cross_entropy_with_logits(logits, targets)

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, smoothing=0.0): super().__init__(); self.g=gamma; self.s=smoothing
    def forward(self, logits, targets):
        p = torch.sigmoid(logits)
        if self.s>0: targets = targets*(1-self.s)+0.5*self.s
        loss_pos = -targets * ((1-p)**self.g) * torch.log(torch.clamp(p, 1e-8, 1.0))
        loss_neg = -(1-targets) * (p**self.g) * torch.log(torch.clamp(1-p, 1.0-1e-8))
        return (loss_pos+loss_neg).mean()

def get_loss(name, smoothing):
    return FocalLoss(2.0, smoothing) if name=="focal" else BCEWithLS(smoothing)

def make_optimizer(model, name, lr, weight_decay):
    if name == "adamw": return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif name == "sgd": return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    else: raise ValueError(f"Unknown optimizer: {name}")

def _epoch_validate(model, dl, device="cpu"):
    model.eval()
    preds, ys = [], []
    with torch.no_grad():
        for xb,yb in dl:
            xb,yb = xb.to(device), yb.to(device)
            p = torch.sigmoid(model(xb)).detach().cpu().numpy()
            preds.append(p); ys.append(yb.detach().cpu().numpy())
    preds = np.concatenate(preds); ys = np.concatenate(ys)
    auc = roc_auc_score(ys, preds)
    acc = accuracy_score(ys.astype(int), (preds >= 0.5).astype(int))
    return auc, acc, preds, ys

# ------------------- Train one combo -------------------
def train_eval_once_with_best(params:dict, enable_tb:bool=False, key:str=None):
    key = key or combo_key(params)
    set_seed(int(params["seed"]))
    vocab = build_vocab(train_df, int(params["vocab_size"]))
    seq_len = int(params["seq_len"])
    tr, va = train_test_split(train_df, test_size=0.2, random_state=int(params["seed"]),
                              stratify=train_df["rule_violation"])
    ds_tr = ScratchDataset(tr, vocab, seq_len, True)
    ds_va = ScratchDataset(va, vocab, seq_len, True)
    dl_tr = make_dataloader(ds_tr, int(params["batch_size"]), True)
    dl_va = make_dataloader(ds_va, int(params["batch_size"]), False)

    model = TextCNN(
        vocab_size=len(vocab),
        emb_dim=int(params["emb_dim"]),
        conv_blocks=int(params["conv_blocks"]),
        channels_start=int(params["channels_start"]),
        channel_growth=str(params["channel_growth"]),
        kernel_sizes_spec=str(params["kernel_sizes"]),
        use_batchnorm=bool(params["use_batchnorm"]),
        pooling=str(params["pooling"]),
        dropout=float(params["dropout"]),
        activation=str(params["activation"]),
        residual=bool(params["residual"]),
    ).to(DEVICE)

    opt = make_optimizer(model, str(params["optimizer"]), float(params["learning_rate"]), float(params["weight_decay"]))
    loss_fn = get_loss(str(params["loss_fn"]), float(params["label_smoothing"]))
    grad_clip = float(params["grad_clip"])
    epochs = int(params["epochs"])

    pos_weight = None
    if str(params["class_weighting"])=="balanced":
        pos_weight = torch.tensor([(len(tr)-tr["rule_violation"].sum())/(tr["rule_violation"].sum()+1e-6)], device=DEVICE)

    tb = None
    if enable_tb:
        try:
            from torch.utils.tensorboard import SummaryWriter
            tag = (
                f"emb{params['emb_dim']}_cb{params['conv_blocks']}_ch{params['channels_start']}"
                f"_lr{params['learning_rate']}_bs{params['batch_size']}"
            )
            tb_run_dir = os.path.join(TB_LOGDIR_BASE, f"{tag}_{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S')}")
            tb = SummaryWriter(log_dir=tb_run_dir)
            tb.add_text("hparams/json", json.dumps(params, indent=2))
        except Exception as e:
            print("TensorBoard unavailable:", e)
            tb = None

    best_auc, best_acc, best_state = -1.0, 0.0, None
    global_step = 0
    no_improve = 0
    epoch_hist = []
    heartbeat_path = os.path.join(ARTIFACT_DIR, f"{key}_heartbeat.json")

    try:
        for ep in range(epochs):
            model.train()
            iterator = dl_tr if tqdm is None else tqdm(dl_tr, leave=False, desc=f"Epoch {ep+1}/{epochs}")
            running_loss, nb = 0.0, 0
            train_correct, train_total = 0, 0

            for i, (xb, yb) in enumerate(iterator):
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad()
                logits = model(xb)
                loss = (F.binary_cross_entropy_with_logits(logits, yb, pos_weight=pos_weight)
                        if pos_weight is not None else loss_fn(logits, yb))
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                opt.step()

                # track train loss & accuracy
                running_loss += float(loss.item()); nb += 1
                with torch.no_grad():
                    probs = torch.sigmoid(logits)
                    preds = (probs >= 0.5).float()
                    train_correct += int((preds == yb).sum().item())
                    train_total   += int(yb.numel())

                if tqdm is None and (i % LOG_EVERY_N == 0):
                    print(f"  batch {i:>4}/{len(dl_tr)}  loss={float(loss.item()):.4f}")
                if tb:
                    tb.add_scalar("train/loss", float(loss.item()), global_step)
                global_step += 1

            # ---- validation at epoch end ----
            auc, acc, _, _ = _epoch_validate(model, dl_va, device=DEVICE)
            train_acc = (train_correct / max(1, train_total))
            epoch_hist.append({"epoch": ep+1,
                               "train_loss": (running_loss/nb if nb else None),
                               "train_acc": float(train_acc),
                               "val_auc": float(auc), "val_acc": float(acc)})

            improved = auc > best_auc + 1e-5
            if improved:
                best_auc, best_acc = auc, acc
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
                no_improve = 0
            else:
                no_improve += 1

            print(f"[SCRATCH] Epoch {ep+1}/{epochs} | "
                  f"train_loss={running_loss/nb:.4f} train_acc={train_acc:.4f} "
                  f"val_auc={auc:.5f} val_acc={acc:.4f} "
                  f"(best {best_auc:.5f}, patience {no_improve}/{EARLY_STOP_PATIENCE})")
            if tb:
                tb.add_scalar("train/acc", float(train_acc), ep)
                tb.add_scalar("val/auc", float(auc), ep)
                tb.add_scalar("val/accuracy", float(acc), ep)

            # ---- heartbeat + latest checkpoint EVERY epoch ----
            if SAVE_TRIAL_ARTIFACTS:
                write_json({
                    "timestamp": now_iso(),
                    "key": key,
                    "params": params,
                    "best_auc": float(best_auc),
                    "best_acc": float(best_acc),
                    "last_epoch": ep+1,
                    "history_tail": epoch_hist[-5:],
                }, heartbeat_path)
            save_latest_checkpoint(key, {k: v.detach().cpu() for k,v in model.state_dict().items()}, params)

            if EARLY_STOP_PATIENCE and no_improve >= EARLY_STOP_PATIENCE:
                print("Early stopping: no improvement.")
                break

        # Final validation predictions/labels for artifacts
        final_auc, final_acc, final_probs, final_true = _epoch_validate(model, dl_va, device=DEVICE)

        if tb:
            tb.add_scalar("val/best_auc", float(best_auc))
            tb.add_scalar("val/best_acc", float(best_acc))
            if TB_WRITE_HPARAMS:
                try:
                    from torch.utils.tensorboard.summary import hparams
                    metric_dict = {"hparam/best_auc": float(best_auc), "hparam/best_acc": float(best_acc)}
                    tb.file_writer.add_summary(hparams(params, metric_dict))
                except Exception:
                    pass
            tb.close()

        return (best_auc, best_acc, best_state, vocab,
                epoch_hist, final_probs.tolist(), final_true.tolist())

    except KeyboardInterrupt:
        # save a heartbeat right when interrupted
        if SAVE_TRIAL_ARTIFACTS:
            write_json({
                "timestamp": now_iso(),
                "key": key,
                "params": params,
                "best_auc": float(best_auc),
                "best_acc": float(best_acc),
                "interrupted": True,
                "history": epoch_hist
            }, heartbeat_path)
        # also persist the latest weights
        save_latest_checkpoint(key, {k: v.detach().cpu() for k,v in model.state_dict().items()}, params)
        raise  # let caller log CSV row with "interrupted"

# ------------------- Predict test with a state -------------------
def predict_test_with_state(best_state, params, vocab, out_csv="submission_scratch.csv"):
    seq_len = int(params["seq_len"])
    class TestDS(Dataset):
        def __init__(self, df, vocab, seq_len):
            self.df=df.reset_index(drop=True); self.vocab=vocab; self.seq_len=seq_len
        def __len__(self): return len(self.df)
        def __getitem__(self, i):
            r = self.df.loc[i]
            half = self.seq_len//2
            x = np.concatenate([encode_text(r["body"], self.vocab, half),
                                encode_text(r["rule"], self.vocab, half)])
            return torch.tensor(x, dtype=torch.long)

    test_ds = TestDS(test_df, vocab, seq_len)
    test_dl = make_dataloader(test_ds, int(params["batch_size"]), False)

    model = TextCNN(
        vocab_size=len(vocab),
        emb_dim=int(params["emb_dim"]),
        conv_blocks=int(params["conv_blocks"]),
        channels_start=int(params["channels_start"]),
        channel_growth=str(params["channel_growth"]),
        kernel_sizes_spec=str(params["kernel_sizes"]),
        use_batchnorm=bool(params["use_batchnorm"]),
        pooling=str(params["pooling"]),
        dropout=float(params["dropout"]),
        activation=str(params["activation"]),
        residual=bool(params["residual"]),
    ).to(DEVICE)
    model.load_state_dict({k: v.to(DEVICE) for k,v in best_state.items()})
    model.eval()

    preds = []
    with torch.no_grad():
        for xb in test_dl:
            xb = xb.to(DEVICE)
            p = torch.sigmoid(model(xb)).detach().cpu().numpy()
            preds.append(p)
    preds = np.concatenate(preds).reshape(-1)

    sub = pd.read_csv(SUB_PATH).copy()
    if "row_id" not in sub.columns:
        if "row_id" in test_df.columns:
            sub = test_df[["row_id"]].copy()
        else:
            sub["row_id"] = np.arange(len(preds))
    sub["rule_violation"] = np.clip(preds, 0, 1)
    sub.to_csv(out_csv, index=False)
    print(f"✅ Wrote {out_csv} (rows={len(sub)})")
    return out_csv

# ------------------- Subgroup evaluation (optional interpretability) -------------------
def evaluate_subgroups(params, best_state, vocab):
    """Optional: quick subgroup metrics by rule/subreddit on validation split."""
    set_seed(int(params["seed"]))
    seq_len = int(params["seq_len"])
    tr, va = train_test_split(train_df, test_size=0.2, random_state=int(params["seed"]),
                              stratify=train_df["rule_violation"])
    ds_va = ScratchDataset(va, vocab, seq_len, True)
    dl_va = make_dataloader(ds_va, int(params["batch_size"]), False)

    model = TextCNN(
        vocab_size=len(vocab),
        emb_dim=int(params["emb_dim"]),
        conv_blocks=int(params["conv_blocks"]),
        channels_start=int(params["channels_start"]),
        channel_growth=str(params["channel_growth"]),
        kernel_sizes_spec=str(params["kernel_sizes"]),
        use_batchnorm=bool(params["use_batchnorm"]),
        pooling=str(params["pooling"]),
        dropout=float(params["dropout"]),
        activation=str(params["activation"]),
        residual=bool(params["residual"]),
    ).to(DEVICE)
    model.load_state_dict({k: v.to(DEVICE) for k,v in best_state.items()})
    model.eval()

    preds = []
    with torch.no_grad():
        for xb,_ in dl_va:
            xb = xb.to(DEVICE)
            p = torch.sigmoid(model(xb)).detach().cpu().numpy()
            preds.append(p)
    preds = np.concatenate(preds).reshape(-1)

    va_local = va.reset_index(drop=True).copy()
    va_local["pred"] = preds
    va_local["pred_bin"] = (va_local["pred"] >= 0.5).astype(int)

    def _agg(group):
        y = group["rule_violation"].values
        p = group["pred"].values
        return pd.Series({
            "AUC": roc_auc_score(y, p) if len(np.unique(y))>1 else np.nan,
            "ACC": accuracy_score(y.astype(int), (p>=0.5).astype(int))
        })
    by_rule = va_local.groupby("rule", dropna=False).apply(_agg)
    by_sr   = va_local.groupby("subreddit", dropna=False).apply(_agg).sort_values("AUC", ascending=False).head(10)
    print("\n=== Interpretation ===")
    print("Top 10 subreddits by AUC:\n", by_sr)
    print("\nBy rule:\n", by_rule)
    return by_rule, by_sr

# ------------------- Param space handling -------------------
CONSTANTS_DEFAULT = {
    "vocab_size": 30000,
    "use_batchnorm": True,
    "pooling": "max",
    "optimizer": "adamw",
    "grad_clip": 1.0,
    "scheduler": "none",        # catalog only
    "class_weighting": "none",
    "seed": 42,
    "activation": "relu",
    "residual": False,
}
REQ = ['seq_len','emb_dim','conv_blocks','channels_start','channel_growth','kernel_sizes',
       'dropout','weight_decay','label_smoothing','learning_rate','batch_size','epochs','loss_fn']
INTS   = ["seq_len","emb_dim","conv_blocks","channels_start","batch_size","epochs","seed"]
FLOATS = ["dropout","weight_decay","label_smoothing","learning_rate","grad_clip"]
STRS   = ["channel_growth","kernel_sizes","pooling","optimizer","scheduler","loss_fn","class_weighting","activation"]
BOOLS  = ["use_batchnorm","residual"]

def coerce_one(p:Dict[str,Any])->Dict[str,Any]:
    x = {**CONSTANTS_DEFAULT, **p}
    missing = [k for k in REQ if k not in x]
    if missing: raise KeyError(f"Missing required param(s): {missing}")
    for k in INTS:   x[k] = int(x[k])
    for k in FLOATS: x[k] = float(x[k])
    for k in STRS:   x[k] = str(x[k])
    for k in BOOLS:
        v = x[k]; x[k] = (v.strip().lower() in ("true","1","yes","y")) if isinstance(v,str) else bool(v)
    return x

def expand_grid(space:Dict[str,List[Any]], shuffle=True, seed=42)->List[Dict[str,Any]]:
    keys = list(space.keys())
    vals = [space[k] if isinstance(space[k], (list, tuple)) else [space[k]] for k in keys]
    combos = []
    for tup in product(*vals):
        combos.append({k:v for k,v in zip(keys, tup)})
    if shuffle:
        rnd = random.Random(seed); rnd.shuffle(combos)
    return combos

# ------------------- Tuner (grid/random + resume + interrupt-safe) -------------------
def run_param_space(space:Dict[str,List[Any]],
                    constants:Dict[str,Any]=None,
                    mode:str="grid",        # "grid" or "random"
                    n_samples:int=None,     # only for mode="random"
                    max_trials:int=10,
                    enable_tb:bool=False,
                    save_best_submission:bool=True):
    constants = constants or {}
    grid = expand_grid(space, shuffle=True, seed=int(constants.get("seed", 42)))
    if mode == "random" and n_samples is not None:
        grid = grid[:n_samples]  # shuffled already

    done = load_done_keys(RESULTS_CSV)
    print(f"Total combos: {len(grid)} | Completed in CSV: {len(done)}")

    # Allow unlimited runs if max_trials == 0/None
    if max_trials in (None, 0):
        max_trials = len(grid)

    best_auc = -1.0
    best_payload = None
    ran = 0
    t0 = time.time()

    for idx, raw in enumerate(grid):
        params = coerce_one({**raw, **constants})
        key = combo_key(params)
        if key in done:
            continue

        print(f"\n=== Trial {ran+1}/{max_trials} | idx={idx} ===")
        print({k: params[k] for k in REQ})

        t1 = time.time()
        running_marker = os.path.join(ARTIFACT_DIR, f"{key}.running")
        open(running_marker, "w").close()  # create empty marker

        auc = acc = np.nan
        state = vocab = None
        status = "running"
        hist = []
        final_probs = final_true = []

        try:
            (auc, acc, state, vocab,
             hist, final_probs, final_true) = train_eval_once_with_best(params, enable_tb, key=key)
            status = "ok"

            if SAVE_CHECKPOINTS and state is not None:
                torch.save({"state_dict": state, "params": params},
                           os.path.join(CHECKPOINT_DIR, f"{key}.pt"))

            if SAVE_TRIAL_ARTIFACTS:
                write_json({
                    "timestamp": now_iso(),
                    "params": params,
                    "best_auc": float(auc),
                    "best_acc": float(acc),
                    "history": hist
                }, os.path.join(ARTIFACT_DIR, f"{key}_metrics.json"))
                np.savez_compressed(os.path.join(ARTIFACT_DIR, f"{key}_val.npz"),
                                    probs=np.array(final_probs, dtype=np.float32),
                                    y=np.array(final_true, dtype=np.int64))

        except KeyboardInterrupt:
            status = "interrupted"
            print("⚠️ Trial interrupted by user. Saved heartbeat & latest checkpoint.")

        except Exception as e:
            status = f"error: {e}"
            print("❌", e)

        finally:
            dur = time.time() - t1
            row_out = {
                "timestamp": now_iso(),
                "key": key,
                "mode": "scratch",
                "device": DEVICE,
                "python": platform.python_version(),
                "grid_idx": idx,
                "val_auc": auc,
                "val_acc": acc,
                "runtime_sec": round(dur,2),
                "status": status,
                **{f"hp/{k}": params[k] for k in sorted(params)}
            }
            append_result_row(row_out, RESULTS_CSV)
            ran += 1
            try:
                if os.path.exists(running_marker):
                    os.remove(running_marker)
            except Exception:
                pass

            if status == "ok" and auc > best_auc:
                best_auc = auc
                best_payload = (state, params, vocab)

            if ran >= max_trials:
                break

    print(f"\nSession done. Ran {ran} trial(s) in {round(time.time()-t0,2)}s.")
    if best_payload and save_best_submission:
        state, params, vocab = best_payload
        predict_test_with_state(state, params, vocab, out_csv=SUBMISSION_CSV)
        try:
            sub_df = pd.read_csv(SUBMISSION_CSV)
            with pd.ExcelWriter(SUBMISSION_XLSX, engine="xlsxwriter") as w:
                sub_df.to_excel(w, sheet_name="submission", index=False)
            print(f"✅ Wrote {SUBMISSION_XLSX}")
        except Exception as e:
            print("Note: could not write XLSX submission:", e)
    else:
        print("No submission written this session.")

# ============================================================
# DEFINE YOUR PARAM SPACE HERE (laptop-safe; resume lets you add more)
# ============================================================
PARAM_SPACE = dict(
    # Capacity/structure
    seq_len=[200, 224, 256],
    emb_dim=[128, 160, 192],
    conv_blocks=[1, 2],
    channels_start=[128, 160],
    channel_growth=["x1.5"],
    kernel_sizes=["3-5-7"],
    activation=["relu", "gelu"],
    residual=[False, True],

    # Optimization/regularization
    optimizer=["adamw"],
    learning_rate=[8e-4, 1e-3, 1.2e-3],
    batch_size=[64, 128],
    epochs=[8],
    dropout=[0.2, 0.25],
    weight_decay=[1e-4, 2e-4],
    label_smoothing=[0.0, 0.03],
    loss_fn=["bce_logits"],
)

# Constants applied to every combo (change here if needed)
CONSTANTS = dict(
    vocab_size=30000,
    use_batchnorm=True,
    pooling="max",
    grad_clip=1.0,
    scheduler="none",
    class_weighting="none",
    seed=42,
    activation="relu",
    residual=False,
)

# ============================================================
# GO: run grid (or random sample) with resume
# Start TensorBoard in a terminal:  tensorboard --logdir tb_scratch_tune
# ============================================================
if __name__ == "__main__":
    run_param_space(
        PARAM_SPACE,
        constants=CONSTANTS,
        mode="grid",             # or "random"
        n_samples=None,          # only used for mode="random"
        max_trials=MAX_TRIALS_PER_RUN,  # 0/None => ALL remaining
        enable_tb=ENABLE_TENSORBOARD,
        save_best_submission=SAVE_BEST_SESSION_SUBMISSION
    )


Device: cpu
Total combos: 6912 | Completed in CSV: 0

=== Trial 1/6912 | idx=0 ===
{'seq_len': 256, 'emb_dim': 192, 'conv_blocks': 2, 'channels_start': 128, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.25, 'weight_decay': 0.0002, 'label_smoothing': 0.0, 'learning_rate': 0.0008, 'batch_size': 128, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=1.4572 train_acc=0.5410 val_auc=0.72612 val_acc=0.6601 (best 0.72612, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.7424 train_acc=0.6531 val_auc=0.78167 val_acc=0.6897 (best 0.78167, patience 0/3)


                                                          

[SCRATCH] Epoch 3/8 | train_loss=0.5147 train_acc=0.7505 val_auc=0.76682 val_acc=0.6798 (best 0.78167, patience 1/3)


                                                          

[SCRATCH] Epoch 4/8 | train_loss=0.3598 train_acc=0.8410 val_auc=0.77058 val_acc=0.6921 (best 0.78167, patience 2/3)


                                                          

[SCRATCH] Epoch 5/8 | train_loss=0.2432 train_acc=0.9150 val_auc=0.75811 val_acc=0.6478 (best 0.78167, patience 3/3)
Early stopping: no improvement.

=== Trial 2/6912 | idx=1 ===
{'seq_len': 200, 'emb_dim': 160, 'conv_blocks': 1, 'channels_start': 128, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.2, 'weight_decay': 0.0002, 'label_smoothing': 0.03, 'learning_rate': 0.0008, 'batch_size': 128, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=0.8820 train_acc=0.5551 val_auc=0.72709 val_acc=0.6700 (best 0.72709, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.6301 train_acc=0.6802 val_auc=0.75141 val_acc=0.6724 (best 0.75141, patience 0/3)


                                                          

[SCRATCH] Epoch 3/8 | train_loss=0.5314 train_acc=0.7461 val_auc=0.76806 val_acc=0.6946 (best 0.76806, patience 0/3)


                                                          

[SCRATCH] Epoch 4/8 | train_loss=0.4442 train_acc=0.8084 val_auc=0.77388 val_acc=0.6650 (best 0.77388, patience 0/3)


                                                          

[SCRATCH] Epoch 5/8 | train_loss=0.3813 train_acc=0.8503 val_auc=0.77672 val_acc=0.6995 (best 0.77672, patience 0/3)


                                                          

[SCRATCH] Epoch 6/8 | train_loss=0.3395 train_acc=0.8749 val_auc=0.78512 val_acc=0.7167 (best 0.78512, patience 0/3)


                                                          

[SCRATCH] Epoch 7/8 | train_loss=0.3141 train_acc=0.8916 val_auc=0.78541 val_acc=0.7217 (best 0.78541, patience 0/3)


                                                          

[SCRATCH] Epoch 8/8 | train_loss=0.2694 train_acc=0.9193 val_auc=0.77633 val_acc=0.7118 (best 0.78541, patience 1/3)

=== Trial 3/6912 | idx=2 ===
{'seq_len': 200, 'emb_dim': 128, 'conv_blocks': 1, 'channels_start': 160, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.25, 'weight_decay': 0.0002, 'label_smoothing': 0.03, 'learning_rate': 0.0012, 'batch_size': 64, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=0.9469 train_acc=0.5533 val_auc=0.78714 val_acc=0.6921 (best 0.78714, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.6239 train_acc=0.7018 val_auc=0.79107 val_acc=0.6995 (best 0.79107, patience 0/3)


                                                          

[SCRATCH] Epoch 3/8 | train_loss=0.4577 train_acc=0.8022 val_auc=0.78871 val_acc=0.7069 (best 0.79107, patience 1/3)


                                                          

[SCRATCH] Epoch 4/8 | train_loss=0.3945 train_acc=0.8355 val_auc=0.80396 val_acc=0.7167 (best 0.80396, patience 0/3)


                                                          

[SCRATCH] Epoch 5/8 | train_loss=0.3313 train_acc=0.8799 val_auc=0.80150 val_acc=0.7167 (best 0.80396, patience 1/3)


                                                          

[SCRATCH] Epoch 6/8 | train_loss=0.2911 train_acc=0.9008 val_auc=0.80505 val_acc=0.7217 (best 0.80505, patience 0/3)


                                                          

⚠️ Trial interrupted by user. Saved heartbeat & latest checkpoint.

=== Trial 4/6912 | idx=3 ===
{'seq_len': 200, 'emb_dim': 192, 'conv_blocks': 2, 'channels_start': 160, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.2, 'weight_decay': 0.0001, 'label_smoothing': 0.03, 'learning_rate': 0.001, 'batch_size': 128, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=1.3909 train_acc=0.5391 val_auc=0.72277 val_acc=0.6404 (best 0.72277, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.5998 train_acc=0.6827 val_auc=0.76495 val_acc=0.6601 (best 0.76495, patience 0/3)


                                                          

[SCRATCH] Epoch 3/8 | train_loss=0.4539 train_acc=0.8016 val_auc=0.77041 val_acc=0.6675 (best 0.77041, patience 0/3)


                                                          

[SCRATCH] Epoch 4/8 | train_loss=0.3194 train_acc=0.8934 val_auc=0.75968 val_acc=0.7020 (best 0.77041, patience 1/3)


                                                          

[SCRATCH] Epoch 5/8 | train_loss=0.2302 train_acc=0.9421 val_auc=0.76248 val_acc=0.6946 (best 0.77041, patience 2/3)


                                                          

[SCRATCH] Epoch 6/8 | train_loss=0.2020 train_acc=0.9600 val_auc=0.75524 val_acc=0.6527 (best 0.77041, patience 3/3)
Early stopping: no improvement.

=== Trial 5/6912 | idx=4 ===
{'seq_len': 256, 'emb_dim': 160, 'conv_blocks': 2, 'channels_start': 128, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.25, 'weight_decay': 0.0002, 'label_smoothing': 0.0, 'learning_rate': 0.0008, 'batch_size': 64, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=1.1306 train_acc=0.5595 val_auc=0.73556 val_acc=0.6749 (best 0.73556, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.6225 train_acc=0.6895 val_auc=0.74803 val_acc=0.6453 (best 0.74803, patience 0/3)


                                                         

⚠️ Trial interrupted by user. Saved heartbeat & latest checkpoint.

=== Trial 6/6912 | idx=5 ===
{'seq_len': 224, 'emb_dim': 192, 'conv_blocks': 2, 'channels_start': 128, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.25, 'weight_decay': 0.0001, 'label_smoothing': 0.0, 'learning_rate': 0.001, 'batch_size': 128, 'epochs': 8, 'loss_fn': 'bce_logits'}


                                                          

[SCRATCH] Epoch 1/8 | train_loss=1.4604 train_acc=0.5465 val_auc=0.70917 val_acc=0.5567 (best 0.70917, patience 0/3)


                                                          

[SCRATCH] Epoch 2/8 | train_loss=0.6969 train_acc=0.6642 val_auc=0.77692 val_acc=0.6798 (best 0.77692, patience 0/3)


                                                          

[SCRATCH] Epoch 3/8 | train_loss=0.5040 train_acc=0.7554 val_auc=0.76808 val_acc=0.6601 (best 0.77692, patience 1/3)


                                                          

[SCRATCH] Epoch 4/8 | train_loss=0.3588 train_acc=0.8398 val_auc=0.77114 val_acc=0.6576 (best 0.77692, patience 2/3)


                                                          

[SCRATCH] Epoch 5/8 | train_loss=0.2404 train_acc=0.9125 val_auc=0.78396 val_acc=0.6946 (best 0.78396, patience 0/3)


                                                          

[SCRATCH] Epoch 6/8 | train_loss=0.1993 train_acc=0.9217 val_auc=0.78367 val_acc=0.6872 (best 0.78396, patience 1/3)


                                                          

[SCRATCH] Epoch 7/8 | train_loss=0.1542 train_acc=0.9526 val_auc=0.79284 val_acc=0.7069 (best 0.79284, patience 0/3)


                                                          

[SCRATCH] Epoch 8/8 | train_loss=0.1270 train_acc=0.9618 val_auc=0.78723 val_acc=0.7118 (best 0.79284, patience 1/3)

=== Trial 7/6912 | idx=6 ===
{'seq_len': 200, 'emb_dim': 192, 'conv_blocks': 2, 'channels_start': 128, 'channel_growth': 'x1.5', 'kernel_sizes': '3-5-7', 'dropout': 0.2, 'weight_decay': 0.0001, 'label_smoothing': 0.03, 'learning_rate': 0.0008, 'batch_size': 64, 'epochs': 8, 'loss_fn': 'bce_logits'}


Epoch 1/8:  46%|████▌     | 12/26 [00:01<00:01,  8.22it/s]