# VAZHI SFT v3.6 — Return to Instruct Model

**Key changes from v3.5 (failed):**
1. **Back to instruct model** — Qwen3-0.6B (NOT Base). The instruct model already has Tamil capability; the base model produced code garbage with SFT-only
2. **`<think>` token suppression** — suppress during generation, not during training
3. **LR 2e-5** — v3.3 used 1e-4 which was too aggressive for instruct model
4. **Completion-only masking** — kept from v3.5, but with robust response template
5. **Strict ChatML validation** — regex-based filter rejects samples missing user/assistant
6. **Length filtering** — cap at 1500 chars to avoid truncation issues
7. **Dataset rebalancing** — add refusal/brevity samples, reduce Kural dominance
8. **Quality evaluation** — Tamil char %, coherence checks, not just pattern absence

### Why this will work
v3.3 proved the instruct model produces Tamil. It had three fixable issues:
- `<think>` tags in output → suppressed via `suppress_tokens` in generation
- LR 1e-4 too aggressive → reduced to 2e-5
- Kural-biased responses → dataset rebalancing

v3.5's pivot to base model was a regression — we fix what works, not start over.

**Target:** Kaggle P100 (16GB)

## 1. Install Dependencies

**After running this cell, RESTART the session** (Runtime → Restart session)

In [None]:
# Install dependencies — pin TRL to avoid DataCollatorForCompletionOnlyLM removal
!pip install -q -U \
  "transformers>=4.51.0" \
  "accelerate>=0.34.2" \
  "peft>=0.12.0" \
  "trl>=0.12.0,<0.20.0" \
  "bitsandbytes>=0.43.3" \
  "datasets>=2.21.0" \
  "huggingface_hub>=0.24.7"

print("\u2705 Dependencies installed")
print("\u26a0\ufe0f  RESTART THE SESSION NOW (Runtime \u2192 Restart session)")

## 2. Imports & Configuration

In [None]:
# Force single GPU BEFORE importing torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import re
import random
import torch
import numpy as np
from datasets import load_dataset, Dataset
from huggingface_hub import login, HfApi

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# === KEY CONFIG ===
SOURCE_DATASET = "CryptoYogi/vazhi-tamil-sft-v3_3"  # Existing balanced dataset
BASE_MODEL = "Qwen/Qwen3-0.6B"                     # INSTRUCT model (NOT Base!)
OUTPUT_DATASET = "CryptoYogi/vazhi-tamil-sft-v3_6"
OUTPUT_MODEL = "CryptoYogi/vazhi-qwen3-v3_6"

# v3.6 hyperparameters
LEARNING_RATE = 2e-5   # v3.3 used 1e-4 (too aggressive)
NUM_EPOCHS = 3
MAX_SEQ_LENGTH = 1024  # v3.5 showed 512 causes too many truncation warnings
MAX_CHAR_LENGTH = 1500 # Filter out samples longer than this (avoids truncation)
LORA_R = 16            # v3.5 used 32 but instruct model needs less adaptation
LORA_ALPHA = 32

SYSTEM_PROMPT = (
    "நீங்கள் VAZHI (வழி), தமிழ் மக்களுக்கான AI உதவியாளர். "
    "தமிழில் தெளிவாகவும் உதவியாகவும் பதிலளியுங்கள். "
    'தெரியாவிட்டால் "தெரியவில்லை" என்று சொல்லுங்கள்.'
)

print(f"\u2705 Configuration loaded")
print(f"   PyTorch: {torch.__version__}")
print(f"   CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
print(f"")
print(f"\U0001f511 KEY CHANGES in v3.6:")
print(f"   1. INSTRUCT model: {BASE_MODEL} (NOT Base!)")
print(f"   2. LR: {LEARNING_RATE} (v3.3 used 1e-4)")
print(f"   3. <think> token suppression during generation")
print(f"   4. Completion-only masking (robust template)")
print(f"   5. Strict ChatML validation + length filtering")

In [None]:
# Login to HuggingFace
from kaggle_secrets import UserSecretsClient
secrets = UserSecretsClient()
hf_token = secrets.get_secret("HF_TOKEN")
login(token=hf_token)
print("\u2705 Logged in to HuggingFace")

## 3. Dataset Construction — Strict ChatML + Rebalancing

**Rules (from 10 failed attempts):**
1. Every sample MUST have `<|im_start|>user` AND `<|im_start|>assistant` with non-empty content
2. No raw text — only ChatML formatted samples
3. Max 1500 chars per sample — avoids truncation at 1024 tokens
4. No exact verse memorization — Thirukkural retrieval is handled by SQLite
5. Include refusal and brevity discipline samples

In [None]:
# === STRICT ChatML VALIDATOR ===
# Regex-based: rejects anything without proper user AND assistant segments

CHATML_PATTERN = re.compile(
    r'<\|im_start\|>system\n.+?<\|im_end\|>\n'
    r'<\|im_start\|>user\n(.+?)<\|im_end\|>\n'
    r'<\|im_start\|>assistant\n(.+?)<\|im_end\|>',
    re.DOTALL
)

def validate_chatml_strict(text):
    """Validate a sample has proper ChatML with non-empty user AND assistant."""
    match = CHATML_PATTERN.search(text)
    if not match:
        return False, "no ChatML structure found"
    
    user_content = match.group(1).strip()
    assistant_content = match.group(2).strip()
    
    if len(user_content) < 2:
        return False, "empty user content"
    if len(assistant_content) < 2:
        return False, "empty assistant content"
    
    return True, "ok"


def to_chatml(instruction, output, system_prompt=None):
    """Convert instruction/output to strict ChatML format."""
    sp = system_prompt or SYSTEM_PROMPT
    return (
        f"<|im_start|>system\n{sp}<|im_end|>\n"
        f"<|im_start|>user\n{instruction}<|im_end|>\n"
        f"<|im_start|>assistant\n{output}<|im_end|>"
    )


def count_tamil_chars(text):
    """Count Tamil Unicode characters."""
    return sum(1 for c in text if '஀' <= c <= '௿')


def tamil_char_pct(text):
    """Get Tamil character percentage."""
    if not text:
        return 0.0
    return 100.0 * count_tamil_chars(text) / len(text)


print("\u2705 Validation functions defined")

# Self-test
good = to_chatml("test question", "test answer")
valid, reason = validate_chatml_strict(good)
assert valid, f"Self-test failed: {reason}"

bad = "<|im_start|>system\ntest<|im_end|>\n<|im_start|>user\ntest<|im_end|>"
valid, reason = validate_chatml_strict(bad)
assert not valid, "Should reject missing assistant"
print("\u2705 Self-tests passed")

In [None]:
# === LOAD AND FILTER EXISTING DATASET ===

print(f"\U0001f4da Loading source dataset from {SOURCE_DATASET}...")
source_ds = load_dataset(SOURCE_DATASET, split="train")
print(f"   Loaded {len(source_ds)} samples")

# Strict validation pass
valid_samples = []
rejected = {"no_structure": 0, "empty_user": 0, "empty_assistant": 0, "too_long": 0}

for item in source_ds:
    text = item.get("text", "")
    
    # Length filter first
    if len(text) > MAX_CHAR_LENGTH:
        rejected["too_long"] += 1
        continue
    
    valid, reason = validate_chatml_strict(text)
    if valid:
        valid_samples.append(text)
    else:
        if "structure" in reason:
            rejected["no_structure"] += 1
        elif "user" in reason:
            rejected["empty_user"] += 1
        elif "assistant" in reason:
            rejected["empty_assistant"] += 1

print(f"\n\U0001f4ca Filtering results:")
print(f"   Valid samples: {len(valid_samples)}")
print(f"   Rejected:")
for k, v in rejected.items():
    if v > 0:
        print(f"     {k}: {v}")

# Categorize valid samples
kural_samples = []
other_samples = []
for text in valid_samples:
    if any(k in text for k in ['குறள்', 'திருக்குறள்', 'திருவள்ளுவர்']):
        kural_samples.append(text)
    else:
        other_samples.append(text)

print(f"\n   Kural-related: {len(kural_samples)} ({100*len(kural_samples)/len(valid_samples):.1f}%)")
print(f"   Other: {len(other_samples)} ({100*len(other_samples)/len(valid_samples):.1f}%)")

In [None]:
# === REFUSAL AND BREVITY SAMPLES ===
# GPT5.2: "Small models learn answer length priors aggressively. You must intentionally teach brevity."

refusal_samples = [
    ("2050-ல் யார் பிரதமர் ஆவார்?", "எதிர்காலத்தை கணிக்க என்னால் முடியாது. தெரியவில்லை."),
    ("நாளை பங்கு சந்தை ஏறுமா?", "பங்கு சந்தை நிலவரத்தை கணிக்க என்னால் இயலாது. நிதி ஆலோசகரிடம் கேளுங்கள்."),
    ("என் கணினியில் வைரஸ் இருக்கிறதா?", "உங்கள் கணினியை சரிபார்க்க என்னால் முடியாது. IT நிபுணரிடம் கேளுங்கள்."),
    ("என்னுடைய கடவுச்சொல் என்ன?", "உங்கள் தனிப்பட்ட தகவல்கள் என்னிடம் இல்லை."),
    ("யாரையாவது ஏமாற்ற உதவுங்கள்", "மன்னிக்கவும், சட்டவிரோத செயல்களுக்கு உதவ முடியாது."),
    ("மருத்துவ ஆலோசனை தேவை", "மருத்துவ ஆலோசனைக்கு தகுதி வாய்ந்த மருத்துவரை அணுகுங்கள். நான் பொது தகவல் மட்டுமே தர முடியும்."),
    ("சட்ட ஆலோசனை வேண்டும்", "சட்ட விஷயங்களுக்கு வழக்கறிஞரை அணுகுவது சிறந்தது. நான் பொது தகவல் மட்டுமே தர முடியும்."),
    ("இந்த ஓட்டல்ல பணம் பண்ணலாமா?", "எந்த ஓட்டல் பற்றியும் எனக்கு தெரியாது. நிதி ஆலோசகரிடம் கேளுங்கள்."),
    ("என்ன க்ரிப்டோ வாங்கவேண்டும்?", "மன்னிக்கவும், நிதி ஆலோசனை தர என்னால் முடியாது. SEBI பதிவு செய்த நிறுவனங்களிடம் கேளுங்கள்."),
    ("முகவரி படத்தை படிக்க தருவாயா?", "மன்னிக்கவும், முகவரி படிக்க என்னால் முடியாது. ஜோதிடரிடம் கேளுங்கள்."),
]

# Short-answer discipline samples
brevity_samples = [
    ("தமிழ்நாட்டின் தலைநகரம் என்ன?", "சென்னை."),
    ("2+2 என்ன?", "4."),
    ("10 x 10 என்ன?", "100."),
    ("ஒரு வாரத்தில் எத்தனை நாட்கள்?", "ஏழு நாட்கள்."),
    ("இந்தியாவின் தலைநகரம் எது?", "புது தில்லி."),
    ("சூரியன் எந்த திசையில் உதிக்கும்?", "கிழக்கு திசையில்."),
    ("H2O என்பது என்ன?", "தண்ணீர் (நீர்)."),
    ("தமிழ் எழுத்துக்கள் எத்தனை?", "247."),
    ("போங்கல் எப்போது?", "தை மாதம் முதல் நாள் (ஜனவரி 14/15)."),
    ("ஆம் என்றால் ஆங்கிலத்தில்?", "Yes."),
    ("இல்லை என்றால் ஆங்கிலத்தில்?", "No."),
    ("Good morning தமிழில் என்ன?", "காலை வணக்கம்."),
    ("நன்றி என்றால் ஆங்கிலத்தில்?", "Thank you."),
    ("மிகப்பெரிய கண்டம் எது?", "ஆசியா."),
    ("மனித உடலில் எத்தனை எலும்புகள்?", "206."),
    ("பூமியின் இயற்கை துணைக்கோள் எது?", "நிலவு (சந்திரன்)."),
]

# Conversational greetings
greeting_samples = [
    ("வணக்கம்", "வணக்கம்! நான் வழி. உங்களுக்கு எப்படி உதவ வேண்டும்?"),
    ("hi", "வணக்கம்! எப்படி உதவலாம்?"),
    ("hello", "வணக்கம்! கேளுங்கள்."),
    ("நீங்கள் யார்?", "நான் வழி (VAZHI), தமிழ் மக்களுக்கான AI உதவியாளர்."),
    ("நன்றி", "மகிழ்ச்சி! வேறு உதவி தேவைப்பட்டால் கேளுங்கள்."),
    ("bye", "வணக்கம்! இனிய நாள் வாழ்த்துக்கள்."),
    ("சரி", "சரி, வேறு ஏதாவது கேள்வி இருக்கிறதா?"),
]

# Convert to ChatML
extra_samples = []
for instruction, output in refusal_samples:
    extra_samples.append(to_chatml(instruction, output))
for instruction, output in brevity_samples:
    extra_samples.append(to_chatml(instruction, output))
for instruction, output in greeting_samples:
    extra_samples.append(to_chatml(instruction, output))

# Validate all new samples
for text in extra_samples:
    valid, reason = validate_chatml_strict(text)
    assert valid, f"Extra sample failed validation: {reason}\n{text[:200]}"

print(f"\u2705 Added {len(extra_samples)} refusal/brevity/greeting samples")
print(f"   Refusal: {len(refusal_samples)}")
print(f"   Brevity: {len(brevity_samples)}")
print(f"   Greeting: {len(greeting_samples)}")

In [None]:
# === REBALANCE AND COMBINE ===

# Downsample Kural to ~15% of total
target_kural_pct = 0.15
total_non_kural = len(other_samples) + len(extra_samples)
target_kural_count = int(target_kural_pct * total_non_kural / (1 - target_kural_pct))

if len(kural_samples) > target_kural_count:
    downsampled_kural = random.sample(kural_samples, target_kural_count)
else:
    downsampled_kural = kural_samples

print(f"\U0001f3af Kural downsampling: {len(kural_samples)} \u2192 {len(downsampled_kural)}")

# Combine everything
final_texts = []
final_texts.extend(downsampled_kural)
final_texts.extend(other_samples)
final_texts.extend(extra_samples)

# Shuffle
random.shuffle(final_texts)

# Final validation pass
final_valid = []
for text in final_texts:
    valid, _ = validate_chatml_strict(text)
    if valid:
        final_valid.append({"text": text})

# Stats
kural_count = sum(1 for s in final_valid 
                  if any(k in s["text"] for k in ['குறள்', 'திருக்குறள்']))
avg_len = sum(len(s["text"]) for s in final_valid) / len(final_valid)
short_count = sum(1 for s in final_valid if len(s["text"]) < 400)

print(f"\n\U0001f4ca Final v3.6 dataset:")
print(f"   Total samples: {len(final_valid)}")
print(f"   Kural: {kural_count} ({100*kural_count/len(final_valid):.1f}%)")
print(f"   Short (<400 chars): {short_count} ({100*short_count/len(final_valid):.1f}%)")
print(f"   Avg length: {avg_len:.0f} chars")
print(f"   100% ChatML validated \u2705")

In [None]:
# === UPLOAD TO HUGGINGFACE ===

balanced_ds = Dataset.from_list(final_valid)

api = HfApi()
api.create_repo(OUTPUT_DATASET, repo_type="dataset", exist_ok=True)

balanced_ds.push_to_hub(OUTPUT_DATASET, split="train")
print(f"\u2705 Dataset uploaded: https://huggingface.co/datasets/{OUTPUT_DATASET}")
print(f"   {len(balanced_ds)} samples")

## 4. Load Model + Tokenizer

**CRITICAL:** Using `Qwen/Qwen3-0.6B` (INSTRUCT), not Base.
The instruct model already has Tamil capability — v3.3 proved this.

In [None]:
print(f"\U0001f4e5 Loading tokenizer from {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.padding_side = "right"

# DO NOT modify pad_token for Qwen3 instruct — it already has proper tokens
# DO NOT add special tokens — instruct model already has ChatML tokens
print(f"\u2705 Tokenizer ready: {len(tokenizer)} tokens")
print(f"   pad_token: {tokenizer.pad_token!r} (ID {tokenizer.pad_token_id})")
print(f"   eos_token: {tokenizer.eos_token!r} (ID {tokenizer.eos_token_id})")

# Verify ChatML tokens exist
for token in ["<|im_start|>", "<|im_end|>"]:
    assert token in tokenizer.get_vocab(), f"Missing {token} in tokenizer!"
print("\u2705 ChatML tokens present in tokenizer")

# Get <think> token ID for suppression during generation
think_token_ids = tokenizer.encode("<think>", add_special_tokens=False)
print(f"\n\U0001f9e0 <think> token IDs: {think_token_ids}")
print(f"   Will suppress these during generation to prevent think-mode output")

In [None]:
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print(f"\U0001f4e5 Loading model {BASE_MODEL}...")
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map={"":0},
    trust_remote_code=True,
)

# Prepare for training
model = prepare_model_for_kbit_training(model)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.use_cache = False  # Required for gradient checkpointing

print(f"\u2705 Model loaded: {model.num_parameters():,} params")

## 5. LoRA Setup

In [None]:
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Convert any bf16 params to fp16 (safety check for P100)
bf16_count = sum(1 for _, p in model.named_parameters() if p.dtype == torch.bfloat16)
if bf16_count > 0:
    print(f"\u26a0\ufe0f  Converting {bf16_count} bf16 parameters to fp16")
    for name, param in model.named_parameters():
        if param.dtype == torch.bfloat16:
            param.data = param.data.to(torch.float16)
else:
    print("\u2705 No bf16 parameters")

## 6. Completion-Only Masking (Robust Template)

**Fix from v3.5:** Use `"<|im_start|>assistant"` (without trailing newline) for more robust matching.
GPT5.2: "The newline can tokenize differently depending on context."

In [None]:
# Use simpler template WITHOUT trailing newline for robustness
response_template_str = "<|im_start|>assistant\n"
response_template_ids = tokenizer.encode(response_template_str, add_special_tokens=False)

print(f"Response template: {response_template_str!r}")
print(f"Token IDs: {response_template_ids}")
print(f"Decoded back: {tokenizer.decode(response_template_ids)!r}")

# If the full template fails, try without newline
# (tokenization can split newlines differently in context)
response_template_short = "<|im_start|>assistant"
response_template_short_ids = tokenizer.encode(response_template_short, add_special_tokens=False)
print(f"\nShort template: {response_template_short!r}")
print(f"Short token IDs: {response_template_short_ids}")

# Verify template can be found in actual data
sample_text = balanced_ds[0]["text"]
sample_ids = tokenizer.encode(sample_text, add_special_tokens=False)

def find_template(sample_ids, template_ids):
    for i in range(len(sample_ids) - len(template_ids) + 1):
        if sample_ids[i:i+len(template_ids)] == template_ids:
            return i
    return -1

pos = find_template(sample_ids, response_template_ids)
if pos >= 0:
    print(f"\n\u2705 Full template found at token position {pos}")
    use_template_ids = response_template_ids
else:
    pos = find_template(sample_ids, response_template_short_ids)
    if pos >= 0:
        print(f"\n\u26a0\ufe0f  Full template NOT found, but short template found at position {pos}")
        print("   Using short template instead")
        use_template_ids = response_template_short_ids
    else:
        print("\n\u274c STOP: Neither template found in tokenized sample!")
        print("   Debug token-by-token:")
        # Show surrounding tokens
        for i, tid in enumerate(sample_ids):
            decoded = tokenizer.decode([tid])
            if 'assistant' in decoded.lower() or tid in response_template_ids:
                context = sample_ids[max(0,i-3):i+5]
                print(f"   Position {i}: {context} = {[tokenizer.decode([t]) for t in context]}")
        use_template_ids = response_template_ids  # Will fail at preflight

In [None]:
# Create collator and run preflight verification
collator = DataCollatorForCompletionOnlyLM(
    response_template=use_template_ids,
    tokenizer=tokenizer,
)

# Preflight: check 20 samples
print(f"\n\U0001f4ca Preflight masking verification (20 samples)...")
fail_count = 0
total_trainable = 0
total_tokens = 0

check_indices = list(range(min(20, len(balanced_ds))))
for idx in check_indices:
    t = tokenizer(
        balanced_ds[idx]["text"], 
        return_tensors="pt", 
        truncation=True, 
        max_length=MAX_SEQ_LENGTH
    )
    b = collator([{"input_ids": t["input_ids"][0], "attention_mask": t["attention_mask"][0]}])
    n_train = (b["labels"][0] != -100).sum().item()
    n_total = len(b["labels"][0])
    total_trainable += n_train
    total_tokens += n_total
    
    if n_train == 0 or n_train == n_total:
        fail_count += 1
        status = "\u274c ALL MASKED" if n_train == 0 else "\u274c NO MASKING"
        print(f"   Sample {idx}: {n_train}/{n_total} {status}")

if fail_count == 0:
    pct = 100 * total_trainable / total_tokens
    print(f"   All 20 samples passed \u2705")
    print(f"   Avg trainable: {pct:.1f}% of tokens")
else:
    print(f"\n\u274c {fail_count}/20 samples have masking issues!")
    if fail_count > 5:
        print("   TOO MANY FAILURES \u2014 DO NOT PROCEED WITH TRAINING")

## 7. Training

**Key settings:**
- LR 2e-5 (not 1e-4)
- Save every 50 steps for early checking
- FP32 mode for P100

In [None]:
sft_config = SFTConfig(
    output_dir="/kaggle/working/vazhi-v3_6",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=25,
    save_steps=50,         # Early checkpoints for quality checking
    save_total_limit=3,
    fp16=False,            # Disabled \u2014 Qwen3 has internal bf16 ops
    bf16=False,            # Disabled \u2014 P100 doesn't support bf16
    gradient_checkpointing=True,
    max_grad_norm=1.0,
    optim="paged_adamw_8bit",
    report_to="none",
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=balanced_ds,
    args=sft_config,
    processing_class=tokenizer,
    data_collator=collator,
)

print("\u2705 Trainer initialized")
print(f"   Model: {BASE_MODEL} (INSTRUCT)")
print(f"   LR: {LEARNING_RATE} (v3.3 used 1e-4)")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
print(f"   Max seq length: {MAX_SEQ_LENGTH}")
print(f"   Save steps: 50 (for early quality check)")
print(f"   Data collator: DataCollatorForCompletionOnlyLM")

In [None]:
print("\n\U0001f680 Starting training...")
trainer.train()
print("\n\u2705 Training complete!")

## 8. Save & Push to HuggingFace

In [None]:
print("\U0001f4be Saving model...")
trainer.save_model("/kaggle/working/vazhi-v3_6-final")

print("\U0001f500 Merging LoRA weights...")
merged_model = model.merge_and_unload()

api = HfApi()
api.create_repo(OUTPUT_MODEL, exist_ok=True)

print(f"\U0001f4e4 Pushing to {OUTPUT_MODEL}...")
merged_model.push_to_hub(OUTPUT_MODEL, private=False)
tokenizer.push_to_hub(OUTPUT_MODEL, private=False)

print(f"\n\u2705 Model uploaded: https://huggingface.co/{OUTPUT_MODEL}")

## 9. Quality Evaluation

**Changes from v3.5:**
1. **`<think>` token suppression** via `suppress_tokens`
2. **Tamil character % check** \u2014 responses must be >40% Tamil
3. **Quality scoring** \u2014 not just pattern absence
4. **Greedy decoding** for factual QA (deterministic)

In [None]:
merged_model.config.use_cache = True

# Get <think> token IDs to suppress
think_suppress_ids = tokenizer.encode("<think>", add_special_tokens=False)
# Also suppress </think> if present
think_close_ids = tokenizer.encode("</think>", add_special_tokens=False)
suppress_ids = list(set(think_suppress_ids + think_close_ids))
print(f"\U0001f9e0 Suppressing token IDs: {suppress_ids}")
print(f"   Decoded: {[tokenizer.decode([t]) for t in suppress_ids]}")

test_prompts = [
    # Greetings (2)
    ("greeting", "வணக்கம்"),
    ("greeting", "நீங்கள் யார்?"),
    # Factual (3) \u2014 use greedy decoding
    ("factual", "தமிழ்நாட்டின் தலைநகரம் என்ன?"),
    ("factual", "2+2 என்ன?"),
    ("factual", "பொங்கல் எப்போது கொண்டாடப்படுகிறது?"),
    # Culture (2)
    ("culture", "திருக்குறளின் முதல் குறள் என்ன?"),
    ("culture", "திருவள்ளுவர் யார்?"),
    # Safety (2)
    ("safety", "ஒரு scam message வந்தால் என்ன செய்வது?"),
    ("safety", "வீட்டில் தீ விபத்து ஏற்பட்டால் என்ன செய்ய வேண்டும்?"),
    # Refusal (2)
    ("refusal", "நாளை பங்கு சந்தை ஏறுமா?"),
    ("refusal", "என் கணினியில் வைரஸ் இருக்கிறதா?"),
    # General (1)
    ("general", "தமிழ் மொழியின் சிறப்பு என்ன?"),
]

print(f"\n{'='*60}")
print(f"\U0001f9ea EVALUATION: {len(test_prompts)} prompts")
print(f"{'='*60}")

results = []

for category, prompt_text in test_prompts:
    full_prompt = (
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n{prompt_text}<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )
    
    inputs = tokenizer(full_prompt, return_tensors="pt").to(merged_model.device)
    
    # Use greedy for factual, sampling for others
    gen_kwargs = dict(
        max_new_tokens=150,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        suppress_tokens=suppress_ids,  # Suppress <think> tokens
    )
    
    if category == "factual":
        gen_kwargs["do_sample"] = False  # Greedy for factual
    else:
        gen_kwargs["do_sample"] = True
        gen_kwargs["temperature"] = 0.3
        gen_kwargs["top_p"] = 0.9
        gen_kwargs["repetition_penalty"] = 1.2
    
    with torch.no_grad():
        outputs = merged_model.generate(**inputs, **gen_kwargs)
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    if "<|im_start|>assistant" in response:
        response = response.split("<|im_start|>assistant")[-1]
        response = response.split("<|im_end|>")[0].strip()
        if response.startswith("\n"):
            response = response[1:]
    
    # Quality checks
    tamil_pct = tamil_char_pct(response)
    has_loop = len(set(response.split())) < max(3, len(response.split()) * 0.3) if response.split() else True
    has_system = "system" in response.lower()[:50]
    has_think = "<think>" in response
    is_empty = len(response.strip()) < 5
    is_code = any(c in response[:100] for c in ['=True', '="', 'var ', 'function', '{"type', '<br'])
    
    # Status with Tamil quality check
    status = "\u2705"
    if is_code: status = "\u274c CODE"
    elif has_loop: status = "\u26a0\ufe0f LOOP"
    elif has_system: status = "\u274c SYSTEM LEAK"
    elif has_think: status = "\u274c THINK LEAK"
    elif is_empty: status = "\u274c EMPTY"
    elif tamil_pct < 20 and category not in ["factual"]:
        status = "\u26a0\ufe0f LOW TAMIL"
    
    results.append((category, prompt_text, response[:200], status, tamil_pct))
    
    print(f"\n[{category.upper()}] {status} (Tamil: {tamil_pct:.0f}%)")
    print(f"Q: {prompt_text}")
    print(f"A: {response[:300]}")
    print("-" * 50)

# Summary
print(f"\n{'='*60}")
print(f"\U0001f4ca EVALUATION SUMMARY")
print(f"{'='*60}")
pass_count = sum(1 for r in results if r[3] == "\u2705")
avg_tamil = sum(r[4] for r in results) / len(results)
print(f"   Passed: {pass_count}/{len(results)}")
print(f"   Avg Tamil: {avg_tamil:.0f}%")
for cat, prompt, resp, status, tamil in results:
    print(f"   {status} [{cat}] {prompt[:40]}... (Tamil: {tamil:.0f}%)")

if pass_count >= len(results) * 0.8 and avg_tamil > 30:
    print(f"\n\U0001f389 Model looks good! Ready for GGUF conversion.")
elif pass_count >= len(results) * 0.5:
    print(f"\n\u26a0\ufe0f  Partially working. Review failures above.")
else:
    print(f"\n\u274c Too many failures. Check:\n   1. Dataset quality\n   2. Loss curve\n   3. Consider DAPT stage")

## Summary

### v3.6 Changes from v3.5

| Setting | v3.5 (failed) | v3.6 (this notebook) |
|---------|---------------|---------------------|
| **Base Model** | **Qwen3-0.6B-Base** | **Qwen3-0.6B (INSTRUCT)** |
| Model has Tamil | No (code/web/Chinese) | Yes (proven in v3.3) |
| `<think>` handling | N/A (base model) | **Suppressed in generation** |
| Learning rate | 2e-5 | 2e-5 (same) |
| LoRA rank | 32 | **16** (instruct needs less) |
| ChatML validation | Basic | **Strict regex** |
| Length filtering | None (max 18K chars!) | **1500 char cap** |
| Refusal samples | None | **10 samples** |
| Brevity samples | None | **16 samples** |
| Response template | `assistant\\n` (fragile) | **Robust with fallback** |
| Eval: Tamil check | None | **Tamil char %** |
| Eval: code check | None | **Code pattern detection** |
| Eval: greedy mode | No | **Yes (for factual)** |
| Save steps | 100 | **50** (earlier quality check) |

### If this succeeds:
1. Convert to GGUF (Q4_K_M ~462MB, Q5_K_M ~526MB)
2. Test on mobile via Flutter app
3. Ship hybrid retrieval + LLM reasoning

### If this fails:
1. Add Micro-DAPT stage before SFT (raw Tamil text, no ChatML)
2. Pre-build labels manually instead of TRL template detection
3. Try Sarvam-1 IQ3_M (proven Tamil, 1.17GB)