# Qwen3 4-bit KD-QAT Refinement + LoRA Recovery

This notebook refines a 4-bit QAT checkpoint using K=64 and K=128 KD caches.

**Pipeline:**
```
4-bit QAT checkpoint
    |
    v
Stage 1: KD-QAT with K=64 (warm-up)
    |
    v
Stage 2: KD-QAT with K=128 (polish)
    |
    v
Inference Test
    |
    v
Stage 3: LoRA Recovery
    |
    v
Final Inference Test
```

**Key Features:**
- Progressive K: K=64 â†’ K=128 for refined teacher signal
- Unfrozen attention in later stages
- Relaxed hard-top1 weights for better convergence
- LoRA recovery for final quality boost

In [1]:
# ============================================================
# SETUP
# ============================================================

!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git
%cd qwen3_apple_style_2bit_qat_lora
!pip install -q transformers accelerate datasets sentencepiece

Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 283, done.[K
remote: Counting objects: 100% (283/283), done.[K
remote: Compressing objects: 100% (207/207), done.[K
remote: Total 283 (delta 172), reused 165 (delta 73), pack-reused 0 (from 0)[K
Receiving objects: 100% (283/283), 369.92 KiB | 5.52 MiB/s, done.
Resolving deltas: 100% (172/172), done.
/content/qwen3_apple_style_2bit_qat_lora


In [3]:
# ============================================================
# CONFIG
# ============================================================

MODEL_NAME = 'Qwen/Qwen3-0.6B'

# Quantization
QUANT_BITS = 4

# Device settings
DEVICE = 'auto'
AMP_DTYPE = 'auto'
PARAM_DTYPE = 'auto'

# KD Caches (progressive: K64 -> K128)
CACHE_K64 = 'caches/alpaca_chat_think_both_L128_K64_R512'
CACHE_K128 = 'caches/alpaca_chat_think_both_L128_K128_R512'

# Input checkpoint (your trained 4-bit QAT)
INIT_CHECKPOINT = 'runs/qwen3_kdqat_cache_q4/qat_state_dict.pt'

# Output directories
RUN_K64 = 'runs/qwen3_q4_kd_k64'
RUN_K128 = 'runs/qwen3_q4_kd_k128'
RUN_LORA = 'runs/qwen3_q4_lora_recovery'

# Training parameters
BATCH_SIZE = 64
GRAD_ACCUM = 1

print(f"Config:")
print(f"  - Model: {MODEL_NAME}")
print(f"  - Quant bits: {QUANT_BITS}")
print(f"  - Init checkpoint: {INIT_CHECKPOINT}")
print(f"  - Caches: K64={CACHE_K64}, K128={CACHE_K128}")

Config:
  - Model: Qwen/Qwen3-0.6B
  - Quant bits: 4
  - Init checkpoint: runs/qwen3_kdqat_cache_q4/qat_state_dict.pt
  - Caches: K64=caches/alpaca_chat_think_both_L128_K64_R512, K128=caches/alpaca_chat_think_both_L128_K128_R512


In [4]:
# ============================================================
# MOUNT GOOGLE DRIVE
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
# ============================================================
# LOAD CHECKPOINT FROM GOOGLE DRIVE
# ============================================================

import os

GD_BASE = '/content/drive/MyDrive/qwen3_runs'

# Load 4-bit checkpoint
CHECKPOINT_TAR = 'qwen3_kdqat_cache_q2_4.tgz'  # Adjust name as needed
GD_CHECKPOINT = f'{GD_BASE}/{CHECKPOINT_TAR}'

if os.path.exists(GD_CHECKPOINT):
    print(f"[load] Extracting checkpoint from Google Drive...")
    !mkdir -p runs
    !tar -xzf {GD_CHECKPOINT} -C runs/
    print(f"[load] Done.")
else:
    print(f"[load] Checkpoint not found: {GD_CHECKPOINT}")
    print("Please upload your 4-bit checkpoint or adjust the path.")

[load] Extracting checkpoint from Google Drive...
[load] Done.


In [8]:
# ============================================================
# LOAD KD CACHES FROM GOOGLE DRIVE
# ============================================================

import os

GD_BASE = '/content/drive/MyDrive/qwen3_caches'

!mkdir -p caches

# Load K=64 cache
cache_name_64 = 'alpaca_chat_think_both_L128_K64_R512'
gd_cache_64 = f"{GD_BASE}/{cache_name_64}"
if os.path.isdir(gd_cache_64):
    print(f"[cache] Copying K=64 cache...")
    !rsync -ah --info=progress2 {gd_cache_64}/ caches/{cache_name_64}/
    print(f"[cache] K=64 ready: {CACHE_K64}")
else:
    print(f"[cache] K=64 not found at {gd_cache_64}")

# Load K=128 cache
cache_name_128 = 'alpaca_chat_think_both_L128_K128_R512'
gd_cache_128 = f"{GD_BASE}/{cache_name_128}"
if os.path.isdir(gd_cache_128):
    print(f"[cache] Copying K=128 cache...")
    !rsync -ah --info=progress2 {gd_cache_128}/ caches/{cache_name_128}/
    print(f"[cache] K=128 ready: {CACHE_K128}")
else:
    print(f"[cache] K=128 not found at {gd_cache_128}")

[cache] Copying K=64 cache...
          8.79G 100%   85.37MB/s    0:01:38 (xfr#21, to-chk=0/22)
[cache] K=64 ready: caches/alpaca_chat_think_both_L128_K64_R512
[cache] K=128 not found at /content/drive/MyDrive/qwen3_caches/alpaca_chat_think_both_L128_K128_R512


## Stage 1: KD-QAT with K=64

Warm-up refinement with K=64 cache. Uses relaxed hard-top1 weights.

In [None]:
# ============================================================
# STAGE 1: KD-QAT with K=64
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_K64} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size {BATCH_SIZE} \
  --gradient_accumulation_steps {GRAD_ACCUM} \
  --learning_rate 3e-6 \
  --warmup_steps 50 \
  --max_steps 1500 \
  --save_steps 1500 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_K64} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.01 \
  --hard-full-top1-weight 0.005

print(f"\n[Stage 1] K=64 refinement complete. Checkpoint: {RUN_K64}")

## Stage 2: KD-QAT with K=128 (Polish)

Final polish with K=128 cache for maximum teacher signal. Lower learning rate.

In [None]:
# ============================================================
# STAGE 2: KD-QAT with K=128 (Polish)
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {RUN_K64}/qat_state_dict.pt \
  --output_dir {RUN_K128} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size {BATCH_SIZE} \
  --gradient_accumulation_steps {GRAD_ACCUM} \
  --learning_rate 1e-6 \
  --warmup_steps 0 \
  --max_steps 500 \
  --save_steps 500 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_K128} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.005 \
  --hard-full-top1-weight 0.002

print(f"\n[Stage 2] K=128 polish complete. Checkpoint: {RUN_K128}")

## Inference Test (Pre-LoRA)

In [None]:
# ============================================================
# INFERENCE TEST (Pre-LoRA)
# ============================================================

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('/content/qwen3_apple_style_2bit_qat_lora')

from qat_lora.model_utils import replace_linear_with_qat
from qat_lora.quantizer import QATQuantConfig

# Load model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)

# Apply QAT structure
qc = QATQuantConfig(n_bits=QUANT_BITS)
replace_linear_with_qat(model, qc=qc, exclude_regex=r"(^lm_head$)", verbose=False)

# Load trained weights (K=128 checkpoint)
CHECKPOINT_TO_TEST = f"{RUN_K128}/qat_state_dict.pt"
state_dict = torch.load(CHECKPOINT_TO_TEST, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model = model.to('cuda').eval()

print(f"Loaded: {CHECKPOINT_TO_TEST}")

In [None]:
# ============================================================
# GENERATE TEST (Pre-LoRA)
# ============================================================

def generate(prompt, max_new_tokens=100, temperature=0.7):
    messages = [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to('cuda')

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Test prompts
prompts = [
    "What is 2 + 2?",
    "Explain quantum computing in simple terms.",
    "Write a haiku about programming.",
]

print("=" * 60)
print("PRE-LORA INFERENCE TEST")
print("=" * 60)

for p in prompts:
    print(f"\nPrompt: {p}")
    print("-" * 40)
    print(generate(p))

In [None]:
# Free memory before LoRA training
del model
torch.cuda.empty_cache()

## Stage 3: LoRA Recovery

Train LoRA adapters on top of the refined QAT checkpoint to recover any remaining quality loss.

In [None]:
# ============================================================
# STAGE 3: LoRA Recovery
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

LORA_R = 32
LORA_ALPHA = 32

!python scripts/train_lora_recovery.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_K128}/qat_state_dict.pt \
  --output_dir {RUN_LORA} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --per_device_train_batch_size 16 \
  --gradient_accumulation_steps 2 \
  --learning_rate 1e-5 \
  --warmup_steps 50 \
  --max_steps 2000 \
  --save_steps 2000 \
  --logging_steps 10 \
  --skip_lm_head \
  --lora_r {LORA_R} \
  --lora_alpha {LORA_ALPHA} \
  --lora_dropout 0.0 \
  --kd_cache_dir {CACHE_K128} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.02 \
  --hard-full-top1-weight 0.01

print(f"\n[Stage 3] LoRA recovery complete. Checkpoint: {RUN_LORA}")

## Final Inference Test (With LoRA)

In [None]:
# ============================================================
# FINAL INFERENCE TEST (With LoRA)
# ============================================================

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
sys.path.append('/content/qwen3_apple_style_2bit_qat_lora')

from qat_lora.model_utils import replace_linear_with_qat
from qat_lora.quantizer import QATQuantConfig

# Load model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)

# Apply QAT structure with LoRA
qc = QATQuantConfig(n_bits=QUANT_BITS)
replace_linear_with_qat(
    model,
    qc=qc,
    exclude_regex=r"(^lm_head$)",
    lora_r=LORA_R,
    lora_alpha=LORA_ALPHA,
    verbose=False
)

# Load trained weights (LoRA checkpoint)
CHECKPOINT_TO_TEST = f"{RUN_LORA}/qat_lora_state_dict.pt"
state_dict = torch.load(CHECKPOINT_TO_TEST, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
model = model.to('cuda').eval()

print(f"Loaded: {CHECKPOINT_TO_TEST}")

In [None]:
# ============================================================
# GENERATE TEST (With LoRA)
# ============================================================

print("=" * 60)
print("POST-LORA INFERENCE TEST")
print("=" * 60)

for p in prompts:
    print(f"\nPrompt: {p}")
    print("-" * 40)
    print(generate(p))

## Save to Google Drive

In [None]:
# ============================================================
# SAVE ALL CHECKPOINTS TO GOOGLE DRIVE
# ============================================================

import os

GD_DEST = '/content/drive/MyDrive/qwen3_caches'

# Save K=64 checkpoint
if os.path.isdir(RUN_K64):
    save_name = os.path.basename(RUN_K64)
    !tar -czvf {save_name}.tgz -C runs {save_name}
    !cp {save_name}.tgz {GD_DEST}/
    print(f"[save] K=64 checkpoint saved to {GD_DEST}/{save_name}.tgz")

# Save K=128 checkpoint
if os.path.isdir(RUN_K128):
    save_name = os.path.basename(RUN_K128)
    !tar -czvf {save_name}.tgz -C runs {save_name}
    !cp {save_name}.tgz {GD_DEST}/
    print(f"[save] K=128 checkpoint saved to {GD_DEST}/{save_name}.tgz")

# Save LoRA checkpoint
if os.path.isdir(RUN_LORA):
    save_name = os.path.basename(RUN_LORA)
    !tar -czvf {save_name}.tgz -C runs {save_name}
    !cp {save_name}.tgz {GD_DEST}/
    print(f"[save] LoRA checkpoint saved to {GD_DEST}/{save_name}.tgz")

print("\n[save] All checkpoints saved!")

## Summary

**Pipeline completed:**

| Stage | Cache | Steps | Output |
|-------|-------|-------|--------|
| 1. KD-QAT warm-up | K=64 | 1500 | `runs/qwen3_q4_kd_k64` |
| 2. KD-QAT polish | K=128 | 500 | `runs/qwen3_q4_kd_k128` |
| 3. LoRA recovery | K=128 | 2000 | `runs/qwen3_q4_lora_recovery` |

**Expected loss progression:**
- Initial 4-bit: ~0.5
- After K=64: ~0.4
- After K=128: ~0.35
- With LoRA: ~0.25-0.3

**Next steps:**
- Evaluate on downstream benchmarks
- Export for deployment (quantize LoRA weights if needed)