<a href="https://colab.research.google.com/github/Anemll/qwen3_apple_style_2bit_qat_lora/blob/main/notebooks/Qwen3_QAT_KD_LoRA-per-layer-2bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Qwen3-0.6B — Apple-style QAT (2-bit / 4-bit) + KD + LoRA recovery

This notebook mirrors the structure of common “phone deployment” notebooks, but uses **this repo’s** pipeline:

- **Stage A (recommended default):** KD-QAT on plain text (C4 streaming) or KD-cache QAT
- **Stage B:** LoRA recovery (either SFT or cached KD-LoRA)
- Plot `loss.csv`
- Run inference sanity checks

Notes:
- Qwen3 requires `transformers>=4.51.0`.
- For disk usage: C4 is huge; prefer `--streaming` unless you explicitly want to download.
- Bitwidth: use `-q 2` (default) or `-q 4` (less aggressive). Checkpoints persist the bitwidth per layer.


## 0) Setup (Colab / local)

If you’re in Colab, clone the repo. If you’re already in the repo directory locally, you can skip this.

In [1]:
# ---- Config (edit these) ----
#MODEL_NAME = 'Qwen/Qwen3-4B-Thinking-2507'
MODEL_NAME = 'Qwen/Qwen3-0.6B'
TEACHER_NAME = MODEL_NAME
QUANT_BITS = 2  # 2 or 4
DEVICE = 'auto'
AMP_DTYPE = 'auto'
PARAM_DTYPE = 'auto'
DTYPE = 'auto'

# Cache dirs
CACHE_DIR_CHAT = 'caches/alpaca_chat_think_both_L128_K32_R256'
CACHE_DIR_TEXT = 'caches/c4_qwen3_L64_K32_R256'
#CACHE_DIR_CHAT = 'caches/Q4B_alpaca_chat_think_L128_K32_R256'
#CACHE_DIR_TEXT = 'caches/Q4B_c4_qwen3_L64_K32_R256'


In [2]:
# Colab-only:
%cd /content/
!git clone https://github.com/Anemll/qwen3_apple_style_2bit_qat_lora
%cd qwen3_apple_style_2bit_qat_lora
!git fetch
!git pull
!git reset --hard HEAD


/content
Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 228, done.[K
remote: Counting objects: 100% (228/228), done.[K
remote: Compressing objects: 100% (162/162), done.[K
remote: Total 228 (delta 131), reused 144 (delta 63), pack-reused 0 (from 0)[K
Receiving objects: 100% (228/228), 281.62 KiB | 9.08 MiB/s, done.
Resolving deltas: 100% (131/131), done.
/content/qwen3_apple_style_2bit_qat_lora
Already up to date.
HEAD is now at 69a5e58 Update layer-by-layer for 4 point checkpoint


## 1) Install dependencies (uv)

This repo is set up to work with `uv`.

In [3]:
!pip -q install uv
!uv pip install -r requirements.txt
!uv pip install -e .
# plotting
!uv pip install -q matplotlib
!uv pip install -q plot


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m123.5 MB/s[0m eta [36m0:00:00[0m
[?25h[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m7 packages[0m [2min 133ms[0m[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m62 packages[0m [2min 419ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 869ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 1ms[0m[0m
 [32m+[39m [1mqat-lora[0m[2m==0.0.0 (from file:///content/qwen3_apple_style_2bit_qat_lora)[0m


## 2) Optional: Hugging Face login

If you hit gated model/dataset errors, log in.

In [None]:
from huggingface_hub import login
login()  # paste token when prompted


## 3) Quick environment check

In [4]:
import torch, transformers
print('torch', torch.__version__)
print('transformers', transformers.__version__)
print('cuda', torch.cuda.is_available())
print('mps', torch.backends.mps.is_available())


torch 2.9.0+cu126
transformers 4.57.3
cuda True
mps False


## 4) Stage A ((Optional)): KD-QAT on streaming C4

This preserves the base model’s behavior under low-bit fake-quant weights.

Tips:
- Start with a small run (`--max_steps 50`) to validate the pipeline.
- Use `-q 4` if 2-bit is too unstable; 4-bit is less aggressive.
- On MPS, prefer `--ema_decay 0` for KD-QAT.


In [None]:
# ============================================================
# GENERATE THINKING DATASET (Alpaca chat format)
# ============================================================
# SKIP THIS CELL if you already have the cache on Google Drive!
# Use the "LOAD FROM GOOGLE DRIVE" cell instead.

!python scripts/precompute_teacher_topk.py \
  --teacher_model_name_or_path {MODEL_NAME} \
  --dataset_name tatsu-lab/alpaca \
  --dataset_split train \
  --dataset_format alpaca_chat \
  --enable_thinking true \
  --max_length 128 \
  --topk 32 \
  --rand_neg 256 \
  --num_sequences 20000 \
  --batch_size 1 \
  --shard_size 512 \
  --device {DEVICE} \
  --dtype {DTYPE} \
  --output_dir {CACHE_DIR_CHAT}

[device] cuda | dtype=torch.bfloat16
tokenizer_config.json: 9.73kB [00:00, 39.0MB/s]
vocab.json: 2.78MB [00:00, 65.7MB/s]
merges.txt: 1.67MB [00:00, 132MB/s]
tokenizer.json: 100% 11.4M/11.4M [00:00<00:00, 20.4MB/s]
config.json: 100% 726/726 [00:00<00:00, 6.09MB/s]
2025-12-24 01:22:39.394668: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-24 01:22:39.417464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766539359.443242    8009 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766539359.449385    8009 cuda_blas.cc:1407]

In [None]:
# ============================================================
# GENERATE TEXT DATASET (C4 streaming)
# ============================================================
# SKIP THIS CELL if you already have the cache on Google Drive!
# Use the "LOAD FROM GOOGLE DRIVE" cell instead.

import os

CACHE_DIR = CACHE_DIR_TEXT

if not os.path.isdir(CACHE_DIR):
    print(f"[cache] {CACHE_DIR} not found -> generating cache...")

    !python scripts/precompute_teacher_topk.py \
      --teacher_model_name_or_path {MODEL_NAME} \
      --dataset_name allenai/c4 \
      --dataset_config_name en \
      --dataset_split train \
      --dataset_text_field text \
      --streaming \
      --shuffle_buffer 10000 \
      --max_length 64 \
      --topk 32 \
      --rand_neg 256 \
      --num_sequences 2000 \
      --batch_size 1 \
      --shard_size 512 \
      --device {DEVICE} \
      --dtype {DTYPE} \
      --output_dir {CACHE_DIR}

else:
    print(f"[cache] {CACHE_DIR} already exists -> skipping generation")

In [None]:
# ============================================================
# COMPRESS CHAT CACHE (for Google Drive upload)
# ============================================================
# SKIP if cache is already compressed or loaded from Google Drive

import os

if os.path.isdir(CACHE_DIR_CHAT):
    print(f"[gzip] Compressing {CACHE_DIR_CHAT}...")
    !tar -zcvf {CACHE_DIR_CHAT}.tgz {CACHE_DIR_CHAT}
    compressed_size = os.path.getsize(f"{CACHE_DIR_CHAT}.tgz")
    print(f"[gzip] Done: {compressed_size / (1024**3):.2f} GB")
else:
    print(f"[gzip] Directory {CACHE_DIR_CHAT} not found. Skipping.")

[gzip] Compressing caches/alpaca_chat_think_L128_K32_R256...
caches/alpaca_chat_think_L128_K32_R256/
caches/alpaca_chat_think_L128_K32_R256/shard_00018.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00011.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00033.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00031.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00036.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00038.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00030.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00000.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00032.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00010.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00039.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00008.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00005.pt
caches/alpaca_chat_think_L128_K32_R256/shard_00025.pt
caches/alpaca_chat_think_L128_K32_R256/meta.json
caches/alpaca_chat_think_L128_K32_R256/shard_00015.pt
caches/alpaca_chat_think_L128_K32_R256/s

In [None]:
# ============================================================
# SAVE CACHED KD DATA TO GOOGLE DRIVE (run after generating cache)
# ============================================================
# This saves the generated cache to Google Drive for future sessions
# Only run this AFTER you've generated the cache with precompute_teacher_topk.py

from google.colab import drive
drive.mount('/content/drive')

# Create destination directory
!mkdir -p /content/drive/MyDrive/qwen3_caches

# Choose which cache to save (should match what you generated)
CACHE_NAME = "alpaca_chat_think_both_L128_K32_R256"

import os

# Check if cache exists - copy folder directly (no compression needed)
if os.path.isdir(f"caches/{CACHE_NAME}"):
    # Copy folder to Google Drive
    print(f"[save] Copying {CACHE_NAME} to Google Drive...")
    !rsync -ah --info=progress2 caches/{CACHE_NAME}/ /content/drive/MyDrive/qwen3_caches/{CACHE_NAME}/

    # Verify
    num_shards = len([f for f in os.listdir(f"/content/drive/MyDrive/qwen3_caches/{CACHE_NAME}") if f.startswith("shard_")])
    print(f"[save] Saved to Google Drive: {num_shards} shards")
else:
    print(f"[save] ERROR: Cache directory caches/{CACHE_NAME} not found")
    print("[save] Run precompute_teacher_topk.py first to generate the cache")


## 4.5) Google Drive Cache Management

**Workflow for KD Cache:**

1. **First time setup** (slow):
   - Run `precompute_teacher_topk.py` to generate cache
   - Run "SAVE TO GOOGLE DRIVE" cell to persist
   
2. **Subsequent sessions** (fast):
   - Run "LOAD FROM GOOGLE DRIVE" cell to restore cache
   - Skip cache generation step

The cached KD data (~2-3 GB compressed) contains precomputed teacher logits for knowledge distillation training.

In [None]:
# ============================================================
# COMPRESS TEXT CACHE (for Google Drive upload)
# ============================================================
# SKIP if cache is already compressed or loaded from Google Drive

import os

if os.path.isdir(CACHE_DIR_TEXT):
    print(f"[gzip] Compressing {CACHE_DIR_TEXT}...")
    !tar -zcvf {CACHE_DIR_TEXT}.tgz {CACHE_DIR_TEXT}
    compressed_size = os.path.getsize(f"{CACHE_DIR_TEXT}.tgz")
    print(f"[gzip] Done: {compressed_size / (1024**3):.2f} GB")
else:
    print(f"[gzip] Directory {CACHE_DIR_TEXT} not found. Skipping.")

#### (!!) LOAD CACHED KD DATA FROM GOOGLE DRIVE


In [5]:
# ============================================================
# LOAD CACHED KD DATA FROM GOOGLE DRIVE (run this cell first!)
# ============================================================
# Mount Google Drive and copy cached KD data back to local storage
# This avoids regenerating the cache every session

from google.colab import drive
drive.mount('/content/drive')

# Create local cache directory
!mkdir -p caches

# Cache folder to load (copy folder directly, no .tgz)
CACHE_NAME = "alpaca_chat_think_both_L128_K32_R256"

# Copy folder directly from Google Drive
SRC_PATH = f"/content/drive/MyDrive/qwen3_caches/{CACHE_NAME}"
DST_PATH = f"caches/{CACHE_NAME}"

print(f"[cache] Copying {CACHE_NAME} from Google Drive...")
!rsync -ah --info=progress2 {SRC_PATH}/ {DST_PATH}/

# Verify copy
import os
if os.path.isdir(DST_PATH):
    num_shards = len([f for f in os.listdir(DST_PATH) if f.startswith("shard_")])
    print(f"[cache] Successfully loaded {CACHE_NAME} with {num_shards} shards")
else:
    print(f"[cache] ERROR: Failed to copy {CACHE_NAME}")


Mounted at /content/drive
[cache] Copying alpaca_chat_think_both_L128_K32_R256 from Google Drive...
          4.40G 100%   25.73MB/s    0:02:43 (xfr#41, to-chk=0/42)
[cache] Successfully loaded alpaca_chat_think_both_L128_K32_R256 with 40 shards


In [6]:
# ============================================================
# LOAD 4-BIT CHECKPOINT FROM GOOGLE DRIVE (for 2-bit initialization)
# ============================================================
# Copy the 4-bit trained checkpoint to use as starting point for 2-bit training

import os

# Create runs directory
!mkdir -p runs

# 4-bit checkpoint to load (best result from 4-bit training)
CHECKPOINT_NAME = "qwen3_kdqat_cache_q2_4"
SRC_PATH = f"/content/drive/MyDrive/qwen3_runs/{CHECKPOINT_NAME}.tgz"
DST_PATH = f"runs/{CHECKPOINT_NAME}.tgz"

# Copy from Google Drive
print(f"[checkpoint] Copying {CHECKPOINT_NAME}.tgz from Google Drive...")
!rsync -ah --info=progress2 {SRC_PATH} {DST_PATH}

# Check tarball structure first
print(f"[checkpoint] Checking tarball structure...")
!tar -tzf {DST_PATH} | head -5

# Extract to runs/ directory (tarball contains folder without runs/ prefix)
print(f"[checkpoint] Extracting {CHECKPOINT_NAME}.tgz...")
!tar -xzf {DST_PATH} -C runs/

# Verify extraction
if os.path.isdir(f"runs/{CHECKPOINT_NAME}"):
    files = os.listdir(f"runs/{CHECKPOINT_NAME}")
    print(f"[checkpoint] Successfully loaded {CHECKPOINT_NAME} with {len(files)} files:")
    for f in sorted(files)[:5]:
        print(f"  - {f}")
    if len(files) > 5:
        print(f"  ... and {len(files)-5} more")
else:
    # Try to find where it extracted
    print(f"[checkpoint] Checking runs/ directory...")
    !ls -la runs/

[checkpoint] Copying qwen3_kdqat_cache_q2_4.tgz from Google Drive...
              0   0%    0.00kB/s    0:00:00 (xfr#0, to-chk=0/1)
[checkpoint] Checking tarball structure...
qwen3_kdqat_cache_q2_4/
qwen3_kdqat_cache_q2_4/special_tokens_map.json
qwen3_kdqat_cache_q2_4/loss.csv
qwen3_kdqat_cache_q2_4/added_tokens.json
qwen3_kdqat_cache_q2_4/tokenizer_config.json
[checkpoint] Extracting qwen3_kdqat_cache_q2_4.tgz...
[checkpoint] Successfully loaded qwen3_kdqat_cache_q2_4 with 12 files:
  - added_tokens.json
  - chat_template.jinja
  - final_state_dict.pt
  - loss.csv
  - merges.txt
  ... and 7 more


In [None]:
RUN_DIR = "runs/qwen3_kdqat_stream_q2"

# DISABLED --- NOTE used! see #5 for first QAT step
# Construct the command string in Python to ensure variable interpolation
command_str = f"""python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --teacher_model_name_or_path {MODEL_NAME} \
  --distill_weight 1.0 \
  --distill_temperature 2.0 \
  --dataset_name allenai/c4 \
  --dataset_config_name en \
  --dataset_split train \
  --dataset_format text \
  --dataset_text_field text \
  --streaming \
  --shuffle_buffer 10000 \
  --output_dir {RUN_DIR} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 16 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 50 \
  --skip_lm_head \
  --ema_decay 0 \
  --logging_steps 10 \
  --save_steps 50"""

# Execute the constructed command string
!{command_str}


### (Optional) Resume

`--resume_from_checkpoint auto` resolves to `checkpoint_last.pt` if it exists in the output directory.

In [None]:
# !python scripts/train_qat.py ... --output_dir {RUN_DIR} --max_steps 500 --resume_from_checkpoint auto


## 5)  KD-cache: precompute teacher top-k + negatives

Cache mode is MPS-friendly:
- no teacher model during training
- no full-vocab logits

If you see good KD loss but bad greedy decoding, increase negative coverage (`--rand_neg`) and/or add hard top-1 terms:
- `--hard-top1-weight 0.05`
- `--hard-full-top1-weight 0.02`–`0.05`

### KD-cache QAT training

This uses cached teacher signals + candidate softmax.

In [None]:
# ============================================================
# STAGE 1: KD-QAT (Conservative - freeze MLP/Attention)
# ============================================================
# First training stage with frozen output layers for stability

%pwd
%cd /content/qwen3_apple_style_2bit_qat_lora

CACHE_DIR = CACHE_DIR_CHAT
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 4 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 5 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.05 \
  --hard-full-top1-weight 0.03 \
  --ov-freeze \
  --freeze-last-mlp \
  --freeze-last-mlp-layers 1

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --prompt "What capital city of France is?" \
  --do_sample true \
  --max_new_tokens 64

#  --prompt "What is Capital of france?" \
#   --prompt "What is Apple Neural Engine?" \


In [None]:
# Define source and destination paths
SOURCE_FILE = "runs/qwen3_kdqat_cache_q2_2/qat_state_dict.pt"
DEST_DIR_GD = "/content/drive/MyDrive/runs/Q4B/q2_2/"

# Ensure the destination directory exists on Google Drive
!mkdir -p {DEST_DIR_GD}

# Copy the file to Google Drive
!cp -v {SOURCE_FILE} {DEST_DIR_GD}
print(f"Copied {SOURCE_FILE} to {DEST_DIR_GD}")

In [None]:
# ============================================================
# STAGE 2: KD-QAT (Unfrozen layers, resume from Stage 1)
# ============================================================
# Continue training with all layers unfrozen

%cd /content/qwen3_apple_style_2bit_qat_lora

CACHE_DIR = CACHE_DIR_CHAT
INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_2"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 4 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --ov-freeze \
  --hard-top1-weight 0.02 \
  --hard-full-top1-weight 0.01

### (!!!!)Stage 3 resume KD-QAT with

---



In [12]:
# Stage 3 resume KD-QAT with unfrozen attention and relaxed hard-top/full!

CACHE_DIR = CACHE_DIR_CHAT  # Use config variable
#INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_2"
INIT_DIR_CACHE =  "runs/progressive_qat_q2_v1"

RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_3"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 128 \
  --gradient_accumulation_steps 1 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 3000 \
  --save_steps 3000 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.00 \
  --hard-full-top1-weight 0.0005


2025-12-24 23:29:24.155008: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-24 23:29:24.171198: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766618964.192763   82536 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766618964.199201   82536 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766618964.215547   82536 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

#save run


In [13]:
# ============================================================
# SAVE PROGRESSIVE QAT CHECKPOINT TO GOOGLE DRIVE
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

from google.colab import drive
import os

# Mount Google Drive if not already mounted
drive.mount('/content/drive')

# Source directory (matches RUN_DIR_PROGRESSIVE from config)
RUN_NAME = "qwen3_kdqat_cache_q2_3"
RUN_DIR = f"runs/{RUN_NAME}"

# Destination on Google Drive
DEST_DIR_GD = "/content/drive/MyDrive/qwen3_runs/"
!mkdir -p {DEST_DIR_GD}

# Check if run directory exists and has content
if os.path.isdir(RUN_DIR) and os.listdir(RUN_DIR):
    # Compress the run directory
    print(f"[archive] Compressing {RUN_DIR}...")
    !tar -zcvf {RUN_NAME}.tgz -C runs {RUN_NAME}

    # Copy to Google Drive
    print(f"[save] Copying {RUN_NAME}.tgz to Google Drive...")
    !rsync -ah --info=progress2 {RUN_NAME}.tgz {DEST_DIR_GD}

    # Verify
    gd_size = os.path.getsize(f"{DEST_DIR_GD}/{RUN_NAME}.tgz")
    print(f"[save] Saved to Google Drive: {gd_size / (1024**3):.2f} GB")
else:
    print(f"[save] ERROR: {RUN_DIR} is empty or doesn't exist")

# got to "Stage 3 resume" to continue distill treaining

/content/qwen3_apple_style_2bit_qat_lora
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[archive] Compressing runs/qwen3_kdqat_cache_q2_3...
qwen3_kdqat_cache_q2_3/
qwen3_kdqat_cache_q2_3/special_tokens_map.json
qwen3_kdqat_cache_q2_3/loss.csv
qwen3_kdqat_cache_q2_3/added_tokens.json
qwen3_kdqat_cache_q2_3/tokenizer_config.json
qwen3_kdqat_cache_q2_3/run_state.json
qwen3_kdqat_cache_q2_3/merges.txt
qwen3_kdqat_cache_q2_3/training_args.json
qwen3_kdqat_cache_q2_3/chat_template.jinja
qwen3_kdqat_cache_q2_3/vocab.json
qwen3_kdqat_cache_q2_3/final_state_dict.pt
qwen3_kdqat_cache_q2_3/qat_state_dict.pt
qwen3_kdqat_cache_q2_3/checkpoint_step3000.pt
qwen3_kdqat_cache_q2_3/checkpoint_last.pt
qwen3_kdqat_cache_q2_3/tokenizer.json
[save] Copying qwen3_kdqat_cache_q2_3.tgz to Google Drive...
          7.46G 100%  332.14MB/s    0:00:21 (xfr#1, to-chk=0/1)
[save] Saved to Google Drive: 6.95 GB


#### Pull and Unzip Progressive QAT Checkpoint

In [6]:
import os
from google.colab import drive

drive.mount('/content/drive')

# Define the checkpoint name to pull
RUN_NAME = "qwen3_kdqat_cache_q2_3"

# Source path on Google Drive
SRC_PATH_GD = f"/content/drive/MyDrive/qwen3_runs/{RUN_NAME}.tgz"
# Destination path locally
DST_PATH_LOCAL = f"{RUN_NAME}.tgz"

# Create runs directory if it doesn't exist
!mkdir -p runs

print(f"[pull] Copying {RUN_NAME}.tgz from Google Drive...")
!rsync -ah --info=progress2 {SRC_PATH_GD} {DST_PATH_LOCAL}

# Check if the tarball was copied successfully
if os.path.exists(DST_PATH_LOCAL):
    print(f"[pull] Extracting {RUN_NAME}.tgz...")
    !tar -xzf {DST_PATH_LOCAL} -C runs/
    print(f"[pull] Successfully extracted to runs/{RUN_NAME}")

    # Optionally, remove the tarball after extraction to save space
    # !rm {DST_PATH_LOCAL}
else:
    print(f"[pull] ERROR: {RUN_NAME}.tgz not found on Google Drive. Make sure it was saved correctly.")

# Use the path where the checkpoint was unzipped to the inference check code cell
#(4coakmebsik). This comment clarifies that the TEST_RUN variable should point to the d
#irectory where the QAT checkpoint was extracted for inference.

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[pull] Copying qwen3_kdqat_cache_q2_3.tgz from Google Drive...
          7.46G 100%  136.13MB/s    0:00:52 (xfr#1, to-chk=0/1)
[pull] Extracting qwen3_kdqat_cache_q2_3.tgz...
[pull] Successfully extracted to runs/qwen3_kdqat_cache_q2_3


In [None]:
# Stage 4 resume KD-QAT with unfrozen attention and relaxed hard-top/full!
#   --hard-full-top1-weight 0.0000
#   learning_rate 2e-6
CACHE_DIR = CACHE_DIR_CHAT  # Use config variable
INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_3"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_4"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 160 \
  --gradient_accumulation_steps 1 \
  --learning_rate 2e-6 \
  --warmup_steps 0 \
  --max_steps 500 \
  --save_steps 3000 \
  --logging_steps 5 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.00 \
  --hard-full-top1-weight 0.0000


## 5.5) Progressive Layer-by-Layer QAT (Experimental)

This approach trains one layer at a time with:
- **Local reconstruction loss**: MSE between quantized and fp MLP outputs
- **Global KD loss**: Cached teacher logits
- **Prefix quantized / suffix fp**: Earlier layers stay quantized, later layers use full precision

### Recommended Training Order (most stable first):

1. **E2E f-only** (Option 1): Train ALL f parameters at once
   - Skip progressive passes, just run Pass 4
   - Most stable, fastest validation
   
2. **Progressive f-only** (Option 2): Layer-by-layer f-param training
   - Uses `--train_f_only` flag
   - Disable local loss with `--local_weight 0.0`
   
3. **Full progressive** (Option 3): Train weights + f per layer
   - Most aggressive, may show instability at later layers

### GPU Configuration:

| GPU | Recommended batch_size |
|-----|------------------------|
| T4 (15GB) | 2-4 |
| V100 (32GB) | 4-8 |
| A100 (40GB) | 8-16 |
| A100 (80GB) / H100 | 16-32 |

In [8]:
# ---- Progressive QAT Config (2-bit from 4-bit checkpoint) ----
# Starting from 4-bit trained checkpoint for better 2-bit initialization

# 4-bit checkpoint as initialization (loaded from Google Drive)
INIT_CHECKPOINT = "runs/qwen3_kdqat_cache_q2_4/qat_state_dict.pt"

# Adjust batch_size for your GPU (A100: 8-16, V100: 4-8, T4: 2-4)
BATCH_SIZE = 96                # Increase for faster instances (A100/H100)
STEPS_PER_LAYER_MLP = 100      # Steps per MLP layer (Pass 1 + Pass 3)
STEPS_PER_LAYER_ATTN = 30      # Steps per attention layer (Pass 2)
E2E_STEPS = 500                # E2E quantizer tuning steps (Pass 4)
LOCAL_WEIGHT = 0.3             # Local reconstruction loss weight
GLOBAL_WEIGHT = 1.0            # Global KD loss weight
LOCAL_TOKEN_SAMPLES = 128      # Tokens to sample for local loss
MAX_GRAD_NORM = 1.0            # Gradient clipping (important for 2-bit)

# Learning rates (lower for 2-bit stability)
LR_PROGRESSIVE = 2e-6          # Learning rate for progressive passes
LR_E2E = 5e-5                  # Learning rate for E2E f-only tuning

# Output directories (2-bit versions)
RUN_DIR_E2E_FONLY = "runs/e2e_f_only_q2"
RUN_DIR_PROGRESSIVE_FONLY = "runs/progressive_f_only_q2"
RUN_DIR_PROGRESSIVE = "runs/progressive_qat_q2_v1"

### Option 1: E2E f-only Training (Recommended First)

**Most stable approach** - trains ALL `_f_param` (quantization scales) simultaneously.
Skip all progressive layer-by-layer passes and go straight to Pass 4.

This is recommended when:
- Progressive layer-by-layer shows instability (local loss hitting 10.0)
- You want to validate the infrastructure works before trying progressive
- You have limited time and want the fastest path to a working checkpoint

The `f` parameter is the learnable quantization scale from Apple-style quantization:
- Actual scale `s = softplus(f)` ensures positivity
- Training only `f` keeps weights frozen - more stable for ultra-low-bit

In [9]:
# E2E f-only: Skip ALL progressive passes, train all f parameters at once
# This is the simplest and most stable approach
# Starting from 4-bit checkpoint for 2-bit training

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_DIR_E2E_FONLY} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --batch_size {BATCH_SIZE} \
  --skip_mlp_pass \
  --skip_attention_pass \
  --skip_mlp_refinement \
  --e2e_steps {E2E_STEPS} \
  --e2e_learning_rate {LR_E2E} \
  --max_grad_norm {MAX_GRAD_NORM} \
  --logging_steps 10 \
  --skip_lm_head

/content/qwen3_apple_style_2bit_qat_lora
[device] cuda | amp_dtype=torch.bfloat16 | param_dtype=torch.bfloat16
tokenizer_config.json: 9.73kB [00:00, 38.5MB/s]
vocab.json: 2.78MB [00:00, 52.2MB/s]
merges.txt: 1.67MB [00:00, 139MB/s]
tokenizer.json: 100% 11.4M/11.4M [00:01<00:00, 10.2MB/s]
[model] Loading Qwen/Qwen3-0.6B
config.json: 100% 726/726 [00:00<00:00, 8.23MB/s]
2025-12-24 18:47:22.719539: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-24 18:47:22.740144: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766602042.765073    9837 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for 

### Option 2: Progressive f-only Training

Layer-by-layer training but only trains `_f_param` (quantization scales), not weights.
More stable than full progressive training, but may still see instability at later layers.

Use `--train_f_only` flag to freeze weights and only train quantization scales per layer.

In [None]:
# Progressive f-only: Layer-by-layer, but only train quantization scales
# Use --train_f_only for more stable training
# Starting from 4-bit checkpoint for 2-bit training

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_DIR_PROGRESSIVE_FONLY} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --batch_size {BATCH_SIZE} \
  --train_f_only \
  --steps_per_layer_mlp {STEPS_PER_LAYER_MLP} \
  --steps_per_layer_attn {STEPS_PER_LAYER_ATTN} \
  --e2e_steps {E2E_STEPS} \
  --local_weight 0.0 \
  --global_weight {GLOBAL_WEIGHT} \
  --max_grad_norm {MAX_GRAD_NORM} \
  --learning_rate {LR_PROGRESSIVE} \
  --e2e_learning_rate {LR_E2E} \
  --logging_steps 10 \
  --skip_lm_head \
  --skip_mlp_refinement

### (!!!) Option 3: Full Progressive Training (weights + f)

Full layer-by-layer training with weights and quantization scales.
Most aggressive but potentially unstable for ultra-low-bit (2-bit).

**Training Order (3-pass v3):**
1. **Pass 1**: Train MLP layers (local reconstruction + global KD)
2. **Pass 2**: Train attention layers (global KD only)
3. **Pass 3**: MLP refinement (addresses MLP-attention coupling)
4. **Pass 4**: E2E quantizer-only tuning (f-param only)

In [10]:
# Full Progressive: MLP pass + E2E f-only (skip attention/refinement for v1)
# For full 3-pass training, remove --skip_attention_pass and --skip_mlp_refinement
# Starting from 4-bit checkpoint for 2-bit training

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_DIR_PROGRESSIVE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --batch_size {BATCH_SIZE} \
  --steps_per_layer_mlp {STEPS_PER_LAYER_MLP} \
  --e2e_steps {E2E_STEPS} \
  --local_weight {LOCAL_WEIGHT} \
  --global_weight {GLOBAL_WEIGHT} \
  --local_token_samples {LOCAL_TOKEN_SAMPLES} \
  --max_grad_norm {MAX_GRAD_NORM} \
  --learning_rate {LR_PROGRESSIVE} \
  --e2e_learning_rate {LR_E2E} \
  --logging_steps 10 \
  --skip_lm_head \
  --skip_attention_pass \
  --max_layer_repeats 20 \
  --max_backtrack 5 \
  --layer_converge_threshold 0.4 \
  --skip_mlp_refinement

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  step 20: local=0.4505 global=1.4807
  step 30: local=0.4532 global=1.3654
  step 40: local=0.4528 global=1.4021
  step 50: local=0.4493 global=1.4020
  step 60: local=0.4494 global=1.4319
  step 70: local=0.4599 global=1.4422
  step 80: local=0.4542 global=1.4362
  step 90: local=0.4558 global=1.3706
  Layer 10 not converged (global=1.4458 > 0.4), repeating...

--- Layer 10/27 MLP (repeat 9/20) ---
  Trainable params: 9,437,187
  step 0: local=0.4655 global=1.3227
  step 10: local=0.4546 global=1.3391
  step 20: local=0.4520 global=1.3870
  step 30: local=0.4616 global=1.4243
  step 40: local=0.4457 global=1.3548
  step 50: local=0.4567 global=1.4170
  step 60: local=0.4602 global=1.4128
  step 70: local=0.4500 global=1.4022
  step 80: local=0.4606 global=1.3591
  step 90: local=0.4484 global=1.4510
  Layer 10 not converged (global=1.4140 > 0.4), repeating...

--- Layer 10/27 MLP (repeat 10/20) ---
  Trainable params: 9

#### SAVE RUN

In [11]:
# ============================================================
# SAVE PROGRESSIVE QAT CHECKPOINT TO GOOGLE DRIVE
# ============================================================

%cd /content/qwen3_apple_style_2bit_qat_lora

from google.colab import drive
import os

# Mount Google Drive if not already mounted
drive.mount('/content/drive')

# Source directory (matches RUN_DIR_PROGRESSIVE from config)
RUN_NAME = "progressive_qat_q2_v1"
RUN_DIR = f"runs/{RUN_NAME}"

# Destination on Google Drive
DEST_DIR_GD = "/content/drive/MyDrive/qwen3_runs/"
!mkdir -p {DEST_DIR_GD}

# Check if run directory exists and has content
if os.path.isdir(RUN_DIR) and os.listdir(RUN_DIR):
    # Compress the run directory
    print(f"[archive] Compressing {RUN_DIR}...")
    !tar -zcvf {RUN_NAME}.tgz -C runs {RUN_NAME}

    # Copy to Google Drive
    print(f"[save] Copying {RUN_NAME}.tgz to Google Drive...")
    !rsync -ah --info=progress2 {RUN_NAME}.tgz {DEST_DIR_GD}

    # Verify
    gd_size = os.path.getsize(f"{DEST_DIR_GD}/{RUN_NAME}.tgz")
    print(f"[save] Saved to Google Drive: {gd_size / (1024**3):.2f} GB")
else:
    print(f"[save] ERROR: {RUN_DIR} is empty or doesn't exist")

# got to "Stage 3 resume" to continue distill treaining


/content/qwen3_apple_style_2bit_qat_lora
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[archive] Compressing runs/progressive_qat_q2_v1...
progressive_qat_q2_v1/
progressive_qat_q2_v1/loss_per_layer.csv
progressive_qat_q2_v1/training_args.json
progressive_qat_q2_v1/qat_state_dict.pt
[save] Copying progressive_qat_q2_v1.tgz to Google Drive...
        945.50M 100%  442.21MB/s    0:00:02 (xfr#1, to-chk=0/1)
[save] Saved to Google Drive: 0.88 GB


In [None]:
# v3: Full 3-pass progressive training
# MLP -> Attention -> MLP refinement -> E2E f-only
# WARNING: May show instability at later layers for 2-bit
# Starting from 4-bit checkpoint for 2-bit training

RUN_DIR_PROGRESSIVE_V3 = "runs/progressive_qat_q2_v3"

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_CHECKPOINT} \
  --output_dir {RUN_DIR_PROGRESSIVE_V3} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --batch_size {BATCH_SIZE} \
  --steps_per_layer_mlp {STEPS_PER_LAYER_MLP} \
  --steps_per_layer_attn {STEPS_PER_LAYER_ATTN} \
  --e2e_steps {E2E_STEPS} \
  --local_weight {LOCAL_WEIGHT} \
  --global_weight {GLOBAL_WEIGHT} \
  --local_token_samples {LOCAL_TOKEN_SAMPLES} \
  --max_grad_norm {MAX_GRAD_NORM} \
  --learning_rate {LR_PROGRESSIVE} \
  --e2e_learning_rate {LR_E2E} \
  --logging_steps 10 \
  --skip_lm_head

In [None]:
# ============================================================
# SAVE PROGRESSIVE QAT CHECKPOINT TO GOOGLE DRIVE
# ============================================================

from google.colab import drive
import os


# Mount Google Drive if not already mounted
drive.mount('/content/drive')

# Source directory
#RUN_NAME = "progressive_qat_v1"
RUN_NAME = "progressive_qat_q2_v3"

RUN_DIR = f"runs/{RUN_NAME}"

# Destination on Google Drive
DEST_DIR_GD = "/content/drive/MyDrive/qwen3_runs/"
!mkdir -p {DEST_DIR_GD}

# Check if run directory exists and has content
if os.path.isdir(RUN_DIR) and os.listdir(RUN_DIR):
    # Compress the run directory
    print(f"[archive] Compressing {RUN_DIR}...")
    !tar -zcvf {RUN_NAME}.tgz -C runs {RUN_NAME}

    # Copy to Google Drive
    print(f"[save] Copying {RUN_NAME}.tgz to Google Drive...")
    !rsync -ah --info=progress2 {RUN_NAME}.tgz {DEST_DIR_GD}

    # Verify
    gd_size = os.path.getsize(f"{DEST_DIR_GD}/{RUN_NAME}.tgz")
    print(f"[save] Saved to Google Drive: {gd_size / (1024**3):.2f} GB")

    # Cleanup local archive (optional)
    # !rm {RUN_NAME}.tgz
else:
    print(f"[save] ERROR: {RUN_DIR} is empty or doesn't exist")
    print("[save] Run progressive training first")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[archive] Compressing runs/qwen3_kdqat_cache_q2_3...
qwen3_kdqat_cache_q2_3/
qwen3_kdqat_cache_q2_3/special_tokens_map.json
qwen3_kdqat_cache_q2_3/loss.csv
qwen3_kdqat_cache_q2_3/added_tokens.json
qwen3_kdqat_cache_q2_3/tokenizer_config.json
qwen3_kdqat_cache_q2_3/run_state.json
qwen3_kdqat_cache_q2_3/merges.txt
qwen3_kdqat_cache_q2_3/training_args.json
qwen3_kdqat_cache_q2_3/chat_template.jinja
qwen3_kdqat_cache_q2_3/vocab.json
qwen3_kdqat_cache_q2_3/final_state_dict.pt
qwen3_kdqat_cache_q2_3/qat_state_dict.pt
qwen3_kdqat_cache_q2_3/tokenizer.json
[save] Copying qwen3_kdqat_cache_q2_3.tgz to Google Drive...
          1.90G 100%  430.50MB/s    0:00:04 (xfr#1, to-chk=0/1)
[save] Saved to Google Drive: 1.77 GB


In [None]:
# v3: Full 3-pass progressive training
# MLP -> Attention -> MLP refinement -> E2E f-only
# WARNING: May show instability at later layers for 2-bit

RUN_DIR_PROGRESSIVE_V3 = "runs/progressive_qat_v3"

!python scripts/train_qat_progressive.py \
  --model_name_or_path {MODEL_NAME} \
  --output_dir {RUN_DIR_PROGRESSIVE_V3} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --kd_cache_dir {CACHE_DIR_CHAT} \
  --batch_size {BATCH_SIZE} \
  --steps_per_layer_mlp {STEPS_PER_LAYER_MLP} \
  --steps_per_layer_attn {STEPS_PER_LAYER_ATTN} \
  --e2e_steps {E2E_STEPS} \
  --local_weight {LOCAL_WEIGHT} \
  --global_weight {GLOBAL_WEIGHT} \
  --local_token_samples {LOCAL_TOKEN_SAMPLES} \
  --max_grad_norm {MAX_GRAD_NORM} \
  --learning_rate {LR_PROGRESSIVE} \
  --e2e_learning_rate {LR_E2E} \
  --logging_steps 10 \
  --skip_lm_head

In [None]:
# Plot per-layer training progress
# Change PLOT_RUN to visualize different runs
import pandas as pd
import matplotlib.pyplot as plt
import os

# Choose which run to visualize
PLOT_RUN = RUN_DIR_E2E_FONLY  # or RUN_DIR_PROGRESSIVE, RUN_DIR_PROGRESSIVE_V3

csv_path = f"{PLOT_RUN}/loss_per_layer.csv"
if not os.path.exists(csv_path):
    print(f"Loss CSV not found at {csv_path}")
    print("Run training first or check the path.")
else:
    df = pd.read_csv(csv_path)

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Pass 1: MLP training (local loss)
    mlp_df = df[(df['pass'] == 1) & (df['component'] == 'mlp')]
    if not mlp_df.empty and 'local' in mlp_df.columns:
        for layer in mlp_df['layer'].unique():
            layer_df = mlp_df[mlp_df['layer'] == layer]
            axes[0, 0].plot(layer_df['step'], layer_df['local'], label=f'L{layer}', alpha=0.7)
        axes[0, 0].set_title('Pass 1: MLP Local Loss per Layer')
        axes[0, 0].set_xlabel('Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend(ncol=4, fontsize=6)
    else:
        axes[0, 0].set_title('Pass 1: MLP Local Loss (skipped or no local loss)')

    # Pass 1: MLP global loss
    if not mlp_df.empty and 'global' in mlp_df.columns:
        for layer in mlp_df['layer'].unique():
            layer_df = mlp_df[mlp_df['layer'] == layer]
            axes[0, 1].plot(layer_df['step'], layer_df['global'], label=f'L{layer}', alpha=0.7)
        axes[0, 1].set_title('Pass 1: MLP Global KD Loss per Layer')
        axes[0, 1].set_xlabel('Step')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend(ncol=4, fontsize=6)
    else:
        axes[0, 1].set_title('Pass 1: MLP Global Loss (skipped)')

    # Pass 2: Attention training
    attn_df = df[(df['pass'] == 2) & (df['component'] == 'attn')]
    if not attn_df.empty and 'global' in attn_df.columns:
        for layer in attn_df['layer'].unique():
            layer_df = attn_df[attn_df['layer'] == layer]
            axes[1, 0].plot(layer_df['step'], layer_df['global'], label=f'L{layer}', alpha=0.7)
        axes[1, 0].set_title('Pass 2: Attention Global KD Loss per Layer')
        axes[1, 0].set_xlabel('Step')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend(ncol=4, fontsize=6)
    else:
        axes[1, 0].set_title('Pass 2: Attention (skipped)')
        axes[1, 0].text(0.5, 0.5, 'Not run', ha='center', va='center', transform=axes[1, 0].transAxes)

    # Pass 4: E2E f-only tuning
    e2e_df = df[(df['pass'] == 4)]
    if not e2e_df.empty and 'global' in e2e_df.columns:
        axes[1, 1].plot(e2e_df['step'], e2e_df['global'], 'b-', linewidth=2)
        axes[1, 1].set_title('Pass 4: E2E f-only Tuning')
        axes[1, 1].set_xlabel('Step')
        axes[1, 1].set_ylabel('Global KD Loss')
    else:
        axes[1, 1].set_title('Pass 4: E2E (not yet run)')

    plt.tight_layout()
    plt.savefig(f"{PLOT_RUN}/loss_per_layer.png", dpi=150)
    plt.show()
    print(f"Saved to {PLOT_RUN}/loss_per_layer.png")

### Inference Check: Progressive QAT Results

Test the progressive QAT checkpoint with a quick inference.

In [8]:
# Test inference with progressive QAT checkpoint
# Change RUN_DIR to test different runs:
#   RUN_DIR_E2E_FONLY, RUN_DIR_PROGRESSIVE_FONLY, RUN_DIR_PROGRESSIVE
# progressive_qat_v1/qat_state_dict.pt

#TEST_RUN = RUN_DIR_E2E_FONLY  # Change this to test other runs

# Use the path where the checkpoint was unzipped
TEST_RUN =  "runs/qwen3_kdqat_cache_q2_4"
TEST_RUN =  "runs/qwen3_kdqat_cache_q2_3"

!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {TEST_RUN}/qat_state_dict.pt \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --prompt "What is Apple Neural Engine?" \
  --do_sample true \
  --max_new_tokens 128

2025-12-25 03:57:41.932578: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-25 03:57:41.952444: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766635061.977425    7043 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766635061.982748    7043 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766635061.996415    7043 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## 6) Stage B: LoRA recovery

Two options:
- **SFT LoRA** (Alpaca-style instruction tuning)
- **Cached KD-LoRA** (preserve teacher distribution; no new “skills”)


In [9]:
# ============================================================
# STAGE B: LoRA Recovery (Cached KD-LoRA)
# ============================================================
# Train LoRA adapters on top of QAT checkpoint

CACHE_DIR = CACHE_DIR_CHAT
#RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2"
RUN_DIR_CACHE  = "runs/progressive_qat_v1"
RUN_DIR_CACHE  = "runs/qwen3_kdqat_cache_q2_4"
RUN_DIR_CACHE =  "runs/qwen3_kdqat_cache_q2_3"


LORA_DIM = 32
LORA_RUN_KD = f"runs/qwen3_lora_recovery_cached_r{LORA_DIM}"

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/train_lora_recovery.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {LORA_RUN_KD} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --per_device_train_batch_size 16 \
  --gradient_accumulation_steps 2 \
  --learning_rate 1e-5 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 2 \
  --skip_lm_head \
  --lora_r {LORA_DIM} \
  --lora_alpha {LORA_DIM} \
  --lora_dropout 0.0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.02 \
  --hard-full-top1-weight 0.01

/content/qwen3_apple_style_2bit_qat_lora
[device] cuda | amp_dtype=torch.bfloat16 | param_dtype=torch.bfloat16
2025-12-25 03:59:07.550602: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-25 03:59:07.570769: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766635147.595924    7465 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766635147.601384    7465 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766635147.615579   

## 7) Plot loss

In Colab, use `--no_show` + `--save` then display the PNG.

In [None]:
!python scripts/plot_loss.py --run_dir {RUN_DIR} --source csv --no_show --save {RUN_DIR}/loss.png
from PIL import Image
display(Image.open(f"{RUN_DIR}/loss.png"))


## 8) Inference sanity checks

Greedy decode (`--do_sample false`) and keep outputs short (`--max_new_tokens 16`).

In [13]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_3"
#RUN_DIR = "runs/progressive_qat_v1"

%cd /content/qwen3_apple_style_2bit_qat_lora

!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r32/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "What is Machine Learning?" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 256


/content/qwen3_apple_style_2bit_qat_lora
2025-12-25 04:11:49.105856: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-25 04:11:49.126275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766635909.151365   10948 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766635909.156770   10948 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766635909.170813   10948 computation_placer.cc:177] computation placer already registered

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_4"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r32/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "Explain how neural networks learn in simple terms" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 1024


2025-12-24 10:41:54.709024: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-24 10:41:54.729231: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766572914.754556  155967 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766572914.759941  155967 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766572914.773878  155967 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r64/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "2+2=" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 90


In [None]:
LORA_DIM = 64
RUN_DIR = "runs/qwen3_kdqat_cache_q2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "What is capital of France?" \
  --do_sample true \
  --max_new_tokens 64


## 9) Optional: snap weights to the exact grid

This produces a float checkpoint with weights snapped to the N-bit codebook (not bitpacked).

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/hard_quantize_checkpoint.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/checkpoint_last.pt \
  --output_path {RUN_DIR}/hard_quant_full_state_dict.pt \
  -q {QUANT_BITS} \
  --skip_lm_head


In [None]:
%cd qwen3_apple_style_2bit_qat_lora
%ls -l runs
