# VishwamAI Model Distillation

This notebook demonstrates the knowledge distillation process from Perplexity-AI/r1-1776 (guru) to our VishwamAI model (shishya).

In [None]:
import jax
import jax.numpy as jnp
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import load_dataset
import os
from tqdm import tqdm
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.model import VishwamAIModel, ModelConfig
from omegaconf import OmegaConf

print(f"JAX devices: {jax.devices()}")

## Load Configuration

In [None]:
# Load Perplexity-R1 distillation configuration
config = OmegaConf.load('vishwamai/configs/training/perplexity_r1_distillation.yaml')

# Display configuration
print("Teacher Model:", config.teacher_model.path)
print("Student Model:", config.student_model.path)
print("\nModel Architecture:")
print(f"Hidden Size: {config.student_model.hidden_size}")
print(f"Num Layers: {config.student_model.num_layers}")
print(f"Attention Heads: {config.student_model.num_heads}")
print(f"KV Heads: {config.student_model.num_kv_heads}")

print("\nDistillation Parameters:")
print(f"Temperature: {config.teacher_model.temperature}")
print(f"Alpha: {config.teacher_model.alpha}")
print(f"Feature Layers: {config.distillation.feature_distillation.layers}")

## Initialize Teacher Model (Guru)

In [None]:
# Load the teacher model (Perplexity r1-1776)
teacher_tokenizer = AutoTokenizer.from_pretrained(config.teacher_model.path)
teacher_config = AutoConfig.from_pretrained(config.teacher_model.path)
teacher_model = AutoModelForCausalLM.from_pretrained(
    config.teacher_model.path,
    device_map='auto',
    torch_dtype='auto'
)

print(f"Teacher model loaded: {teacher_config.model_type}")

## Initialize Student Model (Shishya)

In [None]:
# Initialize student model (smaller version)
student_config = ModelConfig(
    vocab_size=teacher_config.vocab_size,
    hidden_size=config.student_model.hidden_size,
    num_layers=config.student_model.num_layers,
    num_attention_heads=config.student_model.num_heads,
    intermediate_size=config.student_model.intermediate_size,
    max_position_embeddings=config.student_model.max_seq_len,
    use_flash_attention=True,
    use_gqa=True,
    num_key_value_heads=config.student_model.num_kv_heads
)

student_model = VishwamAIModel(student_config)
print(f"Student model initialized with {student_config.num_layers} layers")

## Prepare Training Data

In [None]:
# Load and prepare training data
train_dataset = load_dataset(
    config.data.path,
    split='train'
)

def preprocess_function(examples):
    # Tokenize inputs
    model_inputs = teacher_tokenizer(
        examples['text'],
        max_length=config.model.max_seq_len,
        truncation=True,
        padding='max_length'
    )
    return model_inputs

tokenized_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=train_dataset.column_names
)

print(f"Training dataset size: {len(tokenized_dataset)}")

## Initialize Distillation

In [None]:
# Initialize the guru-shishya training process
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

# Initialize training state
rng = jax.random.PRNGKey(config.training.seed)
state = trainer.create_train_state(rng)

## Training Loop

In [None]:
# Training loop with progress tracking
num_epochs = config.training.num_epochs
steps_per_epoch = len(tokenized_dataset) // config.training.batch_size

for epoch in range(num_epochs):
    with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch+1}") as pbar:
        for step, batch in enumerate(trainer.get_train_dataloader(tokenized_dataset)):
            state, loss_dict, rng = trainer.train_step(state, batch, step, rng)
            
            if step % config.training.log_every == 0:
                pbar.set_postfix({
                    'loss': f"{loss_dict['total_loss']:.4f}",
                    'kd_loss': f"{loss_dict['kd_loss']:.4f}"
                })
            pbar.update(1)
            
        # Save checkpoint at end of epoch
        if (epoch + 1) % config.training.save_every == 0:
            trainer.save_checkpoint(
                state,
                f"{config.student_model.path}/checkpoint-{epoch+1}"
            )

## Save Final Model

In [None]:
# Save the final distilled model with metadata
model_card = {
    "base_model": "perplexity-ai/r1-1776",
    "model_type": "distilled-language-model",
    "distillation_method": "VishwamAI guru-shishya knowledge transfer",
    "architecture": {
        "hidden_size": config.student_model.hidden_size,
        "num_layers": config.student_model.num_layers,
        "num_heads": config.student_model.num_heads,
        "num_kv_heads": config.student_model.num_kv_heads
    },
    "training": {
        "distillation_temperature": config.teacher_model.temperature,
        "alpha": config.teacher_model.alpha,
        "epochs": config.training.num_epochs
    }
}

trainer.save_model(
    state,
    config.student_model.path,
    push_to_hub=True,
    model_card=model_card
)

print("\n🎉 Distillation complete!")
print(f"💡 Model saved to: {config.student_model.path}")
print("📊 Model card and weights pushed to Hugging Face Hub")

## Model Validation

In [None]:
# Quick validation test
test_input = "What is the capital of France?"

# Get teacher output
teacher_output = trainer.generate_text(teacher_model, test_input)
print("Teacher output:", teacher_output)

# Get student output
student_output = trainer.generate_text(student_model, test_input)
print("Student output:", student_output)