In [None]:
%pip install -U "trl>=0.9.6" "transformers>=4.44" "datasets>=2.19" "accelerate>=0.33" peft bitsandbytes einops


In [None]:
# =========================
# 1) Imports & config
# =========================
import os, re, time
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, EarlyStoppingCallback
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import torch
import time, os
from transformers import TrainerCallback, TrainerControl


BASE = "Qwen/Qwen2.5-7B-Instruct"
# BASE = "google/gemma-2-9b-it"  # uncomment to switch to Gemma 2 9B IT

# --- Paths (adjust to your Kaggle Dataset names) ---
TRAIN_PATH = "/kaggle/input/test-data/Untitled-1 (1).jsonl"
VAL_PATH   = "/kaggle/input/eval-data/dev.jsonl"
OUT_DIR    = f"/kaggle/working/sft-{BASE.split('/')[-1]}"
os.makedirs(OUT_DIR, exist_ok=True)

print("Torch:", torch.__version__, "| GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f" GPU{i}:", torch.cuda.get_device_name(i))
print("Output dir:", OUT_DIR)

# =========================================
# 2) Data loader with schema auto-detection
# =========================================
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(Z|[+\-]\d{2}:\d{2})$")

class TimeLimitCallback(TrainerCallback):
    def __init__(self, limit_seconds, save_dir, tokenizer=None):
        self.limit = limit_seconds
        self.save_dir = save_dir
        self.tokenizer = tokenizer
        self.t0 = time.time()
        os.makedirs(save_dir, exist_ok=True)

    def on_step_end(self, args, state, control: TrainerControl, **kwargs):
        elapsed = time.time() - self.t0
        if elapsed >= self.limit:
            tag = time.strftime("%Y%m%d-%H%M%S")
            out_dir = os.path.join(self.save_dir, f"checkpoint_timecap_{tag}")
            print(f"\n[TimeCap] {elapsed:.0f}s elapsed — saving to {out_dir} and stopping…")
            # 1) save a final checkpoint (adapters) + tokenizer
            kwargs["model"].save_pretrained(out_dir)
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(out_dir)
            # 2) request a final eval/save, then stop
            control.should_evaluate = True
            control.should_save = True
            control.should_training_stop = True
            # optional: write a marker file
            with open(os.path.join(out_dir, "TIME_CAP.txt"), "w") as f:
                f.write(f"Stopped at global_step={state.global_step}, elapsed={elapsed:.1f}s\n")
            return control

def is_iso_ok(s: str) -> bool:
    s = (s or "").strip()
    return bool(ISO_RE.match(s)) or s == "ABSTAIN" or ("||" in s)

def load_jsonl(path):
    return load_dataset("json", data_files=path, split="train")

DEFAULT_SYS = (
    "You are a precise time normalizer. Output ONE line in ISO-8601 with offset "
    "unless policy says ABSTAIN or A||B. No prose."
)

def extract_prompt_and_answer(ex: dict):
    # Helper function to safely convert values to strings
    def safe_str(value):
        if value is None:
            return ""
        elif isinstance(value, str):
            return value.strip()
        else:
            # Handle datetime, numbers, or other objects
            return str(value).strip()
    
    # E) pre-concatenated text
    if "text" in ex and isinstance(ex["text"], str) and "[ASSISTANT]\n" in ex["text"]:
        pieces = ex["text"].split("[ASSISTANT]\n", 1)
        prompt = pieces[0] + "[ASSISTANT]\n"
        ans = pieces[1].strip().splitlines()[0].strip()
        return prompt, ans

    # A) messages + answer
    if "messages" in ex and "answer" in ex:
        msgs = ex["messages"] or []
        sys = next((m.get("content","") for m in msgs if (m.get("role","").lower()=="system")), "")
        usr = ""
        for m in reversed(msgs):
            if (m.get("role","").lower()=="user"):
                usr = m.get("content",""); break
        if not sys: sys = DEFAULT_SYS
        prompt = f"<s>[SYSTEM]\n{sys}\n[/SYSTEM]\n[USER]\n{usr}\n[/USER]\n[ASSISTANT]\n"
        return prompt, safe_str(ex.get("answer"))

    # B) prompt + answer
    if "prompt" in ex and "answer" in ex:
        prompt = f"<s>[SYSTEM]\n{DEFAULT_SYS}\n[/SYSTEM]\n[USER]\n{ex['prompt']}\n[/USER]\n[ASSISTANT]\n"
        return prompt, safe_str(ex.get("answer"))

    # C) input/response OR question/target
    for in_key, out_key in [("input","response"), ("question","target")]:
        if in_key in ex and out_key in ex:
            prompt = f"<s>[SYSTEM]\n{DEFAULT_SYS}\n[/SYSTEM]\n[USER]\n{ex[in_key]}\n[/USER]\n[ASSISTANT]\n"
            return prompt, safe_str(ex.get(out_key))

    # D) input_text + gold (your schema)
    if "input_text" in ex and "gold" in ex:
        prompt = f"<s>[SYSTEM]\n{DEFAULT_SYS}\n[/SYSTEM]\n[USER]\n{ex['input_text']}\n[/USER]\n[ASSISTANT]\n"
        return prompt, safe_str(ex.get("gold"))

    raise KeyError("Unsupported row schema; keys present: " + ", ".join(sorted(ex.keys())))

def to_sft_text(ex):
    prompt, ans = extract_prompt_and_answer(ex)
    ans = ans.strip().splitlines()[0].strip()
    return {"text": prompt + ans}

train_raw = load_jsonl(TRAIN_PATH)
val_raw   = load_jsonl(VAL_PATH)

train_cols = list(train_raw.column_names)
val_cols   = list(val_raw.column_names)

train_ds = train_raw.map(to_sft_text, remove_columns=train_cols, desc="Format train")
val_ds   = val_raw.map(to_sft_text,   remove_columns=val_cols,   desc="Format val")

print("Example formatted sample:\n", train_ds[0]["text"][:400] + " ...")
print("Train size:", len(train_ds), " Val size:", len(val_ds))

# ====================================
# 3) Tokenizer & Model (4-bit QLoRA)
# ====================================
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype="float16")
tok = AutoTokenizer.from_pretrained(BASE, use_fast=True, trust_remote_code=True)
tok.padding_side = "left"; tok.truncation_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    BASE, quantization_config=bnb, device_map="auto", trust_remote_code=True
)

lora = LoraConfig(
    r=32, lora_alpha=32, lora_dropout=0.05, bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
)

# ===========================================
# 4) Trainer config + monitoring parameters
# ===========================================
cfg = SFTConfig(
    output_dir=OUT_DIR,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=32,
    learning_rate=2e-4,
    num_train_epochs=1,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,

    logging_steps=5,        # frequent progress prints
    eval_strategy="steps",
    eval_steps=100,         # FIXED: More frequent evaluation (was 1000)
    save_steps=500,         # save checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    fp16=True, bf16=False,
    packing=False,
    report_to=["tensorboard"]
)

tok.model_max_length = 256

trainer = SFTTrainer(
    model=model,
    peft_config=lora,
    args=cfg,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    processing_class=tok,
    formatting_func=lambda ex: ex["text"]
)

trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))

# ======================================================
# 5) Custom monitoring: EM / Format-OK / Abstention-OK
# ======================================================
from transformers import TrainerCallback

def generate_one(model, tok, prompt, max_new_tokens=40):
    ids = tok(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**ids, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0)
    text = tok.decode(out[0], skip_special_tokens=True)
    return text.split("[ASSISTANT]\n")[-1].strip().splitlines()[0].strip()

def eval_on_slice(model, tok, ds, n=128):
    n = min(n, len(ds))
    ems, fmt_ok, abst_ok = [], [], []
    for i in range(n):
        full = ds[i]["text"]
        prompt, gold = full.split("[ASSISTANT]\n", 1)
        prompt = prompt + "[ASSISTANT]\n"
        gold = gold.strip().splitlines()[0].strip()
        pred = generate_one(model, tok, prompt)
        ems.append(pred == gold)
        fmt_ok.append(is_iso_ok(pred))
        abstain_req = (gold == "ABSTAIN")
        abst_ok.append((abstain_req and pred == "ABSTAIN") or (not abstain_req))
    return {"em": sum(ems)/n, "format_ok": sum(fmt_ok)/n, "abstain_ok": sum(abst_ok)/n, "n": n}

class DomainEvalCallback(TrainerCallback):
    def __init__(self, tok, ds_slice, out_csv):
        self.tok = tok
        self.ds_slice = ds_slice
        self.out_csv = out_csv
        self.buffer = []
        # FIXED: Initialize CSV file with headers
        pd.DataFrame(columns=["step", "time", "em", "format_ok", "abstain_ok", "n"]).to_csv(out_csv, index=False)

    def on_evaluate(self, args, state, control, **kwargs):
        model = kwargs["model"]
        metrics = eval_on_slice(model, self.tok, self.ds_slice, n=len(self.ds_slice))
        row = dict(step=int(state.global_step), time=time.time(), **metrics)
        self.buffer.append(row)
        pd.DataFrame(self.buffer).to_csv(self.out_csv, index=False)
        print(f"[domain] step={row['step']} EM={row['em']:.3f} fmt={row['format_ok']:.3f} abst={row['abstain_ok']:.3f}")

    # FIXED: Add training end callback to ensure we always have some data
    def on_train_end(self, args, state, control, **kwargs):
        if not self.buffer:  # If no evaluations happened during training
            print("[domain] No evaluations during training, running final evaluation...")
            model = kwargs["model"]
            metrics = eval_on_slice(model, self.tok, self.ds_slice, n=min(64, len(self.ds_slice)))
            row = dict(step=int(state.global_step), time=time.time(), **metrics)
            self.buffer.append(row)
            pd.DataFrame(self.buffer).to_csv(self.out_csv, index=False)
            print(f"[domain] final: EM={row['em']:.3f} fmt={row['format_ok']:.3f} abst={row['abstain_ok']:.3f}")

dom_csv = f"{OUT_DIR}/domain_metrics.csv"
val_slice = val_ds.select(range(min(256, len(val_ds))))
trainer.add_callback(DomainEvalCallback(tok, val_slice, dom_csv))
trainer.add_callback(TimeLimitCallback(limit_seconds=42900, save_dir=OUT_DIR, tokenizer=tok))

# ======================
# 6) Train (SFT QLoRA)
# ======================
trainer.train()

# Save adapters & tokenizer
trainer.model.save_pretrained(OUT_DIR)
tok.save_pretrained(OUT_DIR)
print("Saved model/adapter to:", OUT_DIR)

# ==============================
# 7) Plot & save monitoring figs
# ==============================
def plot_domain_curves(csv_path, out_dir):
    # FIXED: Check if file exists before trying to read it
    if not os.path.exists(csv_path):
        print(f"Warning: {csv_path} does not exist. Creating empty plots.")
        # Create empty plots as placeholders
        for metric, title in [("em", "Exact-Match"), ("format_ok", "Format Compliance"), ("abstain_ok", "Abstention Compliance")]:
            plt.figure(figsize=(7.5,4.5))
            plt.text(0.5, 0.5, f"No data available\n({csv_path} not found)", 
                    ha='center', va='center', transform=plt.gca().transAxes)
            plt.xlabel("Global Step"); plt.ylabel(title)
            plt.title(f"{title} over training"); plt.grid(True, linewidth=0.4)
            plt.tight_layout()
            p = f"{out_dir}/{metric}_curve.png"
            plt.savefig(p, dpi=160)
            print("Saved placeholder:", p)
            plt.close()
        return
    
    df = pd.read_csv(csv_path)
    if df.empty:
        print("domain_metrics.csv is empty; creating placeholder plots.")
        # Same placeholder logic as above
        for metric, title in [("em", "Exact-Match"), ("format_ok", "Format Compliance"), ("abstain_ok", "Abstention Compliance")]:
            plt.figure(figsize=(7.5,4.5))
            plt.text(0.5, 0.5, "No evaluation data available", 
                    ha='center', va='center', transform=plt.gca().transAxes)
            plt.xlabel("Global Step"); plt.ylabel(title)
            plt.title(f"{title} over training"); plt.grid(True, linewidth=0.4)
            plt.tight_layout()
            p = f"{out_dir}/{metric}_curve.png"
            plt.savefig(p, dpi=160)
            print("Saved placeholder:", p)
            plt.close()
        return
    
    # EM
    plt.figure(figsize=(7.5,4.5))
    plt.plot(df["step"], df["em"], marker="o")
    plt.xlabel("Global Step"); plt.ylabel("Exact-Match")
    plt.title("Exact-Match over training"); plt.grid(True, linewidth=0.4)
    plt.tight_layout()
    p1 = f"{out_dir}/em_curve.png"; plt.savefig(p1, dpi=160); print("Saved:", p1)
    plt.close()

    # Format OK
    plt.figure(figsize=(7.5,4.5))
    plt.plot(df["step"], df["format_ok"], marker="o")
    plt.xlabel("Global Step"); plt.ylabel("Format OK Rate")
    plt.title("Format Compliance over training"); plt.grid(True, linewidth=0.4)
    plt.tight_layout()
    p2 = f"{out_dir}/format_ok_curve.png"; plt.savefig(p2, dpi=160); print("Saved:", p2)
    plt.close()

    # Abstention OK
    plt.figure(figsize=(7.5,4.5))
    plt.plot(df["step"], df["abstain_ok"], marker="o")
    plt.xlabel("Global Step"); plt.ylabel("Abstention Compliance")
    plt.title("Abstention Compliance over training"); plt.grid(True, linewidth=0.4)
    plt.tight_layout()
    p3 = f"{out_dir}/abstain_ok_curve.png"; plt.savefig(p3, dpi=160); print("Saved:", p3)
    plt.close()

plot_domain_curves(dom_csv, OUT_DIR)

# ===============================
# 8) TensorBoard (optional)
# ===============================
# %load_ext tensorboard
# %tensorboard --logdir "$OUT_DIR"