# Qwen3-0.6B — Apple-style QAT (2-bit / 4-bit) + KD + LoRA recovery

This notebook mirrors the structure of common “phone deployment” notebooks, but uses **this repo’s** pipeline:

- **Stage A (recommended default):** KD-QAT on plain text (C4 streaming) or KD-cache QAT
- **Stage B:** LoRA recovery (either SFT or cached KD-LoRA)
- Plot `loss.csv`
- Run inference sanity checks

Notes:
- Qwen3 requires `transformers>=4.51.0`.
- For disk usage: C4 is huge; prefer `--streaming` unless you explicitly want to download.
- Bitwidth: use `-q 2` (default) or `-q 4` (less aggressive). Checkpoints persist the bitwidth per layer.


## 0) Setup (Colab / local)

If you’re in Colab, clone the repo. If you’re already in the repo directory locally, you can skip this.

In [None]:
# ---- Config (edit these) ----
MODEL_NAME = 'Qwen/Qwen3-0.6B'
QUANT_BITS = 2  # 2 or 4
DEVICE = 'auto'
AMP_DTYPE = 'auto'
PARAM_DTYPE = 'auto'
DTYPE = 'auto'


In [None]:
# Colab-only:
# !git clone https://github.com/Anemll/qwen3_apple_style_2bit_qat_lora
# %cd qwen3_apple_style_2bit_qat_lora


## 1) Install dependencies (uv)

This repo is set up to work with `uv`.

In [None]:
!pip -q install uv
!uv pip install -r requirements.txt
!uv pip install -e .
# plotting
!uv pip install -q matplotlib


## 2) Optional: Hugging Face login

If you hit gated model/dataset errors, log in.

In [None]:
# from huggingface_hub import login
# login()  # paste token when prompted


## 3) Quick environment check

In [None]:
import torch, transformers
print('torch', torch.__version__)
print('transformers', transformers.__version__)
print('cuda', torch.cuda.is_available())
print('mps', torch.backends.mps.is_available())


## 4) Stage A (recommended): KD-QAT on streaming C4

This preserves the base model’s behavior under low-bit fake-quant weights.

Tips:
- Start with a small run (`--max_steps 50`) to validate the pipeline.
- Use `-q 4` if 2-bit is too unstable; 4-bit is less aggressive.
- On MPS, prefer `--ema_decay 0` for KD-QAT.


In [None]:
RUN_DIR = f"runs/qwen3_kdqat_stream_q{QUANT_BITS}"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --teacher_model_name_or_path {MODEL_NAME} \
  --distill_weight 1.0 \
  --distill_temperature 2.0 \
  --dataset_name allenai/c4 \
  --dataset_config_name en \
  --dataset_split train \
  --dataset_format text \
  --dataset_text_field text \
  --streaming \
  --shuffle_buffer 10000 \
  --output_dir {RUN_DIR} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 16 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 50 \
  --skip_lm_head \
  --ema_decay 0 \
  --logging_steps 10 \
  --save_steps 50


### (Optional) Resume

`--resume_from_checkpoint auto` resolves to `checkpoint_last.pt` if it exists in the output directory.

In [None]:
# !python scripts/train_qat.py ... --output_dir {RUN_DIR} --max_steps 500 --resume_from_checkpoint auto


## 5) (Optional) KD-cache: precompute teacher top-k + negatives

Cache mode is MPS-friendly:
- no teacher model during training
- no full-vocab logits

If you see good KD loss but bad greedy decoding, increase negative coverage (`--rand_neg`) and/or add hard top-1 terms:
- `--hard-top1-weight 0.05`
- `--hard-full-top1-weight 0.02`–`0.05`

In [None]:
CACHE_DIR = "caches/c4_qwen3_L64_K32_R256"

!python scripts/precompute_teacher_topk.py \
  --teacher_model_name_or_path {MODEL_NAME} \
  --dataset_name allenai/c4 \
  --dataset_config_name en \
  --dataset_split train \
  --dataset_text_field text \
  --streaming \
  --shuffle_buffer 10000 \
  --max_length 64 \
  --topk 32 \
  --rand_neg 256 \
  --num_sequences 2000 \
  --batch_size 1 \
  --shard_size 512 \
  --device {DEVICE} \
  --dtype {DTYPE} \
  --output_dir {CACHE_DIR}


### KD-cache QAT training

This uses cached teacher signals + candidate softmax.

In [None]:
RUN_DIR_CACHE = f"runs/qwen3_kdqat_cache_q{QUANT_BITS}"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 64 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 200 \
  --save_steps 50 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.05 \
  --hard-full-top1-weight 0.02


## 6) Stage B: LoRA recovery

Two options:
- **SFT LoRA** (Alpaca-style instruction tuning)
- **Cached KD-LoRA** (preserve teacher distribution; no new “skills”)


In [None]:
LORA_RUN = "runs/qwen3_lora_recovery_sft"

!python scripts/train_lora_recovery.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/final_state_dict.pt \
  --dataset_name tatsu-lab/alpaca \
  --output_dir {LORA_RUN} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --learning_rate 2e-4 \
  --warmup_steps 50 \
  --max_steps 50 \
  --save_steps 50 \
  --logging_steps 10 \
  --skip_lm_head \
  --lora_r 32 \
  --lora_alpha 32 \
  --lora_dropout 0.0


In [None]:
LORA_RUN_KD = "runs/qwen3_lora_recovery_cached"

!python scripts/train_lora_recovery.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR_CACHE}/final_state_dict.pt \
  --output_dir {LORA_RUN_KD} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --learning_rate 5e-5 \
  --warmup_steps 0 \
  --max_steps 200 \
  --save_steps 50 \
  --logging_steps 10 \
  --skip_lm_head \
  --lora_r 16 \
  --lora_alpha 16 \
  --lora_dropout 0.0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.05 \
  --hard-full-top1-weight 0.02


## 7) Plot loss

In Colab, use `--no_show` + `--save` then display the PNG.

In [None]:
!python scripts/plot_loss.py --run_dir {RUN_DIR} --source csv --no_show --save {RUN_DIR}/loss.png
from PIL import Image
display(Image.open(f"{RUN_DIR}/loss.png"))


## 8) Inference sanity checks

Greedy decode (`--do_sample false`) and keep outputs short (`--max_new_tokens 16`).

In [None]:
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/final_state_dict.pt \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --prompt "The capital of France is" \
  --plain_text \
  --do_sample false \
  --max_new_tokens 16


In [None]:
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/final_state_dict.pt \
  --lora_checkpoint {LORA_RUN}/lora_only_state_dict.pt \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r 32 --lora_alpha 32 --lora_dropout 0.0 \
  --prompt "The opposite of hot is" \
  --plain_text \
  --do_sample false \
  --max_new_tokens 16


## 9) Optional: snap weights to the exact grid

This produces a float checkpoint with weights snapped to the N-bit codebook (not bitpacked).

In [None]:
!python scripts/hard_quantize_checkpoint.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/checkpoint_last.pt \
  --output_path {RUN_DIR}/hard_quant_full_state_dict.pt \
  -q {QUANT_BITS} \
  --skip_lm_head
