# VishwamAI Distillation Training

This notebook implements knowledge distillation from DeepSeek to a smaller VishwamAI model.

In [None]:
# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5


In [None]:
import os
import json
import logging
import gc
from pathlib import Path
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from flax.training import train_state

# Configure TPU and memory settings
jax.config.update('jax_platform_name', 'tpu')
jax.config.update('jax_default_matmul_precision', 'bfloat16')

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader
from huggingface_hub import snapshot_download

In [None]:
from huggingface_hub import snapshot_download
import gc
import jax
import os

def download_partial_model(model_path: str, num_shards: int = 5):
    """Download only specified number of model shards with memory efficiency"""
    
    # Clear memory before starting
    gc.collect()
    jax.clear_backends()

    # Create pattern to match only first N safetensor files
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model"])  # Add other required files

    try:
        local_path = snapshot_download(
            repo_id=model_path,
            allow_patterns=patterns,
            local_files_only=False,
            resume_download=True
        )
        print(f"Successfully downloaded {num_shards} model shards to {local_path}")
        return local_path
    except Exception as e:
        raise ValueError(f"Error downloading model shards: {str(e)}")

# Load distillation configuration
distillation_config = OmegaConf.load(os.path.join("vishwamai", "configs", "training", "perplexity_r1_distillation.yaml"))

# Initialize teacher model with fewer partial weights
teacher_path = download_partial_model(
    distillation_config['distillation']['teacher_model']['path'],
    num_shards=5  # Reduced from 15 to 5 for memory efficiency
)

# Get configurations
teacher_config = distillation_config['distillation']['teacher_model']['config']
student_config = distillation_config['distillation']['student_model']['config']

# Update dtype to bfloat16 for memory efficiency
teacher_config['dtype'] = 'bfloat16'
student_config['dtype'] = 'bfloat16'

# Initialize models with memory cleanup between steps
teacher_model = VishwamAIModel(ModelConfig(**teacher_config))
gc.collect()
jax.clear_backends()

print("Loading teacher weights...")
teacher_model.load_weights(teacher_path)
gc.collect()
jax.clear_backends()

print("Initializing student model...")
student_model = VishwamAIModel(ModelConfig(**student_config))
gc.collect()
jax.clear_backends()

# Initialize tokenizer
print("Initializing tokenizer...")
tokenizer = VishwamAITokenizer(
    vocab_size=teacher_config["vocab_size"],
    model_prefix="vishwamai"
)

# Initialize distillation trainer with smaller batch size
print("Initializing trainer...")
distillation_config['training']['batch_size'] = 4  # Reduced batch size
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=distillation_config
)

print("Setup complete!")

In [None]:
# Import Aim for experiment tracking
import aim

# Initialize Aim
aim_run = aim.Run(experiment="VishwamAI-Distillation")

# Create data loaders with reduced batch size
train_loader = create_train_dataloader(OmegaConf.create(distillation_config))
val_loader = create_val_dataloader(OmegaConf.create(distillation_config))

# Initialize training state
rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)

# Initialize guru knowledge
guru = VishwamaiGuruKnowledge(OmegaConf.create(distillation_config))

# Start training with memory management
try:
    for step in range(distillation_config['training']['max_steps']):
        batch = next(train_loader)

        # Clear memory before heavy operations
        gc.collect()
        jax.clear_backends()

        # Get teacher predictions and features
        teacher_outputs = teacher_model(
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            output_hidden_states=True,
            output_attentions=True
        )

        # Training step with knowledge distillation
        state, metrics, rng = trainer.train_step(
            state=state,
            batch=batch,
            step=step,
            rng=rng
        )

        # Log metrics and cleanup memory
        if step % distillation_config['training']['logging_steps'] == 0:
            distill_metrics = {
                'kd_loss': metrics['kd_loss'],
                'feature_loss': metrics.get('feature_loss', 0.0),
                'attention_loss': metrics.get('attention_loss', 0.0),
                'hidden_loss': metrics.get('hidden_loss', 0.0),
                'total_loss': metrics['total_loss'],
                'temperature': guru.temperature
            }

            # Log to Aim
            for metric_name, metric_value in distill_metrics.items():
                aim_run.track(metric_value, name=metric_name, step=step)

            print(f"\nStep {step}:")
            print(f"KD Loss: {distill_metrics['kd_loss']:.4f}")
            print(f"Feature Loss: {distill_metrics['feature_loss']:.4f}")
            print(f"Total Loss: {distill_metrics['total_loss']:.4f}")

            # Clear memory after logging
            gc.collect()
            jax.clear_backends()

        if step % distillation_config['training']['eval_steps'] == 0:
            eval_metrics = trainer.evaluate(
                state=state,
                val_loader=val_loader,
                teacher_model=teacher_model,
                guru=guru
            )

            # Log evaluation metrics
            for k, v in eval_metrics.items():
                aim_run.track(v, name=f"eval_{k}", step=step)

        # Temperature adjustment and cleanup
        if step % 1000 == 0:
            guru.temperature = max(1.0, guru.temperature * 0.95)
            gc.collect()
            jax.clear_backends()

        # Save checkpoint with memory cleanup
        if step % distillation_config['training']['save_steps'] == 0:
            gc.collect()
            jax.clear_backends()
            
            ckpt_path = f"checkpoints/step_{step}"
            trainer.save_checkpoint(
                state=state,
                path=ckpt_path,
                guru=guru,
                metadata={
                    'temperature': guru.temperature,
                    'step': step,
                    'metrics': metrics
                }
            )
            aim_run.track_artifact(ckpt_path, name="checkpoints")

    # Final quantization
    if distillation_config['distillation']['quantization']['enabled']:
        gc.collect()
        jax.clear_backends()
        
        state = trainer.quantize_model(
            state=state,
            val_loader=val_loader,
            num_calibration_steps=100
        )
        trainer.save_checkpoint(state, "checkpoints/quantized")
        aim_run.track_artifact("checkpoints/quantized", name="quantized_model")

except Exception as e:
    print(f"Training failed: {str(e)}")
    aim_run.set_params({"error": str(e)})
finally:
    aim_run.close()
    gc.collect()
    jax.clear_backends()

## Model Export

In [None]:
from huggingface_hub import HfApi
import os

# Clear memory before export
gc.collect()
jax.clear_backends()

# Push distilled model to HuggingFace Hub
api = HfApi()
repo_id = "VishwamAI/VishwamAI-small"

# Create repository
api.create_repo(repo_id, exist_ok=True)

# Upload files in chunks
files_to_upload = [
    ("configs/student_config.json", "config.json"),
    ("checkpoints/quantized/model.safetensors", "model.safetensors"),
    ("tokenizer/vishwamai.model", "tokenizer.model")
]

for local_path, repo_path in files_to_upload:
    print(f"Uploading {local_path}...")
    api.upload_file(
        path_or_fileobj=local_path,
        path_in_repo=repo_path,
        repo_id=repo_id
    )
    gc.collect()

print("Export complete!")