In [None]:
# Clone and setup repository
!git clone https://github.com/VishwamAI/VishwamAI
%cd VishwamAI

# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5 safetensors

In [None]:
import os
import json
import logging
from pathlib import Path
from omegaconf import OmegaConf
import jax
import jax.numpy as jnp
from flax.training import train_state
import gc

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.distillation import VishwamaiGuruKnowledge, VishwamaiShaalaTrainer
from vishwamai.data_utils import create_train_dataloader, create_val_dataloader
from huggingface_hub import snapshot_download

logging.basicConfig(level=logging.INFO)

In [None]:
# Configure QwQ-32B teacher model
teacher_config = ModelConfig(
    vocab_size=151936,
    hidden_size=7168,
    num_layers=60,
    num_attention_heads=56,
    intermediate_size=28672,
    hidden_dropout_prob=0.1,
    attention_dropout_prob=0.1,
    max_position_embeddings=2048,
    initializer_range=0.02,
    layer_norm_eps=1e-5,
    use_cache=True,
    use_flash_attention=True,
    use_rope=True,
    use_gqa=True,
    num_key_value_heads=8,
    dtype="bfloat16"
)

# Configure smaller student model
student_config = ModelConfig(
    vocab_size=151936,
    hidden_size=2048,
    num_layers=24,
    num_attention_heads=32,
    intermediate_size=8192,
    hidden_dropout_prob=0.1,
    attention_dropout_prob=0.1,
    max_position_embeddings=2048,
    initializer_range=0.02,
    layer_norm_eps=1e-5,
    use_cache=True,
    use_flash_attention=True,
    use_rope=True,
    use_gqa=True,
    num_key_value_heads=8,
    dtype="bfloat16"
)

In [None]:
# Download and initialize QwQ-32B teacher model
def download_model(model_path: str, num_shards: int = 5):
    patterns = [f"model-{i+1:05d}-of-00252.safetensors" for i in range(num_shards)]
    patterns.extend(["config.json", "tokenizer.model"])
    return snapshot_download(
        repo_id=model_path,
        allow_patterns=patterns,
        local_files_only=False,
        resume_download=True
    )

teacher_path = download_model("Qwen/QwQ-32B", num_shards=5)
print(f"Downloaded teacher model to {teacher_path}")

In [None]:
# Initialize models, tokenizer and trainer
teacher_model = VishwamAIModel(teacher_config)
teacher_model.load_weights(teacher_path, reduced_size=True)

student_model = VishwamAIModel(student_config)

tokenizer = VishwamAITokenizer(
    vocab_size=teacher_config.vocab_size,
    model_prefix="vishwamai"
)

# Load distillation config
distillation_config = OmegaConf.load('configs/distillation_config.yaml')

trainer = VishwamaiShaalaTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=distillation_config
)

print("Initialized models and trainer successfully!")

In [None]:
# Create data loaders
train_loader = create_train_dataloader(distillation_config)
val_loader = create_val_dataloader(distillation_config)

# Initialize training state
rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)

print("Created data loaders and training state!")

In [None]:
# Training loop
from tqdm.notebook import tqdm

num_epochs = 10
steps_per_epoch = len(train_loader)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    with tqdm(total=steps_per_epoch) as pbar:
        for step in range(steps_per_epoch):
            # Get batch and teacher predictions
            batch = next(train_loader)
            rng, train_rng = jax.random.split(rng)
            
            # Training step
            outputs, state = trainer.train_step(state, batch, train_rng)
            
            # Update progress
            pbar.update(1)
            pbar.set_postfix({
                'loss': f"{outputs['loss']:.4f}",
                'kd_loss': f"{outputs['metrics']['kd_loss']:.4f}"
            })
            
            # Clear memory
            if step % 10 == 0:
                gc.collect()
        
        # Validation
        val_metrics = trainer.eval_step(state, val_loader)
        print(f"Validation metrics: {val_metrics}")
        
        # Save checkpoint
        if (epoch + 1) % 2 == 0:
            trainer.save_checkpoint(f"checkpoint_epoch_{epoch+1}")

print("Training completed!")

In [None]:
# Save distilled model to Hugging Face Hub
from huggingface_hub import HfApi

def save_to_hub(model, tokenizer, config, repo_id, token=None):
    """Save distilled model to Hugging Face Hub"""
    api = HfApi(token=token)
    
    # Create repo if it doesn't exist
    try:
        api.create_repo(repo_id, private=False, exist_ok=True)
    except Exception as e:
        print(f"Error creating repo: {e}")
        return
    
    # Save model files locally first
    tmp_dir = "distilled_model"
    os.makedirs(tmp_dir, exist_ok=True)
    
    # Save config
    config_dict = {
        "model_type": "VishwamAI",
        "architectures": ["VishwamAIModel"],
        **config.__dict__
    }
    with open(f"{tmp_dir}/config.json", 'w') as f:
        json.dump(config_dict, f, indent=2)
    
    # Save tokenizer
    tokenizer.save(tmp_dir)
    
    # Save model weights in safetensors format
    weights_file = f"{tmp_dir}/model.safetensors"
    save_params = {k: v for k, v in state.params.items()}
    stf.save_file(save_params, weights_file)
    
    # Upload to Hub
    api.upload_folder(
        folder_path=tmp_dir,
        repo_id=repo_id,
        commit_message="Upload distilled model"
    )
    
    print(f"Successfully uploaded model to {repo_id}")

# Save model after training
repo_id = "VishwamAI/qwq-32b-distilled"  # Change this to your desired repo name
hf_token = "your_huggingface_token"  # Add your HF token here

save_to_hub(
    model=student_model,
    tokenizer=tokenizer,
    config=student_config,
    repo_id=repo_id,
    token=hf_token
)

In [None]:
# Create model card
model_card = """
---
language:
- en
tags:
- distillation
- qwq-32b
- vishwamai
license: apache-2.0
---

# QwQ-32B Distilled Model

This is a distilled version of the QwQ-32B model. The model has been trained using knowledge distillation to compress the knowledge from the larger teacher model into a smaller, more efficient student model.

## Model Details
- Teacher Model: QwQ-32B
- Student Architecture: {student_arch}
- Vocabulary Size: {vocab_size}
- Training Data: Various high-quality datasets

## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("{repo_id}")
tokenizer = AutoTokenizer.from_pretrained("{repo_id}")
```
""".format(
    student_arch=f"Hidden Size: {student_config.hidden_size}, Layers: {student_config.num_layers}",
    vocab_size=student_config.vocab_size,
    repo_id=repo_id
)

with open(f"{tmp_dir}/README.md", 'w') as f:
    f.write(model_card)

# Upload model card
api.upload_file(
    path_or_fileobj=f"{tmp_dir}/README.md",
    path_in_repo="README.md",
    repo_id=repo_id,
    commit_message="Add model card"
)

In [None]:
# Push model to HuggingFace Hub
from huggingface_hub import HfApi
import safetensors.flax as stf

def push_to_hub(model_name: str = "VishwamAI/VishwamAI", token: str = None):
    """Push distilled model to HuggingFace Hub"""
    api = HfApi(token=token)
    
    # Create local files
    tmp_dir = "distilled_model"
    os.makedirs(tmp_dir, exist_ok=True)
    
    # Save model configuration
    student_config.save_pretrained(tmp_dir)
    
    # Save tokenizer
    tokenizer.save_pretrained(tmp_dir)
    
    # Save model weights using safetensors
    model_path = f"{tmp_dir}/model.safetensors"
    params = student_model.params
    stf.save_file(params, model_path)
    
    # Create model card
    model_card = f"""
    ---
    language:
    - en
    tags:
    - vishwamai
    - distillation
    - qwq-32b
    license: apache-2.0
    ---
    # VishwamAI Distilled Model
    
    This is a distilled version of QwQ-32B model using VishwamAI's distillation pipeline.
    
    ## Model Details
    - Teacher: QwQ-32B ({teacher_config.hidden_size} hidden size)
    - Student: VishwamAI ({student_config.hidden_size} hidden size)
    - Vocab Size: {student_config.vocab_size}
    - Context Length: {student_config.max_position_embeddings}
    """
    
    with open(f"{tmp_dir}/README.md", "w") as f:
        f.write(model_card)
    
    # Push to hub
    api.create_repo(model_name, exist_ok=True)
    api.upload_folder(
        folder_path=tmp_dir,
        repo_id=model_name,
        commit_message="Upload distilled model"
    )
    print(f"Model successfully pushed to {model_name}")

# Push the model (use your HF token)
push_to_hub(token="your_huggingface_token_here")

In [None]:
# Test the uploaded model
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the model from Hub
model = AutoModelForCausalLM.from_pretrained("VishwamAI/VishwamAI")
tokenizer = AutoTokenizer.from_pretrained("VishwamAI/VishwamAI")

# Test inference
text = "Hello, my name is"
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=50)
print(tokenizer.decode(outputs[0]))