# VishwamAI Enhanced Distillation with ToT and Error Correction

This notebook implements knowledge distillation using safetensors format for efficient model handling.

In [None]:
import os
import json
import logging
from pathlib import Path

# JAX related imports
import jax
import jax.numpy as jnp
import optax

# Data processing and visualization
import numpy as np
import matplotlib.pyplot as plt
import yaml
from tqdm.auto import tqdm

# Model loading and weights handling
from huggingface_hub import login, snapshot_download
import safetensors.flax as stf
import safetensors
from omegaconf import OmegaConf

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("vishwamai_distillation")

# Import VishwamAI components
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.tot import TreeOfThoughts
from vishwamai.integration import ToTIntegrationLayer
from vishwamai.training.distillation import DistillationTrainer
from vishwamai.training.data import create_train_dataloader, create_val_dataloader
from vishwamai.training.utils import set_seed, setup_training, get_training_args

# Check available devices and set random seed
print(f"JAX devices: {jax.devices()}")
set_seed(42)

In [None]:
def load_safetensor_weights(model_path: Path, memory_limit_gb: float = 4.0):
    """Load model weights from safetensors format"""
    weights = {}
    total_size = 0
    current_shard_size = 0
    memory_limit = memory_limit_gb * 1024 * 1024 * 1024
    
    shard_files = sorted(model_path.glob("*.safetensors"))
    if not shard_files:
        raise ValueError(f"No safetensors files found in {model_path}")
    
    for shard_file in shard_files:
        logger.info(f"Loading shard: {shard_file.name}")
        shard = stf.load_file(str(shard_file))
        shard_size = sum(tensor.nbytes for tensor in shard.values())
        
        if current_shard_size + shard_size > memory_limit:
            logger.warning("Memory limit reached, skipping remaining shards")
            break
            
        weights.update(shard)
        current_shard_size += shard_size
        total_size += shard_size
    
    return weights

def download_partial_model(model_path: str, num_shards: int = 15):
    """Download model shards in safetensors format"""
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model", "tokenizer_config.json"])
    
    try:
        local_path = snapshot_download(
            repo_id=model_path,
            allow_patterns=patterns,
            local_files_only=False,
            resume_download=True
        )
        logger.info(f"Downloaded model to: {local_path}")
        return local_path
    except Exception as e:
        raise ValueError(f"Error downloading model: {str(e)}")

# Configuration
TEACHER_MODEL_ID = "perplexity-ai/r1-1776"
HF_TOKEN = "your_token_here"  # Replace with actual token

# Login to Hugging Face
try:
    login(token=HF_TOKEN)
    print("Successfully logged in to Hugging Face Hub")
except Exception as e:
    print(f"Error logging in: {str(e)}\nPlease check your token and try again.")
    raise

In [None]:
# Load and validate configuration
config_path = Path("vishwamai/configs/training/perplexity_r1_distillation.yaml")
if not config_path.exists():
    raise FileNotFoundError(f"Configuration file not found: {config_path}")

config = OmegaConf.load(config_path)

# Get training arguments and setup
training_args = get_training_args(config)
training_setup = setup_training(config)

# Initialize models
try:
    # Download and initialize teacher model
    teacher_path = Path(download_partial_model(TEACHER_MODEL_ID, num_shards=5))
    teacher_weights = load_safetensor_weights(teacher_path)
    
    # Initialize models with proper configuration
    teacher_model = VishwamAIModel(ModelConfig(**config.distillation.teacher_model.config))
    teacher_model.load_weights(teacher_weights)
    
    student_model = VishwamAIModel(ModelConfig(**config.distillation.student_model.config))
    tokenizer = VishwamAITokenizer.from_pretrained(str(teacher_path))
    
    print("Models initialized successfully")
    
except Exception as e:
    logger.error(f"Error initializing models: {str(e)}")
    raise

In [None]:
# Initialize training components
try:
    # Create data loaders with proper batch size handling
    train_loader = create_train_dataloader(
        config,
        tokenizer=tokenizer,
        shuffle=True
    )
    val_loader = create_val_dataloader(
        config,
        tokenizer=tokenizer
    )
    
    # Initialize trainer with all components
    trainer = DistillationTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        tokenizer=tokenizer,
        config=config,
        training_args=training_args,
        use_safetensors=True
    )
    
    # Create checkpoint directory
    checkpoint_dir = Path(config.training.checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    print("Training components initialized successfully")
    
except Exception as e:
    logger.error(f"Error initializing training components: {str(e)}")
    raise

In [None]:
# Training loop with proper error handling
try:
    logger.info("Starting training")
    metrics = trainer.train(
        num_steps=config.training.max_steps,
        checkpoint_dir=checkpoint_dir,
        save_steps=config.training.save_steps,
        eval_steps=config.training.eval_steps
    )
    print("Training completed successfully!")
    
except Exception as e:
    logger.error(f"Training failed: {str(e)}")
    raise
finally:
    # Save current state even if training fails
    try:
        trainer.save_checkpoint(checkpoint_dir / "interrupted_state")
    except:
        logger.warning("Failed to save interrupted state")

In [None]:
# Plot and save training results
plt.figure(figsize=(12, 5))

# Plot losses
plt.subplot(1, 2, 1)
steps = range(len(metrics['loss']))
plt.plot(steps, metrics['loss'], label='Total Loss')
plt.plot(steps, metrics['distill_loss'], label='Distillation Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.grid(True)

# Plot learning rate
plt.subplot(1, 2, 2)
plt.plot(steps, metrics['learning_rate'])
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)

plt.tight_layout()
plt.savefig(checkpoint_dir / 'training_metrics.png')
plt.show()

# Save detailed metrics and configuration
metrics_path = checkpoint_dir / 'training_metrics.json'
with open(metrics_path, 'w') as f:
    json.dump({
        'metrics': metrics,
        'config': OmegaConf.to_container(config),
        'training_args': training_args
    }, f, indent=2)

# Print training summary
print("\nTraining Summary:")
print(f"Final loss: {metrics['loss'][-1]:.4f}")
print(f"Final distillation loss: {metrics['distill_loss'][-1]:.4f}")
print(f"\nResults saved to: {checkpoint_dir}")