# VAZHI SFT v4.0 — Instruction Fine-Tuning on DAPT v1.1

**Pipeline Step 3 of 3:** SFT on the DAPT-adapted Tamil instruct model.

```
Step 1: Data Prep (DONE — Vazhi_DAPT_Data_v1_1.ipynb)
Step 2: DAPT Training (DONE — Vazhi_DAPT_v1_1_Tamil.ipynb)
  → Produced: CryptoYogi/qwen3-0.6b-tamil-v1_1
    PPL 2.6, +55% Tamil vs vanilla, 7/8 eval passed

Step 3 (THIS NOTEBOOK): SFT — teach instruction-following in Tamil
  → Input:  DAPT'd model + v4.0 ChatML dataset (1,514 samples)
  → Output: CryptoYogi/vazhi-v4_0 (final VAZHI model)
           CryptoYogi/vazhi-v4_0-lora (adapter backup)
```

**Pre-SFT validation (GPT5.2):** Perplexity baseline + chat template test before training.

**Key decisions:**
1. **DAPT'd instruct base** — already fluent in Tamil (PPL 2.6), SFT adds instruction-following
2. **fp16** — no 4-bit (0.6B fits in 1.1GB), avoids merge corruption
3. **Dual T4 DataParallel** — no device_map for training (lesson from DAPT v1.1)
4. **Completion-only masking** — only train on assistant responses
5. **`<think>` suppression** — during eval generation only
6. **Chat template eval** — `apply_chat_template(enable_thinking=False)` per GPT5.2

**Target:** Kaggle T4 x2 | ~255 steps (3 epochs) | Est. < 1 hour

## 1. Install Dependencies

**After running this cell, RESTART the session** (Runtime → Restart session)

In [None]:
!pip install -q -U \
  "transformers>=4.45.0,<5.0.0" \
  "accelerate>=0.34.2" \
  "peft>=0.12.0" \
  "trl>=0.12.0,<0.20.0" \
  "datasets>=2.21.0" \
  "huggingface_hub>=0.24.7"

print("\u2705 Dependencies installed")
print("\u26a0\ufe0f  RESTART THE SESSION NOW (Runtime \u2192 Restart session)")

## 2. Configuration

In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import re
import random
import glob
import gc
import torch
import numpy as np
from collections import Counter
from datasets import load_dataset
from huggingface_hub import login, HfApi

from transformers import (
    AutoModelForCausalLM, AutoTokenizer,
    TrainerCallback, TrainingArguments,
)
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# === KEY CONFIG ===
DAPT_MODEL = "CryptoYogi/qwen3-0.6b-tamil-v1_1"   # DAPT'd instruct (Step 2 output)
VANILLA_MODEL = "Qwen/Qwen3-0.6B"                  # For pre-SFT baseline comparison
SFT_DATASET = "CryptoYogi/vazhi-tamil-sft-v4_0"    # v4.0 ChatML dataset (ADR-010)
DAPT_DATASET = "CryptoYogi/vazhi-dapt-tamil-v1_1"  # For perplexity baseline
OUTPUT_MODEL = "CryptoYogi/vazhi-v4_0"              # Final VAZHI model
ADAPTER_REPO = "CryptoYogi/vazhi-v4_0-lora"         # Adapter backup

# Training config
LEARNING_RATE = 2e-5       # Conservative for SFT on DAPT'd model
NUM_EPOCHS = 3
MAX_SEQ_LENGTH = 1024
LORA_R = 16
LORA_ALPHA = 32
BATCH_SIZE = 2             # Per-device
GRADIENT_ACCUMULATION = 4  # 2 x 2 GPUs x 4 = 16 effective batch

# Qwen3 instruct <think> tokens to suppress during generation
THINK_TOKEN_IDS = [151667, 151668]

SYSTEM_PROMPT = (
    "\u0ba8\u0bc0\u0b99\u0bcd\u0b95\u0bb3\u0bcd VAZHI (\u0bb5\u0bb4\u0bbf), \u0ba4\u0bae\u0bbf\u0bb4\u0bcd \u0bae\u0b95\u0bcd\u0b95\u0bb3\u0bc1\u0b95\u0bcd\u0b95\u0bbe\u0ba9 AI \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0bb3\u0bb0\u0bcd. "
    "\u0ba4\u0bae\u0bbf\u0bb4\u0bbf\u0bb2\u0bcd \u0ba4\u0bc6\u0bb3\u0bbf\u0bb5\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0baa\u0ba4\u0bbf\u0bb2\u0bb3\u0bbf\u0baf\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd. "
    '\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bbe\u0bb5\u0bbf\u0b9f\u0bcd\u0b9f\u0bbe\u0bb2\u0bcd "\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bb5\u0bbf\u0bb2\u0bcd\u0bb2\u0bc8" \u0b8e\u0ba9\u0bcd\u0bb1\u0bc1 \u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd.'
)

# GPU setup
n_gpus = torch.cuda.device_count()

print(f"\u2705 Configuration loaded")
print(f"   PyTorch: {torch.__version__}")
print(f"   CUDA: {torch.cuda.is_available()}")
print(f"   GPUs: {n_gpus}")
for i in range(n_gpus):
    name = torch.cuda.get_device_name(i)
    mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
    print(f"   GPU {i}: {name} ({mem:.0f} GB)")

effective_batch = BATCH_SIZE * n_gpus * GRADIENT_ACCUMULATION
print()
print(f"\U0001f4cb SFT v4.0 on DAPT v1.1:")
print(f"   Base:     {DAPT_MODEL} (DAPT'd instruct)")
print(f"   Dataset:  {SFT_DATASET}")
print(f"   Output:   {OUTPUT_MODEL}")
print(f"   LR:       {LEARNING_RATE}")
print(f"   LoRA:     r={LORA_R}, alpha={LORA_ALPHA}")
print(f"   Batch:    {BATCH_SIZE} x {n_gpus} GPUs x {GRADIENT_ACCUMULATION} accum = {effective_batch} effective")
print(f"   Epochs:   {NUM_EPOCHS}")
print(f"   fp16:     True (no 4-bit)")

In [None]:
# Login to HuggingFace
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
hf_token = secrets.get_secret("HF_TOKEN")
login(token=hf_token)
print("\u2705 Logged in to HuggingFace")

## 3. Pre-SFT Validation (GPT5.2)

Before spending compute on SFT, verify that DAPT actually helped:
1. **Perplexity baseline** — DAPT'd model should have lower PPL on Tamil text than vanilla
2. **Chat template test** — DAPT'd model should still respond to chat-formatted prompts

In [None]:
# === PERPLEXITY BASELINE ===
# Compare vanilla instruct vs DAPT'd model on Tamil validation blocks

print("\U0001f4ca Pre-SFT Validation: Perplexity Baseline")
print("="*60)

# Load Tamil val set from DAPT dataset (pre-tokenized 1024-token blocks)
print(f"\U0001f4e5 Loading Tamil val set from {DAPT_DATASET}...")
dapt_ds = load_dataset(DAPT_DATASET, split="validation")
n_eval = min(100, len(dapt_ds))  # Cap for speed
print(f"   {len(dapt_ds)} val blocks available, using {n_eval}")

def compute_ppl(model, dataset, n_samples, device):
    """Compute perplexity on pre-tokenized blocks."""
    model.eval()
    losses = []
    for i in range(n_samples):
        input_ids = torch.tensor([dataset[i]["input_ids"]], dtype=torch.long).to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, labels=input_ids.clone())
            losses.append(outputs.loss.item())
    avg_loss = np.mean(losses)
    return np.exp(min(avg_loss, 20)), avg_loss

# 1. Vanilla instruct PPL
print(f"\n\U0001f4e5 Loading vanilla {VANILLA_MODEL}...")
vanilla = AutoModelForCausalLM.from_pretrained(
    VANILLA_MODEL, torch_dtype=torch.float16, device_map={"":0}, trust_remote_code=True,
)
vanilla_ppl, vanilla_loss = compute_ppl(vanilla, dapt_ds, n_eval, vanilla.device)
print(f"   Vanilla PPL: {vanilla_ppl:.2f} (loss: {vanilla_loss:.4f})")
del vanilla; gc.collect(); torch.cuda.empty_cache()

# 2. DAPT'd model PPL
print(f"\n\U0001f4e5 Loading DAPT'd {DAPT_MODEL}...")
dapt = AutoModelForCausalLM.from_pretrained(
    DAPT_MODEL, torch_dtype=torch.float16, device_map={"":0}, trust_remote_code=True,
)
dapt_ppl, dapt_loss = compute_ppl(dapt, dapt_ds, n_eval, dapt.device)
print(f"   DAPT PPL:    {dapt_ppl:.2f} (loss: {dapt_loss:.4f})")

# Compare
ppl_improvement = vanilla_ppl - dapt_ppl
ppl_pct = 100 * (vanilla_ppl - dapt_ppl) / vanilla_ppl
print(f"\n\U0001f4ca PPL Comparison:")
print(f"   Vanilla: {vanilla_ppl:.2f}")
print(f"   DAPT:    {dapt_ppl:.2f}")
print(f"   Change:  {ppl_improvement:+.2f} ({ppl_pct:+.1f}%)")

if dapt_ppl < vanilla_ppl:
    print(f"\n\u2705 DAPT improved Tamil perplexity! Safe to proceed with SFT.")
elif dapt_ppl < vanilla_ppl * 1.05:  # Within 5%
    print(f"\n\u26a0\ufe0f  DAPT is neutral (within 5%). Proceeding with caution.")
else:
    print(f"\n\u274c DAPT made perplexity WORSE! Investigate before SFT.")
    print(f"   This would mean DAPT hurt Tamil fluency \u2014 check training logs.")

In [None]:
# === CHAT TEMPLATE TEST ===
# Verify DAPT'd model still responds to chat-formatted prompts

print("\U0001f9ea Pre-SFT Validation: Chat Template Test")
print("="*60)

tokenizer = AutoTokenizer.from_pretrained(DAPT_MODEL, trust_remote_code=True)

# Check if tokenizer supports enable_thinking
try:
    test = tokenizer.apply_chat_template(
        [{"role": "user", "content": "test"}],
        tokenize=False, add_generation_prompt=True, enable_thinking=False,
    )
    USE_THINKING_FLAG = True
    print("\u2705 Tokenizer supports enable_thinking=False")
except TypeError:
    USE_THINKING_FLAG = False
    print("\u26a0\ufe0f  enable_thinking not supported, using manual template")

def build_chat_prompt(user_text):
    msgs = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_text},
    ]
    if USE_THINKING_FLAG:
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False,
        )
    else:
        return (
            f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
            f"<|im_start|>user\n{user_text}<|im_end|>\n"
            f"<|im_start|>assistant\n"
        )

def extract_response(full_text):
    """Extract assistant response from generated text."""
    if "<|im_start|>assistant" in full_text:
        resp = full_text.split("<|im_start|>assistant")[-1]
        resp = resp.split("<|im_end|>")[0].strip()
        if resp.startswith("\n"):
            resp = resp[1:]
        return resp
    return full_text

def count_tamil_chars(text):
    return sum(1 for c in text if '\u0B80' <= c <= '\u0BFF')

def tamil_char_pct(text):
    if not text: return 0.0
    return 100.0 * count_tamil_chars(text) / len(text)

chat_prompts = [
    "\u0bb5\u0ba3\u0b95\u0bcd\u0b95\u0bae\u0bcd",
    "\u0ba4\u0bae\u0bbf\u0bb4\u0bcd\u0ba8\u0bbe\u0b9f\u0bcd\u0b9f\u0bbf\u0ba9\u0bcd \u0ba4\u0bb2\u0bc8\u0ba8\u0b95\u0bb0\u0bae\u0bcd \u0b8e\u0ba9\u0bcd\u0ba9?",
    "\u0ba8\u0bc0\u0b99\u0bcd\u0b95\u0bb3\u0bcd \u0baf\u0bbe\u0bb0\u0bcd?",
]

dapt.eval()
dapt.config.use_cache = True

for prompt_text in chat_prompts:
    full_prompt = build_chat_prompt(prompt_text)
    inputs = tokenizer(full_prompt, return_tensors="pt").to(dapt.device)
    with torch.no_grad():
        outputs = dapt.generate(
            **inputs, max_new_tokens=100, do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            suppress_tokens=THINK_TOKEN_IDS,
        )
    full = tokenizer.decode(outputs[0], skip_special_tokens=False)
    response = extract_response(full)
    t_pct = tamil_char_pct(response)
    print(f"\n  Q: {prompt_text}")
    print(f"  A: {response[:200]}")
    print(f"  Tamil: {t_pct:.0f}%")

print("\n" + "="*60)
print("If responses are Tamil (even if incoherent), DAPT preserved")
print("language capability. SFT will teach instruction-following.")

# Free DAPT model
del dapt, dapt_ds
gc.collect(); torch.cuda.empty_cache()
print("\n\U0001f5d1\ufe0f Pre-validation models freed")

## 4. Load SFT Dataset + Validate

Dataset Factory v4.0 output (ADR-010):
- ~50% domain packs (security, govt, education, legal, healthcare, culture)
- ~33% IndicAlign diversity
- ~6% Kural interpretive (hard-capped, anti-memorization)
- ~3% handcrafted (guardrails, refusal, brevity, greeting)
- ~8% general knowledge
- All samples in strict ChatML format

In [None]:
print(f"\U0001f4da Loading SFT dataset from {SFT_DATASET}...")
sft_ds = load_dataset(SFT_DATASET)
train_ds = sft_ds["train"]
eval_ds = sft_ds["validation"]

print(f"\u2705 Dataset loaded:")
print(f"   Train:      {len(train_ds)} samples")
print(f"   Validation: {len(eval_ds)} samples")
print(f"   Columns:    {train_ds.column_names}")

# Composition stats
bucket_dist = Counter(item.get('bucket', 'unknown') for item in train_ds)
print(f"\n\U0001f4ca Composition:")
for bucket, count in sorted(bucket_dist.items(), key=lambda x: -x[1]):
    print(f"   {bucket}: {count} ({100*count/len(train_ds):.1f}%)")

# ChatML validation (all samples)
import re
CHATML_RE = re.compile(
    r'<\|im_start\|>system\n.+?<\|im_end\|>\n'
    r'<\|im_start\|>user\n(.+?)<\|im_end\|>\n'
    r'<\|im_start\|>assistant\n(.+?)<\|im_end\|>',
    re.DOTALL
)

fail_count = 0
think_count = 0
for i in range(len(train_ds)):
    text = train_ds[i]["text"]
    if not CHATML_RE.search(text):
        fail_count += 1
        if fail_count <= 3:
            print(f"   \u274c Sample {i}: invalid ChatML")
    # Check for <think> tags in training data (GPT5.2)
    if "<think>" in text or "</think>" in text:
        think_count += 1
        if think_count <= 3:
            print(f"   \u26a0\ufe0f Sample {i}: contains <think> tag!")

if fail_count == 0:
    print(f"\n\u2705 All {len(train_ds)} train samples pass ChatML validation")
else:
    print(f"\n\u274c {fail_count} samples failed ChatML validation!")
    if fail_count > len(train_ds) * 0.1:
        raise RuntimeError("Too many ChatML failures \u2014 DO NOT TRAIN")

if think_count == 0:
    print(f"\u2705 No <think> tags in training data")
else:
    print(f"\u26a0\ufe0f {think_count} samples contain <think> tags \u2014 strip before training")

# Show a sample
m = CHATML_RE.search(train_ds[0]["text"])
if m:
    print(f"\n\U0001f50d Sample:")
    print(f"   Q: {m.group(1)[:100]}")
    print(f"   A: {m.group(2)[:150]}")

## 5. Load DAPT'd Model + Tokenizer

In [None]:
print(f"\U0001f4e5 Loading tokenizer from {DAPT_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(DAPT_MODEL, trust_remote_code=True)
tokenizer.padding_side = "right"

print(f"\u2705 Tokenizer: {len(tokenizer)} tokens")
print(f"   eos: {tokenizer.eos_token!r} (ID {tokenizer.eos_token_id})")
print(f"   pad: {tokenizer.pad_token!r} (ID {tokenizer.pad_token_id})")

# Verify ChatML tokens
for token in ["<|im_start|>", "<|im_end|>"]:
    assert token in tokenizer.get_vocab(), f"Missing {token}!"
print("\u2705 ChatML tokens present")

# Verify <think> tokens
for tid in THINK_TOKEN_IDS:
    print(f"   Token {tid}: {tokenizer.decode([tid])!r}")

In [None]:
# Load DAPT'd model in fp16 — NO device_map (for DataParallel)
print(f"\U0001f4e5 Loading {DAPT_MODEL} in fp16...")
print(f"   NO device_map \u2014 Trainer handles DataParallel across {n_gpus} GPUs")

model = AutoModelForCausalLM.from_pretrained(
    DAPT_MODEL,
    torch_dtype=torch.float16,
    trust_remote_code=True,
)
model = model.to("cuda:0")

model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.use_cache = False
model.gradient_checkpointing_enable()

# Verify no hf_device_map (would prevent DataParallel)
has_dm = hasattr(model, "hf_device_map")
print(f"   hf_device_map: {has_dm} (must be False)")
if has_dm:
    print(f"   \u26a0\ufe0f WARNING: hf_device_map detected!")

mem_gb = torch.cuda.memory_allocated(0) / 1024**3
print(f"\u2705 Model loaded: {model.num_parameters():,} params ({mem_gb:.1f} GB)")
print(f"   This is the DAPT'd instruct model (Tamil-enhanced)")

## 6. LoRA Setup

In [None]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

mem_gb = torch.cuda.memory_allocated() / 1024**3
print(f"\u2705 LoRA applied | GPU: {mem_gb:.1f} GB")

## 7. Completion-Only Masking

Only train on assistant responses. System prompt and user messages are masked (-100).
This prevents the model from memorizing prompts and focuses learning on response generation.

In [None]:
# Find the response template token IDs
response_template_str = "<|im_start|>assistant\n"
response_template_ids = tokenizer.encode(response_template_str, add_special_tokens=False)
print(f"Response template: {response_template_str!r}")
print(f"Token IDs: {response_template_ids}")
print(f"Decoded: {tokenizer.decode(response_template_ids)!r}")

# Fallback: without trailing newline (Lesson #26)
response_template_short = "<|im_start|>assistant"
response_template_short_ids = tokenizer.encode(response_template_short, add_special_tokens=False)
print(f"\nShort template: {response_template_short!r}")
print(f"Short IDs: {response_template_short_ids}")

# Verify which template is found in actual data
sample_ids = tokenizer.encode(train_ds[0]["text"], add_special_tokens=False)

def find_subseq(seq, subseq):
    for i in range(len(seq) - len(subseq) + 1):
        if seq[i:i+len(subseq)] == subseq:
            return i
    return -1

pos = find_subseq(sample_ids, response_template_ids)
if pos >= 0:
    print(f"\n\u2705 Full template found at position {pos}")
    use_template_ids = response_template_ids
else:
    pos = find_subseq(sample_ids, response_template_short_ids)
    if pos >= 0:
        print(f"\n\u26a0\ufe0f Using short template (found at position {pos})")
        use_template_ids = response_template_short_ids
    else:
        raise RuntimeError("FATAL: Neither template found in tokenized sample!")

# Create collator
collator = DataCollatorForCompletionOnlyLM(
    response_template=use_template_ids,
    tokenizer=tokenizer,
)

# Preflight: verify masking on 20 samples
print(f"\n\U0001f4ca Preflight masking verification (20 samples)...")
fail_count = 0
total_trainable = 0
total_tokens = 0

for idx in range(min(20, len(train_ds))):
    t = tokenizer(
        train_ds[idx]["text"], return_tensors="pt",
        truncation=True, max_length=MAX_SEQ_LENGTH,
    )
    b = collator([{"input_ids": t["input_ids"][0], "attention_mask": t["attention_mask"][0]}])
    n_train = (b["labels"][0] != -100).sum().item()
    n_total = len(b["labels"][0])
    total_trainable += n_train
    total_tokens += n_total
    if n_train == 0 or n_train == n_total:
        fail_count += 1
        status = "ALL MASKED" if n_train == 0 else "NO MASKING"
        print(f"   \u274c Sample {idx}: {n_train}/{n_total} {status}")

if fail_count == 0:
    pct = 100 * total_trainable / total_tokens
    print(f"   All 20 passed \u2705 (avg {pct:.1f}% trainable tokens)")
else:
    print(f"\n\u274c {fail_count}/20 samples have masking issues!")
    if fail_count > 5:
        raise RuntimeError("TOO MANY FAILURES \u2014 DO NOT TRAIN")

In [None]:
# === PREFLIGHT MINI-TRAINING (GPT5.2) ===
# Run 2 steps on 200 samples to catch Trainer/device/config issues
# before burning the full training budget

print("\U0001f6e1\ufe0f Preflight: mini-training (2 steps, 200 samples)...")

preflight_ds = train_ds.select(range(min(200, len(train_ds))))

preflight_config = SFTConfig(
    output_dir="/kaggle/working/preflight_sft",
    num_train_epochs=1,
    max_steps=2,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    learning_rate=LEARNING_RATE,
    logging_steps=1,
    save_strategy="no",
    fp16=True,
    report_to="none",
    seed=RANDOM_SEED,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

preflight_trainer = SFTTrainer(
    model=model,
    train_dataset=preflight_ds,
    args=preflight_config,
    processing_class=tokenizer,
    data_collator=collator,
)

preflight_result = preflight_trainer.train()
preflight_loss = preflight_result.metrics.get("train_loss", 0)
print(f"\u2705 Preflight complete! Loss: {preflight_loss:.4f}")
print(f"   No AcceleratorState errors, no device mismatches.")
print(f"   Trainer + LoRA + DataCollator + DataParallel all working.")

# Clean up preflight artifacts
del preflight_trainer
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
import shutil
if os.path.exists("/kaggle/working/preflight_sft"):
    shutil.rmtree("/kaggle/working/preflight_sft")
print("   Preflight artifacts cleaned up.")

## 8. Training

SFT on DAPT'd model: ~255 steps (3 epochs × ~85 steps/epoch).
Should complete in well under 1 hour on dual T4.

In [None]:
# Compute steps
steps_per_epoch = len(train_ds) // (BATCH_SIZE * n_gpus * GRADIENT_ACCUMULATION)
total_steps = steps_per_epoch * NUM_EPOCHS
log_steps = max(total_steps // 30, 5)
eval_steps = max(steps_per_epoch // 2, 10)
save_steps = max(steps_per_epoch, 20)

print(f"\U0001f4ca Training Plan:")
print(f"   Train samples:    {len(train_ds)}")
print(f"   Steps/epoch:      ~{steps_per_epoch}")
print(f"   Total steps:      ~{total_steps}")
print(f"   Log every:        {log_steps} steps")
print(f"   Eval every:       {eval_steps} steps")
print(f"   Save every:       {save_steps} steps")

# Loss logging callback
class LossLoggingCallback(TrainerCallback):
    def __init__(self):
        self.losses = []
        self.eval_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            if "loss" in logs:
                step = state.global_step
                loss = logs["loss"]
                lr = logs.get("learning_rate", 0)
                self.losses.append((step, loss))
                print(f"  Step {step:4d}/{total_steps} | Loss: {loss:.4f} | LR: {lr:.2e}")
            if "eval_loss" in logs:
                self.eval_losses.append((state.global_step, logs["eval_loss"]))
                print(f"  \U0001f4ca Eval Loss: {logs['eval_loss']:.4f}")

loss_callback = LossLoggingCallback()

OUTPUT_DIR = "/kaggle/working/vazhi-sft-v4_0"

sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=log_steps,
    save_steps=save_steps,
    eval_steps=eval_steps,
    eval_strategy="steps",
    save_total_limit=3,
    fp16=True,
    bf16=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    max_grad_norm=1.0,
    optim="adamw_torch",
    report_to="none",
    seed=RANDOM_SEED,
    load_best_model_at_end=False,
    dataloader_pin_memory=True,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    args=sft_config,
    processing_class=tokenizer,
    data_collator=collator,
    callbacks=[loss_callback],
)

print(f"\u2705 SFTTrainer ready")
print(f"   Model: {DAPT_MODEL} (DAPT'd instruct)")
print(f"   Completion-only masking: \u2705")
print(f"   Dual T4 DataParallel: {n_gpus} GPUs")

In [None]:
print("\U0001f680 Starting SFT v4.0 training...")
print(f"   ~{total_steps} steps, {NUM_EPOCHS} epochs")
print(f"   Base: DAPT'd instruct (Tamil PPL 2.6)")
print()

train_result = trainer.train()

print("\n\u2705 Training complete!")
metrics = train_result.metrics
for k, v in metrics.items():
    print(f"   {k}: {v}")

# Final eval
print("\n\U0001f4ca Final eval...")
eval_metrics = trainer.evaluate()
for k, v in eval_metrics.items():
    print(f"   {k}: {v}")

# Loss summary
if loss_callback.losses:
    start = loss_callback.losses[0][1]
    end = loss_callback.losses[-1][1]
    print(f"\n\U0001f4c8 Loss: {start:.4f} \u2192 {end:.4f} ({100*(start-end)/start:.1f}% drop)")

## 8a. Resume from Checkpoint (if Kaggle disconnected)

**Only run if training was interrupted.** Skip if training completed above.

In [None]:
# === UNCOMMENT ONLY IF TRAINING WAS INTERRUPTED ===

# checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/checkpoint-*"), key=os.path.getmtime)
# if checkpoints:
#     latest = checkpoints[-1]
#     print(f"\U0001f504 Resuming from {latest}")
#     train_result = trainer.train(resume_from_checkpoint=latest)
#     print("\u2705 Resumed and completed!")
#     for k, v in train_result.metrics.items():
#         print(f"   {k}: {v}")
# else:
#     print("\u274c No checkpoints found.")

## 9. Save & Upload LoRA Adapter

In [None]:
ADAPTER_PATH = "/kaggle/working/vazhi-sft-v4_0-lora"

print("\U0001f4be Saving LoRA adapter...")
trainer.save_model(ADAPTER_PATH)
tokenizer.save_pretrained(ADAPTER_PATH)

# Save training metadata for reproducibility (GPT5.2)
import hashlib
ds_hash = hashlib.md5(str(train_ds[:10]["text"]).encode()).hexdigest()[:12]
metadata = {
    "base_model": DAPT_MODEL,
    "dataset": SFT_DATASET,
    "dataset_hash": ds_hash,
    "train_samples": len(train_ds),
    "eval_samples": len(eval_ds),
    "learning_rate": LEARNING_RATE,
    "epochs": NUM_EPOCHS,
    "lora_r": LORA_R,
    "lora_alpha": LORA_ALPHA,
    "max_seq_length": MAX_SEQ_LENGTH,
    "effective_batch": BATCH_SIZE * n_gpus * GRADIENT_ACCUMULATION,
    "train_loss": metrics.get("train_loss"),
    "eval_loss": eval_metrics.get("eval_loss"),
}
with open(f"{ADAPTER_PATH}/training_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)
print(f"   training_metadata.json saved")

adapter_files = glob.glob(f"{ADAPTER_PATH}/*")
print(f"   Files: {[os.path.basename(f) for f in adapter_files]}")
assert any('adapter' in f for f in adapter_files), "No adapter files!"
print("\u2705 Adapter saved")

# Upload
api = HfApi()
api.create_repo(ADAPTER_REPO, exist_ok=True)
print(f"\U0001f4e4 Uploading to {ADAPTER_REPO}...")
api.upload_folder(
    folder_path=ADAPTER_PATH,
    repo_id=ADAPTER_REPO,
    commit_message=f"SFT v4.0 adapter on DAPT v1.1, r={LORA_R}, lr={LEARNING_RATE}, {NUM_EPOCHS} epochs",
)
print(f"\u2705 Adapter: https://huggingface.co/{ADAPTER_REPO}")

## 10. Merge LoRA in FP16

**Hard rule (Lesson #27/39):** NEVER merge into 4-bit. Reload base in fp16, merge there.

Note: We merge onto the DAPT'd model, not vanilla — this preserves Tamil fluency.

In [None]:
# Free training model
del model, trainer
gc.collect(); torch.cuda.empty_cache()
print("\U0001f5d1\ufe0f Training model freed")

# Reload DAPT'd base in fp16 for clean merge
print(f"\U0001f517 Loading {DAPT_MODEL} in fp16 for merge...")
base_fp16 = AutoModelForCausalLM.from_pretrained(
    DAPT_MODEL, torch_dtype=torch.float16, device_map={"":0}, trust_remote_code=True,
)

peft_model = PeftModel.from_pretrained(base_fp16, ADAPTER_PATH)
peft_model.gradient_checkpointing_disable()
peft_model.config.use_cache = True
peft_model.eval()

print("\U0001f500 Merging LoRA in fp16...")
merged_model = peft_model.merge_and_unload()
print(f"\u2705 Merged: {merged_model.num_parameters():,} params")

## 11. Evaluation — Chat-Templated Prompts

**This is SFT eval, not DAPT eval.** We test instruction-following with chat template,
not raw text continuation. Uses `apply_chat_template(enable_thinking=False)` per GPT5.2.

In [None]:
merged_model.eval()
merged_model.config.use_cache = True

def compute_repeat_ratio(text, n=3):
    """Fraction of tokens in repeated n-gram chains. >0.2 is bad (GPT5.2)."""
    words = text.split()
    if len(words) < n:
        return 0.0
    ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
    seen = set()
    repeated_positions = set()
    for i, ng in enumerate(ngrams):
        if ng in seen:
            for j in range(i, i + n):
                repeated_positions.add(j)
        seen.add(ng)
    return len(repeated_positions) / max(len(words), 1)

test_prompts = [
    # Greetings (2)
    ("greeting", "\u0bb5\u0ba3\u0b95\u0bcd\u0b95\u0bae\u0bcd"),
    ("greeting", "\u0ba8\u0bc0\u0b99\u0bcd\u0b95\u0bb3\u0bcd \u0baf\u0bbe\u0bb0\u0bcd?"),
    # Factual (3)
    ("factual", "\u0ba4\u0bae\u0bbf\u0bb4\u0bcd\u0ba8\u0bbe\u0b9f\u0bcd\u0b9f\u0bbf\u0ba9\u0bcd \u0ba4\u0bb2\u0bc8\u0ba8\u0b95\u0bb0\u0bae\u0bcd \u0b8e\u0ba9\u0bcd\u0ba9?"),
    ("factual", "\u0baa\u0bca\u0b99\u0bcd\u0b95\u0bb2\u0bcd \u0b8e\u0baa\u0bcd\u0baa\u0bcb\u0ba4\u0bc1 \u0b95\u0bca\u0ba3\u0bcd\u0b9f\u0bbe\u0b9f\u0baa\u0bcd\u0baa\u0b9f\u0bc1\u0b95\u0bbf\u0bb1\u0ba4\u0bc1?"),
    ("factual", "2+2 \u0b8e\u0ba9\u0bcd\u0ba9?"),
    # Culture (2)
    ("culture", "\u0ba4\u0bbf\u0bb0\u0bc1\u0bb5\u0bb3\u0bcd\u0bb3\u0bc1\u0bb5\u0bb0\u0bcd \u0baf\u0bbe\u0bb0\u0bcd?"),
    ("culture", "\u0ba4\u0bae\u0bbf\u0bb4\u0bcd \u0bae\u0bca\u0bb4\u0bbf\u0baf\u0bbf\u0ba9\u0bcd \u0b9a\u0bbf\u0bb1\u0baa\u0bcd\u0baa\u0bc1 \u0b8e\u0ba9\u0bcd\u0ba9?"),
    # Safety (2)
    ("safety", "\u0b92\u0bb0\u0bc1 scam message \u0bb5\u0ba8\u0bcd\u0ba4\u0bbe\u0bb2\u0bcd \u0b8e\u0ba9\u0bcd\u0ba9 \u0b9a\u0bc6\u0baf\u0bcd\u0bb5\u0ba4\u0bc1?"),
    ("safety", "\u0bb5\u0bc0\u0b9f\u0bcd\u0b9f\u0bbf\u0bb2\u0bcd \u0ba4\u0bc0 \u0bb5\u0bbf\u0baa\u0ba4\u0bcd\u0ba4\u0bc1 \u0b8e\u0ba9\u0bcd\u0ba9 \u0b9a\u0bc6\u0baf\u0bcd\u0baf \u0bb5\u0bc7\u0ba3\u0bcd\u0b9f\u0bc1\u0bae\u0bcd?"),
    # Refusal (2)
    ("refusal", "\u0ba8\u0bbe\u0bb3\u0bc8 \u0baa\u0b99\u0bcd\u0b95\u0bc1 \u0b9a\u0ba8\u0bcd\u0ba4\u0bc8 \u0b8f\u0bb1\u0bc1\u0bae\u0bbe?"),
    ("refusal", "\u0b8e\u0ba9\u0bcd \u0b95\u0ba3\u0bbf\u0ba9\u0bbf\u0baf\u0bbf\u0bb2\u0bcd \u0bb5\u0bc8\u0bb0\u0bb8\u0bcd \u0b87\u0bb0\u0bc1\u0b95\u0bcd\u0b95\u0bbf\u0bb1\u0ba4\u0bbe?"),
    # General (1)
    ("general", "\u0b95\u0bbe\u0bb2\u0bc8\u0baf\u0bbf\u0bb2\u0bcd \u0b8e\u0ba9\u0bcd\u0ba9 \u0b9a\u0bbe\u0baa\u0bcd\u0baa\u0bbf\u0b9f\u0bb2\u0bbe\u0bae\u0bcd?"),
]

print(f"\n{'='*60}")
print(f"\U0001f9ea SFT v4.0 EVAL: {len(test_prompts)} chat-templated prompts")
print(f"   Using: {'apply_chat_template(enable_thinking=False)' if USE_THINKING_FLAG else 'manual ChatML'}")
print(f"{'='*60}")

results = []

for category, prompt_text in test_prompts:
    full_prompt = build_chat_prompt(prompt_text)
    inputs = tokenizer(full_prompt, return_tensors="pt").to(merged_model.device)

    gen_kwargs = dict(
        max_new_tokens=150,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        suppress_tokens=THINK_TOKEN_IDS,
        no_repeat_ngram_size=4,
    )
    if category == "factual":
        gen_kwargs["do_sample"] = False
    else:
        gen_kwargs["do_sample"] = True
        gen_kwargs["temperature"] = 0.3
        gen_kwargs["top_p"] = 0.9
        gen_kwargs["repetition_penalty"] = 1.2

    with torch.no_grad():
        outputs = merged_model.generate(**inputs, **gen_kwargs)

    full = tokenizer.decode(outputs[0], skip_special_tokens=False)
    response = extract_response(full)

    t_pct = tamil_char_pct(response)
    words = response.split()
    repeat_r = compute_repeat_ratio(response)
    has_loop = repeat_r > 0.2
    has_system = "system" in response.lower()[:50]
    has_think = "<think>" in response
    is_empty = len(response.strip()) < 5
    is_code = any(c in response[:100] for c in ['=True', '={"', 'var ', 'function', '<br'])

    status = "\u2705"
    if is_code: status = "\u274c CODE"
    elif has_loop: status = "\u26a0\ufe0f LOOP"
    elif has_system: status = "\u274c SYSTEM"
    elif has_think: status = "\u274c THINK"
    elif is_empty: status = "\u274c EMPTY"
    elif t_pct < 20 and category not in ["factual"]: status = "\u26a0\ufe0f LOW TAMIL"

    results.append((category, prompt_text, response[:300], status, t_pct, repeat_r))

    print(f"\n[{category.upper()}] {status} (Tamil: {t_pct:.0f}%, Repeat: {repeat_r:.2f})")
    print(f"  Q: {prompt_text}")
    print(f"  A: {response[:300]}")
    print("-" * 50)

In [None]:
# === EVAL SUMMARY ===
print(f"\n{'='*60}")
print(f"\U0001f4ca SFT v4.0 EVAL SUMMARY")
print(f"{'='*60}")

pass_count = sum(1 for r in results if r[3] == "\u2705")
avg_tamil = np.mean([r[4] for r in results])
avg_repeat = np.mean([r[5] for r in results])
max_repeat = max(r[5] for r in results)

print(f"   Passed:      {pass_count}/{len(results)}")
print(f"   Avg Tamil:   {avg_tamil:.0f}%")
print(f"   Avg Repeat:  {avg_repeat:.2f} (>0.2 is bad)")
print(f"   Max Repeat:  {max_repeat:.2f}")
print()

for cat, prompt, resp, status, tamil, repeat in results:
    print(f"   {status} [{cat}] {prompt[:40]}... (Tamil: {tamil:.0f}%, Rep: {repeat:.2f})")

print(f"\n\U0001f4cb Previous SFT attempts for comparison:")
print(f"   v3.8 (SFT-only, no DAPT): 0/12 passed, avg Tamil 52% \u274c")
print(f"   v3.6 (merge corruption):  0/12 passed, 0% Tamil \u274c")

if pass_count >= len(results) * 0.8 and avg_tamil > 30 and avg_repeat < 0.2:
    print(f"\n\U0001f389 SFT v4.0 successful! Proceed to GGUF quantization.")
elif pass_count >= len(results) * 0.5:
    print(f"\n\u26a0\ufe0f  Partial success. Upload and test manually.")
else:
    print(f"\n\u274c SFT failed. Check loss curve and pre-merge sanity.")
    print(f"   If loss converged but output is bad: try more epochs or higher LR")
    print(f"   If loss didn't converge: dataset may need more samples")

## 12. Upload Merged Model

In [None]:
api = HfApi()
api.create_repo(OUTPUT_MODEL, exist_ok=True)

print(f"\U0001f4e4 Pushing merged fp16 model to {OUTPUT_MODEL}...")
merged_model.push_to_hub(
    OUTPUT_MODEL,
    private=False,
    commit_message=(
        f"SFT v4.0: VAZHI Tamil assistant, DAPT v1.1 base, "
        f"LoRA r={LORA_R}, lr={LEARNING_RATE}, {NUM_EPOCHS} epochs, "
        f"{len(train_ds)} samples"
    ),
)
tokenizer.push_to_hub(OUTPUT_MODEL)

print(f"\n\u2705 Model:   https://huggingface.co/{OUTPUT_MODEL}")
print(f"\u2705 Adapter: https://huggingface.co/{ADAPTER_REPO}")
print(f"\n\U0001f449 Next: Convert to GGUF (Q4_K_M) for mobile deployment")

## Summary

### Pipeline Complete

| Step | Notebook | Artifact |
|------|----------|----------|
| 1. Data Prep | `Vazhi_DAPT_Data_v1_1.ipynb` | `CryptoYogi/vazhi-dapt-tamil-v1_1` |
| 2. DAPT | `Vazhi_DAPT_v1_1_Tamil.ipynb` | `CryptoYogi/qwen3-0.6b-tamil-v1_1` |
| 3. SFT (this) | `Vazhi_SFT_v4_0_OnDAPT.ipynb` | `CryptoYogi/vazhi-v4_0` |

### Key Config

| Parameter | Value |
|-----------|-------|
| Base model | `CryptoYogi/qwen3-0.6b-tamil-v1_1` (DAPT'd instruct) |
| Dataset | `CryptoYogi/vazhi-tamil-sft-v4_0` (ADR-010, ~1,514 samples) |
| LR | 2e-5 |
| Epochs | 3 |
| LoRA | r=16, alpha=32 |
| Masking | Completion-only (assistant responses only) |
| Merge | fp16 (never 4-bit) |

### Next Steps
1. **GGUF Quantization** — Convert to Q4_K_M (~462MB) for mobile
2. **Mobile test** — Load in Flutter app via llama.cpp
3. **Iterate** — If quality is insufficient, increase epochs or add data

### If SFT failed
1. Check loss curve — did it converge?
2. Check pre-merge sanity — did PeftModel responses look ok?
3. Try higher LR (5e-5) or more epochs (5)
4. Dataset may need more samples — combine with more IndicAlign
5. Fallback: Sarvam-1 IQ3_M (1.17GB, proven Tamil)