# VishwamAI QwQ-32B Distillation

Optimized distillation from [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) using gradient accumulation for TPU.

In [None]:
!git clone https://github.com/VishwamAI/vishwamai.git
%cd vishwamai

In [None]:
# Install dependencies
!pip install -q --upgrade transformers datasets accelerate bitsandbytes sentencepiece \
    flax optax omegaconf huggingface-hub einops aim>=3.17.5 safetensors

In [None]:
import jax
import jax.numpy as jnp
import safetensors.flax as stf
from omegaconf import OmegaConf
import aim
from huggingface_hub import snapshot_download

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.qwen_distiller import QwenDistillationTrainer
from vishwamai.qwen_data import QwenDataLoader

# Print device info
print(f"JAX devices: {jax.devices()}")
print(f"Number of devices: {jax.device_count()}")

In [None]:
# Load configuration
config = OmegaConf.load('configs/distillation_config.yaml')

# TPU-aware batch settings
NUM_TPU_DEVICES = jax.device_count()
BATCH_SIZE = 1  # Per device
GRAD_ACCUM_STEPS = 16  # To achieve effective batch size of 16

config.training.batch_size = BATCH_SIZE
config.training.gradient_accumulation_steps = GRAD_ACCUM_STEPS

print(f"Configuration loaded with:")
print(f"- Batch size per device: {BATCH_SIZE}")
print(f"- Gradient accumulation steps: {GRAD_ACCUM_STEPS}")
print(f"- Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS * NUM_TPU_DEVICES}")

In [None]:
# Download QwQ model
print("Downloading QwQ-32B model files...")
qwq_path = snapshot_download(
    "Qwen/QwQ-32B",
    allow_patterns=["*.safetensors", "config.json", "tokenizer.model"],
    local_files_only=False,
    resume_download=True
)
print(f"Downloaded QwQ-32B to {qwq_path}")

In [None]:
# Initialize data loader with TPU-aware settings
loader = QwenDataLoader(
    safetensor_dir=qwq_path,
    batch_size=BATCH_SIZE,
    max_sequence_length=config.distillation.teacher_model.config.max_position_embeddings,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS
)

print("Data loader initialized with gradient accumulation")

In [None]:
# Initialize models
teacher_model = VishwamAIModel(ModelConfig(**config.distillation.teacher_model.config))
student_model = VishwamAIModel(ModelConfig(**config.distillation.student_model.config))

print("Loading QwQ model weights...")
params = loader.load_all_shards()
teacher_model = teacher_model.bind({'params': params})
print("Models initialized successfully")

In [None]:
# Initialize trainer with experiment tracking
aim_run = aim.Run(
    experiment="VishwamAI-QwQ-Distillation",
    log_system_params=True
)
aim_run.set_params({
    "teacher_model": "QwQ-32B",
    "student_model": "VishwamAI-7B",
    "batch_size_per_device": BATCH_SIZE,
    "gradient_accumulation_steps": GRAD_ACCUM_STEPS,
    "effective_batch_size": BATCH_SIZE * GRAD_ACCUM_STEPS * NUM_TPU_DEVICES,
    **OmegaConf.to_container(config, resolve=True)
})

trainer = QwenDistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)
print("Training setup complete")

In [None]:
# Training loop with gradient accumulation
from tqdm.notebook import tqdm
import time

try:
    for epoch in range(5):  # 5 epochs
        print(f"\nEpoch {epoch + 1}/5")
        epoch_start = time.time()
        
        # Process shards sequentially
        for shard_name, shard_params in tqdm(loader.get_shard_stream(), desc="Processing QwQ shards"):
            # Accumulate gradients over multiple steps
            accumulated_gradients = None
            
            for accum_step in range(GRAD_ACCUM_STEPS):
                # Create batch for current accumulation step
                batch = loader.create_training_batch(
                    input_ids=shard_params['input_ids'],
                    labels=shard_params.get('labels')
                )
                
                # Training step
                state, metrics, grads = trainer.train_step_with_grads(
                    state=state,
                    batch=batch,
                    rng=rng
                )
                
                # Accumulate gradients
                if accumulated_gradients is None:
                    accumulated_gradients = grads
                else:
                    accumulated_gradients = jax.tree_map(
                        lambda x, y: x + y,
                        accumulated_gradients,
                        grads
                    )
            
            # Apply accumulated gradients
            accumulated_gradients = jax.tree_map(
                lambda x: x / GRAD_ACCUM_STEPS,
                accumulated_gradients
            )
            state = state.apply_gradients(grads=accumulated_gradients)
            
            # Log metrics
            if state.step % config.training.logging_steps == 0:
                aim_run.track(
                    metrics,
                    step=state.step,
                    context={
                        'shard': shard_name,
                        'epoch': epoch
                    }
                )
                print(f"Step {state.step}: loss={metrics['loss']:.4f}")
            
            # Save checkpoint
            if state.step % config.training.save_steps == 0:
                ckpt_path = f"checkpoints/step_{state.step}"
                trainer.save_checkpoint(
                    state=state,
                    path=ckpt_path,
                    extra_info={
                        'epoch': epoch,
                        'shard': shard_name,
                        'metrics': metrics
                    }
                )
                aim_run.track_artifact(ckpt_path, name="checkpoints")
            
            # Memory cleanup
            if state.step % 10 == 0:
                jax.clear_caches()
                gc.collect()
        
        epoch_time = time.time() - epoch_start
        print(f"Epoch completed in {epoch_time:.2f}s")

except KeyboardInterrupt:
    print("Training interrupted, saving checkpoint...")
    trainer.save_checkpoint(state, "checkpoints/interrupted")
finally:
    aim_run.close()

In [None]:
# Save final distilled model
final_path = "final_vishwamai_model"
trainer.save_model(
    state=state,
    path=final_path,
    config_override={
        "parent_model": "Qwen/QwQ-32B",
        "distillation_version": "v1.0",
        "training_config": {
            "batch_size_per_device": BATCH_SIZE,
            "gradient_accumulation_steps": GRAD_ACCUM_STEPS,
            "effective_batch_size": BATCH_SIZE * GRAD_ACCUM_STEPS * NUM_TPU_DEVICES
        }
    }
)
print(f"Distilled model saved to {final_path}")