# VAZHI Dataset Factory v4.1.2 — Stage 2 + 3 (GPU)

**Resumes from HF checkpoint** — Stage 1 already completed in v4.1 (CPU).

```
Stage 1 (DONE): 37,947 raw samples → CryptoYogi/vazhi-raw-tamil-qa-v1
┌─ Stage 2: CURATE (this notebook, GPU) ─────────────────────────┐
│ Pass 1 (CPU): lang-id → heuristics → dedup → toxicity          │
│ Pass 2 (GPU): perplexity + IndicSBERT → composite score        │
│ → Upload curated to HF: CryptoYogi/vazhi-curated-tamil-qa-v1   │
├─ Stage 3: COMPOSE (CPU) ────────────────────────────────────────┤
│ Filter → ChatML → absolute count targets → stratified split     │
│ → Upload final SFT to HF: CryptoYogi/vazhi-tamil-sft-v4_1      │
└─────────────────────────────────────────────────────────────────┘
```

**Sources (8 total, from Stage 1):**
- **IndicAlign** (5 subsets): Dolly_T, WikiHow, Indic_ShareLlama, HHRLHF_T, Toxic_Matrix
- **Local**: vazhi-packs (3,007), handcrafted (147), general (8,793)

**Fixes in v4.1.2:**
- Source-aware filtering: vazhi_packs/handcrafted bypass lang-id and tamil_pct filters
- fasttext NumPy 2.x patch (np.array copy=False → np.asarray)
- GPU runtime required for perplexity + SBERT scoring

**Run on:** Colab T4/A100 or Kaggle P100 (GPU required)

In [None]:
# 1. Config & Dependencies
!pip install -q datasets huggingface_hub fasttext-wheel sentence-transformers hdbscan

import json
import os
import re
import gc
import sys
import time
import random
import hashlib
from collections import Counter
from pathlib import Path

import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import login, HfApi

# === CONFIG ===
VERSION = "4.1.2"
RAW_DATASET = "CryptoYogi/vazhi-raw-tamil-qa-v1"
CURATED_DATASET = "CryptoYogi/vazhi-curated-tamil-qa-v1"
OUTPUT_DATASET = "CryptoYogi/vazhi-tamil-sft-v4_1"
DAPT_MODEL = "CryptoYogi/qwen3-0.6b-tamil-v1_1"
TOKENIZER_MODEL = "Qwen/Qwen3-0.6B"
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

SFT_MAX_SEQ_LENGTH = 2048

SYSTEM_PROMPT = (
    "\u0ba8\u0bc0\u0b99\u0bcd\u0b95\u0bb3\u0bcd VAZHI (\u0bb5\u0bb4\u0bbf), \u0ba4\u0bae\u0bbf\u0bb4\u0bcd \u0bae\u0b95\u0bcd\u0b95\u0bb3\u0bc1\u0b95\u0bcd\u0b95\u0bbe\u0ba9 AI \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0bb3\u0bb0\u0bcd. "
    "\u0ba4\u0bae\u0bbf\u0bb4\u0bbf\u0bb2\u0bcd \u0ba4\u0bc6\u0bb3\u0bbf\u0bb5\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0baa\u0ba4\u0bbf\u0bb2\u0bb3\u0bbf\u0baf\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd. "
    '\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bbe\u0bb5\u0bbf\u0b9f\u0bcd\u0b9f\u0bbe\u0bb2\u0bcd "\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bb5\u0bbf\u0bb2\u0bcd\u0bb2\u0bc8" \u0b8e\u0ba9\u0bcd\u0bb1\u0bc1 \u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd.'
)

BUCKET_TARGETS = {
    "vazhi_packs":    {"min": 2500, "target": 3000, "max": 3000},
    "handcrafted":    {"min": 100,  "target": 147,  "max": 200},
    "general":        {"min": 300,  "target": 500,  "max": 700},
    "indicalign":     {"min": 10000, "target": 12000, "max": 14000},
    "safety":         {"min": 1500, "target": 2000,  "max": 2500},
}

SOURCE_PRIORITY = {
    "vazhi_packs": 10, "handcrafted": 10,
    "general": 5,
    "indicalign": 2,
}

# GPU check + tier detection
# Qwen3 has 151K vocab — logits tensor = batch × seq × 151K × 2 bytes
# L4 (22GB): batch=16 → ~2.5GB logits, safe
# A100 (40GB): batch=32 → ~5GB logits, safe
assert torch.cuda.is_available(), "GPU required! Change runtime: Runtime > Change runtime type > GPU"
gpu_name = torch.cuda.get_device_name(0).lower()
IS_HIGH_END_GPU = any(x in gpu_name for x in ["a100", "l4", "h100", "a10"])
MODEL_DTYPE = torch.bfloat16 if IS_HIGH_END_GPU else torch.float16
PPL_BATCH_SIZE = 16  # Conservative for 151K vocab logits (even A100 needs care)
SBERT_BATCH_SIZE = 512 if IS_HIGH_END_GPU else 256
VRAM_GB = torch.cuda.get_device_properties(0).total_memory / 1e9

# Scale PPL batch to available VRAM
if VRAM_GB >= 40:
    PPL_BATCH_SIZE = 32
elif VRAM_GB >= 22:
    PPL_BATCH_SIZE = 16
else:
    PPL_BATCH_SIZE = 8

print(f"\u2705 GPU: {torch.cuda.get_device_name(0)} ({VRAM_GB:.0f}GB)")
print(f"   Tier: {'high-end' if IS_HIGH_END_GPU else 'standard'} \u2192 dtype={MODEL_DTYPE}, PPL batch={PPL_BATCH_SIZE}, SBERT batch={SBERT_BATCH_SIZE}")
print(f"\u2705 Config loaded: Dataset Factory v{VERSION}")
print(f"   Resuming from HF: {RAW_DATASET}")
print(f"   Output: {OUTPUT_DATASET}")

In [None]:
# 2. HuggingFace login
try:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    hf_token = secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("\u2705 Logged in via Kaggle secrets")
except Exception:
    try:
        from google.colab import userdata
        hf_token = userdata.get('HF_TOKEN')
        login(token=hf_token)
        print("\u2705 Logged in via Colab secrets")
    except Exception:
        login()
        print("\u2705 Logged in interactively")

In [None]:
# 3. Resume from HF checkpoint (Stage 1 already complete)
print(f"Loading raw dataset from HF: {RAW_DATASET}")
raw_ds = load_dataset(RAW_DATASET, split="train")
total_raw = len(raw_ds)
print(f"\u2705 Loaded {total_raw:,} samples")
print(f"   Columns: {raw_ds.column_names}")
print(f"   Sources: {Counter(raw_ds['source']).most_common()}")

In [None]:
# 4. Helper functions

def to_chatml(instruction, output, system_prompt=None):
    sp = system_prompt or SYSTEM_PROMPT
    return (
        f"<|im_start|>system\n{sp}<|im_end|>\n"
        f"<|im_start|>user\n{instruction}<|im_end|>\n"
        f"<|im_start|>assistant\n{output}<|im_end|>"
    )

CHATML_PATTERN = re.compile(
    r'<\|im_start\|>system\n.+?<\|im_end\|>\n'
    r'<\|im_start\|>user\n(.+?)<\|im_end\|>\n'
    r'<\|im_start\|>assistant\n(.+?)<\|im_end\|>',
    re.DOTALL
)

def validate_chatml_strict(text):
    match = CHATML_PATTERN.search(text)
    if not match:
        return False, "no ChatML structure found"
    if len(match.group(1).strip()) < 2:
        return False, "empty user content"
    if len(match.group(2).strip()) < 2:
        return False, "empty assistant content"
    return True, "ok"

def count_tamil_chars(text):
    return sum(1 for c in text if '\u0b80' <= c <= '\u0bff')

def tamil_char_pct(text):
    if not text:
        return 0.0
    return count_tamil_chars(text) / len(text)

def is_verbatim_kural_qa(question, answer):
    verbatim_patterns = [
        r'\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*\d+\s*(\u0b8e\u0ba9\u0bcd\u0ba9|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd|\u0b8e\u0bb4\u0bc1\u0ba4\u0bbf\s*\u0b95\u0bbe\u0b9f\u0bcd\u0b9f\u0bc1|\u0b95\u0bc2\u0bb1\u0bc1\u0b95)',
        r'(first|\u0bae\u0bc1\u0ba4\u0bb2\u0bcd)\s*\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*(\u0b8e\u0ba9\u0bcd\u0ba9|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd)',
        r'\u0ba4\u0bbf\u0bb0\u0bc1\u0b95\u0bcd\u0b95\u0bc1\u0bb1\u0bb3\u0bbf\u0ba9\u0bcd\s+\u0bae\u0bc1\u0ba4\u0bb2\u0bcd\s+\u0b95\u0bc1\u0bb1\u0bb3\u0bcd',
        r'\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*[\d]+(?:\s*\u0b90)?\s*\u0b8e\u0bb4\u0bc1\u0ba4\u0bbf',
    ]
    for pat in verbatim_patterns:
        if re.search(pat, question, re.IGNORECASE):
            return True
    if len(answer) < 200 and "\n" in answer and not any(
        w in answer for w in ["\u0bb5\u0bbf\u0bb3\u0b95\u0bcd\u0b95\u0bae\u0bcd", "\u0baa\u0bca\u0bb0\u0bc1\u0bb3\u0bcd", "\u0b85\u0bb0\u0bcd\u0ba4\u0bcd\u0ba4\u0bae\u0bcd"]
    ):
        return True
    return False

# Self-test
good = to_chatml("test", "answer")
assert validate_chatml_strict(good)[0]
print("\u2705 Helper functions ready")

---
## Checkpoint Resume (run ONLY if resuming from a crash)
If the runtime restarted and you lost in-memory data, run the next cell to reload from the local checkpoint saved after PPL scoring. **Skip this cell on a fresh run.**

In [None]:
# CHECKPOINT RESUME — Run ONLY if resuming from crash. Skip on fresh run.
# Loads deduped list with PPL scores from local checkpoint file.

CHECKPOINT_PATH = "/content/vazhi_ppl_checkpoint.json"

import json, os
if os.path.exists(CHECKPOINT_PATH):
    print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
    with open(CHECKPOINT_PATH) as f:
        checkpoint = json.load(f)
    deduped = checkpoint["deduped"]
    total_raw = checkpoint["total_raw"]
    candidates = [None] * checkpoint["n_candidates"]  # placeholder for summary
    clean_candidates = [None] * checkpoint["n_clean"]
    heuristic_dropped = checkpoint["heuristic_dropped"]
    total_dedup_removed = checkpoint["total_dedup_removed"]
    safety_count = checkpoint["safety_count"]
    toxic_dropped = checkpoint["toxic_dropped"]
    print(f"✅ Resumed {len(deduped):,} samples with PPL scores")
    print(f"   PPL scored: {sum(1 for s in deduped if s.get('perplexity') is not None):,}")
    print(f"   → Skip cells 6-11, continue from cell 12 (SBERT)")
else:
    print(f"No checkpoint at {CHECKPOINT_PATH} — run fresh from cell 6")

---
# Stage 2: CURATE

### Pass 1 (CPU): Language detection → heuristics → dedup → toxicity
### Pass 2 (GPU): Perplexity scoring → semantic categorization → composite quality score

**Fix in v4.1.2:** vazhi_packs and handcrafted bypass lang-id and tamil_pct filters (hand-curated product data)

In [None]:
# Pass 1A: Language Detection with fasttext
# Patch fasttext for NumPy 2.x compatibility before importing

import fasttext
ft_file = os.path.join(os.path.dirname(fasttext.__file__), 'FastText.py')
with open(ft_file) as f:
    code = f.read()
if 'np.array(probs, copy=False)' in code:
    code = code.replace('np.array(probs, copy=False)', 'np.asarray(probs)')
    with open(ft_file, 'w') as f:
        f.write(code)
    print("  Patched fasttext for NumPy 2.x")
for mod in [k for k in sys.modules if k.startswith('fasttext')]:
    del sys.modules[mod]
import fasttext
import urllib.request

LID_MODEL_PATH = "/tmp/lid.176.bin"
if not os.path.exists(LID_MODEL_PATH):
    print("Downloading fasttext language ID model (126MB)...")
    urllib.request.urlretrieve(
        "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
        LID_MODEL_PATH,
    )
lid_model = fasttext.load_model(LID_MODEL_PATH)
print("\u2705 fasttext lid.176.bin loaded")


def detect_lang(text):
    clean = text.replace("\n", " ").strip()
    if not clean:
        return "unknown", 0.0
    labels, scores = lid_model.predict(clean, k=1)
    if not labels:
        return "unknown", 0.0
    return labels[0].replace("__label__", ""), float(scores[0])


def add_lang_fields(example):
    combined = example["instruction"] + " " + example["output"]
    lang_id, lang_conf = detect_lang(combined)
    example["lang_id"] = lang_id
    example["lang_confidence"] = round(lang_conf, 4)
    return example


print(f"Running language detection on {len(raw_ds):,} samples...")
t0 = time.time()
raw_ds = raw_ds.map(add_lang_fields, desc="Language detection")
lang_elapsed = time.time() - t0

lang_counts = Counter(raw_ds["lang_id"])
print(f"\nLanguage distribution (top 10): {lang_counts.most_common(10)}")
print(f"   Lang-id took {lang_elapsed:.0f}s ({len(raw_ds)/lang_elapsed:.0f} samples/sec)")

# SOURCE-AWARE FILTER: Keep Tamil OR hand-curated product data
before = len(raw_ds)
tamil_ds = raw_ds.filter(
    lambda x: (x["lang_id"] == "ta" and x["lang_confidence"] >= 0.6)
    or x["source"] in ("vazhi_packs", "handcrafted")
)
lang_dropped = before - len(tamil_ds)

print(f"\n\u2705 Language detection complete")
print(f"   Before: {before:,}")
print(f"   Dropped (non-Tamil, excl. vazhi_packs/handcrafted): {lang_dropped:,}")
print(f"   Remaining: {len(tamil_ds):,}")

# Per-source check
src_counts = Counter(tamil_ds["source"])
print(f"   Per-source: {src_counts.most_common()}")

candidates = tamil_ds.to_list()
print(f"\u2705 {len(candidates):,} candidates in memory")

del raw_ds, tamil_ds
gc.collect()

In [None]:
# Pass 1B: Quality Heuristics + Format Sanity
# FIX: vazhi_packs/handcrafted bypass tamil_pct filter

def repetition_score(text):
    toks = text.split()
    if len(toks) < 5:
        return 0.0
    return Counter(toks).most_common(1)[0][1] / len(toks)

def is_sane(instruction, output):
    if "<think>" in output: return False
    if output.strip() == instruction.strip(): return False
    if "data:image" in output: return False
    if "<|im_start|>system" in output: return False
    if "systemsystem" in output.lower(): return False
    return True

def is_echo(instruction, output):
    if not instruction or not output:
        return False
    instr_set = set(instruction.split())
    out_set = set(output.split())
    if not out_set:
        return False
    return len(instr_set & out_set) / len(out_set) > 0.80

print(f"Running quality heuristics on {len(candidates):,} candidates...")
heuristic_stats = Counter()
clean_candidates = []

for s in candidates:
    flags = []

    # Tamil % — SKIP for hand-curated product data
    if s["tamil_pct"] < 0.30 and s["source"] not in ("vazhi_packs", "handcrafted"):
        flags.append("low_tamil_pct")

    rep = repetition_score(s["output"])
    s["repetition"] = round(rep, 4)
    if rep > 0.25 and len(s["output"].split()) > 20:
        flags.append("high_repetition")

    if is_echo(s["instruction"], s["output"]):
        flags.append("echo")

    if not is_sane(s["instruction"], s["output"]):
        flags.append("format_insane")

    if len(s["output"].strip()) < 10:
        flags.append("trivial_output")
    if len(s["instruction"].strip()) < 5:
        flags.append("trivial_instruction")

    s["heuristic_flags"] = flags
    for f in flags:
        heuristic_stats[f] += 1
    if len(flags) == 0:
        clean_candidates.append(s)

heuristic_dropped = len(candidates) - len(clean_candidates)
print(f"\n\u2705 Heuristic filtering complete")
print(f"   Before: {len(candidates):,}")
print(f"   Dropped: {heuristic_dropped:,}")
print(f"   Remaining: {len(clean_candidates):,}")
print(f"   Flags: {heuristic_stats.most_common()}")
# Verify vazhi_packs survived
vp_count = sum(1 for s in clean_candidates if s["source"] == "vazhi_packs")
hc_count = sum(1 for s in clean_candidates if s["source"] == "handcrafted")
print(f"   vazhi_packs: {vp_count:,} | handcrafted: {hc_count:,}")

In [None]:
# Pass 1C: Deduplication (exact only — MinHash skipped for <50K pool)

print(f"Deduplicating {len(clean_candidates):,} candidates...")

seen = {}
for s in clean_candidates:
    key = s["instruction"][:200].strip().lower()
    if key not in seen:
        seen[key] = s
    else:
        existing_priority = SOURCE_PRIORITY.get(seen[key]["source"], 1)
        new_priority = SOURCE_PRIORITY.get(s["source"], 1)
        if new_priority > existing_priority:
            seen[key] = s
        elif new_priority == existing_priority and s["tamil_pct"] > seen[key]["tamil_pct"]:
            seen[key] = s

deduped = list(seen.values())
exact_dupes = len(clean_candidates) - len(deduped)
print(f"  Exact dedupe: {len(clean_candidates):,} \u2192 {len(deduped):,} ({exact_dupes:,} removed)")

for s in deduped:
    s["is_duplicate"] = False

total_dedup_removed = exact_dupes
print(f"\u2705 Deduplication complete")

In [None]:
# Pass 1D: Toxicity Screening (source-aware)

TOXIC_WORDS_TA = [
    "\u0ba4\u0bc7\u0bb5\u0b9f\u0bbf\u0baf\u0bbe", "\u0bb5\u0bc7\u0b9a\u0bbf",
    "\u0baa\u0bc1\u0ba3\u0bcd\u0b9f\u0bc8", "\u0ba4\u0bb1\u0bcd\u0b95\u0bca\u0bb2\u0bc8",
    "\u0b95\u0bca\u0bb2\u0bcd\u0bb2", "\u0bb5\u0bc6\u0b9f\u0bbf\u0baa\u0bca\u0bb0\u0bc1\u0bb3\u0bcd",
]
TOXIC_WORDS_EN = [
    "suicide method", "how to kill", "make a bomb", "hack into",
    "child porn", "exploit children",
]

def check_wordlist(text):
    text_lower = text.lower()
    flags = []
    for word in TOXIC_WORDS_TA:
        if word in text:
            flags.append(f"ta:{word[:10]}")
    for phrase in TOXIC_WORDS_EN:
        if phrase in text_lower:
            flags.append(f"en:{phrase[:15]}")
    return flags

def classify_toxicity(instruction, output, source, subset):
    instr_flags = check_wordlist(instruction)
    output_flags = check_wordlist(output)
    if subset in ("Toxic_Matrix", "HHRLHF_T") and instr_flags and not output_flags:
        return instr_flags, True
    return instr_flags + output_flags, False

print(f"Running toxicity screening on {len(deduped):,} candidates...")
safety_count = 0
toxic_dropped = 0
for s in deduped:
    tox_flags, is_safety = classify_toxicity(
        s["instruction"], s["output"], s["source"], s["subset"]
    )
    s["toxicity_flags"] = tox_flags
    s["is_safety_sample"] = is_safety
    if is_safety:
        safety_count += 1
    elif tox_flags:
        toxic_dropped += 1

print(f"\u2705 Toxicity screening complete")
print(f"   Safety samples: {safety_count:,}")
print(f"   Toxic flagged: {toxic_dropped:,}")
print(f"   Clean: {len(deduped) - safety_count - toxic_dropped:,}")

In [None]:
# Pass 1 Summary

lang_dropped = total_raw - len(candidates)
print("=" * 60)
print("PASS 1 SUMMARY (CPU Filters)")
print("=" * 60)
print(f"  Raw input:          {total_raw:,}")
print(f"  After lang-id:      {len(candidates):,} (-{lang_dropped:,})")
print(f"  After heuristics:   {len(clean_candidates):,} (-{heuristic_dropped:,})")
print(f"  After dedup:        {len(deduped):,} (-{total_dedup_removed:,})")
print(f"  Safety routed:      {safety_count:,}")
print(f"  Toxic flagged:      {toxic_dropped:,}")
print(f"  \u2192 Pass 1 output:   {len(deduped):,} candidates")

survivor_sources = Counter(s["source"] for s in deduped)
print(f"\nSurvivors by source:")
for src, count in survivor_sources.most_common():
    print(f"  {src}: {count:,}")

In [None]:
# Pass 2E: Perplexity Scoring (GPU, BATCHED)
# Qwen3 151K vocab → logits = batch × seq × 151K × 2 bytes
# batch=16, seq=512 → ~2.5GB per batch (safe for L4 22GB)

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"Loading DAPT v1.1 model: {DAPT_MODEL}")
ppl_tokenizer = AutoTokenizer.from_pretrained(DAPT_MODEL)
if ppl_tokenizer.pad_token is None:
    ppl_tokenizer.pad_token = ppl_tokenizer.eos_token

ppl_model = AutoModelForCausalLM.from_pretrained(
    DAPT_MODEL, torch_dtype=MODEL_DTYPE, device_map="auto"
)
ppl_model.eval()
print(f"\u2705 PPL model loaded on {ppl_model.device} (dtype={MODEL_DTYPE})")

ppl_candidates = deduped
print(f"  Scoring {len(ppl_candidates):,} candidates in batches of {PPL_BATCH_SIZE}")


def compute_perplexity_batch(texts, batch_size=PPL_BATCH_SIZE, max_length=512):
    """Batched perplexity: one forward pass per batch, per-sample loss with padding mask."""
    ppls = []
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = ppl_tokenizer(
            batch, return_tensors="pt", truncation=True,
            max_length=max_length, padding=True
        ).to(ppl_model.device)

        with torch.no_grad():
            logits = ppl_model(**inputs).logits

        # Per-sample PPL with padding mask
        for j in range(len(batch)):
            shift_logits = logits[j, :-1, :]
            shift_labels = inputs["input_ids"][j, 1:]
            shift_mask = inputs["attention_mask"][j, 1:]

            losses = loss_fct(shift_logits, shift_labels)
            masked = losses * shift_mask
            n_tokens = shift_mask.sum()
            if n_tokens > 0:
                ppls.append(torch.exp(masked.sum() / n_tokens).item())
            else:
                ppls.append(None)

        # Free GPU memory between batches (151K vocab logits are huge)
        del logits, inputs
        torch.cuda.empty_cache()

        if (i // batch_size) % 50 == 0:
            print(f"  ...{min(i+batch_size, len(texts)):,} / {len(texts):,}")

    return ppls


output_texts = [s["output"] for s in ppl_candidates]
print(f"\nComputing perplexity...")
t0 = time.time()
ppl_scores = compute_perplexity_batch(output_texts)
elapsed = time.time() - t0

for i, s in enumerate(ppl_candidates):
    s["perplexity"] = round(ppl_scores[i], 2) if ppl_scores[i] is not None else None

for s in deduped:
    if "perplexity" not in s:
        s["perplexity"] = None

scored = [s["perplexity"] for s in deduped if s["perplexity"] is not None]
if scored:
    print(f"\n\u2705 Perplexity scoring complete in {elapsed:.0f}s ({elapsed/60:.1f} min)")
    print(f"   Scored: {len(scored):,} / {len(deduped):,}")
    print(f"   Speed: {len(scored)/elapsed:.0f} samples/sec")
    print(f"   PPL: min={min(scored):.1f}, median={sorted(scored)[len(scored)//2]:.1f}, max={max(scored):.1f}")
    print(f"   PPL < 50: {sum(1 for p in scored if p < 50):,}")
    print(f"   PPL 50-200: {sum(1 for p in scored if 50 <= p < 200):,}")
    print(f"   PPL >= 200 (garbage): {sum(1 for p in scored if p >= 200):,}")

del ppl_model
torch.cuda.empty_cache()
print("\u2705 PPL model unloaded")

In [None]:
# CHECKPOINT SAVE — Persist PPL-scored data to local file
# Protects against runtime restart losing 15+ min of GPU work

CHECKPOINT_PATH = "/content/vazhi_ppl_checkpoint.json"

checkpoint = {
    "deduped": deduped,
    "total_raw": total_raw,
    "n_candidates": len(candidates),
    "n_clean": len(clean_candidates),
    "heuristic_dropped": heuristic_dropped,
    "total_dedup_removed": total_dedup_removed,
    "safety_count": safety_count,
    "toxic_dropped": toxic_dropped,
}

import json
with open(CHECKPOINT_PATH, "w") as f:
    json.dump(checkpoint, f)

size_mb = os.path.getsize(CHECKPOINT_PATH) / 1e6
print(f"✅ Checkpoint saved: {CHECKPOINT_PATH} ({size_mb:.1f} MB)")
print(f"   {len(deduped):,} samples with PPL scores preserved")
print(f"   If runtime restarts, run the resume cell to skip Pass 1 + PPL")

In [None]:
# Pass 2F: SBERT Embeddings (skip HDBSCAN — domain classification uses keyword tagger)
# HDBSCAN is O(n²) on 35K × 768-dim = too slow. Keyword tagger is faster and directly useful.

from sentence_transformers import SentenceTransformer

print("Loading IndicSBERT...")
sbert_model = SentenceTransformer("l3cube-pune/tamil-sentence-similarity-sbert")
print(f"✅ IndicSBERT loaded")

instructions = [s["instruction"][:512] for s in deduped]
print(f"Computing embeddings for {len(instructions):,} instructions (batch={SBERT_BATCH_SIZE})...")
t0 = time.time()
embeddings = sbert_model.encode(
    instructions, batch_size=SBERT_BATCH_SIZE, show_progress_bar=True, normalize_embeddings=True
)
embed_elapsed = time.time() - t0
print(f"✅ Embeddings: {embeddings.shape} in {embed_elapsed:.0f}s ({len(instructions)/embed_elapsed:.0f} samples/sec)")

# Skip HDBSCAN — domain classification uses keyword tagger (next cell)
for s in deduped:
    s["embedding_cluster"] = -1
    s["auto_category"] = "unclustered"

del sbert_model, embeddings
torch.cuda.empty_cache()
print(f"✅ Done — domain classification in next cell")

In [None]:
# Pass 2G: Composite Quality Score + Tokenized Length

from transformers import AutoTokenizer as AT

print(f"Loading tokenizer: {TOKENIZER_MODEL}")
sft_tokenizer = AT.from_pretrained(TOKENIZER_MODEL, trust_remote_code=True)
print(f"\u2705 Tokenizer loaded (vocab: {sft_tokenizer.vocab_size:,})")


def compute_quality_score(s):
    lang_conf = s.get("lang_confidence", 0.0)
    rep = s.get("repetition", 0.0)
    tamil = s.get("tamil_pct", 0.0)
    tok_len = s.get("tokenized_length", SFT_MAX_SEQ_LENGTH + 1)
    tox = s.get("toxicity_flags", [])
    return round(
        lang_conf * 0.35 +
        (1 / (1 + rep * 10)) * 0.25 +
        tamil * 0.20 +
        (1.0 if tok_len <= SFT_MAX_SEQ_LENGTH else 0.0) * 0.10 +
        (0.0 if tox else 1.0) * 0.10
    , 4)


print(f"Scoring {len(deduped):,} samples...")
for i, s in enumerate(deduped):
    chatml_text = to_chatml(s["instruction"], s["output"])
    s["tokenized_length"] = len(sft_tokenizer.encode(chatml_text, add_special_tokens=False))
    s["quality_score"] = compute_quality_score(s)
    if (i + 1) % 10000 == 0:
        print(f"  ...{i+1:,} / {len(deduped):,}")

tok_lengths = [s["tokenized_length"] for s in deduped]
quality_scores = [s["quality_score"] for s in deduped]
print(f"\n\u2705 Scoring complete")
print(f"   Tokens: min={min(tok_lengths)}, median={sorted(tok_lengths)[len(tok_lengths)//2]}, max={max(tok_lengths)}")
print(f"   Within 2048: {sum(1 for t in tok_lengths if t <= SFT_MAX_SEQ_LENGTH):,} / {len(tok_lengths):,}")
print(f"   Quality: min={min(quality_scores):.3f}, median={sorted(quality_scores)[len(quality_scores)//2]:.3f}")
print(f"   Score >= 0.45: {sum(1 for q in quality_scores if q >= 0.45):,}")

In [None]:
# Pass 2H: Domain Classification (keyword-based + source-aware)
# Makes curated dataset filterable by domain for selective SFT emphasis

DOMAIN_KEYWORDS = {
    "healthcare": [
        "மருத்துவ", "சிகிச்சை", "நோய்", "மருந்து", "உடல்", "ஆரோக்கிய",
        "மருத்துவர்", "காய்ச்சல்", "நீரிழிவு", "இரத்த", "புற்றுநோய்",
        "கர்ப்ப", "தடுப்பூசி", "ஊட்டச்சத்து", "முதலுதவி", "மனநல",
        "hospital", "doctor", "health", "diabetes", "fever", "medicine",
    ],
    "legal": [
        "சட்ட", "வழக்கு", "நீதிமன்ற", "உரிமை", "FIR", "காவல்",
        "வழக்கறிஞர்", "ஜாமின்", "தண்டனை", "குற்ற", "சொத்து",
        "பத்திரம்", "விவாகரத்து", "புகார்", "நுகர்வோர்",
        "legal", "court", "law", "police", "rights", "property",
    ],
    "education": [
        "கல்வி", "பள்ளி", "பல்கலை", "தேர்வு", "படிப்பு", "மாணவ",
        "புலமைப்பரிசு", "NEET", "கல்லூரி", "பட்டம்", "ஆசிரிய",
        "பாடத்திட்ட", "சான்றிதழ்", "அரசுப்பள்ளி", "உதவித்தொகை",
        "school", "education", "exam", "scholarship", "student",
    ],
    "security": [
        "மோசடி", "ஏமாற்ற", "பாதுகாப்பு", "OTP", "ஹேக்",
        "சைபர்", "போலி", "வங்கி மோசடி", "ஃபிஷிங்", "லிங்க்",
        "scam", "fraud", "phishing", "hack", "password", "cyber",
    ],
    "government": [
        "அரசு", "திட்ட", "ஓய்வூதிய", "ரேஷன்", "மானிய",
        "ஆதார்", "வாக்காளர்", "பிறப்பு சான்றிதழ்", "பாஸ்போர்ட்",
        "வருமான சான்றிதழ்", "சமூக நல", "இலவச", "அரசாணை",
        "pension", "ration", "scheme", "government", "subsidy", "aadhaar",
    ],
    "culture": [
        "திருக்குறள்", "பண்பாடு", "கலாச்சார", "பொங்கல்", "திருவிழா",
        "தமிழ் இலக்கிய", "சங்க", "பாரதியார்", "கோவில்", "நாட்டுப்புற",
        "பரதநாட்டிய", "கர்நாடக இசை", "சிலப்பதிகாரம்", "தொல்காப்பிய",
        "culture", "temple", "festival", "literature", "heritage",
    ],
}


def classify_domain(sample):
    """Assign domain: vazhi_packs use subset, others use keyword matching."""
    source = sample["source"]
    subset = sample.get("subset", "")

    # vazhi_packs: subset IS the domain
    if source == "vazhi_packs" and subset in DOMAIN_KEYWORDS:
        return subset

    # Safety samples get "safety" domain
    if sample.get("is_safety_sample", False):
        return "safety"

    # Handcrafted: map to safety/guardrails
    if source == "handcrafted":
        return "safety"

    # Keyword matching on instruction + output
    text = (sample["instruction"] + " " + sample["output"]).lower()
    best_domain = "general"
    best_hits = 0
    for domain, keywords in DOMAIN_KEYWORDS.items():
        hits = sum(1 for kw in keywords if kw.lower() in text)
        if hits > best_hits:
            best_hits = hits
            best_domain = domain

    # Require at least 2 keyword hits to avoid false positives
    return best_domain if best_hits >= 2 else "general"


print(f"Classifying {len(deduped):,} samples into domains...")
t0 = time.time()
for s in deduped:
    s["domain"] = classify_domain(s)

domain_counts = Counter(s["domain"] for s in deduped)
elapsed = time.time() - t0
print(f"\n✅ Domain classification complete in {elapsed:.1f}s")
print(f"\nDomain distribution:")
for domain, count in domain_counts.most_common():
    pct = 100 * count / len(deduped)
    print(f"  {domain}: {count:,} ({pct:.1f}%)")

# Cross-check: vazhi_packs domains should match their pack names
vp_domains = Counter(s["domain"] for s in deduped if s["source"] == "vazhi_packs")
print(f"\nvazhi_packs domains: {vp_domains.most_common()}")

In [None]:
# Upload curated dataset to HF

curated_records = []
for s in deduped:
    curated_records.append({
        "instruction": s["instruction"], "output": s["output"],
        "source": s["source"], "subset": s["subset"],
        "char_length": s["char_length"], "tamil_pct": s["tamil_pct"],
        "lang_id": s["lang_id"], "lang_confidence": s["lang_confidence"],
        "heuristic_flags": s.get("heuristic_flags", []),
        "repetition": s.get("repetition", 0.0),
        "toxicity_flags": s.get("toxicity_flags", []),
        "is_safety_sample": s.get("is_safety_sample", False),
        "is_duplicate": s.get("is_duplicate", False),
        "perplexity": s.get("perplexity"),
        "embedding_cluster": s.get("embedding_cluster"),
        "auto_category": s.get("auto_category"),
        "domain": s.get("domain", "general"),
        "quality_score": s.get("quality_score", 0.0),
        "tokenized_length": s.get("tokenized_length", 0),
    })

print(f"Uploading {len(curated_records):,} curated samples to {CURATED_DATASET}...")
api = HfApi()
api.create_repo(CURATED_DATASET, repo_type="dataset", exist_ok=True)
curated_ds = Dataset.from_list(curated_records)
curated_ds.push_to_hub(CURATED_DATASET)
print(f"✅ Curated dataset uploaded: https://huggingface.co/datasets/{CURATED_DATASET}")

---
# Stage 3: COMPOSE

Select from curated pools with absolute count targets, build final SFT dataset.

In [None]:
# Stage 3A: Filtering

print("Applying Stage 3 filters...")
df = curated_records.copy()
before = len(df)

df = [s for s in df if not s["is_duplicate"]]
print(f"  After dedup: {len(df):,} (-{before - len(df):,})")
b = len(df)
df = [s for s in df if s["tokenized_length"] <= SFT_MAX_SEQ_LENGTH]
print(f"  After token \u2264 {SFT_MAX_SEQ_LENGTH}: {len(df):,} (-{b - len(df):,})")
b = len(df)
df = [s for s in df if s["lang_id"] == "ta" or s["source"] in ("vazhi_packs", "handcrafted")]
print(f"  After lang_id (source-aware): {len(df):,} (-{b - len(df):,})")
b = len(df)
df = [s for s in df if len(s["heuristic_flags"]) == 0]
print(f"  After clean heuristics: {len(df):,} (-{b - len(df):,})")
b = len(df)
df = [s for s in df if len(s["toxicity_flags"]) == 0 or s["is_safety_sample"]]
print(f"  After toxicity: {len(df):,} (-{b - len(df):,})")
b = len(df)
df = [s for s in df if s["quality_score"] >= 0.45]
print(f"  After quality \u2265 0.45: {len(df):,} (-{b - len(df):,})")
b = len(df)
df = [s for s in df if s["perplexity"] is None or s["perplexity"] < 200]
print(f"  After PPL < 200: {len(df):,} (-{b - len(df):,})")

print(f"\n\u2705 Filtering: {before:,} \u2192 {len(df):,}")
filtered_sources = Counter(s["source"] for s in df)
print(f"\nFiltered pool by source:")
for src, count in filtered_sources.most_common():
    print(f"  {src}: {count:,}")

In [None]:
# Stage 3B: Composition with Absolute Count Targets

safety_pool = [s for s in df if s["is_safety_sample"]]
non_safety = [s for s in df if not s["is_safety_sample"]]

source_pools = {}
for s in non_safety:
    source_pools.setdefault(s["source"], []).append(s)
source_pools["safety"] = safety_pool

print("Composing final dataset...")
composed = {}
total_composed = 0

for bucket_name, targets in BUCKET_TARGETS.items():
    pool = source_pools.get(bucket_name, [])
    target = targets["target"]
    min_count = targets["min"]
    max_count = targets["max"]

    if len(pool) < min_count:
        print(f"  \u26a0\ufe0f {bucket_name}: only {len(pool):,} available, min is {min_count}")
        selected = pool
    elif len(pool) <= target:
        selected = pool
    else:
        use_count = min(target, max_count)
        pool_sorted = sorted(pool, key=lambda x: x["quality_score"], reverse=True)
        selected = pool_sorted[:use_count]

    composed[bucket_name] = selected
    total_composed += len(selected)
    print(f"  {bucket_name}: {len(selected):,} / {len(pool):,} (target: {target}, range: {min_count}-{max_count})")

print(f"\n\u2705 Composition: {total_composed:,} total")

all_met = True
for bucket_name, targets in BUCKET_TARGETS.items():
    actual = len(composed.get(bucket_name, []))
    if actual < targets["min"]:
        print(f"  \u274c {bucket_name}: {actual} < min {targets['min']}")
        all_met = False
if all_met:
    print("\u2705 All bucket minimums met")

In [None]:
# Stage 3C: ChatML Conversion + Validation

all_samples = []
chatml_failures = 0

for bucket_name, samples in composed.items():
    for s in samples:
        if s["source"] == "vazhi_packs" and is_verbatim_kural_qa(s["instruction"], s["output"]):
            continue
        text = to_chatml(s["instruction"], s["output"])
        valid, reason = validate_chatml_strict(text)
        if not valid:
            chatml_failures += 1
            continue
        all_samples.append({
            "text": text, "bucket": bucket_name,
            "source": s["source"], "subset": s["subset"],
            "quality_score": s["quality_score"],
            "tokenized_length": s["tokenized_length"],
        })

random.shuffle(all_samples)
print(f"\u2705 ChatML: {len(all_samples):,} valid, {chatml_failures} failures")

bucket_counts = Counter(s["bucket"] for s in all_samples)
print(f"\n\U0001f4ca Bucket distribution:")
for bucket, count in sorted(bucket_counts.items()):
    pct = 100 * count / len(all_samples)
    print(f"  {bucket}: {count:,} ({pct:.1f}%)")

In [None]:
# Stage 3D: Stratified Train/Eval Split (90/10)

EVAL_RATIO = 0.10
train_samples = []
eval_samples = []

by_bucket = {}
for s in all_samples:
    by_bucket.setdefault(s["bucket"], []).append(s)

for bucket, samples in by_bucket.items():
    random.shuffle(samples)
    n_eval = max(1, int(len(samples) * EVAL_RATIO))
    eval_samples.extend(samples[:n_eval])
    train_samples.extend(samples[n_eval:])

random.shuffle(train_samples)
random.shuffle(eval_samples)

print(f"\U0001f4ca Split: Train={len(train_samples):,} Eval={len(eval_samples):,}")
print(f"  Eval ratio: {len(eval_samples) / (len(train_samples) + len(eval_samples)):.1%}")

max_tok = max(s["tokenized_length"] for s in all_samples)
print(f"  Max tokens: {max_tok} (limit: {SFT_MAX_SEQ_LENGTH})")
assert max_tok <= SFT_MAX_SEQ_LENGTH

In [None]:
# Stage 3E: Upload to HuggingFace + Summary

train_ds = Dataset.from_list(train_samples)
eval_ds = Dataset.from_list(eval_samples)
dataset_dict = DatasetDict({"train": train_ds, "validation": eval_ds})

api.create_repo(OUTPUT_DATASET, repo_type="dataset", exist_ok=True)
dataset_dict.push_to_hub(OUTPUT_DATASET)

print(f"\n\u2705 Uploaded: https://huggingface.co/datasets/{OUTPUT_DATASET}")
print(f"   Train: {len(train_ds):,} | Eval: {len(eval_ds):,}")

print(f"\n{'=' * 60}")
print(f"VAZHI Dataset Factory v{VERSION} \u2014 COMPLETE")
print(f"{'=' * 60}")
print(f"\n  Stage 1 (Retrieve): {total_raw:,} raw \u2192 {RAW_DATASET}")
print(f"  Stage 2 (Curate):   {len(curated_records):,} curated \u2192 {CURATED_DATASET}")
print(f"  Stage 3 (Compose):  {len(all_samples):,} final \u2192 {OUTPUT_DATASET}")

print(f"\n  Buckets:")
for bucket, count in sorted(bucket_counts.items()):
    target = BUCKET_TARGETS[bucket]
    status = "\u2705" if count >= target["min"] else "\u26a0\ufe0f"
    print(f"    {status} {bucket}: {count:,} (target: {target['target']})")

print(f"\n  Sample outputs (2 per bucket):")
shown = Counter()
for s in all_samples:
    if shown[s['bucket']] < 2:
        shown[s['bucket']] += 1
        print(f"\n  [{s['bucket'].upper()}] source={s['source']} quality={s['quality_score']:.3f}")
        match = CHATML_PATTERN.search(s["text"])
        if match:
            print(f"    Q: {match.group(1)[:100]}")
            print(f"    A: {match.group(2)[:150]}")
    if all(shown[b] >= 2 for b in BUCKET_TARGETS):
        break

print(f"\n\u2705 Done! Next: SFT training with LoRA (r=8, q_proj+v_proj, 2 epochs)")
print(f"   Base model: {DAPT_MODEL}")