# VishwamAI QwQ-32B Distillation

Optimized distillation from [Qwen/QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) to VishwamAI-7B.

In [None]:
import jax
import jax.numpy as jnp
import safetensors.flax as stf
from omegaconf import OmegaConf
import aim
from huggingface_hub import snapshot_download

from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.qwen_distiller import QwenDistillationTrainer
from vishwamai.qwen_data import QwenDataLoader

print(f"JAX devices: {jax.devices()}")

In [None]:
# Load QwQ-specific configuration
config = OmegaConf.load('configs/distillation_config.yaml')
print("Configuration loaded successfully")

# Download QwQ model files
qwq_path = snapshot_download(
    "Qwen/QwQ-32B",
    allow_patterns=["*.safetensors", "config.json", "tokenizer.model"],
    local_files_only=False,
    resume_download=True
)
print(f"Downloaded QwQ-32B to {qwq_path}")

In [None]:
# Initialize data loader optimized for QwQ architecture
loader = QwenDataLoader(
    safetensor_dir=qwq_path,
    batch_size=config.training.batch_size,
    max_sequence_length=config.distillation.teacher_model.config.max_position_embeddings
)

print("Data loader initialized")

In [None]:
# Initialize models with QwQ-specific configurations
teacher_model = VishwamAIModel(ModelConfig(**config.distillation.teacher_model.config))
student_model = VishwamAIModel(ModelConfig(**config.distillation.student_model.config))

print("Loading QwQ model weights...")
params = loader.load_all_shards()
teacher_model = teacher_model.bind({'params': params})
print("Models initialized successfully")

In [None]:
# Initialize trainer with experiment tracking
aim_run = aim.Run(
    experiment="VishwamAI-QwQ-Distillation",
    log_system_params=True
)
aim_run.set_params({
    "teacher_model": "QwQ-32B",
    "student_model": "VishwamAI-7B",
    **OmegaConf.to_container(config, resolve=True)
})

trainer = QwenDistillationTrainer(
    teacher_model=teacher_model,
    student_model=student_model,
    cfg=config
)

rng = jax.random.PRNGKey(42)
state = trainer.create_train_state(rng)
print("Training setup complete")

In [None]:
# Training loop optimized for QwQ architecture
from tqdm.notebook import tqdm
import time

try:
    for epoch in range(5):  # 5 epochs through all shards
        print(f"\nEpoch {epoch + 1}/5")
        epoch_start = time.time()
        
        # Process shards sequentially
        for shard_name, shard_params in tqdm(loader.get_shard_stream(), desc="Processing QwQ shards"):
            # Create training batch
            batch = loader.create_training_batch(
                input_ids=shard_params['input_ids']
            )
            
            # Training step with QwQ optimizations
            state, metrics = trainer.train_step(
                state=state,
                batch=batch,
                rng=rng
            )
            
            # Log metrics with shard context
            if state.step % config.training.logging_steps == 0:
                aim_run.track(
                    metrics,
                    step=state.step,
                    context={
                        'shard': shard_name,
                        'epoch': epoch
                    }
                )
                print(f"Step {state.step}: loss={metrics['loss']:.4f}")
            
            # Save checkpoint
            if state.step % config.training.save_steps == 0:
                ckpt_path = f"checkpoints/step_{state.step}"
                trainer.save_checkpoint(
                    state=state,
                    path=ckpt_path,
                    extra_info={
                        'epoch': epoch,
                        'shard': shard_name,
                        'metrics': metrics
                    }
                )
                aim_run.track_artifact(ckpt_path, name="checkpoints")
            
            # Memory cleanup
            if state.step % 10 == 0:
                jax.clear_caches()
                gc.collect()
        
        epoch_time = time.time() - epoch_start
        print(f"Epoch completed in {epoch_time:.2f}s")

except KeyboardInterrupt:
    print("Training interrupted, saving checkpoint...")
    trainer.save_checkpoint(state, "checkpoints/interrupted")
finally:
    aim_run.close()

In [None]:
# Save final distilled model
final_path = "final_vishwamai_model"
trainer.save_model(
    state=state,
    path=final_path,
    config_override={
        "parent_model": "Qwen/QwQ-32B",
        "distillation_version": "v1.0"
    }
)
print(f"Distilled model saved to {final_path}")