# VAZHI v0.6 - Sarvam-2B Fine-tuning

**Goal**: Fine-tune Sarvam-2B to follow Tamil instructions

**Why Sarvam-2B?**
- Already trained on 2T tokens of 10 Indian languages (including Tamil)
- Just needs instruction-tuning to follow commands
- Q4_K_M size: ~1.2GB (fits mobile)

**Training Data**:
1. AI4Bharat IndicAlign - Tamil instruction pairs
2. VAZHI dataset - Domain-specific (Thirukkural, govt, health, etc.)

**Key Fixes from Failed Qwen Training**:
- LoRA r=8 (was 32 - too aggressive)
- Float16 training (was 4-bit - unstable)
- Learning rate 1e-5 (was 5e-5)
- Gradient clipping enabled

## 1. Setup

In [None]:
# Install dependencies
!pip install -q torch transformers datasets peft accelerate bitsandbytes
!pip install -q trl huggingface_hub sentencepiece

In [None]:
# Login to HuggingFace
from huggingface_hub import login

# For Kaggle: use Kaggle secrets
# For Colab: use Colab secrets or manual login
try:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    hf_token = secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("Logged in via Kaggle secrets")
except:
    try:
        from google.colab import userdata
        hf_token = userdata.get('HF_TOKEN')
        login(token=hf_token)
        print("Logged in via Colab secrets")
    except:
        login()  # Manual login

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset, concatenate_datasets, Dataset
import gc

# Check GPU
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
if torch.cuda.is_available():
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Configuration
BASE_MODEL = "sarvamai/sarvam-2b-v0.5"
OUTPUT_DIR = "./vazhi-sarvam-2b"
LORA_OUTPUT = "./vazhi-sarvam-lora"

## 2. Load Training Data

Combine:
1. IndicAlign Tamil subset (instruction-tuning)
2. VAZHI dataset (domain-specific)

In [None]:
# Load AI4Bharat IndicAlign
print("Loading IndicAlign dataset...")

try:
    # Try loading IndicAlign
    indic_align = load_dataset("ai4bharat/indic-align", split="train")
    print(f"IndicAlign loaded: {len(indic_align)} total examples")
    
    # Filter for Tamil
    tamil_align = indic_align.filter(lambda x: x.get("language", "") == "ta" or x.get("lang", "") == "ta")
    print(f"Tamil subset: {len(tamil_align)} examples")
except Exception as e:
    print(f"Could not load IndicAlign: {e}")
    print("Will use VAZHI data only")
    tamil_align = None

In [None]:
# Check IndicAlign structure
if tamil_align and len(tamil_align) > 0:
    print("IndicAlign columns:", tamil_align.column_names)
    print("\nSample entry:")
    print(tamil_align[0])

In [None]:
# Load VAZHI dataset
print("\nLoading VAZHI dataset...")
vazhi_data = load_dataset("CryptoYogi/vazhi-tamil-v05", split="train")
print(f"VAZHI loaded: {len(vazhi_data)} examples")
print("\nSample:")
print(vazhi_data[0]["text"][:500])

In [None]:
# Format IndicAlign for Sarvam (if available)
def format_indic_align(example):
    """Format IndicAlign examples to Sarvam chat format"""
    # Adjust based on actual column names
    instruction = example.get("instruction", example.get("input", ""))
    response = example.get("response", example.get("output", ""))
    
    # Sarvam uses a simple format
    text = f"""### Instruction:
{instruction}

### Response:
{response}"""
    return {"text": text}

# Format VAZHI (already in ChatML, convert to Sarvam format)
def format_vazhi(example):
    """Convert VAZHI ChatML to Sarvam format"""
    text = example["text"]
    
    # Extract instruction and response from ChatML
    try:
        # Parse ChatML format
        if "<|im_start|>user" in text and "<|im_start|>assistant" in text:
            user_part = text.split("<|im_start|>user")[1].split("<|im_end|>")[0].strip()
            assistant_part = text.split("<|im_start|>assistant")[1].split("<|im_end|>")[0].strip()
            
            formatted = f"""### Instruction:
{user_part}

### Response:
{assistant_part}"""
            return {"text": formatted}
    except:
        pass
    
    # If parsing fails, return as-is
    return {"text": text}

In [None]:
# Process datasets
print("Formatting datasets...")

# Format VAZHI
vazhi_formatted = vazhi_data.map(format_vazhi, remove_columns=vazhi_data.column_names)
print(f"VAZHI formatted: {len(vazhi_formatted)} examples")

# Format IndicAlign if available
if tamil_align and len(tamil_align) > 0:
    indic_formatted = tamil_align.map(format_indic_align, remove_columns=tamil_align.column_names)
    print(f"IndicAlign formatted: {len(indic_formatted)} examples")
    
    # Combine datasets
    # Limit IndicAlign to avoid overwhelming VAZHI's domain knowledge
    indic_sample = indic_formatted.shuffle(seed=42).select(range(min(50000, len(indic_formatted))))
    combined_data = concatenate_datasets([vazhi_formatted, indic_sample])
    print(f"\nCombined: {len(combined_data)} examples")
else:
    combined_data = vazhi_formatted
    print(f"\nUsing VAZHI only: {len(combined_data)} examples")

# Shuffle
train_dataset = combined_data.shuffle(seed=42)
print(f"\nFinal training dataset: {len(train_dataset)} examples")

In [None]:
# Preview formatted data
print("Sample formatted entry:")
print(train_dataset[0]["text"][:800])

## 3. Load Sarvam-2B Model

In [None]:
# Load model in float16 (NOT 4-bit for training stability)
print(f"Loading {BASE_MODEL} in float16...")

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Model loaded! Parameters: {model.num_parameters():,}")
print(f"Vocab size: {tokenizer.vocab_size:,}")

## 4. Test Base Model (Before Training)

In [None]:
TEST_QUESTIONS = [
    "வணக்கம், நீங்கள் யார்?",
    "திருக்குறளின் முதல் குறள் என்ன?",
    "தமிழ்நாட்டின் தலைநகரம் எது?",
]

def test_model(model, tokenizer, question, max_tokens=150):
    """Test model with a Tamil question"""
    prompt = f"""### Instruction:
{question}

### Response:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract response part
    if "### Response:" in response:
        return response.split("### Response:")[-1].strip()
    return response

print("=" * 60)
print("BEFORE TRAINING - Base Sarvam-2B")
print("=" * 60)

for q in TEST_QUESTIONS:
    print(f"\nQ: {q}")
    print(f"A: {test_model(model, tokenizer, q)}")

## 5. Configure LoRA (Conservative Settings)

In [None]:
# CONSERVATIVE LoRA settings - learned from failed Qwen training
lora_config = LoraConfig(
    r=8,                # LOW rank (was 32 - too aggressive)
    lora_alpha=16,      # 2x rank
    target_modules=["q_proj", "v_proj"],  # Only attention (not all 7 modules)
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA
model.enable_input_require_grads()
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
all_params = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable_params:,} / {all_params:,} = {100 * trainable_params / all_params:.4f}%")

## 6. Training Configuration

In [None]:
from trl import SFTConfig, SFTTrainer

# CONSERVATIVE training settings
sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=2,              # Start with 2 epochs
    per_device_train_batch_size=2,   # Small batch for stability
    gradient_accumulation_steps=8,   # Effective batch = 16
    learning_rate=1e-5,              # VERY LOW (was 5e-5)
    weight_decay=0.01,
    warmup_ratio=0.1,                # 10% warmup
    logging_steps=25,
    save_steps=200,
    save_total_limit=3,
    fp16=True,                       # Float16 training
    optim="adamw_torch",             # Standard optimizer
    lr_scheduler_type="cosine",
    report_to="none",
    gradient_checkpointing=True,
    max_grad_norm=0.3,               # Gradient clipping!
    max_seq_length=1024,
)

def formatting_func(example):
    return example["text"]

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    formatting_func=formatting_func,
)

print("Trainer configured!")
print(f"  Learning rate: {sft_config.learning_rate}")
print(f"  LoRA rank: {lora_config.r}")
print(f"  Gradient clipping: {sft_config.max_grad_norm}")
print(f"  Epochs: {sft_config.num_train_epochs}")

## 7. Train!

In [None]:
print("Starting training...")
print("Watch for:")
print("  - Loss should decrease steadily")
print("  - If loss > 3.0, something is wrong")
print("  - If loss jumps suddenly, stop and reduce LR")

trainer.train()

print("\nTraining complete!")

In [None]:
# Save LoRA adapter
model.save_pretrained(LORA_OUTPUT)
tokenizer.save_pretrained(LORA_OUTPUT)
print(f"LoRA saved to {LORA_OUTPUT}")

## 8. Test After Training

In [None]:
print("=" * 60)
print("AFTER TRAINING - Fine-tuned Sarvam-2B")
print("=" * 60)

for q in TEST_QUESTIONS:
    print(f"\nQ: {q}")
    print(f"A: {test_model(model, tokenizer, q)}")

In [None]:
# Test with more domain-specific questions
DOMAIN_QUESTIONS = [
    "ஔவையாரின் ஆத்திசூடி பற்றி சொல்லுங்கள்",
    "OTP யாரிடமும் சொல்லலாமா?",
    "CMCHIS என்றால் என்ன?",
]

print("\nDomain-specific tests:")
for q in DOMAIN_QUESTIONS:
    print(f"\nQ: {q}")
    print(f"A: {test_model(model, tokenizer, q)}")

## 9. Merge and Save Full Model

In [None]:
# Clear memory
del trainer
gc.collect()
torch.cuda.empty_cache()

# Reload base model for merging
print("Reloading base model for merging...")
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, LORA_OUTPUT)

# Merge
print("Merging LoRA weights...")
merged_model = model.merge_and_unload()

# Save merged model
MERGED_OUTPUT = "./vazhi-sarvam-merged"
merged_model.save_pretrained(MERGED_OUTPUT, safe_serialization=True)
tokenizer.save_pretrained(MERGED_OUTPUT)
print(f"Merged model saved to {MERGED_OUTPUT}")

!ls -lh {MERGED_OUTPUT}

## 10. Convert to GGUF

In [None]:
# Setup llama.cpp
!git clone https://github.com/ggerganov/llama.cpp.git
!cd llama.cpp && pip install -q -r requirements.txt

In [None]:
# Convert to GGUF F16
print("Converting to GGUF F16...")
!python llama.cpp/convert_hf_to_gguf.py {MERGED_OUTPUT} --outfile vazhi-sarvam-f16.gguf --outtype f16
!ls -lh vazhi-sarvam-f16.gguf

In [None]:
# Build quantize tool
!cd llama.cpp && mkdir -p build && cd build && cmake .. && make -j4 llama-quantize

In [None]:
# Quantize to Q8_0
print("Quantizing to Q8_0...")
!./llama.cpp/build/bin/llama-quantize vazhi-sarvam-f16.gguf vazhi-sarvam-q8_0.gguf q8_0

# Quantize to Q4_K_M (target for mobile)
print("\nQuantizing to Q4_K_M...")
!./llama.cpp/build/bin/llama-quantize vazhi-sarvam-f16.gguf vazhi-sarvam-q4_k_m.gguf q4_k_m

print("\nAll GGUF files:")
!ls -lh vazhi-sarvam-*.gguf

## 11. Test GGUF Output Quality

**Critical**: Does Tamil survive quantization?

In [None]:
# Build llama-cli
!cd llama.cpp && cd build && make -j4 llama-cli

In [None]:
# Test Q4_K_M
print("=" * 60)
print("GGUF Q4_K_M TEST")
print("=" * 60)

test_prompt = """### Instruction:
திருக்குறளின் முதல் குறள் என்ன?

### Response:
"""

!./llama.cpp/build/bin/llama-cli -m vazhi-sarvam-q4_k_m.gguf \
    -p "{test_prompt}" \
    -n 150 --temp 0.7 -ngl 0

In [None]:
# Test more questions
print("\nTesting additional questions...")

for q in ["வணக்கம், நீங்கள் யார்?", "தமிழ்நாட்டின் தலைநகரம் எது?"]:
    prompt = f"""### Instruction:
{q}

### Response:
"""
    print(f"\nQ: {q}")
    !./llama.cpp/build/bin/llama-cli -m vazhi-sarvam-q4_k_m.gguf -p "{prompt}" -n 150 --temp 0.7 -ngl 0 2>/dev/null | tail -20

## 12. Upload to HuggingFace

In [None]:
from huggingface_hub import HfApi, create_repo

api = HfApi()
GGUF_REPO = "CryptoYogi/vazhi-sarvam-gguf"

# Create repo
create_repo(GGUF_REPO, repo_type="model", exist_ok=True)
print(f"Repository: {GGUF_REPO}")

# Upload Q4_K_M
print("\nUploading Q4_K_M...")
api.upload_file(
    path_or_fileobj="vazhi-sarvam-q4_k_m.gguf",
    path_in_repo="vazhi-sarvam-q4_k_m.gguf",
    repo_id=GGUF_REPO,
)

# Upload Q8_0
print("Uploading Q8_0...")
api.upload_file(
    path_or_fileobj="vazhi-sarvam-q8_0.gguf",
    path_in_repo="vazhi-sarvam-q8_0.gguf",
    repo_id=GGUF_REPO,
)

print(f"\nDone! Models at: https://huggingface.co/{GGUF_REPO}")

## 13. Summary

In [None]:
print("""
================================================================
VAZHI v0.6 TRAINING SUMMARY
================================================================

Base Model: Sarvam-2B-v0.5
Training Data: IndicAlign (Tamil) + VAZHI domain data

Key Settings (Conservative):
- LoRA rank: 8 (not 32)
- Learning rate: 1e-5 (not 5e-5)
- Gradient clipping: 0.3
- Float16 training (not 4-bit)

GGUF File Sizes:
""")
!ls -lh vazhi-sarvam-*.gguf

print("""
Expected:
- F16:     ~4GB
- Q8_0:    ~2GB  
- Q4_K_M:  ~1.2GB  <-- Target for mobile!

Next Steps:
1. Test GGUF quality (Tamil coherence)
2. If good, integrate with VAZHI app
3. If bad, may need larger model or different approach
""")