# Config-Driven Fine-Tuning of Gemma 3 1B on Kaggle TPU (Tunix/JAX)

This notebook implements an end-to-end **LoRA SFT → GRPO** pipeline for **Gemma 3 1B Instruct** using **Tunix** on **Kaggle TPU**.

Key goals:
- **No hard-coded hyperparameters**: all configuration is read from **`config.yaml`**.
- Reasonable **small defaults** to validate the full pipeline quickly.
- Save a **merged (base + LoRA) Hugging Face model** so you can **skip retraining** and run **inference** immediately.

Prerequisites (Kaggle):
1. Enable **TPU VM** in notebook settings.
2. Attach a **Gemma 3 1B IT** weights dataset under `/kaggle/input/...` (Transformers format).
3. Attach your training data under `/kaggle/input/...` (CSV/Parquet).

Notes:
- The **first GRPO step** includes **XLA compilation** and can be the slowest part of the run.
- If `export/.../merged_lora/` already exists and `skip_training_if_export_exists: true`, training is skipped and the notebook goes straight to inference.


In [None]:

# =========================
# 0) Configuration (config.yaml)
# =========================
# All tunables live in config.yaml. This notebook reads them and assigns variables.
# If config.yaml doesn't exist, we write a small "smoke test" config you can edit.

from __future__ import annotations

from pathlib import Path
from typing import Optional, Literal, List, Dict, Any
import os

# YAML loader
try:
    import yaml  # PyYAML
except Exception as e:
    raise RuntimeError(
        "Missing dependency: pyyaml. Install it with `pip install pyyaml` and re-run."
    ) from e

# Pydantic config models (recommended for validation)
try:
    from pydantic import BaseModel, Field
except Exception as e:
    raise RuntimeError(
        "Missing dependency: pydantic. Install it with `pip install pydantic` and re-run."
    ) from e


DEFAULT_CONFIG_YAML = """
model:
  family: gemma3
  id: google/gemma-3-1b-it
  # If null, we auto-search under /kaggle/input
  local_dir: null
  # Used only when auto-searching; a substring that should appear in the model path
  prefer: gemma-3-1b-it
  dtype: bfloat16
  rank: 64
  alpha: 64.0
  lora_module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
  eos_tokens: [1, 106]

data:
  # If null, we auto-search under /kaggle/input for common train files (train.csv/train.parquet)
  train_path: null
  # Optional: if null we split from train
  eval_path: null
  val_frac: 0.02
  seed: 123
  max_train_rows: 5000
  max_eval_rows: 500
  q_col: null
  a_col: null
  r_col: null
  id_col: null
  system_prompt: "You are a helpful assistant."
  template: "{system_prompt}\n\nQuestion: {question}\n\nAnswer: "

sft:
  enabled: true
  max_seq_len: 512
  global_batch_size: 64
  # learning rate scaling:
  #   lr = lr_ref_global * global_batch_size / lr_ref_batch
  lr_ref_global: 2.0e-5
  lr_ref_batch: 64
  max_steps: 50
  eval_every_steps: 25
  weight_decay: 0.05
  b1: 0.9
  b2: 0.99
  max_examples: 4000

grpo:
  enabled: true
  # Keep these small for pipeline validation
  max_prompt_length: 256
  total_generation_steps: 128
  safety_max_prompt_length: 512
  safety_max_generation_steps: 256
  temperature: 0.9
  top_p: 1.0
  top_k: 50
  num_generations: 8
  num_iterations: 1
  beta: 0.04
  epsilon: 0.2
  max_steps: 300
  eval_every_steps: 50
  warmup_steps: 30
  learning_rate: 5.0e-6
  b1: 0.9
  b2: 0.95
  weight_decay: 0.05
  max_grad_norm: 1.0
  train_micro_batch_size: 4
  rollout_micro_batch_pref: 1
  compute_logps_micro_batch_pref: 1
  offload_to_cpu: false
  kv_cache_extra: 256
  max_train_examples: 2000
  max_eval_examples: 200

export:
  enabled: true
  export_root: "/kaggle/working/export"
  run_name: "google_gemma-3-1b-it_tunix_sft_grpo"
  merged_subdir: "merged_lora"
  save_after_sft: true
  save_after_grpo: true
  save_interval_steps: 100
  max_to_keep: 2

inference:
  enabled: true
  # If null, defaults to export_root/run_name/merged_subdir
  model_dir: null
  max_new_tokens: 256
  do_sample: true

runtime:
  require_tpu: true
  mesh_fsdp: 8
  mesh_tp: 1
  skip_training_if_export_exists: true
  force_retrain: false
  silence_asyncio_noise: true
"""


class ModelCfg(BaseModel):
    family: Literal["gemma3", "gemma2"] = "gemma3"
    id: str = "google/gemma-3-1b-it"
    local_dir: Optional[str] = None
    prefer: Optional[str] = "gemma-3-1b-it"
    dtype: Literal["bfloat16", "float16", "float32"] = "bfloat16"
    rank: int = 64
    alpha: float = 64.0
    lora_module_path: str = ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
    eos_tokens: List[int] = Field(default_factory=lambda: [1, 106])


class DataCfg(BaseModel):
    train_path: Optional[str] = None
    eval_path: Optional[str] = None
    val_frac: float = 0.02
    seed: int = 123
    max_train_rows: int = 5000
    max_eval_rows: int = 500
    q_col: Optional[str] = None
    a_col: Optional[str] = None
    r_col: Optional[str] = None
    id_col: Optional[str] = None
    system_prompt: str = "You are a helpful assistant."
    template: str = "{system_prompt}\n\nQuestion: {question}\n\nAnswer: "


class SFTCfg(BaseModel):
    enabled: bool = True
    max_seq_len: int = 512
    global_batch_size: int = 64
    lr_ref_global: float = 2.0e-5
    lr_ref_batch: int = 64
    max_steps: int = 50
    eval_every_steps: int = 25
    weight_decay: float = 0.05
    b1: float = 0.9
    b2: float = 0.99
    max_examples: int = 4000

    def learning_rate(self) -> float:
        # Scale LR linearly with global batch size
        return float(self.lr_ref_global) * float(self.global_batch_size) / float(self.lr_ref_batch)


class GRPOCfg(BaseModel):
    enabled: bool = True
    max_prompt_length: int = 256
    total_generation_steps: int = 128
    safety_max_prompt_length: int = 512
    safety_max_generation_steps: int = 256
    temperature: float = 0.9
    top_p: float = 1.0
    top_k: int = 50
    num_generations: int = 8
    num_iterations: int = 1
    beta: float = 0.04
    epsilon: float = 0.2
    max_steps: int = 300
    eval_every_steps: int = 50
    warmup_steps: int = 30
    learning_rate: float = 5.0e-6
    b1: float = 0.9
    b2: float = 0.95
    weight_decay: float = 0.05
    max_grad_norm: Optional[float] = 1.0
    train_micro_batch_size: int = 4
    rollout_micro_batch_pref: int = 1
    compute_logps_micro_batch_pref: int = 1
    offload_to_cpu: bool = False
    kv_cache_extra: int = 256
    max_train_examples: int = 2000
    max_eval_examples: int = 200


class ExportCfg(BaseModel):
    enabled: bool = True
    export_root: str = "/kaggle/working/export"
    run_name: str = "google_gemma-3-1b-it_tunix_sft_grpo"
    merged_subdir: str = "merged_lora"
    save_after_sft: bool = True
    save_after_grpo: bool = True
    save_interval_steps: int = 100
    max_to_keep: int = 2


class InferenceCfg(BaseModel):
    enabled: bool = True
    model_dir: Optional[str] = None
    max_new_tokens: int = 256
    do_sample: bool = True


class RuntimeCfg(BaseModel):
    require_tpu: bool = True
    mesh_fsdp: int = 8
    mesh_tp: int = 1
    skip_training_if_export_exists: bool = True
    force_retrain: bool = False
    silence_asyncio_noise: bool = True


class NotebookCfg(BaseModel):
    model: ModelCfg = Field(default_factory=ModelCfg)
    data: DataCfg = Field(default_factory=DataCfg)
    sft: SFTCfg = Field(default_factory=SFTCfg)
    grpo: GRPOCfg = Field(default_factory=GRPOCfg)
    export: ExportCfg = Field(default_factory=ExportCfg)
    inference: InferenceCfg = Field(default_factory=InferenceCfg)
    runtime: RuntimeCfg = Field(default_factory=RuntimeCfg)


CONFIG_PATH = Path("config.yaml")

if not CONFIG_PATH.exists():
    CONFIG_PATH.write_text(DEFAULT_CONFIG_YAML.strip() + "\n", encoding="utf-8")
    print(f"✅ Wrote default config.yaml to: {CONFIG_PATH.resolve()}")
    print("   Edit it (especially data.train_path / model.local_dir) and re-run this cell if needed.")

cfg_dict = yaml.safe_load(CONFIG_PATH.read_text(encoding="utf-8"))
cfg = NotebookCfg.model_validate(cfg_dict)

print("✅ Loaded config.yaml")
print("  model:", cfg.model.family, "|", cfg.model.id, "| rank:", cfg.model.rank)
print("  data: train_path:", cfg.data.train_path, "| val_frac:", cfg.data.val_frac)
print("  sft: enabled:", cfg.sft.enabled, "| steps:", cfg.sft.max_steps, "| lr:", cfg.sft.learning_rate())
print("  grpo: enabled:", cfg.grpo.enabled, "| steps:", cfg.grpo.max_steps, "| G:", cfg.grpo.num_generations)
print("  export:", cfg.export.export_root, "/", cfg.export.run_name, "/", cfg.export.merged_subdir)


In [None]:

# =========================
# 1) Assign notebook variables from cfg (compat layer)
# =========================
from pathlib import Path

# Paths
KAGGLE_INPUT = Path("/kaggle/input")
KAGGLE_WORKING = Path("/kaggle/working")

RUN_NAME = cfg.export.run_name

EXPORT_ROOT = Path(cfg.export.export_root)
RUN_EXPORT_DIR = EXPORT_ROOT / RUN_NAME
MERGED_LORA_DIR = RUN_EXPORT_DIR / cfg.export.merged_subdir

TB_DIR = KAGGLE_WORKING / "tensorboard" / RUN_NAME
CKPT_DIR = KAGGLE_WORKING / "checkpoints" / RUN_NAME

for p in [TB_DIR, CKPT_DIR, RUN_EXPORT_DIR]:
    p.mkdir(parents=True, exist_ok=True)

# Model
MODEL_FAMILY = cfg.model.family
MODEL_ID = cfg.model.id
RANK = int(cfg.model.rank)
ALPHA = float(cfg.model.alpha)
LORA_MODULE_PATH = cfg.model.lora_module_path
DTYPE_STR = cfg.model.dtype

# Data
VAL_FRAC = float(cfg.data.val_frac)
q_col = cfg.data.q_col
a_col = cfg.data.a_col
r_col = cfg.data.r_col
id_col = cfg.data.id_col
SYSTEM_PROMPT = cfg.data.system_prompt
TEMPLATE = cfg.data.template

# SFT
RUN_SFT = bool(cfg.sft.enabled)
SFT_MAX_SEQ_LEN = int(cfg.sft.max_seq_len)
SFT_GLOBAL_BATCH_SIZE = int(cfg.sft.global_batch_size)
SFT_LEARNING_RATE = float(cfg.sft.learning_rate())
SFT_MAX_STEPS = int(cfg.sft.max_steps)
SFT_EVAL_EVERY = int(cfg.sft.eval_every_steps)
SFT_WEIGHT_DECAY = float(cfg.sft.weight_decay)
SFT_B1 = float(cfg.sft.b1)
SFT_B2 = float(cfg.sft.b2)
SFT_MAX_EXAMPLES = int(cfg.sft.max_examples)

# GRPO
RUN_GRPO = bool(cfg.grpo.enabled)
MAX_PROMPT_LENGTH = int(cfg.grpo.max_prompt_length)
TOTAL_GENERATION_STEPS = int(cfg.grpo.total_generation_steps)
TEMPERATURE = float(cfg.grpo.temperature)
TOP_P = float(cfg.grpo.top_p)
TOP_K = int(cfg.grpo.top_k)
NUM_GENERATIONS = int(cfg.grpo.num_generations)
NUM_ITERATIONS = int(cfg.grpo.num_iterations)
BETA = float(cfg.grpo.beta)
EPSILON = float(cfg.grpo.epsilon)

MAX_STEPS = int(cfg.grpo.max_steps)
EVAL_EVERY_N_STEPS = int(cfg.grpo.eval_every_steps)
WARMUP_STEPS = int(cfg.grpo.warmup_steps)
LEARNING_RATE = float(cfg.grpo.learning_rate)
B1 = float(cfg.grpo.b1)
B2 = float(cfg.grpo.b2)
WEIGHT_DECAY = float(cfg.grpo.weight_decay)
MAX_GRAD_NORM = cfg.grpo.max_grad_norm

TRAIN_MICRO_BATCH_SIZE = int(cfg.grpo.train_micro_batch_size)
ROLLOUT_MB_PREF = int(cfg.grpo.rollout_micro_batch_pref)
LOGPS_MB_PREF = int(cfg.grpo.compute_logps_micro_batch_pref)
OFFLOAD_TO_CPU = bool(cfg.grpo.offload_to_cpu)

SAFE_KV_CACHE_EXTRA = int(cfg.grpo.kv_cache_extra)
GRPO_SAFETY_MAX_PROMPT_LENGTH = int(cfg.grpo.safety_max_prompt_length)
GRPO_SAFETY_MAX_GEN_STEPS = int(cfg.grpo.safety_max_generation_steps)

GRPO_MAX_TRAIN_EXAMPLES = int(cfg.grpo.max_train_examples)
GRPO_MAX_EVAL_EXAMPLES = int(cfg.grpo.max_eval_examples)

# Runtime
REQUIRE_TPU = bool(cfg.runtime.require_tpu)
MESH_FSDP = int(cfg.runtime.mesh_fsdp)
MESH_TP = int(cfg.runtime.mesh_tp)
SKIP_IF_EXPORT_EXISTS = bool(cfg.runtime.skip_training_if_export_exists)
FORCE_RETRAIN = bool(cfg.runtime.force_retrain)
SILENCE_ASYNCIO_NOISE = bool(cfg.runtime.silence_asyncio_noise)

# Inference
RUN_INFERENCE = bool(cfg.inference.enabled)
INFER_DIR = Path(cfg.inference.model_dir) if cfg.inference.model_dir else MERGED_LORA_DIR
INFER_MAX_NEW_TOKENS = int(cfg.inference.max_new_tokens)
INFER_DO_SAMPLE = bool(cfg.inference.do_sample)

print("✅ Variables assigned from config.yaml via Pydantic.")


In [None]:

# =========================
# 2) Print configuration summary (for reproducibility)
# =========================
import json
from pprint import pprint

print("=== Run Identifiers ===")
print("RUN_NAME:", RUN_NAME)
print("EXPORT_DIR:", RUN_EXPORT_DIR)
print("MERGED_LORA_DIR:", MERGED_LORA_DIR)
print("TB_DIR:", TB_DIR)
print("CKPT_DIR:", CKPT_DIR)

print("\n=== Model ===")
print("MODEL_FAMILY:", MODEL_FAMILY)
print("MODEL_ID:", MODEL_ID)
print("DTYPE:", DTYPE_STR)
print("LoRA rank/alpha:", RANK, "/", ALPHA)
print("LoRA module_path:", LORA_MODULE_PATH)

print("\n=== Data ===")
print("VAL_FRAC:", VAL_FRAC)
print("q_col/a_col/r_col/id_col:", q_col, a_col, r_col, id_col)
print("Max rows (train/eval):", cfg.data.max_train_rows, "/", cfg.data.max_eval_rows)

print("\n=== SFT ===")
print("RUN_SFT:", RUN_SFT)
print("SFT_MAX_SEQ_LEN:", SFT_MAX_SEQ_LEN)
print("SFT_GLOBAL_BATCH_SIZE:", SFT_GLOBAL_BATCH_SIZE)
print("SFT_LEARNING_RATE:", SFT_LEARNING_RATE)
print("SFT_MAX_STEPS:", SFT_MAX_STEPS)
print("SFT_EVAL_EVERY:", SFT_EVAL_EVERY)
print("SFT_MAX_EXAMPLES:", SFT_MAX_EXAMPLES)

print("\n=== GRPO ===")
print("RUN_GRPO:", RUN_GRPO)
print("MAX_PROMPT_LENGTH:", MAX_PROMPT_LENGTH)
print("GRPO_SAFETY_MAX_PROMPT_LENGTH:", GRPO_SAFETY_MAX_PROMPT_LENGTH)
print("TOTAL_GENERATION_STEPS:", TOTAL_GENERATION_STEPS)
print("GRPO_SAFETY_MAX_GEN_STEPS:", GRPO_SAFETY_MAX_GEN_STEPS)
print("TEMPERATURE/TOP_P/TOP_K:", TEMPERATURE, TOP_P, TOP_K)
print("NUM_GENERATIONS (G):", NUM_GENERATIONS)
print("NUM_ITERATIONS (μ):", NUM_ITERATIONS)
print("BETA/EPSILON:", BETA, EPSILON)
print("MAX_STEPS:", MAX_STEPS)
print("EVAL_EVERY_N_STEPS:", EVAL_EVERY_N_STEPS)
print("TRAIN_MICRO_BATCH_SIZE:", TRAIN_MICRO_BATCH_SIZE)
print("ROLLOUT_MB_PREF / LOGPS_MB_PREF:", ROLLOUT_MB_PREF, "/", LOGPS_MB_PREF)
print("OFFLOAD_TO_CPU:", OFFLOAD_TO_CPU)

print("\n=== Inference ===")
print("RUN_INFERENCE:", RUN_INFERENCE)
print("INFER_DIR:", INFER_DIR)
print("INFER_MAX_NEW_TOKENS:", INFER_MAX_NEW_TOKENS)
print("INFER_DO_SAMPLE:", INFER_DO_SAMPLE)

print("\n=== Full config (dict) ===")
pprint(cfg.model_dump())


In [None]:

# =========================
# 3) Runtime setup (versions, seeding, logging)
# =========================
import os
import random
import numpy as np

# Reduce noisy logs (esp. ipykernel asyncio)
import logging
import warnings

if SILENCE_ASYNCIO_NOISE:
    logging.getLogger("asyncio").setLevel(logging.CRITICAL)
    warnings.filterwarnings("ignore", message=r".*coroutine.*was never awaited.*")
    warnings.filterwarnings("ignore", message=r".*Task was destroyed but it is pending.*")

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    try:
        import jax
        jax.random.PRNGKey(seed)  # warms JAX RNG
    except Exception:
        pass

seed_everything(cfg.data.seed)

import jax
import jax.numpy as jnp

print("Python:", os.sys.version)
print("JAX:", jax.__version__)
print("Backend:", jax.default_backend())
print("Devices:", len(jax.devices()), jax.devices()[:1])

if REQUIRE_TPU and jax.default_backend() != "tpu":
    raise RuntimeError(
        "This run requires TPU, but JAX backend is not TPU. "
        "Enable TPU VM in Kaggle Notebook Settings, then restart the session."
    )

def show_hbm_usage() -> None:
    """Best-effort per-device memory report (works on TPU in most Kaggle images)."""
    for d in jax.devices():
        try:
            stats = d.memory_stats()
            used = stats.get("bytes_in_use", 0) / (1024**3)
            limit = stats.get("bytes_limit", 0) / (1024**3)
            print(f"{d}: {used:.2f} GiB / {limit:.2f} GiB ({(used/limit*100) if limit else 0:.1f}%)")
        except Exception:
            print(f"{d}: memory_stats() unavailable")


In [None]:

# =========================
# (Optional) IPython display helpers
# =========================
from IPython.display import display


In [None]:

# =========================
# 3.1) Dependency check (expected in Kaggle TPU images where Tunix is preinstalled)
# =========================
import importlib

required_modules = [
    "tunix",
    "qwix",
    "grain",
    "orbax.checkpoint",
    "transformers",
]
missing = []
for m in required_modules:
    try:
        importlib.import_module(m)
    except Exception:
        missing.append(m)

if missing:
    raise RuntimeError(
        "Missing required Python packages: "
        + ", ".join(missing)
        + "\nIf you're on Kaggle with internet enabled, install with:\n"
        + "  pip install -q google-tunix[prod]==0.1.3 qwix grain orbax-checkpoint transformers\n"
    )

print("✅ Dependencies look available.")


In [None]:

# =========================
# 4) Create TPU mesh (AUTO axis types) — matches Tunix demo expectations
# =========================
from jax.sharding import AxisType

assert len(jax.devices()) >= (MESH_FSDP * MESH_TP), (
    f"Need at least {MESH_FSDP*MESH_TP} devices for mesh, found {len(jax.devices())}."
)

if hasattr(jax, "make_mesh"):
    mesh = jax.make_mesh(
        (MESH_FSDP, MESH_TP),
        ("fsdp", "tp"),
        axis_types=(AxisType.Auto, AxisType.Auto),
    )
else:
    # Fallback: older JAX (axis_types may not be supported; if you hit sharding issues, upgrade JAX)
    from jax.experimental import mesh_utils
    from jax.sharding import Mesh
    mesh = Mesh(mesh_utils.create_device_mesh((MESH_FSDP, MESH_TP)), ("fsdp", "tp"))

print("✅ Mesh:", mesh)


## Data loading and preparation

In [None]:

# =========================
# 5) Locate + load training data (CSV/Parquet)
# =========================
from pathlib import Path
import pandas as pd

def find_first_file(root: Path, patterns: list[str]) -> Path | None:
    root = Path(root)
    for pat in patterns:
        hits = sorted(root.rglob(pat))
        if hits:
            return hits[0]
    return None

def load_table(path: Path) -> pd.DataFrame:
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Dataset file not found: {path}")
    if path.suffix.lower() == ".csv":
        return pd.read_csv(path)
    if path.suffix.lower() in [".parquet", ".pq"]:
        return pd.read_parquet(path)
    raise ValueError(f"Unsupported file type: {path.suffix} (use CSV or Parquet)")

# Pick train/eval paths
train_path = Path(cfg.data.train_path) if cfg.data.train_path else None
eval_path  = Path(cfg.data.eval_path)  if cfg.data.eval_path  else None

if train_path is None:
    # Auto-search common patterns in Kaggle inputs
    train_path = find_first_file(
        KAGGLE_INPUT,
        patterns=[
            "**/train.csv",
            "**/train.parquet",
            "**/training.csv",
            "**/training.parquet",
            "**/*train*.csv",
            "**/*train*.parquet",
        ],
    )

if train_path is None:
    raise FileNotFoundError(
        "Could not auto-locate a training file under /kaggle/input. "
        "Set data.train_path in config.yaml to an explicit CSV/Parquet path."
    )

train_df = load_table(train_path)
print("✅ Loaded train_df:", train_path)
print("  rows:", len(train_df), "| cols:", len(train_df.columns))
display(train_df.head(3))

# Optional: cap rows for quick pipeline validation
if cfg.data.max_train_rows and len(train_df) > cfg.data.max_train_rows:
    train_df = train_df.sample(n=cfg.data.max_train_rows, random_state=cfg.data.seed).reset_index(drop=True)
    print(f"✅ Capped train_df to max_train_rows={cfg.data.max_train_rows}. New rows:", len(train_df))

if eval_path is not None and eval_path.exists():
    eval_df = load_table(eval_path)
    print("✅ Loaded eval_df:", eval_path, "| rows:", len(eval_df))
    if cfg.data.max_eval_rows and len(eval_df) > cfg.data.max_eval_rows:
        eval_df = eval_df.sample(n=cfg.data.max_eval_rows, random_state=cfg.data.seed).reset_index(drop=True)
        print(f"✅ Capped eval_df to max_eval_rows={cfg.data.max_eval_rows}. New rows:", len(eval_df))
else:
    eval_df = None
    print("ℹ️ eval_path not provided (or missing). We'll create val split from train_df.")


In [None]:

# =========================
# 6) Infer dataset schema (q_col / a_col / r_col / id_col)
# =========================
import re

def to_str(x) -> str:
    if x is None:
        return ""
    if isinstance(x, str):
        return x
    return str(x)

def guess_column(columns: list[str], candidates: list[str]) -> str | None:
    cols = {c.lower(): c for c in columns}
    for cand in candidates:
        if cand.lower() in cols:
            return cols[cand.lower()]
    # fuzzy: contains substring
    for c in columns:
        cl = c.lower()
        for cand in candidates:
            if cand.lower() in cl:
                return c
    return None

cols = list(train_df.columns)

# Apply overrides if user set them in config
if q_col is None:
    q_col = guess_column(cols, ["question", "prompt", "query", "input", "instruction"])
if a_col is None:
    a_col = guess_column(cols, ["answer", "output", "response", "completion", "solution"])
if r_col is None:
    r_col = guess_column(cols, ["reward", "score", "label", "target"])
if id_col is None:
    id_col = guess_column(cols, ["id", "uid", "uuid", "task_id", "example_id"])

print("✅ Column mapping:")
print("  q_col:", q_col)
print("  a_col:", a_col)
print("  r_col:", r_col)
print("  id_col:", id_col)

if q_col is None:
    raise ValueError(
        "Could not infer q_col (question/prompt). "
        "Set data.q_col in config.yaml to the correct column name."
    )

# Basic sanity sample
sample_rows = train_df[[q_col] + ([a_col] if a_col else [])].head(3)
display(sample_rows)


In [None]:

# =========================
# 7) Train/val split (if eval_df not provided)
# =========================
import numpy as np

if eval_df is None:
    rng = np.random.default_rng(cfg.data.seed)
    idx = rng.permutation(len(train_df))
    n_val = max(1, int(len(idx) * VAL_FRAC))
    val_idx = idx[:n_val]
    tr_idx  = idx[n_val:]
    val_df = train_df.iloc[val_idx].reset_index(drop=True)
    train_df_split = train_df.iloc[tr_idx].reset_index(drop=True)
else:
    train_df_split = train_df
    val_df = eval_df

print("✅ Split sizes:")
print("  train_df_split:", len(train_df_split))
print("  val_df:", len(val_df))


In [None]:

# =========================
# 8) Dataset diagnostics (what will be used and how)
# =========================
from collections import Counter

def build_prompt(q: str) -> str:
    return TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=str(q))

# How much data will SFT / GRPO actually consume?
sft_n = min(SFT_MAX_EXAMPLES, len(train_df_split)) if RUN_SFT else 0
grpo_train_n = min(GRPO_MAX_TRAIN_EXAMPLES, len(train_df_split)) if RUN_GRPO else 0
grpo_val_n = min(GRPO_MAX_EVAL_EXAMPLES, len(val_df)) if RUN_GRPO else 0

print("=== Dataset usage plan ===")
print(f"Total train rows available: {len(train_df_split)}")
print(f"Total val rows available:   {len(val_df)}")
print("")
print("SFT seeding:")
print(f"  RUN_SFT={RUN_SFT}")
print(f"  Using up to: {sft_n} examples (cap: SFT_MAX_EXAMPLES={SFT_MAX_EXAMPLES})")
print(f"  Technique: supervised fine-tuning (teacher forcing) with LoRA adapters")
print("")
print("GRPO:")
print(f"  RUN_GRPO={RUN_GRPO}")
print(f"  Using up to: {grpo_train_n} train examples (cap: GRPO_MAX_TRAIN_EXAMPLES={GRPO_MAX_TRAIN_EXAMPLES})")
print(f"  Using up to: {grpo_val_n}   eval examples (cap: GRPO_MAX_EVAL_EXAMPLES={GRPO_MAX_EVAL_EXAMPLES})")
print(f"  Technique: reinforcement learning (GRPO) with num_generations={NUM_GENERATIONS}, num_iterations={NUM_ITERATIONS}")
print("")

# Prompt length sanity (character-based, cheap)
if len(train_df_split) > 0:
    sample_prompts = [build_prompt(to_str(x)) for x in train_df_split[q_col].head(min(200, len(train_df_split)))]
    lengths = np.array([len(p) for p in sample_prompts])
    print("Prompt length (chars) on a small sample:")
    print("  min/median/max:", int(lengths.min()), int(np.median(lengths)), int(lengths.max()))


## Model loading (Gemma + Tunix) and LoRA setup

In [None]:

# =========================
# 9) Locate Gemma model directory (Transformers-style)
# =========================
from pathlib import Path

def find_model_dir(root: Path, prefer_substring: str | None = None) -> Path | None:
    """
    Tries to locate a local HuggingFace Transformers model directory under `root`
    by searching for config.json and one or more *.safetensors files.
    """
    root = Path(root)
    candidates: list[Path] = []
    for cfg_path in root.rglob("config.json"):
        d = cfg_path.parent
        has_weights = any(d.glob("*.safetensors")) or (d / "model.safetensors").exists()
        has_tokenizer = (d / "tokenizer.model").exists() or (d / "tokenizer.json").exists()
        if has_weights and has_tokenizer:
            candidates.append(d)

    if not candidates:
        return None

    if prefer_substring:
        prefer = prefer_substring.lower()
        preferred = [c for c in candidates if prefer in str(c).lower()]
        if preferred:
            # prefer shortest path (more specific)
            preferred.sort(key=lambda p: len(str(p)))
            return preferred[0]

    candidates.sort(key=lambda p: len(str(p)))
    return candidates[0]

local_model_path = Path(cfg.model.local_dir) if cfg.model.local_dir else None
if local_model_path is None:
    local_model_path = find_model_dir(KAGGLE_INPUT, prefer_substring=cfg.model.prefer)

if local_model_path is None or not local_model_path.exists():
    raise FileNotFoundError(
        "Could not locate Gemma weights under /kaggle/input. "
        "Set model.local_dir in config.yaml to the exact Transformers directory "
        "(the folder containing config.json + *.safetensors + tokenizer.*)."
    )

print("✅ local_model_path:", local_model_path)

tokenizer_path = local_model_path / "tokenizer.model"
if not tokenizer_path.exists():
    # Some models ship tokenizer.json instead; tokenizer_adapter supports tokenizer.model best.
    print("⚠️ tokenizer.model not found. tokenizer_adapter may still work if tokenizer.json exists.")


In [None]:

# =========================
# 10) Load base model on TPU (Tunix + safe tensors)
# =========================
import qwix
from flax import nnx

# Tunix imports (Gemma family-specific)
if MODEL_FAMILY == "gemma3":
    from tunix.models.gemma3 import model as gemma_lib
    from tunix.models.gemma3 import params_safetensors as params_safetensors_lib
    from tunix.models.gemma3 import params as gemma_params
elif MODEL_FAMILY == "gemma2":
    from tunix.models.gemma2 import model as gemma_lib
    from tunix.models.gemma2 import params as gemma_params
    # Gemma2 typically uses orbax checkpoints in Tunix demos; safe tensors support may differ.
    raise NotImplementedError("This notebook currently focuses on gemma3 safe-tensors loading.")
else:
    raise ValueError(f"Unsupported MODEL_FAMILY={MODEL_FAMILY}")

from tunix.generate import tokenizer_adapter as tokenizer_lib

# dtype mapping
DTYPE_MAP = {
    "bfloat16": jnp.bfloat16,
    "float16": jnp.float16,
    "float32": jnp.float32,
}
DTYPE = DTYPE_MAP[DTYPE_STR]

print("\n--- HBM BEFORE model load ---")
show_hbm_usage()

# Model config for Gemma3-1B
model_config = gemma_lib.ModelConfig.gemma3_1b()

with mesh:
    base_model = params_safetensors_lib.create_model_from_safe_tensors(
        str(local_model_path),
        model_config,
        mesh,
        dtype=DTYPE,
    )

print("\n--- HBM AFTER base model load ---")
show_hbm_usage()

# Tokenizer (prefer local tokenizer.model)
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=str(tokenizer_path))

# EOS tokens
EOS_TOKENS = list(dict.fromkeys([*cfg.model.eos_tokens]))
print("EOS_TOKENS:", EOS_TOKENS)


In [None]:

# =========================
# 11) Apply LoRA (critical: use model-provided sharded inputs)
# =========================
lora_provider = qwix.LoraProvider(
    module_path=LORA_MODULE_PATH,
    rank=RANK,
    alpha=ALPHA,
)

with mesh:
    # Model-provided inputs ensure correct sharding for this model+mesh (prevents gather sharding ambiguity)
    model_input = base_model.get_model_input()
    lora_policy = qwix.apply_lora_to_model(
        base_model,
        lora_provider,
        rngs=nnx.Rngs(params=0, lora=1),
        **model_input,
    )

print("✅ LoRA model created.")
print("\n--- HBM AFTER LoRA apply ---")
show_hbm_usage()


In [None]:

# =========================
# 12) Skip logic: if exported merged model exists, skip training
# =========================
def export_is_present(model_dir: Path) -> bool:
    if not model_dir.exists():
        return False
    has_weights = any(model_dir.glob("*.safetensors")) or (model_dir / "model.safetensors").exists()
    has_cfg = (model_dir / "config.json").exists()
    has_tokenizer = (model_dir / "tokenizer.json").exists() or (model_dir / "tokenizer.model").exists()
    return bool(has_weights and has_cfg and has_tokenizer)

EXPORT_PRESENT = export_is_present(MERGED_LORA_DIR)
print("EXPORT_PRESENT:", EXPORT_PRESENT, "|", MERGED_LORA_DIR)

if FORCE_RETRAIN:
    print("⚠️ force_retrain=true: will retrain even if export exists.")


## Stage 1 — LoRA SFT (format seeding)

In [None]:
# =========================
# 13) Build SFT dataset (TrainingInput) — tokenized + masked
# =========================
import numpy as np
import grain
from typing import List

from transformers import AutoTokenizer

# --- FIX: Correct imports for Tunix SFT ---
# sft_utils is not a direct module; we import 'utils' and alias it.
from tunix.sft import peft_trainer
from tunix.sft import utils as sft_utils 

# If your dataset has no answers, SFT is not meaningful.
if a_col is None:
    print("⚠️ a_col is None — SFT will be skipped (no supervised targets).")
    sft_ds = None
else:
    # Local tokenizer (HF) for easy encoding; must be local_files_only=True in Kaggle offline mode.
    hf_tok = AutoTokenizer.from_pretrained(str(local_model_path), local_files_only=True)
    if hf_tok.pad_token_id is None:
        # Gemma often has no pad token by default; using EOS as pad is a common workaround.
        hf_tok.pad_token = hf_tok.eos_token

    def tokenize_sft(prompt_text: str, answer_text: str) -> peft_trainer.TrainingInput:
        # Full text contains prompt + answer + EOS
        prompt_ids = hf_tok(prompt_text, add_special_tokens=False)["input_ids"]
        full_text = prompt_text + to_str(answer_text) + hf_tok.eos_token
        full_ids = hf_tok(full_text, add_special_tokens=False)["input_ids"]

        # Mask: 0 for prompt, 1 for answer region
        start = min(len(prompt_ids), len(full_ids))
        mask = [0] * start + [1] * (len(full_ids) - start)

        # Truncate
        full_ids = full_ids[:SFT_MAX_SEQ_LEN]
        mask = mask[:SFT_MAX_SEQ_LEN]

        # Pad
        pad_len = SFT_MAX_SEQ_LEN - len(full_ids)
        if pad_len > 0:
            full_ids += [hf_tok.pad_token_id] * pad_len
            mask += [0] * pad_len

        return peft_trainer.TrainingInput(
            input_tokens=np.asarray(full_ids, dtype=np.int32),
            input_mask=np.asarray(mask, dtype=np.int32),
        )

    # Pick subset for SFT seeding
    # NOTE: Ensure 'train_df_split' and 'SFT_MAX_EXAMPLES' are defined above. 
    # If not, replace 'train_df_split' with 'train_df' and 'SFT_MAX_EXAMPLES' with 5000.
    sft_n = min(SFT_MAX_EXAMPLES, len(train_df_split))
    sft_rows = train_df_split.head(sft_n)

    sft_inputs: List[peft_trainer.TrainingInput] = []
    for _, row in sft_rows.iterrows():
        q = to_str(row[q_col])
        a = to_str(row[a_col])
        prompt = TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=q)
        sft_inputs.append(tokenize_sft(prompt, a))

    print(f"✅ Built SFT inputs: {len(sft_inputs)} examples | seq_len={SFT_MAX_SEQ_LEN}")

    # Grain dataset: shuffle + batch + repeat
    # Fallback to seed=42 if 'cfg' is not defined
    _seed = cfg.data.seed if 'cfg' in globals() else 42
    sft_ds = grain.MapDataset.source(sft_inputs).shuffle(seed=_seed).batch(SFT_GLOBAL_BATCH_SIZE).repeat()

In [None]:

# =========================
# 14) TRAINING: Stage 1 — LoRA SFT seeding
# =========================
import optax
from tunix.sft import metrics_logger

RUN_SFT_EFFECTIVE = RUN_SFT and (sft_ds is not None) and not (EXPORT_PRESENT and SKIP_IF_EXPORT_EXISTS and not FORCE_RETRAIN)

if not RUN_SFT_EFFECTIVE:
    print("⚠️ Skipping SFT seeding.")
    if not RUN_SFT:
        print("  - cfg.sft.enabled=false")
    elif sft_ds is None:
        print("  - sft_ds is None (likely missing answers)")
    elif EXPORT_PRESENT and SKIP_IF_EXPORT_EXISTS and not FORCE_RETRAIN:
        print("  - Exported model already present and skip_training_if_export_exists=true")
else:
    SFT_CKPT_DIR = CKPT_DIR / "sft_seed"
    SFT_CKPT_DIR.mkdir(parents=True, exist_ok=True)

    sft_logging_options = metrics_logger.MetricsLoggerOptions(
        log_dir=str(TB_DIR / "sft_seed"),
        flush_every_n_steps=10,
    )

    training_config = peft_trainer.TrainingConfig(
        eval_every_n_steps=SFT_EVAL_EVERY,
        max_steps=SFT_MAX_STEPS,
        metrics_logging_options=sft_logging_options,
        checkpoint_root_directory=str(SFT_CKPT_DIR),
    )

    def gen_model_input_fn(x: peft_trainer.TrainingInput):
        pad_mask = x.input_tokens != hf_tok.pad_token_id
        positions = sft_utils.build_positions_from_mask(pad_mask)
        attention_mask = sft_utils.make_causal_attn_mask(pad_mask)
        return {
            "input_tokens": x.input_tokens,
            "input_mask": x.input_mask,
            "positions": positions,
            "attention_mask": attention_mask,
        }

    sft_optimizer = optax.adamw(
        learning_rate=SFT_LEARNING_RATE,
        b1=SFT_B1,
        b2=SFT_B2,
        weight_decay=SFT_WEIGHT_DECAY,
    )

    sft_trainer = (
        peft_trainer.PeftTrainer(lora_policy, sft_optimizer, training_config)
        .with_gen_model_input_fn(gen_model_input_fn)
    )

    print("Starting SFT seeding...")
    print("Backend:", jax.default_backend())
    print("SFT_LEARNING_RATE:", SFT_LEARNING_RATE, "(global batch:", SFT_GLOBAL_BATCH_SIZE, ")")
    print("SFT_MAX_STEPS:", SFT_MAX_STEPS)

    with mesh:
        sft_trainer.train(sft_ds, None)

    print("✅ SFT seeding complete.")


In [None]:

# =========================
# 15) Export merged model (base + LoRA) for inference (HF-compatible directory)
# =========================
import shutil
import json
import time

def export_merged_lora(output_dir: Path) -> Path:
    """
    Writes a HuggingFace-style directory with merged LoRA weights.
    Returns the output_dir path.
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Merge base + lora and save as safetensors
    # (Gemma3 helper provided by Tunix)
    gemma_params.save_lora_merged_model_as_safetensors(
        str(local_model_path),
        str(output_dir),
        lora_policy,
        rank=RANK,
        alpha=ALPHA,
    )

    # Copy tokenizer/config artifacts expected by HF loaders
    for fname in [
        "config.json",
        "generation_config.json",
        "tokenizer.model",
        "tokenizer.json",
        "tokenizer_config.json",
        "special_tokens_map.json",
        "added_tokens.json",
    ]:
        src = local_model_path / fname
        dst = output_dir / fname
        if src.exists() and not dst.exists():
            shutil.copy2(src, dst)

    # Save run metadata
    meta = {
        "run_name": RUN_NAME,
        "model_id": MODEL_ID,
        "local_model_path": str(local_model_path),
        "dtype": DTYPE_STR,
        "lora_rank": RANK,
        "lora_alpha": ALPHA,
        "sft": cfg.sft.model_dump(),
        "grpo": cfg.grpo.model_dump(),
        "data": {
            "train_path": str(train_path),
            "eval_path": str(eval_path) if eval_path else None,
            "q_col": q_col,
            "a_col": a_col,
            "r_col": r_col,
            "id_col": id_col,
            "train_rows_used": len(train_df_split),
            "val_rows_used": len(val_df),
        },
        "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    }
    (output_dir / "tunix_run_metadata.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")

    # Also store the exact config.yaml used
    shutil.copy2(CONFIG_PATH, output_dir / "config.yaml")

    return output_dir

if cfg.export.enabled and cfg.export.save_after_sft and RUN_SFT_EFFECTIVE:
    print("Exporting merged model after SFT to:", MERGED_LORA_DIR)
    out_dir = export_merged_lora(MERGED_LORA_DIR)
    print("✅ Export complete:", out_dir)
else:
    print("ℹ️ Skipping export after SFT (either disabled or SFT didn't run).")


## Stage 2 — GRPO (reinforcement learning)

In [None]:

# =========================
# 16) Pre-GRPO dataset report (sizes, sampling, what will be used)
# =========================
print("=== Pre-GRPO dataset report ===")
print("Technique: GRPO (Group Relative Policy Optimization) on-policy RL with num_generations samples per prompt.")
print("")

print("Data sources:")
print("  train_df_split rows:", len(train_df_split))
print("  val_df rows:", len(val_df))
print("")

print("Configured caps:")
print("  GRPO_MAX_TRAIN_EXAMPLES:", GRPO_MAX_TRAIN_EXAMPLES)
print("  GRPO_MAX_EVAL_EXAMPLES:", GRPO_MAX_EVAL_EXAMPLES)
print("")

train_use = min(GRPO_MAX_TRAIN_EXAMPLES, len(train_df_split))
val_use   = min(GRPO_MAX_EVAL_EXAMPLES, len(val_df))

print("Will use:")
print("  train examples:", train_use, f"({train_use/len(train_df_split)*100:.1f}% of available)" if len(train_df_split) else "")
print("  eval examples:",  val_use,   f"({val_use/len(val_df)*100:.1f}% of available)" if len(val_df) else "")
print("")

print("Batching:")
print("  TRAIN_MICRO_BATCH_SIZE:", TRAIN_MICRO_BATCH_SIZE)
print("  NUM_GENERATIONS:", NUM_GENERATIONS, "=> full_batch_size =", TRAIN_MICRO_BATCH_SIZE * NUM_GENERATIONS)
print("  MAX_PROMPT_LENGTH:", MAX_PROMPT_LENGTH)
print("  TOTAL_GENERATION_STEPS:", TOTAL_GENERATION_STEPS)
print("")

# Show a couple of samples
if len(train_df_split) > 0:
    _tmp = train_df_split[[q_col] + ([a_col] if a_col else [])].head(2)
    display(_tmp)


In [None]:

# =========================
# 17) GRPO reward functions (minimal, configurable)
# =========================
import numpy as np
import re

def normalize_answer(s: str) -> str:
    s = (s or "").strip()
    s = re.sub(r"\s+", " ", s)
    return s

def reward_format(prompts, completions, **kwargs):
    """
    Reward +1 if completion contains something that looks like a final answer.
    You should customize this to your task (e.g., JSON format, XML tags, etc.).
    """
    rewards = []
    for c in completions:
        c = c or ""
        has_text = len(c.strip()) > 0
        rewards.append(1.0 if has_text else 0.0)
    return np.asarray(rewards, dtype=np.float32)

def reward_answer(prompts, completions, answer=None, **kwargs):
    """
    If ground-truth answers are present, reward exact normalized match.
    If no answer column exists, returns zeros.
    """
    if answer is None:
        return np.zeros(len(completions), dtype=np.float32)

    rewards = []
    for c, a in zip(completions, answer):
        c_norm = normalize_answer(c)
        a_norm = normalize_answer(a)
        rewards.append(1.0 if (a_norm and a_norm in c_norm) else 0.0)
    return np.asarray(rewards, dtype=np.float32)

REWARD_FNS = [reward_format, reward_answer]
print("✅ Reward functions registered:", [fn.__name__ for fn in REWARD_FNS])


In [None]:
# =========================
# 18) TRAINING: Stage 2 — GRPO (batch-safe + OOM-safe defaults)
# =========================
import math
import numpy as np
import grain
import optax
import orbax.checkpoint as ocp

from tunix.sft import metrics_logger

from tunix.rl import rl_cluster as rl_cluster_lib
# --- FIX: Import base_rollout from the correct submodule ---
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GRPOLearner, GRPOConfig

def pick_divisor_microbatch(full_batch_size: int, preferred: int) -> int:
    """Largest <= preferred divisor of full_batch_size."""
    preferred = max(1, int(preferred))
    for mb in range(preferred, 0, -1):
        if full_batch_size % mb == 0:
            return mb
    return 1

RUN_GRPO_EFFECTIVE = RUN_GRPO and not (EXPORT_PRESENT and SKIP_IF_EXPORT_EXISTS and not FORCE_RETRAIN)

if not RUN_GRPO_EFFECTIVE:
    print("⚠️ Skipping GRPO training.")
    if not RUN_GRPO:
        print("  - cfg.grpo.enabled=false")
    elif EXPORT_PRESENT and SKIP_IF_EXPORT_EXISTS and not FORCE_RETRAIN:
        print("  - Exported model already present and skip_training_if_export_exists=true")
else:
    # ------------------------------------------------------------
    # A) Build train_dataset / val_dataset for GRPO (finite + repeat)
    # ------------------------------------------------------------
    # Cap examples for quick testing
    train_use = min(GRPO_MAX_TRAIN_EXAMPLES, len(train_df_split))
    val_use   = min(GRPO_MAX_EVAL_EXAMPLES, len(val_df))

    def build_prompt(q: str) -> str:
        return TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=str(q))

    def make_example_from_row(row) -> dict:
        q = to_str(row[q_col])
        a = to_str(row[a_col]) if a_col is not None else ""
        return {"prompts": build_prompt(q), "question": q, "answer": a}

    train_examples = [make_example_from_row(r) for _, r in train_df_split.head(train_use).iterrows()]
    val_examples   = [make_example_from_row(r) for _, r in val_df.head(val_use).iterrows()]

    # Use cfg.data.seed if available, else default to 42
    _seed = cfg.data.seed if 'cfg' in globals() else 42
    train_base = grain.MapDataset.source(train_examples).shuffle(seed=_seed).batch(TRAIN_MICRO_BATCH_SIZE)
    val_base   = grain.MapDataset.source(val_examples).batch(TRAIN_MICRO_BATCH_SIZE)

    steps_per_epoch = len(train_base)
    need_epochs = max(1, math.ceil(MAX_STEPS / max(1, steps_per_epoch)) + 1)

    train_dataset = train_base.repeat(need_epochs)
    val_dataset   = val_base.repeat(max(1, math.ceil((MAX_STEPS / max(1, len(val_base))) * 0.2) + 1))

    print("✅ Built GRPO datasets")
    print("  train examples:", len(train_examples), "| batches/epoch:", len(train_base))
    print("  val examples:", len(val_examples), "| batches:", len(val_base))
    print("  MAX_STEPS:", MAX_STEPS, "| repeated epochs:", need_epochs)

    # ------------------------------------------------------------
    # B) Length clamp (compile-time safety)
    # ------------------------------------------------------------
    SAFE_MAX_PROMPT_LENGTH = min(int(MAX_PROMPT_LENGTH), int(GRPO_SAFETY_MAX_PROMPT_LENGTH))
    SAFE_GEN_STEPS = min(int(TOTAL_GENERATION_STEPS), int(GRPO_SAFETY_MAX_GEN_STEPS))
    SAFE_KV_CACHE = SAFE_MAX_PROMPT_LENGTH + SAFE_GEN_STEPS + SAFE_KV_CACHE_EXTRA

    print("GRPO length clamp:")
    print("  MAX_PROMPT_LENGTH:", MAX_PROMPT_LENGTH, "->", SAFE_MAX_PROMPT_LENGTH)
    print("  TOTAL_GENERATION_STEPS:", TOTAL_GENERATION_STEPS, "->", SAFE_GEN_STEPS)
    print("  kv_cache_size:", SAFE_KV_CACHE)

    # ------------------------------------------------------------
    # C) Batch sizing constraints (divisibility)
    # ------------------------------------------------------------
    full_batch_size = int(TRAIN_MICRO_BATCH_SIZE) * int(NUM_GENERATIONS)

    roll_mb = pick_divisor_microbatch(full_batch_size, preferred=ROLLOUT_MB_PREF)
    logp_mb = pick_divisor_microbatch(full_batch_size, preferred=LOGPS_MB_PREF)

    print("GRPO batch sizing:")
    print("  TRAIN_MICRO_BATCH_SIZE:", TRAIN_MICRO_BATCH_SIZE)
    print("  NUM_GENERATIONS:", NUM_GENERATIONS)
    print("  full_batch_size:", full_batch_size)
    print("  rollout_micro_batch_size:", roll_mb)
    print("  compute_logps_micro_batch_size:", logp_mb)

    # ------------------------------------------------------------
    # D) Optimizer + configs
    # ------------------------------------------------------------
    checkpointing_options = ocp.CheckpointManagerOptions(
        save_interval_steps=int(cfg.export.save_interval_steps),
        max_to_keep=int(cfg.export.max_to_keep),
    )
    metrics_logging_options = metrics_logger.MetricsLoggerOptions(
        log_dir=str(TB_DIR / "grpo"),
        flush_every_n_steps=10,
    )

    lr_schedule = optax.schedules.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        decay_steps=MAX_STEPS,
        end_value=0.0,
    )

    optimizer = optax.adamw(
        learning_rate=lr_schedule,
        b1=B1,
        b2=B2,
        weight_decay=WEIGHT_DECAY,
    )
    if MAX_GRAD_NORM is not None:
        optimizer = optax.chain(optax.clip_by_global_norm(MAX_GRAD_NORM), optimizer)

    cluster_config = rl_cluster_lib.ClusterConfig(
        role_to_mesh={
            rl_cluster_lib.Role.ACTOR: mesh,
            rl_cluster_lib.Role.REFERENCE: mesh,
            rl_cluster_lib.Role.ROLLOUT: mesh,
        },
        rollout_engine="vanilla",
        offload_to_cpu=OFFLOAD_TO_CPU,
        training_config=rl_cluster_lib.RLTrainingConfig(
            actor_optimizer=optimizer,
            eval_every_n_steps=EVAL_EVERY_N_STEPS,
            max_steps=MAX_STEPS,
            mini_batch_size=TRAIN_MICRO_BATCH_SIZE,
            train_micro_batch_size=TRAIN_MICRO_BATCH_SIZE,
            rollout_micro_batch_size=roll_mb,
            compute_logps_micro_batch_size=logp_mb,
            data_sharding_axis=("fsdp",),
            metrics_logging_options=metrics_logging_options,
            checkpoint_root_directory=str(CKPT_DIR / "grpo"),
            checkpointing_options=checkpointing_options,
        ),
        rollout_config=base_rollout.RolloutConfig(
            max_tokens_to_generate=SAFE_GEN_STEPS,
            max_prompt_length=SAFE_MAX_PROMPT_LENGTH,
            kv_cache_size=SAFE_KV_CACHE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            eos_tokens=EOS_TOKENS,
        ),
    )

    algo_config = GRPOConfig(
        num_generations=NUM_GENERATIONS,
        num_iterations=NUM_ITERATIONS,
        beta=BETA,
        epsilon=EPSILON,
    )

    rl_cluster = rl_cluster_lib.RLCluster(
        actor=lora_policy,
        reference=base_model,
        tokenizer=tokenizer,
        cluster_config=cluster_config,
    )

    grpo_trainer = GRPOLearner(
        rl_cluster=rl_cluster,
        reward_fns=REWARD_FNS,
        algo_config=algo_config,
    )

    print("Starting GRPO training...")
    print("  Note: first step includes XLA compilation and is typically the slowest.")
    with mesh:
        grpo_trainer.train(train_dataset, val_dataset)

    print("✅ GRPO training complete.")

In [None]:
print('is running')

In [None]:

# =========================
# 19) Export merged model after GRPO (final artifact)
# =========================
if cfg.export.enabled and cfg.export.save_after_grpo and RUN_GRPO_EFFECTIVE:
    print("Exporting merged model after GRPO to:", MERGED_LORA_DIR)
    out_dir = export_merged_lora(MERGED_LORA_DIR)
    print("✅ Export complete:", out_dir)
else:
    print("ℹ️ Skipping export after GRPO (either disabled or GRPO didn't run).")

# Refresh export-present flag (useful if you ran training now)
EXPORT_PRESENT = export_is_present(MERGED_LORA_DIR)
print("EXPORT_PRESENT:", EXPORT_PRESENT, "|", MERGED_LORA_DIR)


## Metrics and evaluation

In [None]:

# =========================
# 20) Where are metrics logged?
# =========================
import glob

print("TensorBoard root:", TB_DIR)
event_files = glob.glob(str(TB_DIR / "**" / "events.*"), recursive=True)
print("Found event files:", len(event_files))
for f in event_files[:10]:
    print(" -", f)

print("\nSFT metrics are written to:", TB_DIR / "sft_seed")
print("GRPO metrics are written to:", TB_DIR / "grpo")
print("\nGRPO evaluation runs every EVAL_EVERY_N_STEPS =", EVAL_EVERY_N_STEPS)


## Inference (load the saved merged model and run predictions)

In [None]:

# =========================
# 21) Inference from merged model directory (no retraining required)
# =========================
from pathlib import Path

if not export_is_present(INFER_DIR):
    print("⚠️ Inference skipped: model directory not found or incomplete:", INFER_DIR)
else:
    print("✅ Loading model for inference from:", INFER_DIR)

    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM

    tok = AutoTokenizer.from_pretrained(str(INFER_DIR), local_files_only=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    # CPU inference (Kaggle TPU VMs typically don't have a CUDA GPU)
    model = AutoModelForCausalLM.from_pretrained(
        str(INFER_DIR),
        local_files_only=True,
        torch_dtype=torch.float32,
        device_map="cpu",
    )
    model.eval()

    def generate_one(question: str) -> str:
        prompt = TEMPLATE.format(system_prompt=SYSTEM_PROMPT, question=question)
        inputs = tok(prompt, return_tensors="pt")
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=INFER_MAX_NEW_TOKENS,
                do_sample=INFER_DO_SAMPLE,
                temperature=TEMPERATURE if INFER_DO_SAMPLE else None,
                top_p=TOP_P if INFER_DO_SAMPLE else None,
                top_k=TOP_K if INFER_DO_SAMPLE else None,
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )
        return tok.decode(out[0], skip_special_tokens=True)

    # Quick smoke test
    test_q = to_str(train_df_split[q_col].iloc[0]) if len(train_df_split) else "What is 2+2?"
    print("\n--- Sample generation ---")
    print(generate_one(test_q))


## Troubleshooting and practical tips

Common issues and the most effective fixes:

- **GRPO seems “stuck” at the beginning**  
  The first step includes XLA compilation for multiple graphs (rollout + log-prob computation). Reduce:
  - `grpo.max_prompt_length`
  - `grpo.total_generation_steps`
  - `grpo.num_generations`
  - `grpo.train_micro_batch_size`

- **HBM OOM during reference log-prob computation**  
  Set:
  - `grpo.compute_logps_micro_batch_pref: 1`
  - `grpo.offload_to_cpu: true` (slower, but reduces TPU HBM usage)

- **You do not want to retrain**  
  If `export/<run_name>/<merged_subdir>/` exists, keep:
  - `runtime.skip_training_if_export_exists: true`  
  and simply run the **Inference** section.

- **Want to force a new run**  
  Set:
  - `runtime.force_retrain: true`


In [None]:
print('is running')

In [None]:
# =========================
# 22) Enforce required output format: <reasoning>...</reasoning><answer>...</answer>
# =========================
import re

RE_STRICT = re.compile(r"^<reasoning>.*?</reasoning><answer>.*?</answer>\s*$", re.DOTALL)

def sanitize_to_xml(text: str) -> str:
    t = (text or "").strip()

    # If already compliant
    if RE_STRICT.match(t):
        return t

    # Try to extract "Answer:" and "Explanation:" style outputs
    # 1) Answer
    ans = ""
    m_ans = re.search(r"(?:^|\n)Answer:\s*(.*?)(?:\n|$)", t, re.IGNORECASE)
    if m_ans:
        ans = m_ans.group(1).strip()

    # 2) Explanation/Reasoning
    reason = ""
    m_exp = re.search(r"(?:^|\n)Explanation:\s*(.*)$", t, re.IGNORECASE | re.DOTALL)
    if m_exp:
        reason = m_exp.group(1).strip()

    # If still missing, fallback heuristics
    if not ans:
        # last non-empty line as answer
        lines = [x.strip() for x in t.splitlines() if x.strip()]
        ans = lines[-1] if lines else "N/A"

    if not reason:
        # everything except last line as reasoning
        lines = [x.strip() for x in t.splitlines() if x.strip()]
        reason = "\n".join(lines[:-1]).strip() if len(lines) > 1 else "I will reason step by step."

    # Ensure non-empty answer
    if not ans.strip():
        ans = "N/A"

    return f"<reasoning>{reason}</reasoning><answer>{ans}</answer>"

# Quick check on your last generation
sample_out = generate_one(test_q)
print("RAW:\n", sample_out[:500], "...\n")
print("SANITIZED:\n", sanitize_to_xml(sample_out)[:500], "...\n")

In [None]:
# =========================
# 22.4) Inspect /kaggle/input for submission/test files
# =========================
from pathlib import Path

KAGGLE_INPUT = Path("/kaggle/input")

csvs = sorted(KAGGLE_INPUT.rglob("*.csv"))
jsons = sorted(KAGGLE_INPUT.rglob("*.json"))
jsonls = sorted(KAGGLE_INPUT.rglob("*.jsonl"))
parquets = sorted(KAGGLE_INPUT.rglob("*.parquet"))

print("CSV files found:", len(csvs))
for p in csvs[:200]:
    print(" -", p)

print("\nJSON files found:", len(jsons))
for p in jsons[:50]:
    print(" -", p)

print("\nJSONL files found:", len(jsonls))
for p in jsonls[:50]:
    print(" -", p)

print("\nParquet files found:", len(parquets))
for p in parquets[:50]:
    print(" -", p)

In [None]:
# =========================
# 22.5) Load test_df + sample_sub (robust, nonstandard filenames)
# =========================
from pathlib import Path
import pandas as pd
import re

KAGGLE_INPUT = Path("/kaggle/input")

def find_candidates(ext="csv"):
    return sorted(KAGGLE_INPUT.rglob(f"*.{ext}"))

def looks_like_sample_submission(df: pd.DataFrame) -> bool:
    if df is None or df.shape[1] < 2:
        return False
    cols = [c.lower() for c in df.columns]
    # heuristics: one id-like col + one pred-like col
    id_like = any(any(k in c for k in ["id", "qid", "uid", "index"]) for c in cols)
    return id_like

# ---- Find a test file ----
test_path = None
for p in find_candidates("csv"):
    n = p.name.lower()
    if "test" in n and "train" not in n and "submission" not in n:
        test_path = p
        break

if test_path is None:
    raise FileNotFoundError("Could not auto-find a test CSV under /kaggle/input. Use the listing cell output to set test_path manually.")

print("✅ test_df:", test_path)
test_df = pd.read_csv(test_path)

# ---- Try to find a sample submission (any CSV with id-like cols) ----
sample_sub = None
sample_path = None
for p in find_candidates("csv"):
    n = p.name.lower()
    if "submission" in n or "sample" in n:
        try:
            df = pd.read_csv(p)
            if looks_like_sample_submission(df):
                sample_sub = df
                sample_path = p
                break
        except Exception:
            pass

if sample_sub is not None:
    print("✅ sample_sub:", sample_path)
    print("sample_sub columns:", list(sample_sub.columns))
else:
    print("⚠️ No sample submission file found. Will create one from test_df.")
    # Try to find an id column in test_df
    id_col_guess = None
    for c in test_df.columns:
        if c.lower() in ["id", "qid", "uid", "index"]:
            id_col_guess = c
            break
    if id_col_guess is None:
        # fallback: create an id as row index
        test_df = test_df.reset_index().rename(columns={"index": "id"})
        id_col_guess = "id"

    # Create a default submission template with a placeholder prediction col name
    sample_sub = pd.DataFrame({id_col_guess: test_df[id_col_guess].values, "prediction": ""})
    print("Created sample_sub with columns:", list(sample_sub.columns))

print("\n--- test_df columns ---")
print(list(test_df.columns))
print(test_df.head(2))

# ---- Infer q_col if needed ----
def infer_question_col(df):
    candidates = ["question", "prompt", "query", "problem", "input", "text", "instruction"]
    cols = df.columns
    for c in candidates:
        if c in cols:
            return c
    for c in cols:
        lc = c.lower()
        if any(k in lc for k in candidates):
            return c
    # fallback: longest text-like col
    best, best_len = None, -1
    for c in cols:
        s = df[c]
        if s.dtype == "object":
            avg = s.dropna().astype(str).head(200).map(len).mean()
            if avg > best_len:
                best_len = avg
                best = c
    return best

if "q_col" not in globals() or q_col is None or q_col not in test_df.columns:
    q_col = infer_question_col(test_df)
    print("\n✅ Inferred q_col for test_df:", q_col)

assert q_col in test_df.columns, f"q_col='{q_col}' not found in test_df columns."

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# =========================
# Zip merged_lora/ -> merged_lora.zip
# =========================
from pathlib import Path
import os, shutil, subprocess

MERGED_DIR = Path("/kaggle/working/export/google_gemma-3-1b-it_tunix_sft_grpo/merged_lora")
ZIP_BASENAME = Path("/kaggle/working/merged_lora")  # .zip will be appended

required = ["config.json", "model.safetensors", "tokenizer.model"]
missing = [f for f in required if not (MERGED_DIR / f).exists()]
if missing:
    raise FileNotFoundError(f"Missing required files in {MERGED_DIR}: {missing}")

# Remove old zip if present
zip_path = ZIP_BASENAME.with_suffix(".zip")
if zip_path.exists():
    zip_path.unlink()

# Create zip (shutil.make_archive creates merged_lora.zip containing the folder)
archive_path = shutil.make_archive(str(ZIP_BASENAME), "zip", root_dir=str(MERGED_DIR.parent), base_dir=MERGED_DIR.name)
print("✅ Created zip:", archive_path)
print("Size (MB):", round(Path(archive_path).stat().st_size / (1024**2), 2))

In [None]:
# =========================
# 23) Batch inference → predictions.csv (id + prediction)
# =========================
import pandas as pd
from tqdm import tqdm
import re

# Ensure test_df and sample_sub exist
assert test_df is not None and "question" in test_df.columns and "id" in test_df.columns
assert sample_sub is not None and "id" in sample_sub.columns and "prediction" in sample_sub.columns

RE_STRICT = re.compile(r"^<reasoning>.*?</reasoning><answer>.*?</answer>\s*$", re.DOTALL)

def sanitize_to_xml(text: str) -> str:
    t = (text or "").strip()

    # already compliant
    if RE_STRICT.match(t):
        return t

    # Try to capture "Answer:" and "Explanation:" style
    ans = ""
    m_ans = re.search(r"(?:^|\n)Answer:\s*(.*?)(?:\n|$)", t, re.IGNORECASE)
    if m_ans:
        ans = m_ans.group(1).strip()

    reason = ""
    m_exp = re.search(r"(?:^|\n)Explanation:\s*(.*)$", t, re.IGNORECASE | re.DOTALL)
    if m_exp:
        reason = m_exp.group(1).strip()

    # If model used GSM8K-style #### final
    if not ans:
        m_hash = re.findall(r"####\s*(.+)", t)
        if m_hash:
            ans = m_hash[-1].strip()

    # Fallback: last non-empty line
    if not ans:
        lines = [x.strip() for x in t.splitlines() if x.strip()]
        ans = lines[-1] if lines else "N/A"

    if not reason:
        # everything before the last line (best effort)
        lines = [x.strip() for x in t.splitlines() if x.strip()]
        reason = "\n".join(lines[:-1]).strip() if len(lines) > 1 else "I will reason step by step."

    if not ans.strip():
        ans = "N/A"

    return f"<reasoning>{reason}</reasoning><answer>{ans}</answer>"

# Generate predictions
preds = []
for q in tqdm(test_df["question"].astype(str).tolist(), desc="Generating"):
    raw = generate_one(q)  # from your inference cell
    preds.append(sanitize_to_xml(raw))

submission = sample_sub.copy()
submission["prediction"] = preds

out_path = "/kaggle/working/predictions.csv"
submission.to_csv(out_path, index=False)
print("✅ Wrote:", out_path)
submission.head()

In [None]:
## Load model

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
# =========================
# 0) Settings
# =========================-
from pathlib import Path

# If you attached your merged model as a Kaggle Dataset, it will be under /kaggle/input/<dataset-slug>/merged_lora
# Leave MODEL_DIR=None to auto-discover.
MODEL_DIR = None  # e.g. Path("/kaggle/input/my-merged-model-dataset/merged_lora")

# Inference knobs (keep small-ish for speed on CPU)
INFER_BATCH_SIZE = 4
INFER_MAX_NEW_TOKENS = 256
INFER_DO_SAMPLE = False   # deterministic
INFER_TEMPERATURE = 0.7
INFER_TOP_P = 0.95
INFER_TOP_K = 50

# Prompt
SYSTEM_PROMPT = "You are a helpful assistant."

# Output
SUBMISSION_PATH = Path("/kaggle/working/submission.csv")

In [3]:
# =========================
# 1) Find merged model dir + verify it's complete
# =========================
from pathlib import Path

KAGGLE_INPUT = Path("/kaggle/input")

def find_hf_model_dirs(root: Path):
    """Return candidate HF model dirs under root that look like (config.json + weights + tokenizer)."""
    cands = []
    for p in root.rglob("*"):
        if not p.is_dir():
            continue
        has_config = (p / "config.json").exists()
        has_weights = (p / "model.safetensors").exists() or (p / "model.safetensors.index.json").exists()
        has_tok = any((p / f).exists() for f in ["tokenizer.json", "tokenizer.model", "spiece.model"])
        if has_config and has_weights and has_tok:
            cands.append(p)
    # prefer shorter paths (more likely the intended directory)
    return sorted(cands, key=lambda x: (len(str(x)), str(x)))

def verify_hf_model_dir(p: Path) -> tuple[bool, list[str]]:
    missing = []
    if not p.exists():
        return False, [f"Directory does not exist: {p}"]
    if not (p / "config.json").exists():
        missing.append("config.json")
    has_weights = (p / "model.safetensors").exists() or (p / "model.safetensors.index.json").exists()
    if not has_weights:
        missing.append("model.safetensors OR model.safetensors.index.json (+ shards)")
    has_tokenizer = any((p / f).exists() for f in ["tokenizer.json", "tokenizer.model", "spiece.model"])
    if not has_tokenizer:
        missing.append("tokenizer.json OR tokenizer.model/spiece.model")
    return (len(missing) == 0), missing

if MODEL_DIR is None:
    cands = find_hf_model_dirs(KAGGLE_INPUT)
    if not cands:
        raise FileNotFoundError(
            "Could not find a HF model directory under /kaggle/input.\n"
            "Attach your merged_lora dataset, then rerun."
        )
    MODEL_DIR = cands[0]

MODEL_DIR = Path(MODEL_DIR)
ok, missing = verify_hf_model_dir(MODEL_DIR)

print("MODEL_DIR:", MODEL_DIR)
print("✅ Model folder looks usable." if ok else "❌ Model folder incomplete:")
for m in missing:
    print(" -", m)

assert ok, "Fix MODEL_DIR (or dataset contents) and rerun."

MODEL_DIR: /kaggle/input/gemma-3/transformers/gemma-3-1b-it/1
✅ Model folder looks usable.


In [4]:
# =========================
# 2) Load tokenizer + model (CPU/GPU auto)
# =========================
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

print("Device:", device, "| dtype:", dtype)

tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True, use_fast=True)
if tok.pad_token_id is None:
    tok.pad_token = tok.eos_token

# NOTE: transformers recently warns torch_dtype deprecated -> dtype; handle both.
model_kwargs = dict(local_files_only=True)
try:
    model = AutoModelForCausalLM.from_pretrained(str(MODEL_DIR), dtype=dtype, **model_kwargs)
except TypeError:
    model = AutoModelForCausalLM.from_pretrained(str(MODEL_DIR), torch_dtype=dtype, **model_kwargs)

model.to(device)
model.eval()

print("✅ Loaded model + tokenizer.")



Device: cpu | dtype: torch.float32




✅ Loaded model + tokenizer.


In [5]:
# =========================
# 3) Prompt + batched generation helpers (chat_template-safe)
# =========================
import torch

def build_prompt(question: str) -> str:
    """
    Robust prompt builder:
    - Tries tokenizer.chat_template with structured content (Gemma-style expects .text)
    - Falls back to plain-string chat_template
    - Falls back to a simple prompt
    """
    question = str(question)

    # Try chat template if available
    if getattr(tok, "chat_template", None):
        # 1) Structured content (most compatible with Gemma templates)
        messages_structured = [
            {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {"role": "user", "content": [{"type": "text", "text": question}]},
        ]
        try:
            return tok.apply_chat_template(
                messages_structured,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            # 2) Plain string content (some templates accept this)
            messages_plain = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": question},
            ]
            try:
                return tok.apply_chat_template(
                    messages_plain,
                    tokenize=False,
                    add_generation_prompt=True,
                )
            except Exception:
                pass

    # 3) Hard fallback (always works)
    return f"{SYSTEM_PROMPT}\nQuestion: {question}\nAnswer:"

@torch.inference_mode()
def generate_batch(questions):
    """
    Batched generation that returns ONLY the completion (not the prompt).
    Works on CPU or GPU depending on `device`.
    """
    prompts = [build_prompt(q) for q in questions]

    inputs = tok(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(device)

    # prompt lengths per row (count of non-pad tokens)
    prompt_lens = inputs["attention_mask"].sum(dim=1).tolist()

    gen_kwargs = dict(
        max_new_tokens=int(INFER_MAX_NEW_TOKENS),
        do_sample=bool(INFER_DO_SAMPLE),
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )
    if INFER_DO_SAMPLE:
        gen_kwargs.update(dict(
            temperature=float(INFER_TEMPERATURE),
            top_p=float(INFER_TOP_P),
            top_k=int(INFER_TOP_K),
        ))

    out = model.generate(**inputs, **gen_kwargs)

    # Decode only newly generated tokens (not the prompt)
    texts = []
    for i in range(out.shape[0]):
        start = int(prompt_lens[i])
        completion_ids = out[i][start:]
        texts.append(tok.decode(completion_ids, skip_special_tokens=True).strip())
    return texts

# Smoke test
print(generate_batch(["What is 2+2?"])[0])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


2 + 2 = 4
roneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneroneronerone


In [6]:
# =========================
# 4) Load test_df + sample_sub (robust)
# =========================
import pandas as pd
from pathlib import Path

KAGGLE_INPUT = Path("/kaggle/input")

# Find sample_submission if it exists
sample_paths = sorted(KAGGLE_INPUT.rglob("sample_submission.csv"))
sample_sub = None
sample_path = None
if sample_paths:
    sample_path = sample_paths[0]
    sample_sub = pd.read_csv(sample_path)
    print("✅ sample_submission:", sample_path)
else:
    print("⚠️ sample_submission.csv not found. Will create a template from test_df.")

# Find a test CSV (common in competitions)
# Prefer any file named test.csv; else any csv with "test" and not "train"
csvs = sorted(KAGGLE_INPUT.rglob("*.csv"))
test_path = None
for p in csvs:
    if p.name.lower() == "test.csv":
        test_path = p
        break
if test_path is None:
    for p in csvs:
        n = p.name.lower()
        if "test" in n and "train" not in n and "submission" not in n:
            test_path = p
            break

if test_path is None:
    raise FileNotFoundError("Could not locate a test CSV under /kaggle/input.")

test_df = pd.read_csv(test_path)
print("✅ test_df:", test_path)
print("test_df columns:", list(test_df.columns))
display(test_df.head(2))

# Infer q_col
def infer_question_col(df: pd.DataFrame) -> str:
    candidates = ["question", "prompt", "query", "problem", "input", "text", "instruction"]
    cols = list(df.columns)
    lower_map = {c.lower(): c for c in cols}

    # exact matches
    for c in candidates:
        if c in lower_map:
            return lower_map[c]

    # substring matches
    for c in cols:
        lc = c.lower()
        if any(k in lc for k in candidates):
            return c

    # fallback: longest text-like
    best, best_len = None, -1
    for c in cols:
        s = df[c]
        if s.dtype == "object":
            avg = s.dropna().astype(str).head(200).map(len).mean()
            if avg > best_len:
                best_len = avg
                best = c
    if best is None:
        raise ValueError("Could not infer a question/prompt column. Set q_col manually.")
    return best

q_col = infer_question_col(test_df)
print("✅ q_col:", q_col)

# Build sample_sub if missing
if sample_sub is None:
    # Prefer an id column if present
    id_col = None
    for c in test_df.columns:
        if c.lower() in ["id", "qid", "uid", "index"]:
            id_col = c
            break
    if id_col is None:
        test_df = test_df.reset_index().rename(columns={"index": "id"})
        id_col = "id"

    sample_sub = pd.DataFrame({id_col: test_df[id_col].values, "prediction": ""})
    sample_path = None
    print("✅ Created sample_sub with columns:", list(sample_sub.columns))

print("sample_sub columns:", list(sample_sub.columns))
display(sample_sub.head(2))

⚠️ sample_submission.csv not found. Will create a template from test_df.
✅ test_df: /kaggle/input/grade-school-math-8k-q-a/main_test.csv
test_df columns: ['question', 'answer']


Unnamed: 0,question,answer
0,Janet’s ducks lay 16 eggs per day. She eats th...,Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eg...
1,A robe takes 2 bolts of blue fiber and half th...,It takes 2/2=<<2/2=1>>1 bolt of white fiber\nS...


✅ q_col: question
✅ Created sample_sub with columns: ['id', 'prediction']
sample_sub columns: ['id', 'prediction']


Unnamed: 0,id,prediction
0,0,
1,1,


In [None]:
# =========================
# 5) Batch inference -> submission.csv
# =========================
from tqdm import tqdm
import math

sub_cols = list(sample_sub.columns)
id_col_sub = sub_cols[0]
pred_col_sub = sub_cols[1] if len(sub_cols) > 1 else sub_cols[0]

questions = test_df[q_col].astype(str).tolist()

preds = []
for i in tqdm(range(0, len(questions), INFER_BATCH_SIZE), desc="Generating"):
    batch_q = questions[i:i+INFER_BATCH_SIZE]
    batch_out = generate_batch(batch_q)
    preds.extend(batch_out)

submission = sample_sub.copy()

# If sample_sub has same row count as test, fill directly. Otherwise align by id if possible.
if len(submission) == len(preds):
    submission[pred_col_sub] = preds
else:
    # safer: create a fresh dataframe
    print("⚠️ sample_sub rows != test rows. Writing a new submission with id alignment.")
    if id_col_sub in test_df.columns:
        submission = pd.DataFrame({id_col_sub: test_df[id_col_sub].values, pred_col_sub: preds})
    else:
        submission = pd.DataFrame({id_col_sub: list(range(len(preds))), pred_col_sub: preds})

submission.to_csv(SUBMISSION_PATH, index=False)
print("✅ Wrote:", SUBMISSION_PATH)
display(submission.head(10))

Generating:  62%|██████▏   | 206/330 [1:46:35<1:04:09, 31.05s/it]


KeyboardInterrupt: 

In [None]:
# =========================
# 6) OPTIONAL: Submit via Kaggle CLI + check latest score
# =========================
import os, subprocess
from pathlib import Path

def detect_competition_slug():
    # Heuristic: competition slug is usually the directory under /kaggle/input that contains sample_submission.csv
    sample_paths = sorted(Path("/kaggle/input").rglob("sample_submission.csv"))
    if not sample_paths:
        return None
    # /kaggle/input/<slug>/.../sample_submission.csv -> slug is parts[3]
    return sample_paths[0].parts[3]

COMPETITION = detect_competition_slug()
print("Detected COMPETITION slug:", COMPETITION)

# Check if kaggle CLI exists
try:
    subprocess.run(["kaggle", "--version"], check=True, capture_output=True, text=True)
    kaggle_cli_ok = True
except Exception as e:
    kaggle_cli_ok = False
    print("❌ kaggle CLI not available:", e)

# Check credentials
kaggle_json = Path.home() / ".kaggle" / "kaggle.json"
has_creds = kaggle_json.exists() or (os.environ.get("KAGGLE_USERNAME") and os.environ.get("KAGGLE_KEY"))

print("Kaggle credentials present:", bool(has_creds))

if (not kaggle_cli_ok) or (not has_creds) or (COMPETITION is None):
    print("⚠️ Cannot auto-submit from notebook.")
    print("   You can still submit manually in Kaggle UI using:", SUBMISSION_PATH)
else:
    msg = "Inference submission from merged_lora"
    print("Submitting...")
    subprocess.run(
        ["kaggle", "competitions", "submit", "-c", COMPETITION, "-f", str(SUBMISSION_PATH), "-m", msg],
        check=True
    )
    print("✅ Submitted.")

    print("\nLatest submissions (top):")
    # This prints table including scores when available
    subprocess.run(["kaggle", "competitions", "submissions", "-c", COMPETITION], check=True)

In [None]:
# =========================
# 7) OPTIONAL: Local evaluation (if 'answer' exists)
# =========================
import re

def extract_final_number(text: str):
    if text is None:
        return None
    s = str(text)

    # GSM8K often uses "#### <answer>"
    if "####" in s:
        tail = s.split("####")[-1].strip()
        m = re.search(r"[-+]?\d+(?:\.\d+)?", tail.replace(",", ""))
        return m.group(0) if m else None

    # fallback: last number anywhere
    nums = re.findall(r"[-+]?\d+(?:\.\d+)?", s.replace(",", ""))
    return nums[-1] if nums else None

if "answer" not in test_df.columns:
    print("⚠️ No 'answer' column in test_df; cannot do local accuracy here.")
else:
    N = min(50, len(test_df))
    qs = test_df[q_col].astype(str).head(N).tolist()
    gold = test_df["answer"].astype(str).head(N).tolist()

    pred_texts = []
    for i in tqdm(range(0, N, INFER_BATCH_SIZE), desc="Eval generating"):
        pred_texts.extend(generate_batch(qs[i:i+INFER_BATCH_SIZE]))

    pred_nums = [extract_final_number(t) for t in pred_texts]
    gold_nums = [extract_final_number(t) for t in gold]

    correct = 0
    for p, g in zip(pred_nums, gold_nums):
        if (p is not None) and (g is not None) and (p == g):
            correct += 1

    acc = correct / N if N else 0.0
    print(f"✅ Local numeric exact-match on {N} examples: {acc:.3f} ({correct}/{N})")