# VishwamAI Distillation Training

This notebook demonstrates training VishwamAI through knowledge distillation from Google's Gemma-3b-it model.

## Setup and Imports

In [None]:
import jax
import jax.numpy as jnp
from transformers import AutoModelForCausalLM, AutoTokenizer
from vishwamai import VishwamAI, VishwamAITokenizer, DistillationTrainer
from vishwamai.training import TPUTrainingConfig
import torch

print(f"Number of TPU devices: {jax.device_count()}")
print(f"JAX devices: {jax.devices()}")

## Configuration

Setting up a 1B parameter model configuration

In [None]:
# Model configuration for 1B parameters
config = {
    "model": {
        "vocab_size": 131072,
        "hidden_dim": 2048,  # Scaled for 1B params
        "num_layers": 24,
        "num_heads": 16,
        "head_dim": 128,
        "mlp_dim": 8192,
        "max_seq_len": 2048,
        "dropout_rate": 0.1
    },
    "training": {
        "batch_size": 16,
        "grad_accum_steps": 4,
        "learning_rate": 1e-4,
        "warmup_steps": 2000,
        "max_steps": 100000,
        "weight_decay": 0.01,
        "max_grad_norm": 1.0
    },
    "optimization": {
        "use_fp8": True,
        "use_pjit": True,
        "block_size": 128,
        "mixed_precision": True
    },
    "distillation": {
        "temperature": 2.0,
        "alpha": 0.5,
        "use_intermediate_distillation": True
    }
}

## Load Teacher Model (Gemma-3b-it)

In [None]:
# Load Gemma model and tokenizer
teacher_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)
teacher_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3b-it")

## Initialize Student Model

In [None]:
# Initialize student model
student_model = VishwamAI.from_config(config)
student_tokenizer = VishwamAITokenizer.from_pretrained("vishwamai/base")

## Prepare Training Data

In [None]:
from datasets import load_dataset
import numpy as np

# Load a sample dataset (RedPajama)
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample")

def prepare_data(examples):
    # Tokenize input text
    student_encodings = student_tokenizer(
        examples["text"],
        truncation=True,
        max_length=config["model"]["max_seq_len"],
        padding="max_length",
        return_tensors="np"
    )
    
    # Get teacher logits
    with torch.no_grad():
        teacher_outputs = teacher_model(
            input_ids=torch.tensor(student_encodings["input_ids"]).to(teacher_model.device)
        )
        teacher_logits = teacher_outputs.logits.cpu().numpy()
    
    return {
        "input_ids": student_encodings["input_ids"],
        "attention_mask": student_encodings["attention_mask"],
        "teacher_logits": teacher_logits
    }

# Prepare datasets
train_dataset = dataset["train"].map(
    prepare_data,
    batched=True,
    batch_size=8
)

eval_dataset = dataset["validation"].map(
    prepare_data,
    batched=True,
    batch_size=8
)

## Setup Training

In [None]:
# Initialize distillation trainer
training_config = TPUTrainingConfig(**config["training"])

trainer = DistillationTrainer(
    student_model=student_model,
    teacher_logits_dim=teacher_model.config.vocab_size,
    training_config=training_config,
    temperature=config["distillation"]["temperature"],
    alpha=config["distillation"]["alpha"]
)

# Create data loaders
train_loader = trainer.get_train_dataloader(train_dataset)
eval_loader = trainer.get_eval_dataloader(eval_dataset)

## Training Loop

In [None]:
# Training with progress tracking
from tqdm.auto import tqdm
import time

print("Starting training...")
start_time = time.time()

for epoch in range(3):  # Train for 3 epochs
    # Training
    trainer.train_epoch(train_loader)
    
    # Evaluation
    metrics = trainer.evaluate(eval_loader)
    
    print(f"Epoch {epoch+1} - Eval loss: {metrics['eval_loss']:.4f}")
    
    # Save checkpoint
    trainer.save_checkpoint(f"checkpoint_epoch_{epoch+1}")

training_time = time.time() - start_time
print(f"\nTraining completed in {training_time/3600:.2f} hours")

## Evaluation

In [None]:
# Compare teacher and student model outputs
def compare_models(prompt: str):
    # Teacher model generation
    teacher_inputs = teacher_tokenizer(prompt, return_tensors="pt").to(teacher_model.device)
    teacher_outputs = teacher_model.generate(
        **teacher_inputs,
        max_length=100,
        temperature=0.7,
        num_return_sequences=1
    )
    teacher_response = teacher_tokenizer.decode(teacher_outputs[0], skip_special_tokens=True)
    
    # Student model generation
    student_inputs = student_tokenizer(prompt, return_tensors="jax")
    student_outputs = student_model.generate(
        student_inputs["input_ids"],
        max_length=100,
        temperature=0.7
    )
    student_response = student_tokenizer.decode(student_outputs[0], skip_special_tokens=True)
    
    print("Teacher output:")
    print(teacher_response)
    print("\nStudent output:")
    print(student_response)

# Test the models
test_prompts = [
    "Explain how a computer processor works:",
    "Write a short poem about artificial intelligence:",
    "What are the key principles of machine learning?"
]

for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 50)
    compare_models(prompt)

## Save Final Model

In [None]:
# Save the trained student model
output_dir = "vishwamai_1b_distilled"
student_model.save_pretrained(output_dir)
student_tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")