# VAZHI Gemma-2B Tamil Fine-tuning (Full Dataset)

**Goal:** Fine-tune Gemma-2B Tamil with ALL VAZHI data for complete Tamil AI assistant

**Key Insight:** Model size (1.63 GB) stays the same regardless of training data amount!
So we train with everything to maximize knowledge.

**Training Data:** 11,112 items covering:
- üõ°Ô∏è Scam/Security protection
- üèõÔ∏è Government schemes
- üè• Healthcare
- üìö Education
- ‚öñÔ∏è Legal
- ü™∑ Culture (Thirukkural, Siddhars, Classical literature)
- üó£Ô∏è Dialects (Chennai, Madurai, Kongu)
- üö´ Guardrails ("I don't know" responses)

**Output:** Single 1.63 GB model that knows everything!

In [None]:
# Install dependencies
!pip install -q transformers datasets peft accelerate bitsandbytes trl huggingface_hub

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Step 1: Load Base Model

Load the Gemma-2B Tamil model that we verified works.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

MODEL_ID = "abhinand/gemma-2b-it-tamil-v0.1-alpha"

print(f"Loading {MODEL_ID}...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Try loading in bf16 first (preferred)
# If OOM, we'll fall back to 8-bit (NOT 4-bit - that corrupts!)
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    print("Loaded in bfloat16!")
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        print("OOM with bf16, trying 8-bit quantization...")
        bnb_config = BitsAndBytesConfig(
            load_in_8bit=True,  # 8-bit, NOT 4-bit!
        )
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            quantization_config=bnb_config,
            device_map="auto",
        )
        print("Loaded in 8-bit!")
    else:
        raise e

print(f"Parameters: {model.num_parameters() / 1e9:.2f}B")

## Step 2: Verify Base Model Works

In [None]:
def test_model(model, tokenizer, prompt):
    """Quick test of model output"""
    formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("### Response:")[-1].strip()

# Test before fine-tuning
print("=" * 50)
print("BEFORE FINE-TUNING:")
print("=" * 50)

test_prompts = [
    "‡Æ§‡Æø‡Æ∞‡ØÅ‡Æï‡Øç‡Æï‡ØÅ‡Æ±‡Æ≥‡Æø‡Æ©‡Øç ‡ÆÆ‡ØÅ‡Æ§‡Æ≤‡Øç ‡Æï‡ØÅ‡Æ±‡Æ≥‡Øç ‡Æé‡Æ©‡Øç‡Æ©?",
    "PM-KISAN ‡Æ§‡Æø‡Æü‡Øç‡Æü‡ÆÆ‡Øç ‡Æé‡Æ©‡Øç‡Æ©?",
    "‡Æ§‡ÆÆ‡Æø‡Æ¥‡Øç‡Æ®‡Ææ‡Æü‡Øç‡Æü‡Æø‡Æ©‡Øç ‡Æ§‡Æ≤‡Øà‡Æ®‡Æï‡Æ∞‡ÆÆ‡Øç ‡Æé‡Æ§‡ØÅ?",
    "‡Æá‡Æ®‡Øç‡Æ§ SMS ‡Æâ‡Æ£‡Øç‡ÆÆ‡Øà‡ÆØ‡Ææ? '‡Æ®‡ØÄ‡Æô‡Øç‡Æï‡Æ≥‡Øç lottery ‡Æµ‡ØÜ‡Æ©‡Øç‡Æ±‡ØÄ‡Æ∞‡Øç‡Æï‡Æ≥‡Øç, ‚Çπ500 ‡ÆÖ‡Æ©‡ØÅ‡Æ™‡Øç‡Æ™‡ØÅ‡Æô‡Øç‡Æï‡Æ≥‡Øç'",
]

for prompt in test_prompts:
    print(f"\nQ: {prompt}")
    print(f"A: {test_model(model, tokenizer, prompt)[:200]}...")

## Step 3: Load Full VAZHI Dataset

In [None]:
from datasets import load_dataset

# Load full VAZHI dataset from HuggingFace
print("Loading VAZHI dataset from HuggingFace...")
dataset = load_dataset("CryptoYogi/vazhi-tamil-v05")

print(f"\nDataset loaded!")
print(f"Train: {len(dataset['train'])} samples")
print(f"Validation: {len(dataset['validation'])} samples")

# Show sample
print(f"\nSample:")
sample = dataset['train'][0]
print(f"Keys: {sample.keys()}")
if 'text' in sample:
    print(f"Text: {sample['text'][:300]}...")
elif 'instruction' in sample:
    print(f"Instruction: {sample['instruction'][:100]}...")
    print(f"Output: {sample['output'][:100]}...")

In [None]:
def format_for_training(example):
    """Format as instruction-response pair"""
    # Handle both formats
    if 'text' in example and example['text']:
        return {"text": example['text']}
    elif 'instruction' in example and 'output' in example:
        text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
        return {"text": text}
    else:
        return {"text": ""}

# Format datasets
train_dataset = dataset['train'].map(format_for_training)
val_dataset = dataset['validation'].map(format_for_training)

# Filter empty
train_dataset = train_dataset.filter(lambda x: len(x['text']) > 10)
val_dataset = val_dataset.filter(lambda x: len(x['text']) > 10)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## Step 4: Configure LoRA (Conservative)

Using conservative settings to avoid corrupting the model.

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare for training if using quantization
if hasattr(model, 'is_loaded_in_8bit') and model.is_loaded_in_8bit:
    model = prepare_model_for_kbit_training(model)

# Conservative LoRA config
lora_config = LoraConfig(
    r=8,                    # Moderate rank
    lora_alpha=16,          # Standard alpha = 2*r
    target_modules=["q_proj", "v_proj"],  # Only attention
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## Step 5: Training Configuration

In [None]:
from trl import SFTTrainer, SFTConfig

# Calculate steps
batch_size = 2
grad_accum = 8
effective_batch = batch_size * grad_accum  # 16
steps_per_epoch = len(train_dataset) // effective_batch
total_steps = steps_per_epoch * 2  # 2 epochs

print(f"Effective batch size: {effective_batch}")
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total steps (2 epochs): {total_steps}")

training_args = SFTConfig(
    output_dir="./vazhi-gemma-full",
    
    # Learning settings
    learning_rate=1e-5,         # Conservative
    num_train_epochs=2,         # 2 epochs
    
    # Batch settings
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=grad_accum,
    
    # Stability
    max_grad_norm=0.3,
    warmup_ratio=0.1,
    weight_decay=0.01,
    
    # Precision
    bf16=True,
    
    # Memory optimization
    gradient_checkpointing=True,
    optim="adamw_torch",
    
    # Logging & saving
    logging_steps=25,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=200,
    save_total_limit=3,
    
    # Data
    max_seq_length=512,
    dataset_text_field="text",
    
    # Reporting
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

print("\nTrainer configured!")

## Step 6: Train!

This will take ~2-3 hours on T4 with full dataset.

In [None]:
print("="*50)
print("STARTING TRAINING")
print("="*50)
print(f"Training samples: {len(train_dataset)}")
print(f"Expected steps: {total_steps}")
print("\nWatch for:")
print("- Loss should decrease gradually")
print("- No sudden spikes (indicates divergence)")
print("- Target final loss: ~1.5-2.5")
print("="*50)

trainer.train()

In [None]:
# Save checkpoint
trainer.save_model("./vazhi-gemma-full-final")
print("Training complete! Model saved.")

## Step 7: Test Fine-tuned Model

In [None]:
print("=" * 50)
print("AFTER FINE-TUNING:")
print("=" * 50)

# Core tests
for prompt in test_prompts:
    print(f"\nQ: {prompt}")
    print(f"A: {test_model(model, tokenizer, prompt)[:300]}")

# Domain-specific tests
domain_tests = [
    # Culture
    "‡Æµ‡Æ≥‡Øç‡Æ≥‡ØÅ‡Æµ‡Æ∞‡Øç ‡ÆØ‡Ææ‡Æ∞‡Øç?",
    # Scam
    "OTP share ‡Æ™‡Æ£‡Øç‡Æ£‡Æ≤‡Ææ‡ÆÆ‡Ææ?",
    # Health  
    "CMCHIS ‡Æ§‡Æø‡Æü‡Øç‡Æü‡ÆÆ‡Øç ‡Æé‡Æ©‡Øç‡Æ©?",
    # Guardrails
    "Bitcoin-‡Æ≤‡Øç invest ‡Æ™‡Æ£‡Øç‡Æ£‡Æ≤‡Ææ‡ÆÆ‡Ææ?",
]

print("\n" + "=" * 50)
print("DOMAIN-SPECIFIC TESTS:")
print("=" * 50)

for prompt in domain_tests:
    print(f"\nQ: {prompt}")
    print(f"A: {test_model(model, tokenizer, prompt)[:300]}")

## Step 8: Merge LoRA & Save

In [None]:
# Merge LoRA into base model
print("Merging LoRA into base model...")
merged_model = model.merge_and_unload()

# Save merged model
merged_model.save_pretrained("./vazhi-gemma-merged")
tokenizer.save_pretrained("./vazhi-gemma-merged")

print("Merged model saved!")

## Step 9: Convert to GGUF

In [None]:
# Clone and build llama.cpp
!git clone --depth 1 https://github.com/ggerganov/llama.cpp.git 2>/dev/null || echo "Already exists"
!cd llama.cpp && mkdir -p build && cd build && cmake .. -DGGML_CUDA=OFF && cmake --build . --config Release -j4 -- llama-quantize

print("llama.cpp ready!")

In [None]:
import subprocess
import os

# Convert to GGUF F16
print("Converting to GGUF F16...")
result = subprocess.run([
    "python", "llama.cpp/convert_hf_to_gguf.py",
    "./vazhi-gemma-merged",
    "--outfile", "./vazhi-gemma-f16.gguf",
    "--outtype", "f16"
], capture_output=True, text=True)

if os.path.exists("./vazhi-gemma-f16.gguf"):
    size = os.path.getsize("./vazhi-gemma-f16.gguf") / 1e9
    print(f"‚úÖ F16 created: {size:.2f} GB")
else:
    print(f"‚ùå F16 failed: {result.stderr[-500:]}")

In [None]:
# Quantize to Q4_K_M
print("Quantizing to Q4_K_M...")
result = subprocess.run([
    "./llama.cpp/build/bin/llama-quantize",
    "./vazhi-gemma-f16.gguf",
    "./vazhi-gemma-q4_k_m.gguf",
    "Q4_K_M"
], capture_output=True, text=True)

if os.path.exists("./vazhi-gemma-q4_k_m.gguf"):
    size = os.path.getsize("./vazhi-gemma-q4_k_m.gguf") / 1e9
    print(f"‚úÖ Q4_K_M created: {size:.2f} GB")
else:
    print(f"‚ùå Q4_K_M failed: {result.stderr[-500:]}")

## Step 10: Final GGUF Validation

In [None]:
!pip install -q llama-cpp-python

In [None]:
from llama_cpp import Llama

print("Loading GGUF model...")
llm = Llama(
    model_path="./vazhi-gemma-q4_k_m.gguf",
    n_ctx=512,
    n_threads=4,
    verbose=False
)

print("=" * 60)
print("FINAL GGUF Q4_K_M VALIDATION")
print("=" * 60)

all_tests = test_prompts + domain_tests

results = []
for prompt in all_tests:
    print(f"\nQ: {prompt}")
    response = llm(
        f"### Instruction:\n{prompt}\n\n### Response:\n",
        max_tokens=150,
        stop=["###", "\n\n"],
        echo=False
    )
    answer = response['choices'][0]['text'].strip()
    print(f"A: {answer[:300]}")
    results.append({"q": prompt, "a": answer})

In [None]:
# Quality check
print("\n" + "=" * 60)
print("QUALITY SUMMARY")
print("=" * 60)

def check_tamil(text):
    """Count Tamil characters"""
    tamil = sum(1 for c in text if 0x0B80 <= ord(c) <= 0x0BFF)
    return tamil

for r in results:
    tamil_chars = check_tamil(r['a'])
    status = "‚úÖ" if tamil_chars > 10 else "‚ö†Ô∏è"
    print(f"{status} Tamil chars: {tamil_chars:3d} | {r['q'][:40]}...")

## Step 11: Upload to HuggingFace (Optional)

In [None]:
# Uncomment to upload
# from huggingface_hub import HfApi, login
# login()  # Enter your HF token
# 
# api = HfApi()
# api.upload_file(
#     path_or_fileobj="./vazhi-gemma-q4_k_m.gguf",
#     path_in_repo="vazhi-gemma-q4_k_m.gguf",
#     repo_id="CryptoYogi/vazhi-model-v1",
#     repo_type="model",
# )
# print("Uploaded to HuggingFace!")

## Results Summary

### Training
- Dataset: 11,112 samples (full VAZHI)
- Base: Gemma-2B Tamil
- Method: LoRA (r=8, bf16)
- Duration: ~2-3 hours on T4

### Output
- Model: `vazhi-gemma-q4_k_m.gguf`
- Size: ~1.63 GB
- Covers: All 6 knowledge packs + guardrails

### Quality Checklist
- [ ] Produces coherent Tamil
- [ ] Thirukkural correct
- [ ] Chennai capital correct
- [ ] Scam detection works
- [ ] Govt schemes accurate
- [ ] Guardrails work (refuses unknown)

If all pass ‚Üí Ready for mobile integration! üéâ