# VishwamAI Model Distillation

This notebook demonstrates the knowledge distillation process from Perplexity-AI/r1-1776 (guru) to our VishwamAI model (shishya). Modifications have been made to ensure compatibility between PyTorch and JAX, validate layer mappings, and enhance robustness.

In [None]:
import jax
import jax.numpy as jnp
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import load_dataset
import os
from tqdm import tqdm
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.model import VishwamAIModel, ModelConfig
from omegaconf import OmegaConf

# Verify JAX setup
print(f"JAX devices: {jax.devices()}")

## Load Configuration

In [None]:
# Load Perplexity-R1 distillation configuration
try:
    config = OmegaConf.load('vishwamai/configs/training/perplexity_r1_distillation.yaml')
except FileNotFoundError:
    raise FileNotFoundError("Configuration file not found. Please check the path.")

# Display configuration with validation
print("Teacher Model:", config.teacher_model.path)
print("Student Model:", config.student_model.path)
print("\nModel Architecture:")
print(f"Hidden Size: {config.student_model.hidden_size}")
print(f"Num Layers: {config.student_model.num_layers}")
print(f"Attention Heads: {config.student_model.num_heads}")
print(f"KV Heads: {config.student_model.num_kv_heads}")

print("\nDistillation Parameters:")
print(f"Temperature: {config.teacher_model.temperature}")
print(f"Alpha: {config.teacher_model.alpha}")
print(f"Feature Layers: {config.distillation.feature_distillation.layers}")

# Validate key parameters
assert config.teacher_model.temperature > 0, "Temperature must be positive"
assert 0 <= config.teacher_model.alpha <= 1, "Alpha must be between 0 and 1"

## Initialize Teacher Model (Guru)

In [None]:
# Load the teacher model (Perplexity r1-1776)
teacher_tokenizer = AutoTokenizer.from_pretrained(config.teacher_model.path)
teacher_config = AutoConfig.from_pretrained(config.teacher_model.path)
teacher_model = AutoModelForCausalLM.from_pretrained(
    config.teacher_model.path,
    device_map='auto',
    torch_dtype=torch.bfloat16  # Specify dtype for consistency
)

# Move to evaluation mode and ensure GPU placement
teacher_model.eval()
print(f"Teacher model loaded: {teacher_config.model_type}")
print(f"Teacher layers: {teacher_config.num_hidden_layers}")

## Initialize Student Model (Shishya)

In [None]:
# Initialize student model with explicit parameters
student_config = ModelConfig(
    vocab_size=teacher_config.vocab_size,
    hidden_size=config.student_model.hidden_size,
    num_layers=config.student_model.num_layers,
    num_attention_heads=config.student_model.num_heads,
    intermediate_size=config.student_model.intermediate_size,
    max_position_embeddings=config.student_model.max_seq_len,
    use_flash_attention=True,
    use_gqa=True,
    num_key_value_heads=config.student_model.num_kv_heads
)

student_model = VishwamAIModel(student_config)
print(f"Student model initialized with {student_config.num_layers} layers")

# Validate layer mapping for feature distillation
if config.distillation.feature_distillation.layers:
    assert max(config.distillation.feature_distillation.layers) < student_config.num_layers, \
        "Feature distillation layers exceed student model layers"

## Prepare Training Data

In [None]:
# Load and prepare training data
try:
    train_dataset = load_dataset(config.data.path, split='train')
except Exception as e:
    raise ValueError(f"Failed to load dataset from {config.data.path}: {str(e)}")

def preprocess_function(examples):
    # Tokenize inputs using teacher's tokenizer
    model_inputs = teacher_tokenizer(
        examples['text'],
        max_length=config.model.max_seq_len,
        truncation=True,
        padding='max_length',
        return_tensors='pt'  # Return PyTorch tensors initially
    )
    return {k: v.numpy() for k, v in model_inputs.items()}  # Convert to numpy for JAX compatibility

tokenized_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=train_dataset.column_names
)

print(f"Training dataset size: {len(tokenized_dataset)}")
# Validate dataset
assert len(tokenized_dataset) > 0, "Tokenized dataset is empty"

## Initialize Distillation

In [None]:
# Initialize the guru-shishya training process
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

# Initialize training state with seed
rng = jax.random.PRNGKey(config.training.seed)
state = trainer.create_train_state(rng)

## Training Loop

In [None]:
# Training loop with progress tracking and error handling
num_epochs = config.training.num_epochs
batch_size = config.training.batch_size
steps_per_epoch = len(tokenized_dataset) // batch_size

for epoch in range(num_epochs):
    with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        dataloader = trainer.get_train_dataloader(tokenized_dataset)
        for step, batch in enumerate(dataloader):
            try:
                # Ensure batch is in JAX format
                batch = {k: jnp.array(v) for k, v in batch.items()}
                state, loss_dict, rng = trainer.train_step(state, batch, step, rng)
                
                if step % config.training.log_every == 0:
                    pbar.set_postfix({
                        'loss': f"{float(loss_dict['total_loss']):.4f}",
                        'kd_loss': f"{float(loss_dict['kd_loss']):.4f}"
                    })
                pbar.update(1)
            except Exception as e:
                print(f"Error in step {step}: {str(e)}")
                break
        
        # Save checkpoint at end of epoch
        if (epoch + 1) % config.training.save_every == 0:
            checkpoint_path = f"{config.student_model.path}/checkpoint-{epoch+1}"
            trainer.save_checkpoint(state, checkpoint_path)
            print(f"Checkpoint saved at: {checkpoint_path}")

## Save Final Model

In [None]:
# Save the final distilled model with metadata
model_card = {
    "base_model": "perplexity-ai/r1-1776",
    "model_type": "distilled-language-model",
    "distillation_method": "VishwamAI guru-shishya knowledge transfer",
    "architecture": {
        "hidden_size": config.student_model.hidden_size,
        "num_layers": config.student_model.num_layers,
        "num_heads": config.student_model.num_heads,
        "num_kv_heads": config.student_model.num_kv_heads
    },
    "training": {
        "distillation_temperature": config.teacher_model.temperature,
        "alpha": config.teacher_model.alpha,
        "epochs": config.training.num_epochs
    }
}

trainer.save_model(
    state,
    config.student_model.path,
    push_to_hub=True,
    model_card=model_card
)

print("\n🎉 Distillation complete!")
print(f"💡 Model saved to: {config.student_model.path}")
print("📊 Model card and weights pushed to Hugging Face Hub")

## Model Validation

In [None]:
# Quick validation test
test_input = "What is the capital of France?"

# Get teacher output
teacher_output = trainer.generate_text(teacher_model, test_input, max_length=50)
print("Teacher output:", teacher_output)

# Get student output
student_output = trainer.generate_text(student_model, test_input, max_length=50)
print("Student output:", student_output)

# Basic output validation
assert isinstance(teacher_output, str), "Teacher output is not a string"
assert isinstance(student_output, str), "Student output is not a string"