<a href="https://colab.research.google.com/github/VishwamAI/VishwamAI/blob/main/train_vishwamai_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VishwamAI QwQ-32B Distillation

Memory-efficient distillation from [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) with chunked loading.

In [1]:
!git clone https://github.com/VishwamAI/VishwamAI.git
%cd VishwamAI

Cloning into 'VishwamAI'...
remote: Enumerating objects: 2575, done.[K
remote: Counting objects: 100% (85/85), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 2575 (delta 35), reused 54 (delta 23), pack-reused 2490 (from 1)[K
Receiving objects: 100% (2575/2575), 35.46 MiB | 43.86 MiB/s, done.
Resolving deltas: 100% (1289/1289), done.
/content/VishwamAI


In [2]:
# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5 safetensors

In [3]:
import os
import jax
import jax.numpy as jnp
from omegaconf import OmegaConf
import aim
import gc
from huggingface_hub import snapshot_download

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.qwen_distiller import QwenDistillationTrainer  # Updated import
from vishwamai.qwen_data import QwenDataLoader
from vishwamai.tensor_utils import get_memory_usage

# Clear any existing cache
jax.clear_caches()
gc.collect()

# Print device info
print(f"JAX devices: {jax.devices()}")
print(f"Number of devices: {jax.device_count()}")
print(f"Initial memory usage: {get_memory_usage():.2f}GB")

JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
Number of devices: 8
Initial memory usage: 1.60GB


In [4]:
# Memory-efficient settings
CHUNK_SIZE = 32  # Size of chunks for loading
BATCH_SIZE = 1   # Per device batch size
GRAD_ACCUM_STEPS = 16  # Gradient accumulation steps

# Load and update configuration
config = OmegaConf.load('configs/distillation_config.yaml')
config.training.batch_size = BATCH_SIZE
config.training.gradient_accumulation_steps = GRAD_ACCUM_STEPS
config.memory_optimization.chunk_size = CHUNK_SIZE

print(f"Configuration loaded with:")
print(f"- Chunk size: {CHUNK_SIZE}")
print(f"- Batch size per device: {BATCH_SIZE}")
print(f"- Gradient accumulation steps: {GRAD_ACCUM_STEPS}")

Configuration loaded with:
- Chunk size: 32
- Batch size per device: 1
- Gradient accumulation steps: 16


In [5]:
# Download QwQ model with progress tracking
print("Downloading QwQ-32B model files...")
qwq_path = snapshot_download(
    "Qwen/QwQ-32B",
    allow_patterns=["*.safetensors", "config.json", "tokenizer.model"],
    local_files_only=False,
    resume_download=True
)

# Verify shard count
shard_files = [f for f in os.listdir(qwq_path) if f.endswith('.safetensors')]
print(f"Found {len(shard_files)} safetensor shards")
assert len(shard_files) == 14, f"Expected 14 safetensor files, found {len(shard_files)}"

Downloading QwQ-32B model files...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

Found 14 safetensor shards


In [6]:
# Initialize data loader with memory-efficient settings
loader = QwenDataLoader(
    safetensor_dir=qwq_path,
    batch_size=BATCH_SIZE,
    max_sequence_length=config.distillation.teacher_model.config.max_position_embeddings,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    chunk_size=CHUNK_SIZE
)

print("Data loader initialized with chunked loading")
print(f"Memory usage: {get_memory_usage():.2f}GB")

Using provided chunk size: 32
Initialized loader with:
 - 8 devices
 - Global batch size: 16
 - Per-device batch size: 2
 - Gradient accumulation steps: 16
 - Chunk size: 32
 - Initial memory usage: 1.61GB
Data loader initialized with chunked loading
Memory usage: 1.61GB


In [None]:
# Initialize models with memory tracking
print("Initializing models...")
print(f"Memory before: {get_memory_usage():.2f}GB")

teacher_model = VishwamAIModel(ModelConfig(**config.distillation.teacher_model.config))
student_model = VishwamAIModel(ModelConfig(**config.distillation.student_model.config))

print("\nLoading QwQ model weights in chunks...")
params = loader.load_all_shards()  # This now uses chunked loading
teacher_model = teacher_model.bind({'params': params})

# Clear memory after loading
jax.clear_caches()
gc.collect()

print(f"Memory after: {get_memory_usage():.2f}GB")
print("Models initialized successfully")

Initializing models...
Memory before: 1.61GB

Loading QwQ model weights in chunks...
Loading shard model-00001-of-00014.safetensors...
Memory before: 1.61GB


In [None]:
# Initialize trainer with experiment tracking
aim_run = aim.Run(
    experiment=config.monitoring.aim_experiment,
    log_system_params=True
)
aim_run.set_params({
    "teacher_model": "QwQ-32B",
    "student_model": "VishwamAI-7B",
    "chunk_size": CHUNK_SIZE,
    "batch_size_per_device": BATCH_SIZE,
    "gradient_accumulation_steps": GRAD_ACCUM_STEPS,
    "memory_initial": get_memory_usage(),
    **OmegaConf.to_container(config, resolve=True)
})

trainer = QwenDistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)
print("Training setup complete")
print(f"Memory usage: {get_memory_usage():.2f}GB")

In [None]:
# Training loop with memory management
from tqdm.notebook import tqdm
import time

try:
    for epoch in range(5):  # 5 epochs
        print(f"\nEpoch {epoch + 1}/5")
        epoch_start = time.time()

        # Process shards sequentially with memory cleanup
        for shard_name, shard_params in tqdm(loader.get_shard_stream(), desc="Processing QwQ shards"):
            # Memory check
            current_mem = get_memory_usage()
            if current_mem > config.memory_optimization.max_memory_gb:
                print(f"Warning: High memory usage ({current_mem:.2f}GB)")
                jax.clear_caches()
                gc.collect()

            # Accumulate gradients over multiple steps
            accumulated_gradients = None

            for accum_step in range(GRAD_ACCUM_STEPS):
                # Create batch for current accumulation step
                batch = loader.create_training_batch(
                    input_ids=shard_params['input_ids'],
                    labels=shard_params.get('labels')
                )

                # Training step
                state, metrics, grads = trainer.train_step_with_grads(
                    state=state,
                    batch=batch,
                    rng=rng
                )

                # Accumulate gradients
                if accumulated_gradients is None:
                    accumulated_gradients = grads
                else:
                    accumulated_gradients = jax.tree_map(
                        lambda x, y: x + y,
                        accumulated_gradients,
                        grads
                    )

                # Clear intermediate memory
                if accum_step % 4 == 0:
                    jax.clear_caches()
                    gc.collect()

            # Apply accumulated gradients
            accumulated_gradients = jax.tree_map(
                lambda x: x / GRAD_ACCUM_STEPS,
                accumulated_gradients
            )
            state = state.apply_gradients(grads=accumulated_gradients)

            # Log metrics with memory usage
            if state.step % config.training.logging_steps == 0:
                current_mem = get_memory_usage()
                metrics['memory_usage'] = current_mem
                aim_run.track(
                    metrics,
                    step=state.step,
                    context={
                        'shard': shard_name,
                        'epoch': epoch,
                        'memory_gb': current_mem
                    }
                )
                print(f"Step {state.step}: loss={metrics['loss']:.4f}, memory={current_mem:.2f}GB")

            # Save checkpoint
            if state.step % config.training.save_steps == 0:
                ckpt_path = f"checkpoints/step_{state.step}"
                trainer.save_checkpoint(
                    state=state,
                    path=ckpt_path,
                    extra_info={
                        'epoch': epoch,
                        'shard': shard_name,
                        'metrics': metrics,
                        'memory_usage': get_memory_usage()
                    }
                )
                aim_run.track_artifact(ckpt_path, name="checkpoints")

            # Memory cleanup after each shard
            del accumulated_gradients
            del shard_params
            jax.clear_caches()
            gc.collect()

        epoch_time = time.time() - epoch_start
        print(f"Epoch completed in {epoch_time:.2f}s")
        print(f"Memory usage: {get_memory_usage():.2f}GB")

except KeyboardInterrupt:
    print("Training interrupted, saving checkpoint...")
    trainer.save_checkpoint(state, "checkpoints/interrupted")
finally:
    aim_run.close()

In [None]:
# Save final model with memory usage stats
final_path = "final_vishwamai_model"

# Clear memory before saving
jax.clear_caches()
gc.collect()

trainer.save_model(
    state=state,
    path=final_path,
    config_override={
        "parent_model": "Qwen/QwQ-32B",
        "distillation_version": "v1.0",
        "training_config": {
            "chunk_size": CHUNK_SIZE,
            "batch_size_per_device": BATCH_SIZE,
            "gradient_accumulation_steps": GRAD_ACCUM_STEPS,
            "peak_memory_usage": get_memory_usage()
        }
    }
)
print(f"Distilled model saved to {final_path}")
print(f"Final memory usage: {get_memory_usage():.2f}GB")