# Qwen3 2-bit QAT - Improved Pipeline v1

This notebook implements improved 2-bit QAT training with:

**Key Improvements:**
1. **MSE Grid-Search f_init** - Optimal scale initialization before training
2. **Full 3-Pass Progressive** - MLP + Attention + MLP Refinement + E2E
3. **Relaxed Convergence** - 2-bit needs looser thresholds (0.8-1.2 vs 0.4)
4. **Lower Learning Rates** - More stable for 2-bit

**Starting Point:** 4-bit QAT checkpoint (not from scratch)

**Pipeline:**
```
4-bit checkpoint
    |
    v
[MSE Grid-Search f_init Calibration] <- NEW
    |
    v
Pass 1: MLP L-b-L (local + global KD)
    |
    v
Pass 2: Attention L-b-L (global KD)  <- Previously skipped
    |
    v
Pass 3: MLP Refinement L-b-L        <- Previously skipped
    |
    v
Pass 4: E2E f-only tuning
    |
    v
[Optional: KD-QAT Refinement]
    |
    v
[Optional: LoRA Recovery]
```

In [None]:
# ============================================================
# SETUP: Clone repo and install dependencies
# ============================================================

!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git
%cd qwen3_apple_style_2bit_qat_lora
!pip install -q transformers accelerate datasets sentencepiece

In [None]:
# ============================================================
# CONFIG
# ============================================================

MODEL_NAME = 'Qwen/Qwen3-0.6B'

# Quantization
QUANT_BITS = 2

# KD Cache (use K=64 if available, otherwise K=32)
# K=64 provides richer teacher signal for 2-bit
CACHE_DIR_CHAT = 'caches/alpaca_chat_think_both_L128_K64_R512'
# Fallback to K=32 if K=64 not available:
# CACHE_DIR_CHAT = 'caches/alpaca_chat_think_both_L128_K32_R256'

# 4-bit checkpoint to start from
INIT_CHECKPOINT = 'runs/qwen3_kdqat_cache_q2_4/qat_state_dict.pt'

# Output directory
RUN_DIR = 'runs/progressive_qat_q2_improved_v1'

# Training parameters (optimized for 2-bit)
BATCH_SIZE = 2
STEPS_PER_MLP = 50
STEPS_PER_ATTN = 30
E2E_STEPS = 200

# Lower learning rates for 2-bit stability
LEARNING_RATE = 1e-6      # Down from 5e-6 for 4-bit
E2E_LEARNING_RATE = 5e-7  # Down from 1e-6 for 4-bit

# Relaxed convergence for 2-bit (4 levels is hard!)
LAYER_CONVERGE_THRESHOLD = 1.0  # Up from 0.4 for 4-bit
MAX_LAYER_REPEATS = 3           # Allow more attempts

# Calibration method for f_init
CALIBRATE_METHOD = 'mse_grid'  # 'mse_grid', 'newton', or 'percentile'

print(f"Config:")
print(f"  - Model: {MODEL_NAME}")
print(f"  - Quant bits: {QUANT_BITS}")
print(f"  - Init checkpoint: {INIT_CHECKPOINT}")
print(f"  - Cache: {CACHE_DIR_CHAT}")
print(f"  - Output: {RUN_DIR}")
print(f"  - Calibration: {CALIBRATE_METHOD}")

In [None]:
# ============================================================
# MOUNT GOOGLE DRIVE
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ============================================================
# LOAD 4-BIT CHECKPOINT FROM GOOGLE DRIVE
# ============================================================

import os

CHECKPOINT_TAR = 'qwen3_kdqat_cache_q2_4.tgz'
GD_CHECKPOINT = f'/content/drive/MyDrive/qwen3_caches/{CHECKPOINT_TAR}'

if os.path.exists(GD_CHECKPOINT):
    print(f"[load] Extracting 4-bit checkpoint from Google Drive...")
    !mkdir -p runs
    !tar -xzf {GD_CHECKPOINT} -C runs/
    
    # Verify
    if os.path.exists(INIT_CHECKPOINT):
        print(f"[load] Checkpoint ready: {INIT_CHECKPOINT}")
    else:
        print(f"[load] ERROR: Expected {INIT_CHECKPOINT} not found")
        !ls -la runs/
else:
    print(f"[load] ERROR: {GD_CHECKPOINT} not found")
    print("Please upload your 4-bit checkpoint to Google Drive first.")

In [None]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

import os

# Try K=64 first, fall back to K=32
CACHE_OPTIONS = [
    'alpaca_chat_think_both_L128_K64_R512',
    'alpaca_chat_think_both_L128_K32_R256',
]

GD_CACHE_DIR = '/content/drive/MyDrive/qwen3_caches'

cache_loaded = False
for cache_name in CACHE_OPTIONS:
    gd_cache_path = f"{GD_CACHE_DIR}/{cache_name}"
    if os.path.isdir(gd_cache_path):
        print(f"[cache] Found {cache_name}, copying...")
        !mkdir -p caches
        !rsync -ah --info=progress2 {gd_cache_path}/ caches/{cache_name}/
        CACHE_DIR_CHAT = f'caches/{cache_name}'
        cache_loaded = True
        break

if cache_loaded:
    print(f"[cache] Using: {CACHE_DIR_CHAT}")
else:
    print("[cache] ERROR: No KD cache found. Generate one first using Generate_KD_Cache_K64_K128.ipynb")

## Progressive QAT Training (Full 3-Pass)

This runs the complete progressive pipeline:
- **Pass 1**: MLP layers (local reconstruction + global KD)
- **Pass 2**: Attention layers (global KD only)
- **Pass 3**: MLP refinement (fixes MLP-attention coupling)
- **Pass 4**: E2E f-only tuning

In [None]:
# ============================================================
# PROGRESSIVE QAT TRAINING (Full 3-Pass with MSE Calibration)
# ============================================================
# This is the main training cell
# Expected time: ~2-3 hours on A100 for all 4 passes

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_DIR} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --quant_bits {QUANT_BITS} \
  --calibrate_f_init {CALIBRATE_METHOD} \
  --skip_lm_head \
  --batch_size {BATCH_SIZE} \
  --steps_per_layer_mlp {STEPS_PER_MLP} \
  --steps_per_layer_attn {STEPS_PER_ATTN} \
  --e2e_steps {E2E_STEPS} \
  --learning_rate {LEARNING_RATE} \
  --e2e_learning_rate {E2E_LEARNING_RATE} \
  --layer_converge_threshold {LAYER_CONVERGE_THRESHOLD} \
  --max_layer_repeats {MAX_LAYER_REPEATS} \
  --logging_steps 10 \
  --device auto

In [None]:
# ============================================================
# CHECK TRAINING RESULTS
# ============================================================

import os
import pandas as pd

loss_log = f"{RUN_DIR}/loss_per_layer.csv"
if os.path.exists(loss_log):
    df = pd.read_csv(loss_log)
    print("Loss summary by pass:")
    print(df.groupby(['pass', 'component'])['global'].agg(['min', 'max', 'mean']))
    
    # Final loss
    final_loss = df['global'].iloc[-1]
    print(f"\nFinal global loss: {final_loss:.4f}")
else:
    print(f"Loss log not found: {loss_log}")

## Optional: KD-QAT Refinement

If loss is still high (>1.5), run additional KD-QAT refinement:

In [None]:
# ============================================================
# OPTIONAL: KD-QAT REFINEMENT (if needed)
# ============================================================

RUN_KDQAT_REFINE = False  # Set to True to run
KDQAT_STEPS = 1000
KDQAT_LR = 1e-6
KDQAT_OUTPUT = 'runs/kdqat_refine_q2_improved_v1'

if RUN_KDQAT_REFINE:
    !python scripts/train_qat.py \
      --model_name_or_path {MODEL_NAME} \
      --init_model_state {RUN_DIR}/qat_state_dict.pt \
      --output_dir {KDQAT_OUTPUT} \
      --kd_cache_dir {CACHE_DIR_CHAT} \
      --quant_bits {QUANT_BITS} \
      --skip_lm_head \
      --batch_size 2 \
      --max_steps {KDQAT_STEPS} \
      --learning_rate {KDQAT_LR} \
      --logging_steps 50 \
      --save_steps 500 \
      --device auto
else:
    print("KD-QAT refinement skipped. Set RUN_KDQAT_REFINE = True to run.")

In [None]:
# ============================================================
# SAVE TO GOOGLE DRIVE
# ============================================================

import os

SAVE_NAME = 'progressive_qat_q2_improved_v1'
GD_DEST = f'/content/drive/MyDrive/qwen3_caches/{SAVE_NAME}.tgz'

if os.path.isdir(RUN_DIR):
    print(f"[save] Compressing {RUN_DIR}...")
    !tar -czvf {SAVE_NAME}.tgz -C runs {os.path.basename(RUN_DIR)}
    
    print(f"[save] Copying to Google Drive...")
    !cp {SAVE_NAME}.tgz {GD_DEST}
    
    if os.path.exists(GD_DEST):
        size_mb = os.path.getsize(GD_DEST) / (1024*1024)
        print(f"[save] Saved to {GD_DEST} ({size_mb:.1f} MB)")
    else:
        print(f"[save] ERROR: Failed to save")
else:
    print(f"[save] ERROR: {RUN_DIR} not found")

## Inference Test

In [None]:
# ============================================================
# INFERENCE TEST
# ============================================================

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('/content/qwen3_apple_style_2bit_qat_lora')

from qat_lora.model_utils import replace_linear_with_qat
from qat_lora.quantizer import QATQuantConfig

# Load model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)

# Apply QAT structure
qc = QATQuantConfig(n_bits=QUANT_BITS)
replace_linear_with_qat(model, qc=qc, exclude_regex=r"(^lm_head$)", verbose=False)

# Load trained weights
CHECKPOINT_TO_TEST = f"{RUN_DIR}/qat_state_dict.pt"
state_dict = torch.load(CHECKPOINT_TO_TEST, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model = model.to('cuda').eval()

print(f"Loaded: {CHECKPOINT_TO_TEST}")

In [None]:
# ============================================================
# GENERATE TEST
# ============================================================

def generate(prompt, max_new_tokens=100, temperature=0.7):
    messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Test prompts
prompts = [
    "What is 2 + 2?",
    "Explain quantum computing in simple terms.",
    "Write a haiku about programming.",
]

for p in prompts:
    print(f"\n{'='*60}")
    print(f"Prompt: {p}")
    print(f"{'='*60}")
    print(generate(p))

## Summary

**Improvements in this notebook:**

| Feature | Previous | This Notebook |
|---------|----------|---------------|
| f_init calibration | Newton heuristic | MSE Grid-Search |
| Progressive passes | MLP only | MLP + Attn + Refine |
| Converge threshold | 0.4 | 1.0 (relaxed) |
| Learning rate | 5e-6 | 1e-6 (lower) |
| Max repeats | 1 | 3 |

**If results are still not satisfactory, try:**
1. Use K=128 cache (generate with `Generate_KD_Cache_K64_K128.ipynb`)
2. Add EMA (future notebook)
3. Mixed precision (embed 4-bit, rest 2-bit)
4. LoRA with error-initialized weights