# VAZHI Dataset Factory v4.1 — 3-Stage Data Pipeline

**Pipeline:** Retrieve → Curate → Compose

```
┌─ Stage 1: RETRIEVE (CPU, ~30 min) ──────────────────────────────┐
│ IndicAlign (9 subsets) + tamil-orca + GSM8K_TAMIL + local        │
│ → Upload raw to HF: CryptoYogi/vazhi-raw-tamil-qa-v1            │
├─ Stage 2: CURATE (CPU then GPU, ~3-5 hours) ───────────────────┤
│ Pass 1 (CPU): lang-id → tamil% → empties → MinHash dedup        │
│ Pass 2 (GPU): perplexity + IndicSBERT on candidate subset       │
│ → Upload curated to HF: CryptoYogi/vazhi-curated-tamil-qa-v1    │
├─ Stage 3: COMPOSE (CPU, ~5 min) ────────────────────────────────┤
│ Filter (quality + dedup + toxicity + token-length ≤ 2048)        │
│ → ChatML conversion → absolute count targets → stratified split  │
│ → Upload final SFT to HF: CryptoYogi/vazhi-tamil-sft-v4_1      │
└──────────────────────────────────────────────────────────────────┘
```

**Key changes from v4.0:**
- **Broad retrieval** — ALL available Tamil Q&A from every source (~520K+), no artificial caps
- **ML-based curation** — fasttext lang-id, MinHash dedup, perplexity scoring, semantic clustering
- **Absolute count targets** — no percentage-based anchoring that caused cascading downsampling
- **max_seq_length=2048** — stops the 74% domain pack rejection caused by 1024 window
- **Safety routing** — Toxic_Matrix/HHRLHF_T refusal pairs routed to safety bucket, not filtered
- **HF checkpointing** — each stage uploads to HF before continuing

**Run on:** Kaggle P100 (GPU needed for perplexity + embeddings in Stage 2)

**Output:** `CryptoYogi/vazhi-tamil-sft-v4_1` with `train` and `validation` splits

## 1. Config & Dependencies

In [None]:
!pip install -q datasets huggingface_hub fasttext-wheel text-dedup sentence-transformers hdbscan

import json
import os
import re
import random
import hashlib
import subprocess
from collections import Counter
from pathlib import Path

import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import login, HfApi

# === CONFIG ===
VERSION = "4.1"
RAW_DATASET = "CryptoYogi/vazhi-raw-tamil-qa-v1"
CURATED_DATASET = "CryptoYogi/vazhi-curated-tamil-qa-v1"
OUTPUT_DATASET = "CryptoYogi/vazhi-tamil-sft-v4_1"
DAPT_MODEL = "CryptoYogi/qwen3-0.6b-tamil-v1_1"
TOKENIZER_MODEL = "Qwen/Qwen3-0.6B"
REPO_URL = "https://github.com/CryptoYogiLLC/vazhi.git"
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

SFT_MAX_SEQ_LENGTH = 2048  # Training window — stops domain pack rejection

SYSTEM_PROMPT = (
    "\u0ba8\u0bc0\u0b99\u0bcd\u0b95\u0bb3\u0bcd VAZHI (\u0bb5\u0bb4\u0bbf), \u0ba4\u0bae\u0bbf\u0bb4\u0bcd \u0bae\u0b95\u0bcd\u0b95\u0bb3\u0bc1\u0b95\u0bcd\u0b95\u0bbe\u0ba9 AI \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0bb3\u0bb0\u0bcd. "
    "\u0ba4\u0bae\u0bbf\u0bb4\u0bbf\u0bb2\u0bcd \u0ba4\u0bc6\u0bb3\u0bbf\u0bb5\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0b89\u0ba4\u0bb5\u0bbf\u0baf\u0bbe\u0b95\u0bb5\u0bc1\u0bae\u0bcd \u0baa\u0ba4\u0bbf\u0bb2\u0bb3\u0bbf\u0baf\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd. "
    '\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bbe\u0bb5\u0bbf\u0b9f\u0bcd\u0b9f\u0bbe\u0bb2\u0bcd "\u0ba4\u0bc6\u0bb0\u0bbf\u0baf\u0bb5\u0bbf\u0bb2\u0bcd\u0bb2\u0bc8" \u0b8e\u0ba9\u0bcd\u0bb1\u0bc1 \u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd.'
)

# Absolute count targets per source bucket (no percentage anchoring)
BUCKET_TARGETS = {
    "vazhi_packs":    {"min": 1500, "target": 2500, "max": 3000},
    "handcrafted":    {"min": 100,  "target": 147,  "max": 200},
    "general":        {"min": 250,  "target": 350,  "max": 500},
    "indicalign":     {"min": 2000, "target": 4000, "max": 6000},
    "tamil_orca":     {"min": 1000, "target": 2000, "max": 3000},
    "gsm8k_tamil":    {"min": 200,  "target": 400,  "max": 600},
    "safety":         {"min": 200,  "target": 500,  "max": 1000},
}

# Source priority for dedup (higher = keep)
SOURCE_PRIORITY = {
    "vazhi_packs": 10, "handcrafted": 10,
    "general": 5,
    "indicalign": 2, "tamil_orca": 2, "gsm8k_tamil": 2,
}

print(f"\u2705 Config loaded: Dataset Factory v{VERSION}")
print(f"   Raw output: {RAW_DATASET}")
print(f"   Curated output: {CURATED_DATASET}")
print(f"   Final SFT output: {OUTPUT_DATASET}")
print(f"   max_seq_length: {SFT_MAX_SEQ_LENGTH}")
print(f"\n   Bucket targets (absolute counts):")
target_total = sum(v["target"] for v in BUCKET_TARGETS.values())
for name, cfg in BUCKET_TARGETS.items():
    print(f"     {name}: {cfg['target']} (range {cfg['min']}-{cfg['max']})")
print(f"   Expected total: ~{target_total}")

In [None]:
# HuggingFace login
try:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    hf_token = secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("\u2705 Logged in via Kaggle secrets")
except Exception:
    try:
        from google.colab import userdata
        hf_token = userdata.get('HF_TOKEN')
        login(token=hf_token)
        print("\u2705 Logged in via Colab secrets")
    except Exception:
        login()
        print("\u2705 Logged in interactively")

---
# Stage 1: RETRIEVE

Pull ALL available Tamil Q&A from every open-licensed source. Minimal filtering (only reject empty/null). Store with source metadata for downstream curation.

**Output:** `CryptoYogi/vazhi-raw-tamil-qa-v1` on HuggingFace

In [None]:
# Clone only data/ and vazhi-packs/ from the public repo (sparse checkout)
REPO_DIR = Path("/kaggle/working/vazhi")
if not REPO_DIR.exists():
    print(f"Cloning {REPO_URL} (sparse: data/ + vazhi-packs/)...")
    subprocess.run(
        ["git", "clone", "--depth", "1", "--filter=blob:none", "--sparse", REPO_URL, str(REPO_DIR)],
        check=True,
    )
    subprocess.run(
        ["git", "sparse-checkout", "set", "data/", "vazhi-packs/"],
        cwd=str(REPO_DIR),
        check=True,
    )
    print(f"\u2705 Cloned to {REPO_DIR}")
else:
    subprocess.run(["git", "pull"], cwd=str(REPO_DIR), check=True)
    print(f"\u2705 Repo already at {REPO_DIR}, pulled latest")

SOURCES_DIR = REPO_DIR / "data" / "sources"
LEGACY_DIR = REPO_DIR / "data" / "LEGACY"
PACKS_DIR = SOURCES_DIR / "sft" / "vazhi-packs"
HANDCRAFTED_DIR = SOURCES_DIR / "sft" / "handcrafted"

for label, d in [("sources/sft/vazhi-packs", PACKS_DIR), ("sources/sft/handcrafted", HANDCRAFTED_DIR), ("LEGACY", LEGACY_DIR)]:
    assert d.exists(), f"Missing: {d}"
    print(f"  \u2705 {label}: {d}")
print(f"\n\u2705 All source data available")

In [None]:
# Helper functions (reused from v4.0 + new additions)

def to_chatml(instruction, output, system_prompt=None):
    """Convert instruction/output pair to strict ChatML format."""
    sp = system_prompt or SYSTEM_PROMPT
    return (
        f"<|im_start|>system\n{sp}<|im_end|>\n"
        f"<|im_start|>user\n{instruction}<|im_end|>\n"
        f"<|im_start|>assistant\n{output}<|im_end|>"
    )


CHATML_PATTERN = re.compile(
    r'<\|im_start\|>system\n.+?<\|im_end\|>\n'
    r'<\|im_start\|>user\n(.+?)<\|im_end\|>\n'
    r'<\|im_start\|>assistant\n(.+?)<\|im_end\|>',
    re.DOTALL
)


def validate_chatml_strict(text):
    """Validate a sample has proper ChatML with non-empty user AND assistant."""
    match = CHATML_PATTERN.search(text)
    if not match:
        return False, "no ChatML structure found"
    user_content = match.group(1).strip()
    assistant_content = match.group(2).strip()
    if len(user_content) < 2:
        return False, "empty user content"
    if len(assistant_content) < 2:
        return False, "empty assistant content"
    return True, "ok"


def count_tamil_chars(text):
    """Count Tamil Unicode characters (U+0B80 to U+0BFF)."""
    return sum(1 for c in text if '\u0b80' <= c <= '\u0bff')


def tamil_char_pct(text):
    """Get Tamil character percentage (0.0-1.0)."""
    if not text:
        return 0.0
    return count_tamil_chars(text) / len(text)


def is_verbatim_kural_qa(question, answer):
    """Reject Q&As that ask for exact verse text."""
    verbatim_patterns = [
        r'\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*\d+\s*(\u0b8e\u0ba9\u0bcd\u0ba9|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd|\u0b8e\u0bb4\u0bc1\u0ba4\u0bbf\s*\u0b95\u0bbe\u0b9f\u0bcd\u0b9f\u0bc1|\u0b95\u0bc2\u0bb1\u0bc1\u0b95)',
        r'(first|\u0bae\u0bc1\u0ba4\u0bb2\u0bcd)\s*\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*(\u0b8e\u0ba9\u0bcd\u0ba9|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1|\u0b9a\u0bca\u0bb2\u0bcd\u0bb2\u0bc1\u0b99\u0bcd\u0b95\u0bb3\u0bcd)',
        r'\u0ba4\u0bbf\u0bb0\u0bc1\u0b95\u0bcd\u0b95\u0bc1\u0bb1\u0bb3\u0bbf\u0ba9\u0bcd\s+\u0bae\u0bc1\u0ba4\u0bb2\u0bcd\s+\u0b95\u0bc1\u0bb1\u0bb3\u0bcd',
        r'\u0b95\u0bc1\u0bb1\u0bb3\u0bcd\s*[\d]+(?:\s*\u0b90)?\s*\u0b8e\u0bb4\u0bc1\u0ba4\u0bbf',
    ]
    for pat in verbatim_patterns:
        if re.search(pat, question, re.IGNORECASE):
            return True
    if len(answer) < 200 and "\n" in answer and not any(
        w in answer for w in ["\u0bb5\u0bbf\u0bb3\u0b95\u0bcd\u0b95\u0bae\u0bcd", "\u0baa\u0bca\u0bb0\u0bc1\u0bb3\u0bcd", "\u0b85\u0bb0\u0bcd\u0ba4\u0bcd\u0ba4\u0bae\u0bcd"]
    ):
        return True
    return False


# Self-test
good = to_chatml("test question", "test answer")
valid, reason = validate_chatml_strict(good)
assert valid, f"Self-test failed: {reason}"
assert 0.0 <= tamil_char_pct("\u0ba4\u0bae\u0bbf\u0bb4\u0bcd test") <= 1.0
print("\u2705 Helper functions defined and self-tested")

In [None]:
# Stage 1A: Load IndicAlign — ALL 9 Tamil subsets (no cap)
# Stream each completely, extract tam_Taml pairs, tag with source metadata.

INDICALIGN_SUBSETS = [
    "Dolly_T", "WikiHow", "Wiki_Conv", "OpenAssistant_T",
    "Anudesh", "HHRLHF_T", "Indic_ShareLlama", "Wiki_Chat", "Toxic_Matrix",
]

raw_samples = []  # Global accumulator for all Stage 1 sources

print("Loading IndicAlign subsets (streaming all)...")
for subset in INDICALIGN_SUBSETS:
    try:
        ds = load_dataset(
            "ai4bharat/indic-align",
            subset,
            split="train",
            streaming=True,
        )

        subset_count = 0
        batch = []

        for item in ds:
            pairs = item.get("tam_Taml", [])
            if not pairs or not isinstance(pairs, (list, tuple)):
                continue

            for pair in pairs:
                if not isinstance(pair, (list, tuple)) or len(pair) < 2:
                    continue

                instruction = str(pair[0]).strip() if pair[0] else ""
                output = str(pair[1]).strip() if pair[1] else ""

                # Minimal filter: reject empty/null only
                if not instruction or not output:
                    continue

                batch.append({
                    "instruction": instruction,
                    "output": output,
                    "source": "indicalign",
                    "subset": subset,
                })
                subset_count += 1

            # Flush batch every 10K to manage memory (for Wiki_Chat ~198K)
            if len(batch) >= 10000:
                raw_samples.extend(batch)
                batch = []

        raw_samples.extend(batch)
        print(f"  {subset}: {subset_count:,} pairs")

    except Exception as e:
        print(f"  {subset}: FAILED - {e}")

indicalign_total = sum(1 for s in raw_samples if s["source"] == "indicalign")
print(f"\n\u2705 IndicAlign total: {indicalign_total:,} pairs")

In [None]:
# Stage 1B: Load tamil-orca (~97K samples)
# Schema: Instruction, Query, Answer — combine Instruction+Query into instruction field.

print("Loading tamil-orca (streaming)...")
try:
    ds = load_dataset("azharmo/tamil-orca", split="train", streaming=True)
    orca_count = 0
    batch = []

    for item in ds:
        instruction_part = str(item.get("Instruction", "")).strip()
        query_part = str(item.get("Query", "")).strip()
        answer = str(item.get("Answer", "")).strip()

        # Combine Instruction + Query for the instruction field
        if instruction_part and query_part:
            instruction = f"{instruction_part}\n{query_part}"
        elif query_part:
            instruction = query_part
        elif instruction_part:
            instruction = instruction_part
        else:
            continue

        if not answer:
            continue

        batch.append({
            "instruction": instruction,
            "output": answer,
            "source": "tamil_orca",
            "subset": "",
        })
        orca_count += 1

        if len(batch) >= 10000:
            raw_samples.extend(batch)
            batch = []

    raw_samples.extend(batch)
    print(f"\u2705 tamil-orca: {orca_count:,} samples")
except Exception as e:
    print(f"\u274c tamil-orca FAILED: {e}")

In [None]:
# Stage 1C: Load GSM8K_TAMIL (~8.8K math samples)
# Schema: question, answer

print("Loading GSM8K_TAMIL (streaming)...")
try:
    ds = load_dataset("Vishal0407/GSM8K_TAMIL", split="train", streaming=True)
    gsm_count = 0

    for item in ds:
        question = str(item.get("question", "")).strip()
        answer = str(item.get("answer", "")).strip()

        if not question or not answer:
            continue

        raw_samples.append({
            "instruction": question,
            "output": answer,
            "source": "gsm8k_tamil",
            "subset": "",
        })
        gsm_count += 1

    print(f"\u2705 GSM8K_TAMIL: {gsm_count:,} samples")
except Exception as e:
    print(f"\u274c GSM8K_TAMIL FAILED: {e}")

In [None]:
# Stage 1D: Load local sources (vazhi-packs, handcrafted, general from LEGACY)

# --- vazhi-packs ---
print("Loading vazhi-packs...")
packs_count = 0
for pack_file in sorted(PACKS_DIR.glob("*.json")):
    with open(pack_file, encoding="utf-8") as f:
        pairs = json.load(f)
    file_count = 0
    for pair in pairs:
        instruction = pair.get("instruction", "").strip()
        output = pair.get("output", "").strip()
        if not instruction or not output:
            continue
        raw_samples.append({
            "instruction": instruction,
            "output": output,
            "source": "vazhi_packs",
            "subset": pack_file.stem,
        })
        file_count += 1
        packs_count += 1
    print(f"  {pack_file.stem}: {file_count}")
print(f"\u2705 vazhi-packs total: {packs_count:,}")

# --- handcrafted ---
print("\nLoading handcrafted...")
hc_count = 0
for hc_file in sorted(HANDCRAFTED_DIR.glob("*.json")):
    with open(hc_file, encoding="utf-8") as f:
        items = json.load(f)
    file_count = 0
    for item in items:
        instruction = item.get("instruction", "").strip()
        output = item.get("output", "").strip()
        if not instruction or not output:
            continue
        raw_samples.append({
            "instruction": instruction,
            "output": output,
            "source": "handcrafted",
            "subset": hc_file.stem,
        })
        file_count += 1
        hc_count += 1
    print(f"  {hc_file.stem}: {file_count}")
print(f"\u2705 handcrafted total: {hc_count:,}")

# --- general (LEGACY) ---
print("\nLoading general (LEGACY)...")
GENERAL_FILES = [
    "06_health.json", "09_weather.json", "10_shopping.json",
    "12_daily_routines.json", "13_emotions.json",
    "14_chennai_dialect.json", "15_madurai_dialect.json",
    "16_kongu_dialect.json", "31_malaysia_dialect.json",
    "03_numbers_time.json",
]
gen_count = 0
for fname in GENERAL_FILES:
    fpath = LEGACY_DIR / fname
    if not fpath.exists():
        print(f"  {fname}: NOT FOUND, skipping")
        continue
    with open(fpath, encoding="utf-8") as f:
        items = json.load(f)
    file_count = 0
    for item in items:
        instruction = item.get("instruction", item.get("question", "")).strip()
        output = item.get("output", item.get("answer", "")).strip()
        if not instruction or not output:
            continue
        raw_samples.append({
            "instruction": instruction,
            "output": output,
            "source": "general",
            "subset": fname.replace(".json", ""),
        })
        file_count += 1
        gen_count += 1
    print(f"  {fname}: {file_count}")
print(f"\u2705 general total: {gen_count:,}")

In [None]:
# Stage 1E: Schema normalization — compute tamil_pct and char_length for all samples

print(f"Normalizing {len(raw_samples):,} samples...")
for s in raw_samples:
    combined = s["instruction"] + s["output"]
    s["char_length"] = len(combined)
    s["tamil_pct"] = round(tamil_char_pct(combined), 4)

print(f"\u2705 Schema normalization complete")
print(f"   Fields: {list(raw_samples[0].keys())}")

In [None]:
# Stage 1F: Stats, spot-check, and upload raw dataset to HF

print("=" * 60)
print(f"STAGE 1 COMPLETE: {len(raw_samples):,} raw samples")
print("=" * 60)

# Per-source counts
source_counts = Counter(s["source"] for s in raw_samples)
print("\nPer-source counts:")
for src, count in source_counts.most_common():
    print(f"  {src}: {count:,}")

# Per-subset counts for IndicAlign
subset_counts = Counter(s["subset"] for s in raw_samples if s["source"] == "indicalign")
print("\nIndicAlign subset breakdown:")
for sub, count in subset_counts.most_common():
    print(f"  {sub}: {count:,}")

# Length distribution
lengths = [s["char_length"] for s in raw_samples]
print(f"\nLength stats: min={min(lengths)}, median={sorted(lengths)[len(lengths)//2]}, max={max(lengths)}, mean={sum(lengths)/len(lengths):.0f}")

# Tamil % distribution
tamil_pcts = [s["tamil_pct"] for s in raw_samples]
print(f"Tamil% stats: min={min(tamil_pcts):.2f}, median={sorted(tamil_pcts)[len(tamil_pcts)//2]:.2f}, mean={sum(tamil_pcts)/len(tamil_pcts):.2f}")

# Spot-check: 3 random samples from each source
print("\n" + "=" * 60)
print("Spot-check (3 per source):")
print("=" * 60)
for src in sorted(source_counts.keys()):
    src_samples = [s for s in raw_samples if s["source"] == src]
    picks = random.sample(src_samples, min(3, len(src_samples)))
    print(f"\n[{src.upper()}]")
    for p in picks:
        print(f"  Q: {p['instruction'][:80]}...")
        print(f"  A: {p['output'][:80]}...")
        print(f"  tamil%={p['tamil_pct']:.2f} len={p['char_length']}")

# Upload to HuggingFace
print(f"\nUploading {len(raw_samples):,} raw samples to {RAW_DATASET}...")
raw_ds = Dataset.from_list(raw_samples)
api = HfApi()
api.create_repo(RAW_DATASET, repo_type="dataset", exist_ok=True)
raw_ds.push_to_hub(RAW_DATASET)
print(f"\u2705 Raw dataset uploaded: https://huggingface.co/datasets/{RAW_DATASET}")

---
# Stage 2: CURATE

Two-pass curation: cheap CPU filters first to reduce the pool, then GPU-based scoring on the surviving candidates.

**Input:** In-memory raw dataset (also backed up on HF from Stage 1)  
**Output:** `CryptoYogi/vazhi-curated-tamil-qa-v1` on HuggingFace

### Pass 1 (CPU): Language detection → heuristics → dedup → toxicity  
### Pass 2 (GPU): Perplexity scoring → semantic categorization → composite quality score

In [None]:
# Stage 2 — Pass 1A: Language Detection with fasttext lid.176.bin

import fasttext
import urllib.request

LID_MODEL_PATH = "/kaggle/working/lid.176.bin"
if not os.path.exists(LID_MODEL_PATH):
    print("Downloading fasttext language ID model (126MB)...")
    urllib.request.urlretrieve(
        "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
        LID_MODEL_PATH,
    )
lid_model = fasttext.load_model(LID_MODEL_PATH)
print("\u2705 fasttext lid.176.bin loaded")


def detect_lang(text):
    """Detect language using fasttext. Returns (lang_code, confidence)."""
    clean = text.replace("\n", " ").strip()
    if not clean:
        return "unknown", 0.0
    labels, scores = lid_model.predict(clean, k=1)
    return labels[0].replace("__label__", ""), float(scores[0])


# Run language detection on all samples
print(f"Running language detection on {len(raw_samples):,} samples...")
lang_dropped = 0
for i, s in enumerate(raw_samples):
    combined = s["instruction"] + " " + s["output"]
    lang_id, lang_conf = detect_lang(combined)
    s["lang_id"] = lang_id
    s["lang_confidence"] = round(lang_conf, 4)
    if i % 50000 == 0 and i > 0:
        print(f"  ...{i:,} / {len(raw_samples):,}")

# Apply threshold: keep only Tamil with confidence >= 0.6
before = len(raw_samples)
candidates = [s for s in raw_samples if s["lang_id"] == "ta" and s["lang_confidence"] >= 0.6]
lang_dropped = before - len(candidates)

print(f"\n\u2705 Language detection complete")
print(f"   Before: {before:,}")
print(f"   Dropped (non-Tamil or low confidence): {lang_dropped:,}")
print(f"   Remaining: {len(candidates):,}")

# Show lang distribution of dropped samples
non_tamil = [s for s in raw_samples if s["lang_id"] != "ta" or s["lang_confidence"] < 0.6]
dropped_langs = Counter(s["lang_id"] for s in non_tamil)
print(f"   Dropped language distribution: {dropped_langs.most_common(10)}")

In [None]:
# Stage 2 — Pass 1B: Quality Heuristics + Format Sanity

def repetition_score(text):
    """Most-common-token ratio. High = repetitive/garbage."""
    toks = text.split()
    if len(toks) < 5:
        return 0.0
    most = Counter(toks).most_common(1)[0][1] / len(toks)
    return most


def is_sane(instruction, output):
    """Format sanity — catches template corruption."""
    if "<think>" in output:
        return False
    if output.strip() == instruction.strip():
        return False
    if "data:image" in output:
        return False
    if "<|im_start|>system" in output:
        return False
    if "systemsystem" in output.lower():
        return False
    return True


def is_echo(instruction, output):
    """Check if answer is >80% copy of instruction."""
    if not instruction or not output:
        return False
    # Simple overlap check
    instr_set = set(instruction.split())
    out_set = set(output.split())
    if not out_set:
        return False
    overlap = len(instr_set & out_set) / len(out_set)
    return overlap > 0.80


print(f"Running quality heuristics on {len(candidates):,} candidates...")
heuristic_stats = Counter()
clean_candidates = []

for s in candidates:
    flags = []

    # Tamil % threshold
    if s["tamil_pct"] < 0.30:
        flags.append("low_tamil_pct")

    # Repetition score
    rep = repetition_score(s["output"])
    s["repetition"] = round(rep, 4)
    if rep > 0.25 and len(s["output"].split()) > 20:
        flags.append("high_repetition")

    # Echo detection
    if is_echo(s["instruction"], s["output"]):
        flags.append("echo")

    # Format sanity
    if not is_sane(s["instruction"], s["output"]):
        flags.append("format_insane")

    # Empty/trivial filter
    if len(s["output"].strip()) < 10:
        flags.append("trivial_output")
    if len(s["instruction"].strip()) < 5:
        flags.append("trivial_instruction")

    s["heuristic_flags"] = flags

    for f in flags:
        heuristic_stats[f] += 1

    if len(flags) == 0:
        clean_candidates.append(s)

heuristic_dropped = len(candidates) - len(clean_candidates)
print(f"\n\u2705 Heuristic filtering complete")
print(f"   Before: {len(candidates):,}")
print(f"   Dropped: {heuristic_dropped:,}")
print(f"   Remaining: {len(clean_candidates):,}")
print(f"   Flag distribution:")
for flag, count in heuristic_stats.most_common():
    print(f"     {flag}: {count:,}")

In [None]:
# Stage 2 — Pass 1C: Deduplication with Source Priority
# Phase 1: exact dedupe on instruction[:200] (fast hash-based)
# Phase 2: near-duplicate detection via MinHash (if time permits)

print(f"Deduplicating {len(clean_candidates):,} candidates...")

# Phase 1: Exact dedupe on instruction[:200]
seen = {}
for s in clean_candidates:
    key = s["instruction"][:200].strip().lower()
    if key not in seen:
        seen[key] = s
    else:
        # Keep higher-priority source
        existing_priority = SOURCE_PRIORITY.get(seen[key]["source"], 1)
        new_priority = SOURCE_PRIORITY.get(s["source"], 1)
        if new_priority > existing_priority:
            seen[key] = s
        elif new_priority == existing_priority and s["tamil_pct"] > seen[key]["tamil_pct"]:
            seen[key] = s

exact_deduped = list(seen.values())
exact_dupes = len(clean_candidates) - len(exact_deduped)
print(f"  Phase 1 (exact): {len(clean_candidates):,} \u2192 {len(exact_deduped):,} ({exact_dupes:,} duplicates removed)")

# Phase 2: MinHash near-duplicate detection (Jaccard >= 0.8)
# Uses MD5 hash of character 3-grams for efficiency
def get_char_ngrams(text, n=3):
    """Get character n-grams for MinHash."""
    text = text.lower().strip()
    return set(text[i:i+n] for i in range(len(text) - n + 1))

def minhash_signature(ngrams, num_hashes=128):
    """Compute MinHash signature."""
    if not ngrams:
        return [float('inf')] * num_hashes
    sig = []
    for i in range(num_hashes):
        min_hash = float('inf')
        for ng in ngrams:
            h = int(hashlib.md5(f"{i}_{ng}".encode()).hexdigest(), 16) % (2**32)
            min_hash = min(min_hash, h)
        sig.append(min_hash)
    return sig

def jaccard_from_sigs(sig1, sig2):
    """Estimate Jaccard similarity from MinHash signatures."""
    return sum(a == b for a, b in zip(sig1, sig2)) / len(sig1)

# Only run MinHash if pool is manageable (< 300K)
if len(exact_deduped) <= 300000:
    print(f"  Phase 2 (MinHash): computing signatures for {len(exact_deduped):,} samples...")
    sigs = []
    for i, s in enumerate(exact_deduped):
        ngrams = get_char_ngrams(s["instruction"])
        sigs.append(minhash_signature(ngrams))
        if i % 50000 == 0 and i > 0:
            print(f"    ...{i:,} / {len(exact_deduped):,} signatures computed")

    # LSH-style banding for efficient near-duplicate detection
    # Band size = 4, so 128/4 = 32 bands — catches Jaccard >= ~0.8
    BAND_SIZE = 4
    NUM_BANDS = 128 // BAND_SIZE
    near_dupes = set()

    for band_idx in range(NUM_BANDS):
        buckets = {}
        start = band_idx * BAND_SIZE
        end = start + BAND_SIZE
        for i, sig in enumerate(sigs):
            band_hash = tuple(sig[start:end])
            if band_hash in buckets:
                # Found a candidate pair — verify with full Jaccard
                j = buckets[band_hash]
                if i not in near_dupes and j not in near_dupes:
                    jac = jaccard_from_sigs(sigs[i], sigs[j])
                    if jac >= 0.8:
                        # Keep higher-priority source
                        si, sj = exact_deduped[i], exact_deduped[j]
                        pi = SOURCE_PRIORITY.get(si["source"], 1)
                        pj = SOURCE_PRIORITY.get(sj["source"], 1)
                        if pi < pj or (pi == pj and si["tamil_pct"] < sj["tamil_pct"]):
                            near_dupes.add(i)
                        else:
                            near_dupes.add(j)
            else:
                buckets[band_hash] = i

    deduped = [s for i, s in enumerate(exact_deduped) if i not in near_dupes]
    print(f"  Phase 2 (MinHash): {len(exact_deduped):,} \u2192 {len(deduped):,} ({len(near_dupes):,} near-duplicates removed)")
else:
    deduped = exact_deduped
    print(f"  Phase 2 (MinHash): skipped (pool > 300K, exact dedupe sufficient)")

# Mark all with is_duplicate = False (survivors)
for s in deduped:
    s["is_duplicate"] = False

total_dedup_removed = len(clean_candidates) - len(deduped)
print(f"\n\u2705 Deduplication complete: {len(clean_candidates):,} \u2192 {len(deduped):,} ({total_dedup_removed:,} total removed)")

In [None]:
# Stage 2 — Pass 1D: Toxicity Screening with Source-Aware Logic
# For Toxic_Matrix/HHRLHF_T: toxic prompt + safe output = safety training data
# For all others: any toxicity = flag for exclusion

# Tamil toxicity wordlist (common slurs and NSFW terms)
# Minimal list — expand as needed
TOXIC_WORDS_TA = [
    "\u0ba4\u0bc7\u0bb5\u0b9f\u0bbf\u0baf\u0bbe",  # derogatory
    "\u0bb5\u0bc7\u0b9a\u0bbf",  # slur
    "\u0baa\u0bc1\u0ba3\u0bcd\u0b9f\u0bc8",  # vulgar
    "\u0ba4\u0bb1\u0bcd\u0b95\u0bca\u0bb2\u0bc8",  # murder
    "\u0b95\u0bca\u0bb2\u0bcd\u0bb2",  # kill
    "\u0bb5\u0bc6\u0b9f\u0bbf\u0baa\u0bca\u0bb0\u0bc1\u0bb3\u0bcd",  # explosive
]

TOXIC_WORDS_EN = [
    "suicide method", "how to kill", "make a bomb", "hack into",
    "child porn", "exploit children",
]


def check_wordlist(text):
    """Check text against toxicity wordlists. Returns list of matched terms."""
    text_lower = text.lower()
    flags = []
    for word in TOXIC_WORDS_TA:
        if word in text:
            flags.append(f"ta:{word[:10]}")
    for phrase in TOXIC_WORDS_EN:
        if phrase in text_lower:
            flags.append(f"en:{phrase[:15]}")
    return flags


def classify_toxicity(instruction, output, source, subset):
    """Source-aware toxicity classification."""
    instr_flags = check_wordlist(instruction)
    output_flags = check_wordlist(output)

    # Toxic_Matrix / HHRLHF_T: toxic prompt + safe response = safety training data
    if subset in ("Toxic_Matrix", "HHRLHF_T") and instr_flags and not output_flags:
        return instr_flags, True  # is_safety_sample = True

    # Normal sources: any toxicity = flag for exclusion
    return instr_flags + output_flags, False


print(f"Running toxicity screening on {len(deduped):,} candidates...")
safety_count = 0
toxic_dropped = 0

for s in deduped:
    tox_flags, is_safety = classify_toxicity(
        s["instruction"], s["output"], s["source"], s["subset"]
    )
    s["toxicity_flags"] = tox_flags
    s["is_safety_sample"] = is_safety
    if is_safety:
        safety_count += 1
    elif tox_flags:
        toxic_dropped += 1

print(f"\u2705 Toxicity screening complete")
print(f"   Safety samples (toxic prompt + safe refusal): {safety_count:,}")
print(f"   Toxic (will be excluded in Stage 3): {toxic_dropped:,}")
print(f"   Clean: {len(deduped) - safety_count - toxic_dropped:,}")

In [None]:
# Stage 2 — Pass 1 Summary Stats

print("=" * 60)
print("PASS 1 SUMMARY (CPU Filters)")
print("=" * 60)
print(f"  Raw input:          {len(raw_samples):,}")
print(f"  After lang-id:      {len(candidates):,} (-{lang_dropped:,})")
print(f"  After heuristics:   {len(clean_candidates):,} (-{heuristic_dropped:,})")
print(f"  After dedup:        {len(deduped):,} (-{total_dedup_removed:,})")
print(f"  Safety routed:      {safety_count:,}")
print(f"  Toxic flagged:      {toxic_dropped:,}")
print(f"  \u2192 Pass 1 output:   {len(deduped):,} candidates")

# Per-source breakdown of survivors
survivor_sources = Counter(s["source"] for s in deduped)
print(f"\nSurvivors by source:")
for src, count in survivor_sources.most_common():
    print(f"  {src}: {count:,}")

In [None]:
# Stage 2 — Pass 2E: Perplexity Scoring (GPU)
# Uses DAPT v1.1 model to score output quality.
# PPL is a fluency metric, NOT a quality metric — use as weak signal for garbage detection.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"Loading DAPT v1.1 model for perplexity scoring: {DAPT_MODEL}")
ppl_tokenizer = AutoTokenizer.from_pretrained(DAPT_MODEL)
ppl_model = AutoModelForCausalLM.from_pretrained(
    DAPT_MODEL, torch_dtype=torch.float16, device_map="auto"
)
ppl_model.eval()
print(f"\u2705 PPL model loaded on {ppl_model.device}")

# Stratified sample if pool > 200K (proportional per source)
MAX_PPL_SAMPLES = 200000
if len(deduped) > MAX_PPL_SAMPLES:
    print(f"Pool ({len(deduped):,}) > {MAX_PPL_SAMPLES:,}, stratified sampling...")
    by_source = {}
    for s in deduped:
        by_source.setdefault(s["source"], []).append(s)
    ppl_candidates = []
    for src, samples in by_source.items():
        proportion = len(samples) / len(deduped)
        n = max(100, int(MAX_PPL_SAMPLES * proportion))
        ppl_candidates.extend(random.sample(samples, min(n, len(samples))))
    ppl_candidate_set = set(id(s) for s in ppl_candidates)
    print(f"  Selected {len(ppl_candidates):,} for PPL scoring")
else:
    ppl_candidates = deduped
    ppl_candidate_set = set(id(s) for s in ppl_candidates)
    print(f"  Scoring all {len(ppl_candidates):,} candidates")


def compute_perplexity(text, max_length=512):
    """Compute perplexity for a single text."""
    inputs = ppl_tokenizer(
        text, return_tensors="pt", truncation=True, max_length=max_length
    ).to(ppl_model.device)
    with torch.no_grad():
        out = ppl_model(**inputs, labels=inputs["input_ids"])
    return torch.exp(out.loss).item()


print(f"\nComputing perplexity for {len(ppl_candidates):,} samples...")
for i, s in enumerate(ppl_candidates):
    try:
        ppl = compute_perplexity(s["output"])
        s["perplexity"] = round(ppl, 2)
    except Exception:
        s["perplexity"] = None
    if (i + 1) % 10000 == 0:
        print(f"  ...{i+1:,} / {len(ppl_candidates):,}")

# Mark unscored samples
for s in deduped:
    if "perplexity" not in s:
        s["perplexity"] = None

# Stats
scored = [s["perplexity"] for s in deduped if s["perplexity"] is not None]
if scored:
    print(f"\n\u2705 Perplexity scoring complete")
    print(f"   Scored: {len(scored):,} / {len(deduped):,}")
    print(f"   PPL stats: min={min(scored):.1f}, median={sorted(scored)[len(scored)//2]:.1f}, max={max(scored):.1f}, mean={sum(scored)/len(scored):.1f}")
    print(f"   PPL < 50: {sum(1 for p in scored if p < 50):,}")
    print(f"   PPL 50-200: {sum(1 for p in scored if 50 <= p < 200):,}")
    print(f"   PPL >= 200 (likely garbage): {sum(1 for p in scored if p >= 200):,}")

# Free GPU memory
del ppl_model
torch.cuda.empty_cache()
print("\u2705 PPL model unloaded, GPU memory freed")

In [None]:
# Stage 2 — Pass 2F: Semantic Categorization (IndicSBERT + HDBSCAN)
# Compute sentence embeddings, cluster with HDBSCAN, map to VAZHI domains.

from sentence_transformers import SentenceTransformer
import hdbscan

print("Loading IndicSBERT for semantic categorization...")
sbert_model = SentenceTransformer("l3cube-pune/tamil-sentence-similarity-sbert")
print(f"\u2705 IndicSBERT loaded")

# Compute embeddings for instructions (batch processing)
instructions = [s["instruction"][:512] for s in deduped]  # Truncate for efficiency
print(f"Computing embeddings for {len(instructions):,} instructions...")
embeddings = sbert_model.encode(
    instructions, batch_size=256, show_progress_bar=True, normalize_embeddings=True
)
print(f"\u2705 Embeddings computed: shape {embeddings.shape}")

# Cluster with HDBSCAN
print("Clustering with HDBSCAN...")
clusterer = hdbscan.HDBSCAN(
    min_cluster_size=50,
    min_samples=10,
    metric="euclidean",
    cluster_selection_method="eom",
)
cluster_labels = clusterer.fit_predict(embeddings)

n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
noise_count = sum(1 for l in cluster_labels if l == -1)
print(f"\u2705 Clustering complete: {n_clusters} clusters, {noise_count:,} noise points")

# Assign cluster labels to samples
for i, s in enumerate(deduped):
    s["embedding_cluster"] = int(cluster_labels[i])
    s["auto_category"] = f"cluster_{cluster_labels[i]}" if cluster_labels[i] >= 0 else "unclustered"

# Show top 10 clusters with sample instructions
cluster_counts = Counter(cluster_labels)
print(f"\nTop 10 clusters:")
for cluster_id, count in cluster_counts.most_common(11):
    if cluster_id == -1:
        print(f"  noise: {count:,} samples")
        continue
    cluster_samples = [s for s in deduped if s["embedding_cluster"] == cluster_id]
    example = cluster_samples[0]["instruction"][:80]
    print(f"  cluster_{cluster_id}: {count:,} samples — e.g. \"{example}...\"")

# Free GPU memory
del sbert_model, embeddings
torch.cuda.empty_cache()
print("\n\u2705 SBERT model unloaded, GPU memory freed")

In [None]:
# Stage 2 — Pass 2G: Composite Quality Score + Tokenized Length

from transformers import AutoTokenizer as AT

print(f"Loading tokenizer for length validation: {TOKENIZER_MODEL}")
sft_tokenizer = AT.from_pretrained(TOKENIZER_MODEL, trust_remote_code=True)
print(f"\u2705 Tokenizer loaded (vocab size: {sft_tokenizer.vocab_size:,})")


def compute_quality_score(s):
    """Composite quality score (0-1). Higher = better."""
    lang_conf = s.get("lang_confidence", 0.0)
    rep = s.get("repetition", 0.0)
    tamil = s.get("tamil_pct", 0.0)
    tok_len = s.get("tokenized_length", SFT_MAX_SEQ_LENGTH + 1)
    tox = s.get("toxicity_flags", [])

    score = (
        lang_conf * 0.35 +
        (1 / (1 + rep * 10)) * 0.25 +
        tamil * 0.20 +
        (1.0 if tok_len <= SFT_MAX_SEQ_LENGTH else 0.0) * 0.10 +
        (0.0 if tox else 1.0) * 0.10
    )
    return round(score, 4)


print(f"Computing tokenized lengths and quality scores for {len(deduped):,} samples...")
for i, s in enumerate(deduped):
    # Tokenized length (ChatML-wrapped)
    chatml_text = to_chatml(s["instruction"], s["output"])
    tok_len = len(sft_tokenizer.encode(chatml_text, add_special_tokens=False))
    s["tokenized_length"] = tok_len

    # Quality score
    s["quality_score"] = compute_quality_score(s)

    if (i + 1) % 50000 == 0:
        print(f"  ...{i+1:,} / {len(deduped):,}")

# Stats
tok_lengths = [s["tokenized_length"] for s in deduped]
quality_scores = [s["quality_score"] for s in deduped]

print(f"\n\u2705 Scoring complete")
print(f"   Token length: min={min(tok_lengths)}, median={sorted(tok_lengths)[len(tok_lengths)//2]}, max={max(tok_lengths)}")
print(f"   Within 2048 tokens: {sum(1 for t in tok_lengths if t <= SFT_MAX_SEQ_LENGTH):,} / {len(tok_lengths):,}")
print(f"   Quality score: min={min(quality_scores):.3f}, median={sorted(quality_scores)[len(quality_scores)//2]:.3f}, mean={sum(quality_scores)/len(quality_scores):.3f}")
print(f"   Score >= 0.45: {sum(1 for q in quality_scores if q >= 0.45):,}")
print(f"   Score < 0.45: {sum(1 for q in quality_scores if q < 0.45):,}")

In [None]:
# Stage 2 — Upload curated dataset to HF

# Ensure all fields are serializable
curated_records = []
for s in deduped:
    curated_records.append({
        "instruction": s["instruction"],
        "output": s["output"],
        "source": s["source"],
        "subset": s["subset"],
        "char_length": s["char_length"],
        "tamil_pct": s["tamil_pct"],
        "lang_id": s["lang_id"],
        "lang_confidence": s["lang_confidence"],
        "heuristic_flags": s.get("heuristic_flags", []),
        "repetition": s.get("repetition", 0.0),
        "toxicity_flags": s.get("toxicity_flags", []),
        "is_safety_sample": s.get("is_safety_sample", False),
        "is_duplicate": s.get("is_duplicate", False),
        "perplexity": s.get("perplexity"),
        "embedding_cluster": s.get("embedding_cluster"),
        "auto_category": s.get("auto_category"),
        "quality_score": s.get("quality_score", 0.0),
        "tokenized_length": s.get("tokenized_length", 0),
    })

print(f"\nUploading {len(curated_records):,} curated samples to {CURATED_DATASET}...")
curated_ds = Dataset.from_list(curated_records)
api.create_repo(CURATED_DATASET, repo_type="dataset", exist_ok=True)
curated_ds.push_to_hub(CURATED_DATASET)

print(f"\n\u2705 Curated dataset uploaded: https://huggingface.co/datasets/{CURATED_DATASET}")
print(f"   Total samples: {len(curated_records):,}")
print(f"   Schema: {list(curated_records[0].keys())}")

---
# Stage 3: COMPOSE

Select from curated pools with absolute count targets, build final SFT dataset.

**Input:** In-memory curated dataset (also backed up on HF from Stage 2)  
**Output:** `CryptoYogi/vazhi-tamil-sft-v4_1` on HuggingFace

In [None]:
# Stage 3A: Filtering

print("Applying Stage 3 filters...")
df = curated_records.copy()  # Work on list of dicts

before = len(df)

# Hard filters
df = [s for s in df if not s["is_duplicate"]]
print(f"  After dedup filter: {len(df):,} (-{before - len(df):,})")
before2 = len(df)

df = [s for s in df if s["tokenized_length"] <= SFT_MAX_SEQ_LENGTH]
print(f"  After token length \u2264 {SFT_MAX_SEQ_LENGTH}: {len(df):,} (-{before2 - len(df):,})")
before3 = len(df)

df = [s for s in df if s["lang_id"] == "ta"]
print(f"  After lang_id == ta: {len(df):,} (-{before3 - len(df):,})")
before4 = len(df)

df = [s for s in df if len(s["heuristic_flags"]) == 0]
print(f"  After clean heuristics: {len(df):,} (-{before4 - len(df):,})")
before5 = len(df)

# Toxicity: exclude flagged UNLESS safety sample
df = [s for s in df if len(s["toxicity_flags"]) == 0 or s["is_safety_sample"]]
print(f"  After toxicity filter: {len(df):,} (-{before5 - len(df):,})")
before6 = len(df)

# Soft quality filter
df = [s for s in df if s["quality_score"] >= 0.45]
print(f"  After quality \u2265 0.45: {len(df):,} (-{before6 - len(df):,})")
before7 = len(df)

# PPL garbage filter (only for scored samples)
df = [s for s in df if s["perplexity"] is None or s["perplexity"] < 200]
print(f"  After PPL < 200: {len(df):,} (-{before7 - len(df):,})")

print(f"\n\u2705 Filtering complete: {before:,} \u2192 {len(df):,} ({before - len(df):,} removed)")

# Per-source counts after filtering
filtered_sources = Counter(s["source"] for s in df)
print(f"\nFiltered pool by source:")
for src, count in filtered_sources.most_common():
    print(f"  {src}: {count:,}")

In [None]:
# Stage 3B: Composition with Absolute Count Targets
# Each source has independent min/max — no anchoring, no cascading.

# Separate safety samples from source pools
safety_pool = [s for s in df if s["is_safety_sample"]]
non_safety = [s for s in df if not s["is_safety_sample"]]

# Build source pools
source_pools = {}
for s in non_safety:
    source_pools.setdefault(s["source"], []).append(s)
source_pools["safety"] = safety_pool

print("Composing final dataset with absolute count targets...")
composed = {}
total_composed = 0

for bucket_name, targets in BUCKET_TARGETS.items():
    pool = source_pools.get(bucket_name, [])
    target = targets["target"]
    min_count = targets["min"]
    max_count = targets["max"]

    if len(pool) < min_count:
        print(f"  \u26a0\ufe0f {bucket_name}: only {len(pool):,} available, min is {min_count}")
        # Use all available — warn but don't fail
        selected = pool
    elif len(pool) <= target:
        # Pool smaller than target: use all
        selected = pool
    else:
        # Pool larger than target: sample down (capped at max)
        use_count = min(target, max_count)
        # Sort by quality_score descending, take top N
        pool_sorted = sorted(pool, key=lambda x: x["quality_score"], reverse=True)
        selected = pool_sorted[:use_count]

    composed[bucket_name] = selected
    total_composed += len(selected)
    print(f"  {bucket_name}: {len(selected):,} / {len(pool):,} available (target: {target}, range: {min_count}-{max_count})")

print(f"\n\u2705 Composition complete: {total_composed:,} total samples")

# Verify minimums
all_met = True
for bucket_name, targets in BUCKET_TARGETS.items():
    actual = len(composed.get(bucket_name, []))
    if actual < targets["min"]:
        print(f"  \u274c {bucket_name}: {actual} < min {targets['min']}")
        all_met = False
if all_met:
    print("\u2705 All bucket minimums met")

In [None]:
# Stage 3C: ChatML Conversion + Validation

all_samples = []
chatml_failures = 0

for bucket_name, samples in composed.items():
    for s in samples:
        # Apply anti-memorization filter for Thirukkural (vazhi_packs only)
        if s["source"] == "vazhi_packs" and is_verbatim_kural_qa(s["instruction"], s["output"]):
            continue

        text = to_chatml(s["instruction"], s["output"])

        # Strict ChatML validation
        valid, reason = validate_chatml_strict(text)
        if not valid:
            chatml_failures += 1
            continue

        all_samples.append({
            "text": text,
            "bucket": bucket_name,
            "source": s["source"],
            "subset": s["subset"],
            "quality_score": s["quality_score"],
            "tokenized_length": s["tokenized_length"],
        })

random.shuffle(all_samples)

print(f"\u2705 ChatML conversion complete")
print(f"   Valid samples: {len(all_samples):,}")
print(f"   ChatML failures: {chatml_failures}")

# Bucket distribution
bucket_counts = Counter(s["bucket"] for s in all_samples)
print(f"\n\U0001f4ca Final bucket distribution:")
for bucket, count in sorted(bucket_counts.items()):
    pct = 100 * count / len(all_samples)
    print(f"  {bucket}: {count:,} ({pct:.1f}%)")

In [None]:
# Stage 3D: Stratified Train/Eval Split (90/10 by bucket)

EVAL_RATIO = 0.10
train_samples = []
eval_samples = []

by_bucket = {}
for s in all_samples:
    by_bucket.setdefault(s["bucket"], []).append(s)

for bucket, samples in by_bucket.items():
    random.shuffle(samples)
    n_eval = max(1, int(len(samples) * EVAL_RATIO))
    eval_samples.extend(samples[:n_eval])
    train_samples.extend(samples[n_eval:])

random.shuffle(train_samples)
random.shuffle(eval_samples)

print(f"\U0001f4ca Stratified split:")
print(f"  Train: {len(train_samples):,}")
print(f"  Eval:  {len(eval_samples):,}")
print(f"  Eval ratio: {len(eval_samples) / (len(train_samples) + len(eval_samples)):.1%}")

# Verify eval has all buckets
eval_buckets = Counter(s["bucket"] for s in eval_samples)
print(f"\n  Eval bucket distribution:")
for bucket, count in sorted(eval_buckets.items()):
    print(f"    {bucket}: {count}")

# Verify all samples within token limit
max_tok = max(s["tokenized_length"] for s in all_samples)
print(f"\n  Max tokenized length: {max_tok} (limit: {SFT_MAX_SEQ_LENGTH})")
assert max_tok <= SFT_MAX_SEQ_LENGTH, f"Token length violation: {max_tok} > {SFT_MAX_SEQ_LENGTH}"

In [None]:
# Stage 3E: Upload to HuggingFace + Summary

train_ds = Dataset.from_list(train_samples)
eval_ds = Dataset.from_list(eval_samples)
dataset_dict = DatasetDict({"train": train_ds, "validation": eval_ds})

api.create_repo(OUTPUT_DATASET, repo_type="dataset", exist_ok=True)
dataset_dict.push_to_hub(OUTPUT_DATASET)

print(f"\n\u2705 Dataset uploaded: https://huggingface.co/datasets/{OUTPUT_DATASET}")
print(f"   Train: {len(train_ds):,} samples")
print(f"   Validation: {len(eval_ds):,} samples")

# Final summary
print("\n" + "=" * 60)
print(f"VAZHI Dataset Factory v{VERSION} \u2014 COMPLETE")
print("=" * 60)
print(f"\n  Stage 1 (Retrieve): {len(raw_samples):,} raw samples \u2192 {RAW_DATASET}")
print(f"  Stage 2 (Curate):   {len(curated_records):,} curated samples \u2192 {CURATED_DATASET}")
print(f"  Stage 3 (Compose):  {len(all_samples):,} final SFT samples \u2192 {OUTPUT_DATASET}")
print(f"\n  Train: {len(train_samples):,} | Eval: {len(eval_samples):,}")
print(f"  max_seq_length: {SFT_MAX_SEQ_LENGTH}")

print(f"\n  Bucket composition:")
for bucket, count in sorted(bucket_counts.items()):
    target = BUCKET_TARGETS[bucket]
    status = "\u2705" if count >= target["min"] else "\u26a0\ufe0f"
    print(f"    {status} {bucket}: {count:,} (target: {target['target']}, range: {target['min']}-{target['max']})")

# Sample outputs
print(f"\n{'=' * 60}")
print("Sample outputs (2 per bucket):")
print("=" * 60)

shown = Counter()
for s in all_samples:
    if shown[s['bucket']] < 2:
        shown[s['bucket']] += 1
        print(f"\n[{s['bucket'].upper()}] source={s['source']} quality={s['quality_score']:.3f}")
        match = CHATML_PATTERN.search(s["text"])
        if match:
            print(f"  Q: {match.group(1)[:100]}")
            print(f"  A: {match.group(2)[:150]}")
    if all(shown[b] >= 2 for b in BUCKET_TARGETS):
        break

print(f"\n\u2705 Dataset Factory v{VERSION} complete!")
print(f"   Next step: SFT training with conservative LoRA (r=8, q_proj+v_proj, 2 epochs)")
print(f"   Base model: {DAPT_MODEL}")