<a href="https://colab.research.google.com/github/VishwamAI/VishwamAI/blob/main/train_vishwamai_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VishwamAI Distillation Training

This notebook implements knowledge distillation from DeepSeek to a smaller VishwamAI model.

In [1]:
# Clone the repository and change directory
!git clone https://github.com/VishwamAI/VishwamAI
%cd VishwamAI

fatal: destination path 'VishwamAI' already exists and is not an empty directory.
/content/VishwamAI


In [2]:
# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5 safetensors


In [3]:
import os
import json
import logging
from pathlib import Path
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from flax.training import train_state
import gc
import jax

from vishwamai.model import VishwamAIModel, ModelConfig
import safetensors
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader
from huggingface_hub import snapshot_download


## Configuration Setup

In [4]:
# Teacher model config (DeepSeek)
teacher_config = {
    "vocab_size": 129280,
    "hidden_size": 7168,
    "num_layers": 61,
    "num_attention_heads": 128,
    "intermediate_size": 18432,
    "hidden_dropout_prob": 0.1,
    "attention_dropout_prob": 0.1,
    "max_position_embeddings": 163840,
    "initializer_range": 0.02,
    "layer_norm_eps": 1e-5,
    "use_cache": True, # Changed from true to True
    "pad_token_id": 0,
    "bos_token_id": 1,
    "eos_token_id": 2,
    "tie_word_embeddings": True,
    "use_flash_attention": True,
    "use_rope": True,
    "use_alibi": False, # Changed from false to False
    "use_gqa": True,
    "num_key_value_heads": 128,
    "dtype": "bfloat16"
}

# Student model config (smaller VishwamAI)
student_config = {
    "vocab_size": 129280,
    "hidden_size": 2048,  # Smaller hidden size
    "num_layers": 24,  # Fewer layers
    "num_attention_heads": 32,
    "intermediate_size": 8192,
    "hidden_dropout_prob": 0.1,
    "attention_dropout_prob": 0.1,
    "max_position_embeddings": 163840,
    "initializer_range": 0.02,
    "layer_norm_eps": 1e-5,
    "use_cache": True, # Changed from true to True
    "pad_token_id": 0,
    "bos_token_id": 1,
    "eos_token_id": 2,
    "tie_word_embeddings": True,
    "use_flash_attention": True,
    "use_rope": True,
    "use_alibi": False, # Changed from false to False
    "use_gqa": True,
    "num_key_value_heads": 32,
    "dtype": "bfloat16"
}

## Model Setup

In [6]:
import safetensors.flax as stf
import safetensors
def download_partial_model(model_path: str, num_shards: int = 15):
    """Download only specified number of model shards"""
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model"])  # Add other required files

    try:
        local_path = snapshot_download(
            repo_id=model_path,
            allow_patterns=patterns,
            local_files_only=False,
            resume_download=True
        )
        print(f"Successfully downloaded {num_shards} model shards to {local_path}")
        return local_path
    except Exception as e:
        raise ValueError(f"Error downloading model shards: {str(e)}")

# Load distillation configuration
config_path = os.path.join("vishwamai", "configs", "training", "perplexity_r1_distillation.yaml")
# Check if config file exists. If not, create a default config
if not os.path.exists(config_path):
    distillation_config = OmegaConf.create({
        'distillation': {
            'teacher_model': {
                'path': "perplexity-ai/r1-1776",
                'config': {
                    'hidden_size': 7168,
                    'intermediate_size': 18432,
                    'num_attention_heads': 128,
                    'num_layers': 61,
                    'num_key_value_heads': 128,
                    'vocab_size': 129280,
                    'max_position_embeddings': 163840
                }
            },
            'student_model': {
                'path': "model-00001-to-00015-of-00252.safetensors",
                'config': {
                    'hidden_size': 2048,
                    'intermediate_size': 8192,
                    'num_attention_heads': 32,
                    'num_layers': 24,
                    'num_key_value_heads': 32,
                    'vocab_size': 129280,
                    'max_position_embeddings': 163840
                }
            }
        }
    })
    print(f"Warning: Config file not found at {config_path}. Using default config.")
# If config file exists, load it using OmegaConf
else:
    distillation_config = OmegaConf.load(config_path)

# Download partial teacher model
teacher_path = download_partial_model(
    distillation_config['distillation']['teacher_model']['path'],
    num_shards=5  # Reduced for memory efficiency
)

# Initialize teacher model
teacher_config = distillation_config['distillation']['teacher_model']['config']
student_config = distillation_config['distillation']['student_model']['config']

teacher_model = VishwamAIModel(ModelConfig(**teacher_config))
teacher_model.load_weights(teacher_path, reduced_size=True)

# Initialize student model
student_model = VishwamAIModel(ModelConfig(**student_config))

# Initialize tokenizer (placeholder)
tokenizer = VishwamAITokenizer(
    vocab_size=teacher_config["vocab_size"],
    model_prefix="vishwamai"
)

# Initialize trainer (placeholder)
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=distillation_config
)

print("Setup completed successfully!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

model-00004-of-00252.safetensors:   0%|          | 0.00/5.24G [00:00<?, ?B/s]

model-00001-of-00252.safetensors:   0%|          | 0.00/5.22G [00:00<?, ?B/s]

model-00003-of-00252.safetensors:   0%|          | 0.00/5.24G [00:00<?, ?B/s]

model-00002-of-00252.safetensors:   0%|          | 0.00/5.24G [00:00<?, ?B/s]

model-00005-of-00252.safetensors:   0%|          | 0.00/5.33G [00:00<?, ?B/s]

Successfully downloaded 5 model shards to /root/.cache/huggingface/hub/models--perplexity-ai--r1-1776/snapshots/1ae3222e162fe7dd8511b1f74f27e100e7b82d6a
Debug: Using safetensors.flax as stf: <module 'safetensors.flax' from '/usr/local/lib/python3.11/dist-packages/safetensors/flax.py'>
Loading reduced size model for memory constraints...
Loading model-00001-of-00252.safetensors...
Loading model-00002-of-00252.safetensors...
Loading model-00003-of-00252.safetensors...
Loading model-00004-of-00252.safetensors...
Loading model-00005-of-00252.safetensors...
Setup completed successfully!


## Training Loop

In [None]:
# Import Aim for experiment tracking
import aim
# Initialize Aim
aim_run = aim.Run(experiment="VishwamAI-Distillation")

# Create data loaders
train_loader = create_train_dataloader(OmegaConf.create(distillation_config))
val_loader = create_val_dataloader(OmegaConf.create(distillation_config))

# Initialize training state
rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)

# Initialize guru knowledge with feature matching
guru = VishwamaiGuruKnowledge(OmegaConf.create(distillation_config))

# Start training
try:
    for step in range(distillation_config['training']['max_steps']):
        batch = next(train_loader)

        # Get teacher predictions and features
        teacher_outputs = teacher_model(
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            output_hidden_states=True,
            output_attentions=True
        )

        # Training step with knowledge distillation
        state, metrics, rng = trainer.train_step(
            state=state,
            batch=batch,
            step=step,
            rng=rng
        )

        # Enhanced logging with distillation metrics
        if step % distillation_config['training']['logging_steps'] == 0:
            distill_metrics = {
                'kd_loss': metrics['kd_loss'],
                'feature_loss': metrics.get('feature_loss', 0.0),
                'attention_loss': metrics.get('attention_loss', 0.0),
                'hidden_loss': metrics.get('hidden_loss', 0.0),
                'total_loss': metrics['total_loss'],
                'temperature': guru.temperature
            }

            # Log to Aim
            for metric_name, metric_value in distill_metrics.items():
                aim_run.track(metric_value, name=metric_name, step=step)

            # Print current distillation progress
            print(f"\nStep {step}:")
            print(f"KD Loss: {distill_metrics['kd_loss']:.4f}")
            print(f"Feature Loss: {distill_metrics['feature_loss']:.4f}")
            print(f"Total Loss: {distill_metrics['total_loss']:.4f}")

        # Evaluation with feature matching
        if step % distillation_config['training']['eval_steps'] == 0:
            eval_metrics = trainer.evaluate(
                state=state,
                val_loader=val_loader,
                teacher_model=teacher_model,
                guru=guru
            )

            # Log evaluation metrics to Aim
            for k, v in eval_metrics.items():
                aim_run.track(v, name=f"eval_{k}", step=step)

        # Dynamic temperature adjustment
        if step % 1000 == 0:
            guru.temperature = max(1.0, guru.temperature * 0.95)

        # Save checkpoint
        if step % distillation_config['training']['save_steps'] == 0:
            ckpt_path = f"checkpoints/step_{step}"
            trainer.save_checkpoint(
                state=state,
                path=ckpt_path,
                guru=guru,
                metadata={
                    'temperature': guru.temperature,
                    'step': step,
                    'metrics': metrics
                }
            )
            # Track checkpoint as artifact in Aim
            aim_run.track_artifact(ckpt_path, name="checkpoints")

    # Final quantization with teacher guidance
    if distillation_config['distillation']['quantization']['enabled']:
        state = trainer.quantize_model(
            state=state,
            val_loader=val_loader,
            num_calibration_steps=100
        )
        trainer.save_checkpoint(state, "checkpoints/quantized")
        aim_run.track_artifact("checkpoints/quantized", name="quantized_model")

except Exception as e:
    print(f"Training failed: {str(e)}")
    aim_run.set_params({"error": str(e)})
finally:
    aim_run.close()




Downloading data:   0%|          | 0/1024 [00:00<?, ?files/s]

c4-train.00033-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00034-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00035-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00036-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00037-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00038-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00039-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00040-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00041-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00042-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00043-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00044-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00045-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00046-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00047-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00048-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00049-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00050-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00051-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00052-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00053-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00054-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00055-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00056-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00057-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00058-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00059-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00060-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00061-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00062-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00063-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00064-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00065-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00066-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00067-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00068-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00069-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00070-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00071-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00072-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00073-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00074-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00075-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00076-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00077-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00078-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00079-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00080-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00081-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00082-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00083-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00084-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00085-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00086-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00087-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00088-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00089-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00090-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00091-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00092-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00093-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00094-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00095-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00096-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00097-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00098-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00099-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00100-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00101-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00102-of-01024.json.gz:   0%|          | 0.00/321M [00:00<?, ?B/s]

c4-train.00103-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00104-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00105-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00106-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00107-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00108-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00109-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

c4-train.00110-of-01024.json.gz:   0%|          | 0.00/320M [00:00<?, ?B/s]

c4-train.00111-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00112-of-01024.json.gz:   0%|          | 0.00/319M [00:00<?, ?B/s]

c4-train.00113-of-01024.json.gz:   0%|          | 0.00/318M [00:00<?, ?B/s]

## Model Export

In [None]:
from huggingface_hub import HfApi

# Push distilled model to HuggingFace Hub
api = HfApi()

repo_id = "VishwamAI/VishwamAI-small"
api.create_repo(repo_id, exist_ok=True)

# Upload model files
api.upload_file(
    path_or_fileobj="configs/student_config.json",
    path_in_repo="config.json",
    repo_id=repo_id
)

api.upload_file(
    path_or_fileobj="checkpoints/quantized/model.safetensors",
    path_in_repo="model.safetensors",
    repo_id=repo_id
)

api.upload_file(
    path_or_fileobj="tokenizer/vishwamai.model",
    path_in_repo="tokenizer.model",
    repo_id=repo_id
)

## Evaluation and Analysis

In [None]:
# Compare teacher and student model performance
def evaluate_models(teacher, student, val_loader, num_batches=10):
    teacher_metrics = []
    student_metrics = []

    for i in range(num_batches):
        batch = next(val_loader)

        # Teacher predictions
        teacher_output = teacher(batch['input_ids'], deterministic=True)
        teacher_metrics.append({
            'loss': float(teacher_output['loss']),
            'accuracy': float(teacher_output.get('accuracy', 0))
        })

        # Student predictions
        student_output = student(batch['input_ids'], deterministic=True)
        student_metrics.append({
            'loss': float(student_output['loss']),
            'accuracy': float(student_output.get('accuracy', 0))
        })

    # Calculate averages
    teacher_avg = {k: sum(m[k] for m in teacher_metrics) / len(teacher_metrics)
                  for k in teacher_metrics[0]}
    student_avg = {k: sum(m[k] for m in student_metrics) / len(student_metrics)
                   for k in student_metrics[0]}

    return {
        'teacher': teacher_avg,
        'student': student_avg,
        'compression_ratio': f"{teacher.param_count / student.param_count:.2f}x"
    }

results = evaluate_models(teacher_model, student_model, val_loader)
print("\nModel Comparison:")
print(f"Compression Ratio: {results['compression_ratio']}")
print("\nTeacher Model:")
print(f"Loss: {results['teacher']['loss']:.4f}")
print(f"Accuracy: {results['teacher']['accuracy']:.4f}")
print("\nStudent Model:")
print(f"Loss: {results['student']['loss']:.4f}")
print(f"Accuracy: {results['student']['accuracy']:.4f}")

In [None]:
# Update model paths
TEACHER_MODEL_PATH = "perplexity-ai/r1-1776"
OUTPUT_MODEL_PATH = "VishwamAI/Perplexity_r1_disttled_experiment"

# Verify teacher model files exist
import os
from huggingface_hub import snapshot_download

# Download teacher model files
teacher_path = snapshot_download(
    repo_id=TEACHER_MODEL_PATH,
    allow_patterns=["*.safetensors", "config.json", "tokenizer.model"]
)
print(f"Downloaded teacher model to {teacher_path}")

In [None]:
# Modified model saving logic
def save_sharded_model(state, save_dir, num_shards=15):
    """Save model weights in sharded safetensor format."""
    import safetensors
    import safetensors.flax as stf
    import os

    os.makedirs(save_dir, exist_ok=True)

    # Get model parameters
    params = state.params

    # Calculate parameters per shard
    total_params = sum(p.size for p in jax.tree_leaves(params))
    params_per_shard = total_params // num_shards

    # Save in sharded format
    current_shard = 0
    current_size = 0
    shard_dict = {}

    for name, param in params.items():
        param_size = param.size

        if current_size + param_size > params_per_shard:
            # Save current shard
            shard_path = os.path.join(
                save_dir,
                f"model-{current_shard+1:05d}-of-{num_shards:05d}.safetensors"
            )
            stf.save_file(shard_dict, shard_path)  # Use stf instead of safetensors

            # Start new shard
            current_shard += 1
            current_size = 0
            shard_dict = {}

        shard_dict[name] = param
        current_size += param_size

    # Save final shard
    if shard_dict:
        shard_path = os.path.join(
            save_dir,
            f"model-{current_shard+1:05d}-of-{num_shards:05d}.safetensors"
        )
        stf.save_file(shard_dict, shard_path)  # Use stf instead of safetensors

In [None]:
# Modified model export cell
from huggingface_hub import HfApi

# Push distilled model to HuggingFace Hub
api = HfApi()

# Save model in sharded format
save_dir = "checkpoints/final"
save_sharded_model(state, save_dir)

# Upload to HF Hub
api.create_repo(OUTPUT_MODEL_PATH, exist_ok=True)

# Upload configuration
api.upload_file(
    path_or_fileobj="configs/student_config.json",
    path_in_repo="config.json",
    repo_id=OUTPUT_MODEL_PATH
)

# Upload all model shards
for shard_file in sorted(os.listdir(save_dir)):
    if shard_file.endswith(".safetensors"):
        api.upload_file(
            path_or_fileobj=os.path.join(save_dir, shard_file),
            path_in_repo=shard_file,
            repo_id=OUTPUT_MODEL_PATH
        )

# Upload tokenizer
api.upload_file(
    path_or_fileobj="tokenizer/vishwamai.model",
    path_in_repo="tokenizer.model",
    repo_id=OUTPUT_MODEL_PATH
)