# Qwen3-0.6B — Apple-style QAT (2-bit / 4-bit) + KD + LoRA recovery

This notebook mirrors the structure of common “phone deployment” notebooks, but uses **this repo’s** pipeline:

- **Stage A (recommended default):** KD-QAT on plain text (C4 streaming) or KD-cache QAT
- **Stage B:** LoRA recovery (either SFT or cached KD-LoRA)
- Plot `loss.csv`
- Run inference sanity checks

Notes:
- Qwen3 requires `transformers>=4.51.0`.
- For disk usage: C4 is huge; prefer `--streaming` unless you explicitly want to download.
- Bitwidth: use `-q 2` (default) or `-q 4` (less aggressive). Checkpoints persist the bitwidth per layer.


## 0) Setup (Colab / local)

If you’re in Colab, clone the repo. If you’re already in the repo directory locally, you can skip this.

In [None]:
# ---- Config (edit these) ----
MODEL_NAME = 'Qwen/Qwen3-0.6B'
TEACHER_NAME = MODEL_NAME
QUANT_BITS = 4  # 2 or 4
DEVICE = 'auto'
AMP_DTYPE = 'auto'
PARAM_DTYPE = 'auto'
DTYPE = 'auto'

# Cache dirs
CACHE_DIR_CHAT = 'caches/alpaca_chat_think_both_L128_K32_R256'
CACHE_DIR_TEXT = 'caches/c4_qwen3_L64_K32_R256'


In [None]:
# Colab-only:
!git clone https://github.com/Anemll/qwen3_apple_style_2bit_qat_lora
%cd qwen3_apple_style_2bit_qat_lora


Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 131, done.[K
remote: Counting objects: 100% (131/131), done.[K
remote: Compressing objects: 100% (82/82), done.[K
remote: Total 131 (delta 68), reused 109 (delta 46), pack-reused 0 (from 0)[K
Receiving objects: 100% (131/131), 99.37 KiB | 16.56 MiB/s, done.
Resolving deltas: 100% (68/68), done.
/content/qwen3_apple_style_2bit_qat_lora


## 1) Install dependencies (uv)

This repo is set up to work with `uv`.

In [None]:
!pip -q install uv
!uv pip install -r requirements.txt
!uv pip install -e .
# plotting
!uv pip install -q matplotlib
!uv pip install -q plot


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/22.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/22.2 MB[0m [31m167.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m12.1/22.2 MB[0m [31m195.3 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m21.2/22.2 MB[0m [31m264.0 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m22.2/22.2 MB[0m [31m250.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.2/22.2 MB[0m [31m119.5 MB/s[0m eta [36m0:00:00[0m
[?25h[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m7 packages[0m [2min 134ms[0m[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m62 packages[0m [2min 361ms[0m[0m
[2K[2mPr

## 2) Optional: Hugging Face login

If you hit gated model/dataset errors, log in.

In [None]:
# from huggingface_hub import login
# login()  # paste token when prompted


## 3) Quick environment check

In [None]:
import torch, transformers
print('torch', torch.__version__)
print('transformers', transformers.__version__)
print('cuda', torch.cuda.is_available())
print('mps', torch.backends.mps.is_available())


torch 2.9.0+cu126
transformers 4.57.3
cuda True
mps False


## 4) Stage A ((Optional)): KD-QAT on streaming C4

This preserves the base model’s behavior under low-bit fake-quant weights.

Tips:
- Start with a small run (`--max_steps 50`) to validate the pipeline.
- Use `-q 4` if 2-bit is too unstable; 4-bit is less aggressive.
- On MPS, prefer `--ema_decay 0` for KD-QAT.


In [None]:
# generate thinking dataset, SKIP if you copy from GD
!python scripts/precompute_teacher_topk.py \
  --teacher_model_name_or_path {MODEL_NAME} \
  --dataset_name tatsu-lab/alpaca \
  --dataset_split train \
  --dataset_format alpaca_chat \
  --enable_thinking both \
  --max_length 128 \
  --topk 32 \
  --rand_neg 256 \
  --num_sequences 20000 \
  --batch_size 1 \
  --shard_size 512 \
  --device {DEVICE} \
  --dtype {DTYPE} \
  --output_dir {CACHE_DIR_CHAT}


[device] cuda | dtype=torch.bfloat16
2025-12-17 23:40:12.337234: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-17 23:40:12.357384: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766014812.382692   39984 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766014812.388150   39984 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766014812.402070   39984 computation_placer.cc:177] computation placer already registered. Pl

In [None]:
# generate standard dataset, SKIP if you COPY from Google Drive!
import os

CACHE_DIR = CACHE_DIR_TEXT

if not os.path.isdir(CACHE_DIR):
    print(f"[cache] {CACHE_DIR} not found → generating cache")

    !python scripts/precompute_teacher_topk.py \
      --teacher_model_name_or_path {MODEL_NAME} \
      --dataset_name allenai/c4 \
      --dataset_config_name en \
      --dataset_split train \
      --dataset_text_field text \
      --streaming \
      --shuffle_buffer 10000 \
      --max_length 64 \
      --topk 32 \
      --rand_neg 256 \
      --num_sequences 2000 \
      --batch_size 1 \
      --shard_size 512 \
      --device {DEVICE} \
      --dtype {DTYPE} \
      --output_dir {CACHE_DIR}

else:
    print(f"[cache] {CACHE_DIR} already exists → skipping generation")


[device] cuda | dtype=torch.bfloat16
2025-12-18 02:00:07.257740: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 02:00:07.277713: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766023207.303615   12696 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766023207.309183   12696 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766023207.323190   12696 computation_placer.cc:177] computation placer already registered. Pl

In [None]:
#moutn GD and copy training sets!
from google.colab import drive
drive.mount('/content/drive')

#copy back from drive
!mkdir -p caches

!rsync -ah --info=progress2 \
  /content/drive/MyDrive/qwen3_caches/c4_qwen3_L64_K32_R256/ \
  caches/c4_qwen3_L64_K32_R256/

!rsync -ah --info=progress2 \
  /content/drive/MyDrive/qwen3_caches/alpaca_chat_think_both_L128_K32_R256/ \
  caches/alpaca_chat_think_both_L128_K32_R256


Mounted at /content/drive
        218.38M 100%   10.69MB/s    0:00:19 (xfr#5, to-chk=0/6)
          4.40G 100%   11.82MB/s    0:05:55 (xfr#41, to-chk=0/42)


In [None]:
#copy training sets to GD
from google.colab import drive
drive.mount('/content/drive')

!mkdir -p /content/drive/MyDrive/qwen3_caches
#!cp -rv caches/c4_qwen3_L64_K32_R256 /content/drive/MyDrive/qwen3_caches/
#!cp -rv caches/c4_qwen3_L64_K32_R256 /content/drive/MyDrive/qwen3_caches/


#!cp -rv caches/c4_qwen3_L64_K32_R256 /content/drive/MyDrive/qwen3_caches/
#!cp -rv caches/alpaca_chat_think_both_L128_K32_R256  /content/drive/MyDrive/qwen3_caches/

!rsync -ah --progress caches/c4_qwen3_L64_K32_R256 /content/drive/MyDrive/qwen3_caches/
!rsync -ah --progress caches/alpaca_chat_think_both_L128_K32_R256 /content/drive/MyDrive/qwen3_caches/

#cp -r caches/alpaca_chat_think_both_L128_K32_R256 \
#      /content/drive/MyDrive/qwen3_caches/
#RUN_DIR = "runs/qwen3_kdqat_stream_q2"



In [None]:
# DISABLED --- NOTE used! see #5 for first QAT step
!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --teacher_model_name_or_path {MODEL_NAME} \
  --distill_weight 1.0 \
  --distill_temperature 2.0 \
  --dataset_name allenai/c4 \
  --dataset_config_name en \
  --dataset_split train \
  --dataset_format text \
  --dataset_text_field text \
  --streaming \
  --shuffle_buffer 10000 \
  --output_dir {RUN_DIR} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 16 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 50 \
  --skip_lm_head \
  --ema_decay 0 \
  --logging_steps 10 \
  --save_steps 50


### (Optional) Resume

`--resume_from_checkpoint auto` resolves to `checkpoint_last.pt` if it exists in the output directory.

In [None]:
# !python scripts/train_qat.py ... --output_dir {RUN_DIR} --max_steps 500 --resume_from_checkpoint auto


## 5)  KD-cache: precompute teacher top-k + negatives

Cache mode is MPS-friendly:
- no teacher model during training
- no full-vocab logits

If you see good KD loss but bad greedy decoding, increase negative coverage (`--rand_neg`) and/or add hard top-1 terms:
- `--hard-top1-weight 0.05`
- `--hard-full-top1-weight 0.02`–`0.05`

### KD-cache QAT training

This uses cached teacher signals + candidate softmax.

In [None]:
#stage 1 - conservative run freeze MLP and Attention
#CACHE_DIR = CACHE_DIR_TEXT
CACHE_DIR = "caches/alpaca_chat_think_both_L128_K32_R256"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 128 \
  --gradient_accumulation_steps 1 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 5 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.05 \
  --hard-full-top1-weight 0.03 \
  --ov-freeze \
  --freeze-last-mlp \
  --freeze-last-mlp-layers 1
  #  --save_steps 500 \


2025-12-18 14:36:50.092192: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 14:36:50.109207: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766068610.130796    5337 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766068610.137293    5337 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766068610.153610    5337 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
# Starge 2 resume KD-QAT with unfrozen layers
CACHE_DIR = "caches/alpaca_chat_think_both_L128_K32_R256"
INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_2"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 128 \
  --gradient_accumulation_steps 1 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --ov-freeze \
  --hard-top1-weight 0.02 \
  --hard-full-top1-weight 0.01


2025-12-18 14:56:30.687661: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 14:56:30.704782: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766069790.726044   10483 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766069790.732532   10483 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766069790.748388   10483 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
# Stage 3 resume KD-QAT with unfrozen attention and relaxed hard-top/full!

CACHE_DIR = "caches/alpaca_chat_think_both_L128_K32_R256"
INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_2"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_3"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 128 \
  --gradient_accumulation_steps 1 \
  --learning_rate 5e-6 \
  --warmup_steps 0 \
  --max_steps 1000 \
  --save_steps 3000 \
  --logging_steps 10 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.00 \
  --hard-full-top1-weight 0.0005


2025-12-18 15:15:24.480459: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 15:15:24.498705: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766070924.519620   15385 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766070924.525944   15385 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766070924.542169   15385 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
# Stage 4 resume KD-QAT with unfrozen attention and relaxed hard-top/full!
#   --hard-full-top1-weight 0.0000
#   learning_rate 2e-6
CACHE_DIR = "caches/alpaca_chat_think_both_L128_K32_R256"
INIT_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_3"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_4"

!python scripts/train_qat.py \
  --model_name_or_path {MODEL_NAME} \
  --init_model_state {INIT_DIR_CACHE}/qat_state_dict.pt \
  --output_dir {RUN_DIR_CACHE} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --max_length 128 \
  --per_device_train_batch_size 160 \
  --gradient_accumulation_steps 1 \
  --learning_rate 2e-6 \
  --warmup_steps 0 \
  --max_steps 500 \
  --save_steps 3000 \
  --logging_steps 5 \
  --skip_lm_head \
  --ema_decay 0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.00 \
  --hard-full-top1-weight 0.0000


2025-12-18 15:40:54.295253: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 15:40:54.313779: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766072454.335347   22166 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766072454.341994   22166 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766072454.358756   22166 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## 6) Stage B: LoRA recovery

Two options:
- **SFT LoRA** (Alpaca-style instruction tuning)
- **Cached KD-LoRA** (preserve teacher distribution; no new “skills”)


In [None]:
CACHE_DIR = "caches/alpaca_chat_think_both_L128_K32_R256"
RUN_DIR_CACHE = "runs/qwen3_kdqat_cache_q2_2"
LORA_DIM = 64
LORA_RUN_KD = f"runs/qwen3_lora_recovery_cached_r{LORA_DIM}" # Define LORA_RUN_KD as well

#   --max_steps 438 \
!python scripts/train_lora_recovery.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR_CACHE}/final_state_dict.pt \
  --output_dir {LORA_RUN_KD} \
  --device {DEVICE} \
  --amp_dtype {AMP_DTYPE} \
  --param_dtype {PARAM_DTYPE} \
  -q {QUANT_BITS} \
  --per_device_train_batch_size 128 \
  --gradient_accumulation_steps 1 \
  --learning_rate 1e-5 \
  --warmup_steps 0 \
  --max_steps 1500 \
  --save_steps 3000 \
  --logging_steps 2 \
  --skip_lm_head \
  --lora_r {LORA_DIM} \
  --lora_alpha {LORA_DIM} \
  --lora_dropout 0.0 \
  --kd_cache_dir {CACHE_DIR} \
  --kd_cache_shuffle_files \
  --distill_temperature 2.0 \
  --distill_weight 1.0 \
  --hard-top1-weight 0.02 \
  --hard-full-top1-weight 0.01


[device] cuda | amp_dtype=torch.bfloat16 | param_dtype=torch.bfloat16
2025-12-18 16:59:07.388152: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 16:59:07.408019: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766077147.432804   43555 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766077147.438097   43555 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766077147.451543   43555 computation_placer.cc:177] computat

## 7) Plot loss

In Colab, use `--no_show` + `--save` then display the PNG.

In [None]:
!python scripts/plot_loss.py --run_dir {RUN_DIR} --source csv --no_show --save {RUN_DIR}/loss.png
from PIL import Image
display(Image.open(f"{RUN_DIR}/loss.png"))


## 8) Inference sanity checks

Greedy decode (`--do_sample false`) and keep outputs short (`--max_new_tokens 16`).

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r64/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "What is capital of France?" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 40


2025-12-18 17:21:38.241432: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 17:21:38.261766: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766078498.286912   49388 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766078498.292409   49388 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766078498.306387   49388 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r64/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "What is Apple Neural Engine?" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 100


2025-12-18 17:26:48.407223: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 17:26:48.427401: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766078808.452492   50954 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766078808.457837   50954 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766078808.471491   50954 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --lora_checkpoint "runs/qwen3_lora_recovery_cached_r64/lora_only_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "2+2=" \
  --do_sample false \
  --enable_thinking true \
  --max_new_tokens 90


2025-12-18 17:26:15.784397: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 17:26:15.804707: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766078775.830117   50751 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766078775.835578   50751 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766078775.849485   50751 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
LORA_DIM = 64
RUN_DIR = "runs/qwen3_kdqat_cache_q2"
!python scripts/run_inference.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}"/qat_state_dict.pt" \
  --device {DEVICE} \
  --dtype {DTYPE} \
  -q {QUANT_BITS} \
  --skip_lm_head \
  --lora_r {LORA_DIM} --lora_alpha {LORA_DIM} --lora_dropout 0.0 \
  --prompt "What is capital of France?" \
  --do_sample true \
  --max_new_tokens 64


2025-12-18 15:56:05.695278: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-18 15:56:05.715142: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1766073365.739846   26345 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1766073365.745393   26345 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1766073365.759608   26345 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

## 9) Optional: snap weights to the exact grid

This produces a float checkpoint with weights snapped to the N-bit codebook (not bitpacked).

In [None]:
RUN_DIR = "runs/qwen3_kdqat_cache_q2_2"
!python scripts/hard_quantize_checkpoint.py \
  --model_name_or_path {MODEL_NAME} \
  --qat_checkpoint {RUN_DIR}/checkpoint_last.pt \
  --output_path {RUN_DIR}/hard_quant_full_state_dict.pt \
  -q {QUANT_BITS} \
  --skip_lm_head


In [None]:
%cd qwen3_apple_style_2bit_qat_lora
%ls -l runs


[Errno 2] No such file or directory: 'qwen3_apple_style_2bit_qat_lora'
/content/qwen3_apple_style_2bit_qat_lora
total 8
drwxr-xr-x 2 root root 4096 Dec 17 21:49 [0m[01;34mqwen3_kdqat_cache_q2[0m/
drwxr-xr-x 2 root root 4096 Dec 17 22:01 [01;34mqwen3_lora_recovery_cached[0m/
