# VishwamAI Distillation Training

This notebook implements knowledge distillation from DeepSeek to a smaller VishwamAI model.

In [None]:
# Clone the repository
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI

In [None]:
# Install dependencies
%pip install -q transformers datasets accelerate bitsandbytes wandb safetensors sentencepiece flax optax omegaconf safetensors huggingface-hub

In [None]:
import os
import json
import logging
from pathlib import Path
import wandb
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from flax.training import train_state

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader

## Configuration Setup

In [None]:
# Teacher model config (DeepSeek)
teacher_config = {
    "hidden_size": 7168,
    "intermediate_size": 18432,
    "num_attention_heads": 128,
    "num_layers": 61,
    "num_key_value_heads": 128,
    "vocab_size": 129280,
    "max_position_embeddings": 163840
}

# Student model config (smaller VishwamAI)
student_config = {
    "hidden_size": 2048,  # Smaller hidden size
    "intermediate_size": 8192,
    "num_attention_heads": 32,
    "num_layers": 24,  # Fewer layers
    "num_key_value_heads": 32,
    "vocab_size": 129280,
    "max_position_embeddings": 163840
}

# Distillation configuration
distillation_config = {
    "training": {
        "max_steps": 50000,
        "eval_steps": 500,
        "save_steps": 1000,
        "logging_steps": 100,
        "warmup_steps": 2000
    },
    "distillation": {
        "teacher_model": {
            "path": "perplexity-ai/r1-1776",
            "temperature": 2.0,
            "alpha": 0.5  # Weight between distillation and task loss
        },
        "feature_distillation": {
            "layers": [0, 8, 16],  # Layer indices to match
            "loss_weight": 0.1
        },
        "attention_distillation": {
            "loss_weight": 0.1
        },
        "hidden_distillation": {
            "loss_weight": 0.1
        },
        "pruning": {
            "enabled": True,
            "target_sparsity": 0.3,
            "begin_step": 1000,
            "end_step": 40000,
            "pruning_schedule": "cubic"
        },
        "quantization": {
            "enabled": True,
            "precision": "int8"
        }
    },
    "optimizer": {
        "learning_rate": 1e-4,
        "weight_decay": 0.01,
        "beta1": 0.9,
        "beta2": 0.95,
        "clip_grad_norm": 1.0
    },
    "data": {
        "train_batch_size": 32,
        "eval_batch_size": 32,
        "max_length": 2048
    }
}

# Save configurations
Path("configs").mkdir(exist_ok=True)
config_path = Path("configs")

with open(config_path / "teacher_config.json", "w") as f:
    json.dump(teacher_config, f, indent=2)

with open(config_path / "student_config.json", "w") as f:
    json.dump(student_config, f, indent=2)

with open(config_path / "distillation_config.yaml", "w") as f:
    OmegaConf.save(OmegaConf.create(distillation_config), f)

## Model Setup

In [None]:
# Initialize teacher model
teacher_model = VishwamAIModel(ModelConfig(**teacher_config))
teacher_model.load_weights(distillation_config['distillation']['teacher_model']['path'])

# Initialize student model
student_model = VishwamAIModel(ModelConfig(**student_config))

# Initialize tokenizer
tokenizer = VishwamAITokenizer(
    vocab_size=teacher_config["vocab_size"],
    model_prefix="vishwamai"
)

# Initialize distillation trainer
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=OmegaConf.create(distillation_config)
)

## Training Loop

In [None]:
# Initialize wandb
wandb.init(
    project="VishwamAI",
    name="vishwamai-distillation",
    config=distillation_config
)

# Create data loaders
train_loader = create_train_dataloader(OmegaConf.create(distillation_config))
val_loader = create_val_dataloader(OmegaConf.create(distillation_config))

# Initialize training state
rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)

# Initialize guru knowledge with feature matching
guru = VishwamaiGuruKnowledge(OmegaConf.create(distillation_config))

# Training loop with guru knowledge
try:
    for step in range(distillation_config['training']['max_steps']):
        batch = next(train_loader)
        
        # Get teacher predictions and features
        teacher_outputs = teacher_model(
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            output_hidden_states=True,
            output_attentions=True
        )
        
        # Training step with knowledge distillation
        state, metrics, rng = trainer.train_step(
            state=state,
            batch=batch,
            teacher_outputs=teacher_outputs,
            guru=guru,
            step=step,
            rng=rng
        )
        
        # Enhanced logging with distillation metrics
        if step % distillation_config['training']['logging_steps'] == 0:
            distill_metrics = {
                'kd_loss': metrics['kd_loss'],
                'feature_loss': metrics.get('feature_loss', 0.0),
                'attention_loss': metrics.get('attention_loss', 0.0),
                'hidden_loss': metrics.get('hidden_loss', 0.0),
                'total_loss': metrics['total_loss'],
                'temperature': guru.temperature
            }
            wandb.log({**metrics, **distill_metrics}, step=step)
            
            # Print current distillation progress
            print(f"\nStep {step}:")
            print(f"KD Loss: {distill_metrics['kd_loss']:.4f}")
            print(f"Feature Loss: {distill_metrics['feature_loss']:.4f}")
            print(f"Total Loss: {distill_metrics['total_loss']:.4f}")
        
        # Evaluation with feature matching
        if step % distillation_config['training']['eval_steps'] == 0:
            eval_metrics = trainer.evaluate(
                state=state,
                val_loader=val_loader,
                teacher_model=teacher_model,
                guru=guru
            )
            wandb.log({f"eval_{k}": v for k, v in eval_metrics.items()}, step=step)
        
        # Dynamic temperature adjustment based on training progress
        if step % 1000 == 0:
            guru.temperature = max(
                1.0,  # Minimum temperature
                guru.temperature * 0.95  # Gradual temperature decay
            )
        
        # Save checkpoint with distillation state
        if step % distillation_config['training']['save_steps'] == 0:
            ckpt_path = f"checkpoints/step_{step}"
            trainer.save_checkpoint(
                state=state,
                path=ckpt_path,
                guru=guru,  # Save guru state
                metadata={
                    'temperature': guru.temperature,
                    'step': step,
                    'metrics': metrics
                }
            )
            
    # Final quantization with teacher guidance
    if distillation_config['distillation']['quantization']['enabled']:
        state = trainer.quantize_model(
            state=state,
            val_loader=val_loader,
            teacher_model=teacher_model,
            guru=guru,
            num_calibration_steps=100
        )
        trainer.save_checkpoint(state, "checkpoints/quantized")
        
except Exception as e:
    print(f"Training failed: {str(e)}")
finally:
    wandb.finish()

## Model Export

In [None]:
from huggingface_hub import HfApi

# Push distilled model to HuggingFace Hub
api = HfApi()

repo_id = "VishwamAI/VishwamAI-small"
api.create_repo(repo_id, exist_ok=True)

# Upload model files
api.upload_file(
    path_or_fileobj="configs/student_config.json",
    path_in_repo="config.json",
    repo_id=repo_id
)

api.upload_file(
    path_or_fileobj="checkpoints/quantized/model.safetensors",
    path_in_repo="model.safetensors",
    repo_id=repo_id
)

api.upload_file(
    path_or_fileobj="tokenizer/vishwamai.model",
    path_in_repo="tokenizer.model",
    repo_id=repo_id
)

## Evaluation and Analysis

In [None]:
# Compare teacher and student model performance
def evaluate_models(teacher, student, val_loader, num_batches=10):
    teacher_metrics = []
    student_metrics = []
    
    for i in range(num_batches):
        batch = next(val_loader)
        
        # Teacher predictions
        teacher_output = teacher(batch['input_ids'], deterministic=True)
        teacher_metrics.append({
            'loss': float(teacher_output['loss']),
            'accuracy': float(teacher_output.get('accuracy', 0))
        })
        
        # Student predictions
        student_output = student(batch['input_ids'], deterministic=True)
        student_metrics.append({
            'loss': float(student_output['loss']),
            'accuracy': float(student_output.get('accuracy', 0))
        })
    
    # Calculate averages
    teacher_avg = {k: sum(m[k] for m in teacher_metrics) / len(teacher_metrics)
                  for k in teacher_metrics[0]}
    student_avg = {k: sum(m[k] for m in student_metrics) / len(student_metrics)
                   for k in student_metrics[0]}
    
    return {
        'teacher': teacher_avg,
        'student': student_avg,
        'compression_ratio': f"{teacher.param_count / student.param_count:.2f}x"
    }

results = evaluate_models(teacher_model, student_model, val_loader)
print("\nModel Comparison:")
print(f"Compression Ratio: {results['compression_ratio']}")
print("\nTeacher Model:")
print(f"Loss: {results['teacher']['loss']:.4f}")
print(f"Accuracy: {results['teacher']['accuracy']:.4f}")
print("\nStudent Model:")
print(f"Loss: {results['student']['loss']:.4f}")
print(f"Accuracy: {results['student']['accuracy']:.4f}")

In [None]:
# Update model paths
TEACHER_MODEL_PATH = "perplexity-ai/r1-1776"
OUTPUT_MODEL_PATH = "VishwamAI/Perplexity_r1_disttled_experiment"

# Verify teacher model files exist
import os
from huggingface_hub import snapshot_download

# Download teacher model files
teacher_path = snapshot_download(
    repo_id=TEACHER_MODEL_PATH,
    allow_patterns=["*.safetensors", "config.json", "tokenizer.model"]
)
print(f"Downloaded teacher model to {teacher_path}")

In [None]:
# Modified model saving logic
def save_sharded_model(state, save_dir, num_shards=252):
    """Save model weights in sharded safetensor format."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Get model parameters
    params = state.params
    
    # Calculate parameters per shard
    total_params = sum(p.size for p in jax.tree_leaves(params))
    params_per_shard = total_params // num_shards
    
    # Save in sharded format
    current_shard = 0
    current_size = 0
    shard_dict = {}
    
    for name, param in params.items():
        param_size = param.size
        
        if current_size + param_size > params_per_shard:
            # Save current shard
            shard_path = os.path.join(
                save_dir, 
                f"model-{current_shard+1:05d}-of-{num_shards:05d}.safetensors"
            )
            safetensors.save_file(shard_dict, shard_path)
            
            # Start new shard
            current_shard += 1
            current_size = 0
            shard_dict = {}
        
        shard_dict[name] = param
        current_size += param_size
    
    # Save final shard
    if shard_dict:
        shard_path = os.path.join(
            save_dir,
            f"model-{current_shard+1:05d}-of-{num_shards:05d}.safetensors"
        )
        safetensors.save_file(shard_dict, shard_path)

In [None]:
# Modified model export cell
from huggingface_hub import HfApi

# Push distilled model to HuggingFace Hub
api = HfApi()

# Save model in sharded format
save_dir = "checkpoints/final"
save_sharded_model(state, save_dir)

# Upload to HF Hub
api.create_repo(OUTPUT_MODEL_PATH, exist_ok=True)

# Upload configuration
api.upload_file(
    path_or_fileobj="configs/student_config.json",
    path_in_repo="config.json",
    repo_id=OUTPUT_MODEL_PATH
)

# Upload all model shards
for shard_file in sorted(os.listdir(save_dir)):
    if shard_file.endswith(".safetensors"):
        api.upload_file(
            path_or_fileobj=os.path.join(save_dir, shard_file),
            path_in_repo=shard_file,
            repo_id=OUTPUT_MODEL_PATH
        )

# Upload tokenizer
api.upload_file(
    path_or_fileobj="tokenizer/vishwamai.model", 
    path_in_repo="tokenizer.model",
    repo_id=OUTPUT_MODEL_PATH
)