# VAZHI Quick Validation Test

**Purpose**: Validate training setup in 5-10 minutes before committing to hours of training.

**Key Fixes from Failed Run:**
1. Match prompt format to base model (Gemma uses Alpaca-style)
2. Disable packing (causes cross-contamination without flash attention)
3. Lower learning rate (prevent catastrophic forgetting)
4. Train on small subset first (50 samples)

In [None]:
# CRITICAL: Force single GPU to avoid DataParallel issues
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Install dependencies
!pip install -q bitsandbytes peft trl accelerate datasets

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 1: Check Base Model's Expected Format

First, let's see what format the pre-trained model expects.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "abhinand/gemma-2b-it-tamil-v0.1-alpha"

# Load tokenizer first to check format
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Check special tokens
print("=" * 50)
print("TOKENIZER INFO:")
print("=" * 50)
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"EOS token: {tokenizer.eos_token} (id: {tokenizer.eos_token_id})")
print(f"BOS token: {tokenizer.bos_token} (id: {tokenizer.bos_token_id})")
print(f"PAD token: {tokenizer.pad_token} (id: {tokenizer.pad_token_id})")

# Check if it has chat template
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template:
    print(f"\nChat template exists: YES")
    print(f"Template: {tokenizer.chat_template[:200]}...")
else:
    print(f"\nChat template: NO (use simple format)")

In [None]:
# Load model in 4-bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"":0},
    trust_remote_code=True,
)

print(f"Model loaded! Memory: {model.get_memory_footprint() / 1e9:.2f} GB")

In [None]:
# Test different prompt formats to see which works best
def test_format(prompt_text):
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

question = "தமிழ்நாட்டின் தலைநகரம் எது?"

# Format 1: Alpaca style
print("=" * 50)
print("FORMAT 1: Alpaca Style")
print("=" * 50)
alpaca_prompt = f"### Instruction:\n{question}\n\n### Response:\n"
print(f"Prompt: {alpaca_prompt}")
print(f"Output: {test_format(alpaca_prompt)}")

# Format 2: Simple
print("\n" + "=" * 50)
print("FORMAT 2: Simple")
print("=" * 50)
simple_prompt = f"Question: {question}\nAnswer:"
print(f"Prompt: {simple_prompt}")
print(f"Output: {test_format(simple_prompt)}")

# Format 3: Direct
print("\n" + "=" * 50)
print("FORMAT 3: Direct")
print("=" * 50)
direct_prompt = f"{question}"
print(f"Prompt: {direct_prompt}")
print(f"Output: {test_format(direct_prompt)}")

## Step 2: Prepare Training Data with CORRECT Format

We need to convert our data to match the base model's expected format.

In [None]:
from datasets import load_dataset

# Load full dataset
dataset = load_dataset("CryptoYogi/vazhi-tamil-v05")
print(f"Full dataset: {len(dataset['train'])} samples")

# Check first sample's format
print("\n" + "=" * 50)
print("ORIGINAL DATA FORMAT:")
print("=" * 50)
sample = dataset['train'][0]
print(f"Keys: {sample.keys()}")
if 'text' in sample:
    print(f"Text format: {sample['text'][:300]}...")
if 'instruction' in sample:
    print(f"Instruction: {sample['instruction'][:100]}...")
if 'output' in sample:
    print(f"Output: {sample['output'][:100]}...")

In [None]:
# Use SMALL SUBSET for quick validation
VALIDATION_SIZE = 50  # Just 50 samples for quick test

# Format function - convert to Alpaca format (what Gemma expects)
def format_to_alpaca(example):
    """Convert any format to Alpaca format for Gemma"""
    
    # If text field exists, try to extract instruction/output
    if 'text' in example and example['text']:
        text = example['text']
        
        # Check if it's ChatML format and convert
        if '<|im_start|>' in text:
            # Extract user message
            import re
            user_match = re.search(r'<\|im_start\|>user\n(.+?)<\|im_end\|>', text, re.DOTALL)
            assistant_match = re.search(r'<\|im_start\|>assistant\n(.+?)<\|im_end\|>', text, re.DOTALL)
            
            if user_match and assistant_match:
                instruction = user_match.group(1).strip()
                output = assistant_match.group(1).strip()
                formatted = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
                return {"text": formatted}
        
        # If already in correct format or other format, use as-is
        if '### Instruction' in text:
            return {"text": text}
    
    # If instruction/output fields exist
    if 'instruction' in example and 'output' in example:
        instruction = example['instruction']
        output = example['output']
        formatted = f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
        return {"text": formatted}
    
    return {"text": ""}

# Take small subset
small_train = dataset['train'].select(range(VALIDATION_SIZE))

# Format to Alpaca
formatted_train = small_train.map(format_to_alpaca)
formatted_train = formatted_train.filter(lambda x: len(x['text']) > 10)

print(f"Validation subset: {len(formatted_train)} samples")
print("\nFormatted sample:")
print(formatted_train[0]['text'][:400])

## Step 3: Add LoRA and Quick Training

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,  # Smaller rank for quick test
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
from trl import SFTTrainer, SFTConfig
import time

# CONSERVATIVE settings for validation
training_config = SFTConfig(
    output_dir="./vazhi-validation",
    
    # Dataset
    dataset_text_field="text",
    max_length=256,  # Shorter for quick test
    packing=False,   # DISABLED - was causing issues!
    
    # Small batch for quick test
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    
    # LOWER learning rate to prevent catastrophic forgetting
    learning_rate=5e-5,  # Was 2e-4, now 5e-5 (4x lower)
    num_train_epochs=1,
    warmup_ratio=0.1,
    
    # Optimizer
    optim="paged_adamw_8bit",
    weight_decay=0.01,
    
    # Precision
    fp16=False,
    bf16=True,
    
    # Logging - more frequent for validation
    logging_steps=5,
    
    # Speed
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    
    seed=42,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=formatted_train,
    args=training_config,
)

print(f"Quick validation: {len(formatted_train)} samples")
print(f"Learning rate: 5e-5 (conservative)")
print(f"Packing: DISABLED")
print(f"Expected time: ~2-5 minutes")

In [None]:
# Quick test BEFORE training
def test_model(prompt):
    inputs = tokenizer(
        f"### Instruction:\n{prompt}\n\n### Response:\n",
        return_tensors="pt"
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "### Response:" in response:
        return response.split("### Response:")[-1].strip()
    return response

print("=" * 60)
print("BEFORE TRAINING:")
print("=" * 60)
before_response = test_model("தமிழ்நாட்டின் தலைநகரம் எது?")
print(f"Q: தமிழ்நாட்டின் தலைநகரம் எது?")
print(f"A: {before_response[:200]}")

In [None]:
# Run quick training
print("Starting quick validation training...")
start = time.time()

trainer_stats = trainer.train()

elapsed = time.time() - start
print(f"\nTraining complete!")
print(f"Time: {elapsed / 60:.1f} minutes")
print(f"Final loss: {trainer_stats.training_loss:.4f}")

In [None]:
# Test AFTER training - check if model still works
print("=" * 60)
print("AFTER TRAINING:")
print("=" * 60)

after_response = test_model("தமிழ்நாட்டின் தலைநகரம் எது?")
print(f"Q: தமிழ்நாட்டின் தலைநகரம் எது?")
print(f"A: {after_response[:200]}")

# Check if response is coherent (not garbage)
import re
tamil_chars = len(re.findall(r'[\u0B80-\u0BFF]', after_response))
total_chars = len(after_response)
tamil_pct = (tamil_chars / total_chars * 100) if total_chars > 0 else 0

print(f"\n" + "=" * 60)
print("VALIDATION RESULTS:")
print("=" * 60)
print(f"Final loss: {trainer_stats.training_loss:.4f}")
print(f"Tamil content: {tamil_pct:.1f}%")
print(f"Response coherent: {'YES' if tamil_pct > 20 and len(after_response) > 20 else 'NO - TRAINING FAILED'}")

if trainer_stats.training_loss < 1.5 and tamil_pct > 20:
    print("\n✅ VALIDATION PASSED - Safe to run full training!")
else:
    print("\n❌ VALIDATION FAILED - Do not proceed with full training!")

## Validation Summary

**If validation PASSED:**
- Update the full training notebook with these fixes:
  1. Use Alpaca format (not ChatML)
  2. Disable packing
  3. Use learning rate 5e-5
  4. Run full training

**If validation FAILED:**
- Check the format conversion
- Try even lower learning rate (1e-5)
- Check if base model is compatible