# VishwamAI Enhanced Distillation with ToT and Error Correction

This notebook implements knowledge distillation from larger models to a smaller VishwamAI model with Tree of Thoughts (ToT) integration and error correction components.

In [None]:
# Clone the repository
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI

In [None]:
# Install dependencies
%pip install -q transformers datasets accelerate bitsandbytes wandb safetensors sentencepiece flax optax omegaconf safetensors huggingface-hub einops mlflow aim

In [None]:
import os
import json
import logging
from pathlib import Path
import mlflow
import aim
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
import flax
from flax.training import train_state
from huggingface_hub import snapshot_download, HfApi
import safetensors.flax as stf
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("vishwamai_distillation")

# Import VishwamAI components
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.tot import TreeOfThoughts
from vishwamai.integration import ToTIntegrationLayer, MixtureDensityNetwork, MultiLevelToTAttention
from vishwamai.error_correction_trainer import ErrorCorrectionTrainer
from vishwamai.training import train, create_train_dataloader, create_val_dataloader
from vishwamai.transformer import VishwamAIModel as VisionTransformer10B

print(f"JAX devices: {jax.devices()}")

## Enhanced Configuration Setup

Setting up configuration for distillation with Tree of Thoughts and Error Correction components.

In [None]:
# Create advanced configuration with ToT and Error Correction components
config = {
    "model": {
        "vocab_size": 129280,
        "hidden_size": 2048,  # Smaller hidden size for student model
        "num_layers": 24,  # Fewer layers
        "num_attention_heads": 32,
        "intermediate_size": 8192,
        "hidden_dropout_prob": 0.1,
        "attention_dropout_prob": 0.1,
        "max_position_embeddings": 4096,  # Reduced context length for better training efficiency
        "initializer_range": 0.02,
        "layer_norm_eps": 1e-5,
        "use_cache": True,
        "pad_token_id": 0,
        "bos_token_id": 1,
        "eos_token_id": 2,
        "tie_word_embeddings": True,
        "use_flash_attention": True,
        "use_rope": True,
        "use_alibi": False,
        "use_gqa": True,
        "num_key_value_heads": 32,
        "dtype": "bfloat16",
        
        # MoE configuration
        "moe_enabled": True,
        "num_experts": 8,
        "expert_capacity": 0.25,
        "expert_dropout": 0.1,
        "moe_layers": [5, 11, 17, 23],  # Specific layers to use MoE
        
        # MoD configuration
        "use_mod": True,
        "mod_num_mixtures": 5,
        "mod_balance_weight": 0.01,
        
        # ToT configuration
        "use_tot": True
    },
    
    "training": {
        "learning_rate": 5e-5,
        "warmup_steps": 1000,
        "max_steps": 100000,
        "batch_size": 16,
        "eval_batch_size": 8,
        "adam_beta1": 0.9,
        "adam_beta2": 0.999,
        "weight_decay": 0.01,
        "max_grad_norm": 1.0,
        "z_loss": 0.01,
        "log_every": 100,
        "eval_every": 1000,
        "save_every": 5000,
        "checkpoint_dir": "checkpoints/tot_distillation",
        "seed": 42,
        "logging_steps": 100,
        "eval_steps": 1000,
        "save_steps": 5000,
        
        # ToT specific training config
        "use_tot": True,
        "tot_search_strategy": "beam",
        "tot_max_thoughts": 5,
        "tot_max_depth": 3,
        "tot_beam_width": 8,
        "tot_pruning_threshold": 0.3,
        "tot_exploration_factor": 1.0,
        "tot_guidance_alpha": 0.2,
        
        # Error correction config
        "use_error_correction": True,
        "error_history_size": 100,
        "error_threshold_percentile": 85.0,
        "ec_loss_weight": 0.2
    },
    
    "data": {
        "dataset_name": "c4",
        "train_split": "train",
        "val_split": "validation",
        "text_column": "text",
        "max_seq_length": 1024,
        "preprocessing_num_workers": 4,
        "max_train_samples": 100000,
        "max_val_samples": 5000
    },
    
    "distillation": {
        "teacher_model": {
            "path": "perplexity-ai/r1-1776",
            "config": {
                "hidden_size": 7168,
                "intermediate_size": 18432,
                "num_attention_heads": 128,
                "num_layers": 61,
                "num_key_value_heads": 128,
                "vocab_size": 129280,
                "max_position_embeddings": 163840
            }
        },
        "kd_temperature": 2.0,
        "alpha_kd": 0.5,  # Weight for KL divergence loss
        "alpha_ce": 0.5,  # Weight for cross-entropy loss
        "alpha_tot": 0.2,  # Weight for ToT guidance
        "feature_matching": True,  # Whether to match intermediate features
        "feature_layers": [5, 11, 17, 23],  # Layers to match features
        "quantization": {
            "enabled": False
        },
        "output_path": "VishwamAI/tot-enhanced-distilled-model"
    }
}

# Convert to OmegaConf format
config = OmegaConf.create(config)

# Create directories for checkpoint saving
os.makedirs(config.training.checkpoint_dir, exist_ok=True)

## Download and Initialize Teacher Model

Downloading the teacher model with reduced shards to manage memory usage.

In [None]:
def download_partial_model(model_path: str, num_shards: int = 5):
    """Download only specified number of model shards"""
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model"])  # Add other required files

    try:
        local_path = snapshot_download(
            repo_id=model_path,
            allow_patterns=patterns,
            local_files_only=False,
            resume_download=True
        )
        print(f"Successfully downloaded {num_shards} model shards to {local_path}")
        return local_path
    except Exception as e:
        raise ValueError(f"Error downloading model shards: {str(e)}")

# Download partial teacher model
teacher_path = download_partial_model(
    config.distillation.teacher_model.path,
    num_shards=5  # Reduced for memory efficiency
)

## Initialize Models and Components

Setting up teacher model, student model, and associated components for ToT integration and error correction.

In [None]:
# Initialize error correction trainer
error_trainer = ErrorCorrectionTrainer(
    config=config,
    use_tot=config.training.use_tot,
    use_mod=config.model.use_mod,
    history_size=config.training.error_history_size,
    threshold_percentile=config.training.error_threshold_percentile
)

# Initialize teacher model (with reduced size for memory efficiency)
teacher_config = ModelConfig(**config.distillation.teacher_model.config)
teacher_model = VishwamAIModel(teacher_config)
teacher_model.load_weights(teacher_path, reduced_size=True)
print("Teacher model loaded successfully")

# Initialize student model
student_config = ModelConfig(**config.model)
student_model = VishwamAIModel(student_config)
print("Student model initialized")

# Initialize tokenizer
tokenizer = VishwamAITokenizer(
    vocab_size=config.model.vocab_size,
    model_prefix="vishwamai",
    error_tokens=True  # Enable error tokens for integration with error correction
)
print("Tokenizer initialized with error correction tokens")

In [None]:
# Initialize Tree of Thoughts components for student model
vision_transformer = VisionTransformer10B(student_config)
tot_model = TreeOfThoughts(
    transformer=vision_transformer,
    max_thoughts=config.training.tot_max_thoughts,
    max_depth=config.training.tot_max_depth,
    beam_width=config.training.tot_beam_width,
    pruning_threshold=config.training.tot_pruning_threshold,
    exploration_factor=config.training.tot_exploration_factor
)

# Create integration components
tot_integration = ToTIntegrationLayer(config.model)
mla = MultiLevelToTAttention(
    hidden_size=config.model.hidden_size,
    num_heads=min(8, config.model.num_attention_heads)
)
print("ToT components initialized")

# Create MoD component if enabled
if config.model.use_mod:
    mod_layer = MixtureDensityNetwork(
        hidden_size=config.model.hidden_size,
        num_mixtures=config.model.mod_num_mixtures
    )
    student_model.mod_layer = mod_layer
    print("MoD components initialized")

# Add components to student model
student_model.tot_model = tot_model
student_model.tot_integration = tot_integration
student_model.tot_mla = mla
student_model.use_tot = config.model.use_tot
student_model.use_mod = config.model.use_mod

print("All advanced components integrated into student model")

## Define Enhanced Knowledge Distillation Loss

Creating a specialized loss function that combines standard KD loss with ToT guidance.

In [None]:
import jax.nn as nn

def tot_guided_distillation_loss(
    student_logits, 
    teacher_logits, 
    labels, 
    tot_outputs=None, 
    temperature=1.0,
    alpha_kd=0.5, 
    alpha_ce=0.5, 
    alpha_tot=0.2
):
    """Enhanced KD loss with ToT guidance."""
    
    # KL divergence loss between student and teacher
    teacher_probs = nn.softmax(teacher_logits / temperature)
    kd_loss = -jnp.sum(teacher_probs * nn.log_softmax(student_logits / temperature)) * (temperature ** 2)
    
    # Standard cross entropy with ground truth
    ce_loss = -jnp.sum(nn.one_hot(labels, student_logits.shape[-1]) * nn.log_softmax(student_logits))
    
    # Add ToT guidance if available
    tot_loss = 0.0
    if tot_outputs is not None and alpha_tot > 0.0:
        attention_weights = tot_outputs.get('attention_weights', None)
        if attention_weights is not None:
            # Use ToT attention weights to guide where to focus learning
            attention_weights_flat = jnp.mean(attention_weights, axis=(0, 1))
            teacher_probs_weighted = teacher_probs * attention_weights_flat
            tot_loss = -jnp.sum(teacher_probs_weighted * nn.log_softmax(student_logits / temperature))
    
    # Add feature matching loss if available
    feature_loss = 0.0
    if 'feature_loss' in tot_outputs:
        feature_loss = tot_outputs['feature_loss']
    
    # Combine losses with weights
    combined_loss = alpha_kd * kd_loss + alpha_ce * ce_loss
    if tot_outputs is not None and alpha_tot > 0.0:
        combined_loss += alpha_tot * (tot_loss + feature_loss)
        
    return combined_loss, {
        "kd_loss": kd_loss, 
        "ce_loss": ce_loss, 
        "tot_loss": tot_loss,
        "feature_loss": feature_loss
    }

## Create Data Loaders

Setting up data loaders for training and evaluation.

In [None]:
# Create data loaders
logger.info("Creating data loaders...")
train_loader = create_train_dataloader(config)
val_loader = create_val_dataloader(config)
logger.info("Data loaders created successfully")

## Advanced Training with ToT and Error Correction

Running the training process with tracking and visualization.

In [None]:
# Initialize experiment tracking
mlflow.set_experiment("VishwamAI-ToT-Distillation")
aim_run = aim.Run(experiment="VishwamAI-ToT-Distillation")
logger.info("Starting advanced training with ToT and error correction")

# Get JAX PRNGKey for reproducibility
rng_key = jax.random.PRNGKey(config.training.seed)

# Track configurations in experiment tools
config_dict = OmegaConf.to_container(config)
aim_run.set_params(config_dict)

# Run training with ToT and error correction integration
with mlflow.start_run() as run:
    # Log important configuration parameters
    mlflow.log_params({
        "model.hidden_size": config.model.hidden_size,
        "model.num_layers": config.model.num_layers,
        "model.num_attention_heads": config.model.num_attention_heads,
        "training.learning_rate": config.training.learning_rate,
        "training.batch_size": config.training.batch_size,
        "training.use_tot": config.training.use_tot,
        "training.use_error_correction": config.training.use_error_correction,
        "model.use_mod": config.model.use_mod,
        "distillation.kd_temperature": config.distillation.kd_temperature,
        "distillation.alpha_tot": config.distillation.alpha_tot
    })
    
    try:
        # Start the training
        final_state = train(
            model=student_model,
            config=config,
            train_dataloader=train_loader,
            val_dataloader=val_loader,
            num_steps=config.training.max_steps,
            log_every=config.training.log_every,
            eval_every=config.training.eval_every,
            checkpoint_dir=config.training.checkpoint_dir
        )
        
        # Log final metrics
        final_metrics = {
            "final_loss": float(final_state.best_metrics['loss']),
            "final_accuracy": float(final_state.best_metrics['accuracy']),
            "ec_improvement": float(final_state.best_metrics.get('ec_improvement', 0.0))
        }
        
        # Log to both tracking systems
        mlflow.log_metrics(final_metrics)
        for k, v in final_metrics.items():
            aim_run.track(v, name=k)
        
        # Log model artifacts
        mlflow.log_artifacts(config.training.checkpoint_dir)
        
        logger.info("Training completed successfully!")
        logger.info(f"Best loss: {final_state.best_metrics['loss']:.4f}")
        logger.info(f"Best accuracy: {final_state.best_metrics['accuracy']:.4f}")
        logger.info(f"Error correction improvement: {final_state.best_metrics.get('ec_improvement', 0.0):.4f}")
        
    except Exception as e:
        logger.error(f"Training failed: {str(e)}")
        mlflow.log_param("error", str(e))
        aim_run.set_params({"error": str(e)})
    finally:
        aim_run.close()

## Visualize Training Results

Creating visualizations of the training metrics.

In [None]:
# Extract metrics from the training run
def get_training_metrics_from_mlflow(run_id=None):
    """Get metrics from MLflow tracking."""
    client = mlflow.tracking.MlflowClient()
    if run_id is None:
        runs = client.search_runs(experiment_ids=[mlflow.get_experiment_by_name("VishwamAI-ToT-Distillation").experiment_id])
        if not runs:
            return {}
        run = runs[0]  # Get the latest run
        run_id = run.info.run_id
    
    run = client.get_run(run_id)
    metrics_data = run.data.metrics
    
    # Get metrics history for certain key metrics
    metric_keys = ["loss", "accuracy", "ec_improvement", "tot_score", "kd_loss"]
    metrics_history = {}
    
    for key in metric_keys:
        if key in metrics_data:
            metrics_history[key] = [
                (m.step, m.value) for m in client.get_metric_history(run_id, key)
            ]
    
    return metrics_history

# Create visualizations
metrics = get_training_metrics_from_mlflow()

if metrics:
    plt.figure(figsize=(15, 10))
    
    # Plot loss
    plt.subplot(2, 2, 1)
    if 'loss' in metrics:
        steps, values = zip(*metrics['loss'])
        plt.plot(steps, values, label='Training Loss')
    if 'kd_loss' in metrics:
        steps, values = zip(*metrics['kd_loss'])
        plt.plot(steps, values, label='KD Loss')
    plt.xlabel('Training Steps')
    plt.ylabel('Loss')
    plt.title('Training and KD Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracy
    plt.subplot(2, 2, 2)
    if 'accuracy' in metrics:
        steps, values = zip(*metrics['accuracy'])
        plt.plot(steps, values)
    plt.xlabel('Training Steps')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy')
    plt.grid(True)
    
    # Plot error correction improvement
    plt.subplot(2, 2, 3)
    if 'ec_improvement' in metrics:
        steps, values = zip(*metrics['ec_improvement'])
        plt.plot(steps, values)
    plt.xlabel('Training Steps')
    plt.ylabel('Error Correction Improvement (%)')
    plt.title('Error Correction Impact')
    plt.grid(True)
    
    # Plot ToT score
    plt.subplot(2, 2, 4)
    if 'tot_score' in metrics:
        steps, values = zip(*metrics['tot_score'])
        plt.plot(steps, values)
    plt.xlabel('Training Steps')
    plt.ylabel('ToT Score')
    plt.title('Tree of Thoughts Quality Score')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.show()
else:
    print("No metrics available to visualize.")