# VishwamAI QwQ-32B Distillation Training

This notebook implements the distillation of QwQ-32B into a smaller model.

<a href="https://colab.research.google.com/github/VishwamAI/VishwamAI/blob/main/train_vishwamai_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/VishwamAI/VishwamAI
%cd VishwamAI

In [None]:
# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5 safetensors


In [None]:
# Verify installation and imports
import os
import json
import logging
from pathlib import Path
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from flax.training import train_state
import gc
import jax

# Now these imports should work
from vishwamai.model import VishwamAIModel, ModelConfig
import safetensors
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader
from huggingface_hub import snapshot_download

print("All imports successful!")

In [None]:
# Load configuration
config = OmegaConf.load('configs/distillation_config.yaml')

# Filter out unsupported config parameters
teacher_model_config = {
    k: v for k, v in config.distillation.teacher_model.config.items()
    if k not in ['rope_scaling', 'seq_length']  # Remove unsupported params
}

# Initialize teacher and student configs
teacher_config = ModelConfig(**teacher_model_config)
student_config = ModelConfig(**config.distillation.student_model)

print("Teacher config:\n", OmegaConf.to_yaml(teacher_config))
print("\nStudent config:\n", OmegaConf.to_yaml(student_config))

In [None]:
# Download QwQ-32B model shards
def download_model_shards(model_path: str, num_shards: int = 5):
    """Download model shards efficiently."""
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model"])
    
    return snapshot_download(
        repo_id=model_path,
        allow_patterns=patterns,
        local_files_only=False,
        resume_download=True
    )

teacher_path = download_model_shards(
    config.distillation.teacher_model.path,
    num_shards=5  # Using 5 shards for memory efficiency
)
print(f"Downloaded QwQ-32B model to {teacher_path}")

In [None]:
# Initialize models and tokenizer
teacher_model = VishwamAIModel(teacher_config)
student_model = VishwamAIModel(student_config)

print("Loading teacher weights...")
teacher_model.load_weights(teacher_path, reduced_size=True)

print("\nInitializing tokenizer...")
tokenizer = VishwamAITokenizer(
    vocab_size=teacher_config.vocab_size,
    model_prefix="vishwamai"
)

print("Models and tokenizer initialized successfully!")

In [None]:
# Initialize distillation trainer
trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

# Create training state
rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)

print("Trainer and training state initialized!")

In [None]:
# Create data loaders
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader

print("Creating data loaders...")
train_loader = create_train_dataloader(config)
val_loader = create_val_dataloader(config)
print("Data loaders created successfully!")

In [None]:
# Training loop
from tqdm.notebook import tqdm
import gc

num_epochs = 10
steps_per_epoch = len(train_loader)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    with tqdm(total=steps_per_epoch) as pbar:
        for step in range(steps_per_epoch):
            # Get batch and train
            batch = next(train_loader)
            rng, train_rng = jax.random.split(rng)
            
            # Training step
            outputs, state = trainer.train_step(state, batch, train_rng)
            
            # Update progress
            pbar.update(1)
            pbar.set_postfix({
                'loss': f"{outputs['loss']:.4f}",
                'kd_loss': f"{outputs['metrics']['kd_loss']:.4f}",
                'correction_rate': f"{outputs['metrics'].get('error_correction_rate', 0.0):.4f}"
            })
            
            # Memory management
            if step % 10 == 0:
                gc.collect()
        
        # Validation
        val_metrics = trainer.eval_step(state, val_loader)
        print(f"\nValidation metrics: {val_metrics}")
        
        # Save checkpoint
        if (epoch + 1) % 2 == 0:
            trainer.save_checkpoint(f"checkpoint_epoch_{epoch+1}")

print("Training completed!")

In [None]:
# Save final model
final_save_path = "final_distilled_model"
student_model.save_pretrained(final_save_path)
tokenizer.save(final_save_path)

print(f"Saved distilled model to {final_save_path}")