# Generate KD Cache with Qwen3-32B Teacher (All 3 Variants)

This notebook generates a high-quality KD cache using **Qwen3-32B** as the teacher model.

**Three thinking variants:**
- `true` - with thinking tags (`<think>...</think>`)
- `false` - without thinking tags
- `none` - raw text (no chat template)

**Requirements:**
- GPU with 80GB+ VRAM (A100 80GB, H100) for full precision
- Or 40GB+ VRAM with 4-bit quantization

**Estimated time:** ~4-8 hours for 50K sequences on A100

In [None]:
# ============================================================
# CHECK GPU
# ============================================================

!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

In [None]:
# ============================================================
# SETUP: Clone repo and install dependencies
# ============================================================

!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git 2>/dev/null || echo "Repo already exists"
%cd qwen3_apple_style_2bit_qat_lora
!git pull origin main

!pip install -q transformers accelerate datasets sentencepiece
!pip install -q bitsandbytes  # For 4-bit quantization if needed

In [None]:
# ============================================================
# CONFIG
# ============================================================

# Teacher model - Qwen3-32B
MODEL_NAME = 'Qwen/Qwen3-32B'

# Dataset
DATASET_NAME = 'teknium/OpenHermes-2.5'
DATASET_FORMAT = 'openhermes'

# Cache parameters
TOP_K = 128           # Number of top-k logits to cache
RAND_NEG = 64         # Random negatives
MAX_LENGTH = 128      # Sequence length
NUM_SEQUENCES = 50000 # Total sequences (3x variants = 150K effective samples)
SHARD_SIZE = 1000     # Samples per shard

# Batch size - adjust based on GPU memory
# A100 80GB: batch_size=8-16
# A100 40GB with 4-bit: batch_size=16-32
BATCH_SIZE = 8

# Enable all 3 thinking variants
ENABLE_THINKING = 'all'

# Output cache name
CACHE_NAME = f"openhermes_32B_L{MAX_LENGTH}_K{TOP_K}_ALL_N{NUM_SEQUENCES//1000}K"
CACHE_DIR = f"caches/{CACHE_NAME}"

print(f"="*60)
print(f"KD Cache Configuration")
print(f"="*60)
print(f"Teacher model:    {MODEL_NAME}")
print(f"Dataset:          {DATASET_NAME}")
print(f"Thinking mode:    {ENABLE_THINKING} (3 variants)")
print(f"Top-K:            {TOP_K}")
print(f"Random negatives: {RAND_NEG}")
print(f"Max length:       {MAX_LENGTH}")
print(f"Sequences:        {NUM_SEQUENCES}")
print(f"Batch size:       {BATCH_SIZE}")
print(f"Output:           {CACHE_DIR}")
print(f"="*60)

In [None]:
# ============================================================
# OPTIONAL: Mount Google Drive for saving cache
# ============================================================

SAVE_TO_DRIVE = True  # Set to False if not using Colab/Drive

if SAVE_TO_DRIVE:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        DRIVE_CACHE_DIR = '/content/drive/MyDrive/qwen3_caches'
        !mkdir -p {DRIVE_CACHE_DIR}
        print(f"Google Drive mounted. Will save to: {DRIVE_CACHE_DIR}")
    except:
        print("Not running in Colab or Drive mount failed. Saving locally only.")
        SAVE_TO_DRIVE = False

In [None]:
# ============================================================
# GENERATE KD CACHE
# ============================================================
# Estimated time: ~4-8 hours for 50K sequences on A100
#
# With --enable_thinking all, each input generates 3 variants:
#   1. With thinking tags
#   2. Without thinking tags  
#   3. Raw text (no template)
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/precompute_teacher_topk.py \
    --teacher_model_name_or_path {MODEL_NAME} \
    --dataset_name {DATASET_NAME} \
    --dataset_split train \
    --dataset_format {DATASET_FORMAT} \
    --enable_thinking {ENABLE_THINKING} \
    --max_length {MAX_LENGTH} \
    --topk {TOP_K} \
    --rand_neg {RAND_NEG} \
    --num_sequences {NUM_SEQUENCES} \
    --output_dir {CACHE_DIR} \
    --batch_size {BATCH_SIZE} \
    --shard_size {SHARD_SIZE} \
    --dtype bf16 \
    --device auto

In [None]:
# ============================================================
# VERIFY CACHE
# ============================================================

import os
import json

if os.path.isdir(CACHE_DIR):
    # Count shards
    shards = [f for f in os.listdir(CACHE_DIR) if f.startswith('shard_')]
    print(f"[cache] Generated {len(shards)} shards")
    
    # Check meta.json
    meta_path = os.path.join(CACHE_DIR, 'meta.json')
    if os.path.exists(meta_path):
        with open(meta_path) as f:
            meta = json.load(f)
        print(f"[cache] Meta info:")
        for k, v in meta.items():
            print(f"  - {k}: {v}")
    
    # Calculate size
    total_size = sum(os.path.getsize(os.path.join(CACHE_DIR, f)) for f in os.listdir(CACHE_DIR))
    print(f"[cache] Total size: {total_size / (1024**3):.2f} GB")
else:
    print(f"[cache] ERROR: {CACHE_DIR} not found")

In [None]:
# ============================================================
# INSPECT A SAMPLE SHARD
# ============================================================

import torch
import os

shard_files = sorted([f for f in os.listdir(CACHE_DIR) if f.startswith('shard_')])
if shard_files:
    shard_path = os.path.join(CACHE_DIR, shard_files[0])
    shard = torch.load(shard_path, map_location='cpu')
    
    print(f"Shard: {shard_files[0]}")
    print(f"Keys: {list(shard.keys())}")
    
    if 'input_ids' in shard:
        print(f"\ninput_ids shape: {shard['input_ids'].shape}")
    if 'topk_indices' in shard:
        print(f"topk_indices shape: {shard['topk_indices'].shape}")
    if 'topk_probs' in shard:
        print(f"topk_probs shape: {shard['topk_probs'].shape}")

In [None]:
# ============================================================
# SAVE TO GOOGLE DRIVE
# ============================================================

if SAVE_TO_DRIVE and os.path.isdir(CACHE_DIR):
    print(f"[save] Copying {CACHE_NAME} to Google Drive...")
    !rsync -ah --info=progress2 {CACHE_DIR}/ {DRIVE_CACHE_DIR}/{CACHE_NAME}/
    
    # Verify
    gd_path = f"{DRIVE_CACHE_DIR}/{CACHE_NAME}"
    if os.path.isdir(gd_path):
        num_files = len(os.listdir(gd_path))
        print(f"[save] Successfully saved to Google Drive: {num_files} files")
    else:
        print(f"[save] ERROR: Failed to copy to Google Drive")
else:
    print("[save] Skipping Google Drive save")

## Memory Optimization Options

If you run out of memory with Qwen3-32B, try these options:

### Option 1: Reduce batch size
```python
BATCH_SIZE = 4  # or even 2
```

### Option 2: Use 4-bit quantization (add to script)
```bash
# Modify precompute_teacher_topk.py to support load_in_4bit
# Or use this in a custom cell:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)
```

### Option 3: Use Qwen3-14B instead
```python
MODEL_NAME = 'Qwen/Qwen3-14B'  # Fits in 40GB with bf16
```

## Summary

**Cache generated:** `openhermes_32B_L128_K128_ALL_N50K`

**Contains:**
- 50K sequences Ã— 3 variants = ~150K effective training samples
- Top-128 teacher logits from Qwen3-32B
- Three thinking modes: think, no-think, no-template

**To use in training:**
```python
CACHE_DIR = 'caches/openhermes_32B_L128_K128_ALL_N50K'
```

**Load from Google Drive:**
```bash
rsync -ah --info=progress2 /content/drive/MyDrive/qwen3_caches/openhermes_32B_L128_K128_ALL_N50K/ caches/openhermes_32B_L128_K128_ALL_N50K/
```