# Anemll-Style Layer-by-Layer QAT

This notebook implements layer-by-layer QAT training using `AnemllQATLinear` with:
- Groupwise LUT quantization
- Low-rank scale factors (A @ B)
- KD cache for distillation
- **Hard label loss** for improved convergence

## Pipeline:
1. Load model and replace linears with AnemllQATLinear
2. Layer-by-layer scale optimization (weights frozen)
3. Layer-by-layer weight training (with hard label loss)
4. End-to-end refinement
5. (Optional) LoRA recovery

## Distillation Options:
- `temperature`: KL divergence temperature (default: 2.0)
- `hard_top1_weight`: Hard label top-1 loss weight (recommended: 0.1 for weights, 0.0 for scales)
- `hard_full_weight`: Hard label full vocab loss weight (optional)

In [1]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### GITUB

In [3]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
!git reset --hard HEAD
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

from qat_lora import *

Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 420, done.[K
remote: Counting objects: 100% (110/110), done.[K
remote: Compressing objects: 100% (81/81), done.[K
remote: Total 420 (delta 75), reused 58 (delta 29), pack-reused 310 (from 1)[K
Receiving objects: 100% (420/420), 544.19 KiB | 10.88 MiB/s, done.
Resolving deltas: 100% (266/266), done.
/content/qwen3_apple_style_2bit_qat_lora
Already up to date.
HEAD is now at 07af0da Added E2E QAT support , CLI , weight snapping. Updated notebook Add Anemll QAT Command Line Reference and new training scripts - Introduced a comprehensive command line reference for Anemll-style Quantization-Aware Training (QAT), detailing usage, scripts, and parameters for both layer-by-layer and end-to-end training approaches. - Added new scripts for layer-by-layer training (`train_anemll_lbl.py`), end-to-end training (`train_anemll_qat.py`), and inference (`run_anemll_inference.py`), enhancing usability and flexibility in m

In [4]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [5]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

Extracting alpaca_chat_think_both_L128_K32_R256.tgz from Google Drive...
total 4298960
drwx------ 2 root root      4096 Dec 18 00:00 .
drwxr-xr-x 3 root root      4096 Dec 26 19:09 ..
-rw------- 1 root root       421 Dec 18 00:15 meta.json
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00000.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00001.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00002.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00003.pt
-rw------- 1 root root 112692165 Dec 18 00:16 shard_00004.pt
-rw------- 1 root root 112692165 Dec 18 00:16 shard_00005.pt


In [6]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
GROUP_SIZE = 32      # Group size for scales
SCALE_RANK = 4       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_GROUP_SIZE = 32
ATTN_SCALE_RANK = 8

# Training
BATCH_SIZE = 4
GRAD_ACCUM = 4

if torch.cuda.is_available():
    BATCH_SIZE=64
    GRAD_ACCUM=1

LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD / Distillation params
DISTILL_TEMP = 2.0
HARD_TOP1_WEIGHT = 0.1    # Hard label top-1 loss (helps convergence)
HARD_FULL_WEIGHT = 0.0005    # Hard label full vocab loss (optional)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16


QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}'

print(f'Quality: {QUAL}')

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')
print(f'Distillation: temp={DISTILL_TEMP}, hard_top1={HARD_TOP1_WEIGHT}, hard_full={HARD_FULL_WEIGHT}')

Quality: q4_a4
Device: cuda, dtype: torch.bfloat16
Quant config: lut=16, group=32, rank=4
Distillation: temp=2.0, hard_top1=0.1, hard_full=0.0005


In [7]:
# ============================================================
# Extracting LOCAL CACHE
# ============================================================

import os
from pathlib import Path

# Verify drive is mounted and cache exists
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

if not os.path.exists(cache_local_path):
    print(f'Cache not found at {cache_local_path}')
    print(f'Extracting from Google Drive...')
    os.makedirs(LOCAL_CACHES, exist_ok=True)
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/

# Verify cache exists now
assert os.path.exists(cache_local_path), f'Cache still not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'Cache ready: {len(cache_files)} files in {cache_local_path}')

Cache ready: 40 files in caches/alpaca_chat_think_both_L128_K32_R256


In [8]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

Loading Qwen/Qwen3-0.6B...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded. Parameters: 596,049,920


In [9]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinear
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear as ane_module
importlib.reload(ane_module)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfig, replace_linear_with_anemll

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create configs
mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    group_size=GROUP_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
)

attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    group_size=ATTN_GROUP_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
)

print('\nReplacing linear layers...')
count = replace_linear_with_anemll(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked
from qat_lora import AnemllQATLinear
qat_count = sum(1 for _, m in model.named_modules() if isinstance(m, AnemllQATLinear))
print(f"\nVerification: {qat_count} AnemllQATLinear modules in model")

Checking model structure...
  Found Linear: model.layers.0.self_attn.q_proj
  Found Linear: model.layers.0.self_attn.k_proj
  Found Linear: model.layers.0.self_attn.v_proj
  Found Linear: model.layers.0.self_attn.o_proj
  Found Linear: model.layers.0.mlp.gate_proj
Total Linear modules: 197

Replacing linear layers...
  [replaced] model.layers.0.self_attn.q_proj
  [replaced] model.layers.0.self_attn.k_proj
  [replaced] model.layers.0.self_attn.v_proj
  [replaced] model.layers.0.self_attn.o_proj
  [replaced] model.layers.0.mlp.gate_proj
  [replaced] model.layers.0.mlp.up_proj
  [replaced] model.layers.0.mlp.down_proj
  [replaced] model.layers.1.self_attn.q_proj
  [replaced] model.layers.1.self_attn.k_proj
  [replaced] model.layers.1.self_attn.v_proj
  [replaced] model.layers.1.self_attn.o_proj
  [replaced] model.layers.1.mlp.gate_proj
  [replaced] model.layers.1.mlp.up_proj
  [replaced] model.layers.1.mlp.down_proj
  [replaced] model.layers.2.self_attn.q_proj
  [replaced] model.layers.2.

In [10]:
# ============================================================
# IMPORT LAYER-BY-LAYER QAT UTILITIES & VERIFY GRADIENTS
# ============================================================

from qat_lora import (
    evaluate_kd_loss,
    train_all_layers,
    AnemllQATLinear,
)

print('Layer QAT utilities imported from qat_lora')

# Verify gradient flow works
print('\nVerifying gradient flow...')
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if isinstance(m, AnemllQATLinear):
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinear modules found! Replacement failed.")
else:
    # Test gradient flow
    test_module.weight.requires_grad = True
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.weight.grad is not None:
            print(f"  Gradient OK: weight.grad.shape = {test_module.weight.grad.shape}")
            test_module.weight.grad = None  # Clear for actual training
        else:
            print("  ERROR: weight.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

Layer QAT utilities imported from qat_lora

Verifying gradient flow...
ERROR: No AnemllQATLinear modules found! Replacement failed.

Computing initial KD loss...
Initial KD Loss: 1.0885


# **SCALE OPTIMIZATION** (Weights Frozen)

After layer-by-layer QAT on weights, optimize the per-weight scales (A @ B) to further reduce quantization error.

- Weights are **frozen**
- Only `scale_A` and `scale_B` are trained
- Much fewer parameters → can use higher learning rate

In [11]:
# ============================================================
# LAYER-BY-LAYER SCALE OPTIMIZATION
# ============================================================
# Freeze weights, only train scale_A and scale_B tensors
# Higher LR since fewer parameters
# Note: Hard label loss not needed for scale optimization

SCALE_LR = 1e-3  # Higher LR for scales (fewer params)
SCALE_EPOCHS = 2  # More epochs since scales have less capacity


print('Starting scale-only layer-by-layer optimization...')
print(f'LR: {SCALE_LR}, Epochs per layer: {SCALE_EPOCHS}')

# Get loss before scale optimization
pre_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'KD Loss before scale optimization: {pre_scale_loss:.4f}')

# Train scales layer-by-layer (no hard label needed for scales)
scale_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=SCALE_LR,
    epochs_per_layer=SCALE_EPOCHS,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=False,  # Freeze weights
    train_scales=True,    # Train scales only
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=0.0,  # Not needed for scale optimization
    hard_full_weight=0.0,
    verbose=True,
    steps_per_layer=100,
)

# Evaluate after scale optimization
post_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'\n=== Scale Optimization Results ===')
print(f'Before: {pre_scale_loss:.4f}')
print(f'After:  {post_scale_loss:.4f}')
print(f'Improvement: {pre_scale_loss - post_scale_loss:.4f}')

Starting scale-only layer-by-layer optimization...
LR: 0.001, Epochs per layer: 2
KD Loss before scale optimization: 1.0885
Training 28 layers (mode=scales only)...
Cache: caches/alpaca_chat_think_both_L128_K32_R256
Batch size: 64, Grad accum: 1
LR: 0.001, Steps per layer: 100

[Initial Global KD Loss]: 1.0885

=== Layer 0 === (131,072 trainable params, mode=scales only)
  [Global KD Loss BEFORE]: 1.0345
  step 10: local=0.0439 global=0.8089 (3.9s)
  step 20: local=0.0710 global=0.7905 (7.2s)
  step 30: local=0.0724 global=0.7238 (10.4s)
  step 40: local=0.0671 global=0.6626 (13.7s)
  step 50: local=0.0641 global=0.6703 (17.0s)
  step 60: local=0.0619 global=0.6684 (20.3s)
  step 70: local=0.0622 global=0.6874 (23.5s)
  step 80: local=0.0609 global=0.6786 (26.8s)
  step 90: local=0.0638 global=0.6774 (30.1s)
  step 100: local=0.0646 global=0.6453 (33.3s)
  ---
  [Local Loss]:   0.0283 -> 0.0646 (Δ=-0.0364)
  [Global Loss]:  1.0063 -> 0.6453 (Δ=0.3609)
  [Eval KD]:      1.0345 -> 0.6752

# **RUN** LAYER-BY-LAYER TRAINING

In [12]:
# ============================================================
# LAYER-BY-LAYER WEIGHT TRAINING
# ============================================================
# Train weights with hard label loss for better convergence

print('Starting layer-by-layer weight training...')
print(f'LR: {LR}, Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train all layers using the imported function
layer_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs_per_layer=EPOCHS_PER_LAYER,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=True,   # Train weights
    train_scales=False,   # Keep scales frozen for now
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=HARD_TOP1_WEIGHT,  # Helps convergence
    hard_full_weight=HARD_FULL_WEIGHT,
    verbose=True,
    steps_per_layer=100,
)

Starting layer-by-layer weight training...
LR: 2e-05, Hard label: top1=0.1, full=0.0005
Training 28 layers (mode=weights)...
Cache: caches/alpaca_chat_think_both_L128_K32_R256
Batch size: 64, Grad accum: 1
LR: 2e-05, Steps per layer: 100
Hard label: top1=0.1, full=0.0005

[Initial Global KD Loss]: 0.3553

=== Layer 0 === (15,728,640 trainable params, mode=weights)
  Hard label: top1=0.1, full=0.0005
  [Global KD Loss BEFORE]: 0.3761
  step 10: local=0.0611 global=0.4788 (3.7s)
  step 20: local=0.0566 global=0.4685 (7.3s)
  step 30: local=0.0539 global=0.4798 (10.9s)
  step 40: local=0.0555 global=0.4607 (14.5s)
  step 50: local=0.0579 global=0.4550 (18.2s)
  step 60: local=0.0562 global=0.4516 (21.8s)
  step 70: local=0.0518 global=0.4697 (25.4s)
  step 80: local=0.0544 global=0.4401 (29.1s)
  step 90: local=0.0541 global=0.4612 (32.7s)
  step 100: local=0.0510 global=0.4753 (36.4s)
  ---
  [Local Loss]:   0.0640 -> 0.0510 (Δ=0.0130)
  [Global Loss]:  0.4520 -> 0.4753 (Δ=-0.0234)
  [Ev

In [13]:
# ============================================================
# EVALUATE AFTER LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

Initial KD Loss: 1.0885
After Layer-by-Layer: 0.2810
Improvement: 0.8074


In [14]:
# ============================================================
# SAVE CHECKPOINT
# ============================================================

import os

RUN_NAME = f'anemll_{QUAL}_layer_by_layer_v1'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'layer_losses': layer_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved to {SAVE_DIR}')

Saved to runs/anemll_q4_a4_layer_by_layer_v1


In [15]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

anemll_q4_a4_layer_by_layer_v1/
anemll_q4_a4_layer_by_layer_v1/config.json
anemll_q4_a4_layer_by_layer_v1/model_state_dict.pt
Uploaded to /content/drive/MyDrive/qwen3_runs/anemll_q4_a4_layer_by_layer_v1.tgz


In [16]:
# ============================================================
# END-TO-END KD-QAT: TRAIN SCALES (WEIGHTS FROZEN)
# ============================================================
# Scale training doesn't need hard label loss

# Train scales (weights frozen) - higher LR since fewer params
e2e_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=500,
    batch_size=128 if torch.cuda.is_available() else 32,
    lr=1e-3,  # Higher LR for scales
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,  # Not needed for scale training
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
)

=== End-to-End KD-QAT ===
Mode: scales
Trainable params: 3,670,016
Steps: 500, LR: 0.001, Batch: 128

Initial KD Loss: 0.2811
[20/500] loss=0.4970 (17s, ETA 408s)
[40/500] loss=0.3894 (30s, ETA 344s)
[60/500] loss=0.3566 (43s, ETA 314s)
[80/500] loss=0.3363 (56s, ETA 293s)
[100/500] loss=0.3298 (69s, ETA 275s)
  [Eval] KD Loss: 0.3063 (best: 0.2811)
[120/500] loss=0.3201 (86s, ETA 271s)
[140/500] loss=0.3161 (99s, ETA 253s)
[160/500] loss=0.3138 (111s, ETA 236s)
[180/500] loss=0.2946 (124s, ETA 220s)
[200/500] loss=0.2914 (137s, ETA 205s)
  [Eval] KD Loss: 0.2948 (best: 0.2811)
[220/500] loss=0.2922 (154s, ETA 196s)
[240/500] loss=0.2901 (167s, ETA 181s)
[260/500] loss=0.2929 (180s, ETA 166s)
[280/500] loss=0.2944 (193s, ETA 151s)
[300/500] loss=0.2999 (206s, ETA 137s)
  [Eval] KD Loss: 0.2697 (best: 0.2811)
[320/500] loss=0.2909 (223s, ETA 126s)
[340/500] loss=0.2857 (236s, ETA 111s)
[360/500] loss=0.2839 (249s, ETA 97s)
[380/500] loss=0.2839 (262s, ETA 83s)
[400/500] loss=0.2864 (275

In [46]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [42]:
# ============================================================
# END-TO-END KD-QAT: TRAIN WEIGHTS (SCALES FROZEN)
# ============================================================
# Use hard label loss for stable weight training

from qat_lora import train_e2e, save_checkpoint, load_checkpoint, unfreeze_model_for_training

# Unfreeze for training (clear any cached weights)
unfreeze_model_for_training(model)

print('E2E weight training with hard label loss...')
print(f'Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train weights (scales frozen)
e2e_weights_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=3000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=5e-6,
    temperature=DISTILL_TEMP,
    train_weights=True,
    train_scales=False,
    hard_top1_weight=HARD_TOP1_WEIGHT,  # Helps prevent divergence
    hard_full_weight=HARD_FULL_WEIGHT,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
)

E2E weight training with hard label loss...
Hard label: top1=0.1, full=0.0005
=== End-to-End KD-QAT ===
Mode: weights
Trainable params: 440,401,920
Steps: 3000, LR: 5e-06, Batch: 64
Hard label: top1=0.1, full=0.0005

Initial KD Loss: 0.2660
[50/3000] loss=0.3469 (25s, ETA 1469s)
[100/3000] loss=0.3468 (45s, ETA 1317s)
[150/3000] loss=0.3414 (66s, ETA 1252s)
[200/3000] loss=0.3259 (86s, ETA 1210s)
[250/3000] loss=0.3268 (107s, ETA 1177s)
[300/3000] loss=0.3270 (127s, ETA 1147s)
[350/3000] loss=0.3154 (148s, ETA 1119s)
[400/3000] loss=0.3064 (168s, ETA 1094s)
[450/3000] loss=0.3027 (189s, ETA 1070s)
[500/3000] loss=0.3052 (209s, ETA 1047s)
  [Eval] KD Loss: 0.2119 (best: 0.2660)
[550/3000] loss=0.3018 (235s, ETA 1047s)
[600/3000] loss=0.3028 (256s, ETA 1022s)
[650/3000] loss=0.2885 (276s, ETA 998s)
[700/3000] loss=0.2932 (297s, ETA 975s)
[750/3000] loss=0.2870 (317s, ETA 952s)
[800/3000] loss=0.2854 (338s, ETA 930s)
[850/3000] loss=0.2875 (359s, ETA 907s)
[900/3000] loss=0.2861 (379s, ET

In [43]:
# ============================================================
# SAVE FINAL CHECKPOINT
# ============================================================

E2E_RUN_NAME = f'anemll_{QUAL}_e2e_v1'
E2E_SAVE_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}'

# Save with config
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'e2e_weights_result': e2e_weights_result,
    'e2e_scales_result': e2e_scales_result,
}

save_checkpoint(model, E2E_SAVE_DIR, config=config)

# Upload to Google Drive
!tar -czvf {E2E_RUN_NAME}.tgz -C {LOCAL_RUNS} {E2E_RUN_NAME}
!cp {E2E_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{E2E_RUN_NAME}.tgz')

Saved checkpoint to runs/anemll_q4_a4_e2e_v1/
  - model_state_dict.pt
  - indices.pt (196 layers, 420.0 MB)
  - config.json
anemll_q4_a4_e2e_v1/
anemll_q4_a4_e2e_v1/config.json
anemll_q4_a4_e2e_v1/indices.pt
anemll_q4_a4_e2e_v1/model_state_dict.pt

Uploaded to /content/drive/MyDrive/qwen3_runs/anemll_q4_a4_e2e_v1.tgz


# **INFERENCE OPTIMIZATION**

Before running inference, freeze all layers to precompute quantized weights.
This avoids recomputing `LUT[indices] * (scale_A @ scale_B)` on every forward pass.

In [47]:
# ============================================================
# FREEZE MODEL FOR FAST INFERENCE
# ============================================================
# Precompute quantized weights once for all layers
# This caches LUT[idx] * scale to avoid recomputation per token

from qat_lora import freeze_model_for_inference, unfreeze_model_for_training

print('Freezing model for inference...')
num_frozen = freeze_model_for_inference(model, verbose=False)
print(f'Frozen {num_frozen} layers')

# To resume training later:
# unfreeze_model_for_training(model)

Freezing model for inference...


KeyError: "attribute '_cached_weight_q' already exists"

In [48]:
import torch

# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval() # Set model to evaluation mode once

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt,max_new_tokens=1024)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50) # Separator for readability


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Prompt: What is the capital of France?
Response: <think>
Okay, the user is asking for the capital of France. Let me recall... France's capital is Paris. I think that's right. But wait, maybe I should double-check. Paris is indeed the capital, so I'm confident. No need to overcomplicate. Just answer with Paris.
</think>

The capital of France is **Paris**.
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
Okay, the user is asking about Apple Neural Engine. Let me start by recalling what I know about Apple Neural Engine. It's a framework developed by Apple for AI and machine learning. The main purpose is to accelerate training and deployment of AI models. 

First, I should explain the basic concept. Apple Neural Engine is a set of tools and libraries that allow developers to run machine learning models on Apple's hardware. It's designed to be efficient and scalable, which is important for large-scale applications.

I should mention 

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality

# **END-TO-END KD-QAT REFINEMENT**

After layer-by-layer training, refine the model with all layers unfrozen.

Two modes:
1. **Train weights** (scales frozen) - Fine-tune weights globally with hard label loss
2. **Train scales** (weights frozen) - Optimize scales for better quantization

## Distillation Options

| Parameter | Weight Training | Scale Training |
|-----------|----------------|----------------|
| `temperature` | 2.0 | 2.0 |
| `hard_top1_weight` | 0.1 (recommended) | 0.0 |
| `hard_full_weight` | 0.0 | 0.0 |

Hard label loss helps prevent divergence during weight training.

# **EXPORT FOR ANEMLL CONVERTER**

Snap weights to quantized values and export for external tools.

Two export modes:
- `store_lut_values=True`: weights = LUT[idx] (normalized in [-1,1]), scales separate
- `store_lut_values=False`: weights = LUT[idx] * scale (full dequant)

In [49]:
# ============================================================
# SNAP WEIGHTS AND EXPORT
# ============================================================
# Snap weights to LUT[idx] values for ANEMLL converter

from qat_lora import snap_all_weights, export_quantized_model, unfreeze_model_for_training

# First unfreeze to clear cached weights
unfreeze_model_for_training(model)

# Export quantized representation BEFORE snapping (keeps original weights)
print('Exporting quantized model representation...')
export_dict = export_quantized_model(model, verbose=True)

# Save export for ANEMLL converter
EXPORT_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}_export'
os.makedirs(EXPORT_DIR, exist_ok=True)
torch.save(export_dict, f'{EXPORT_DIR}/quantized_model.pt')
print(f'\nSaved export to {EXPORT_DIR}/quantized_model.pt')

# Each layer in export_dict contains:
# - indices: [out, in] uint8 LUT indices
# - quantized_weights: [out, in] LUT[idx] values in [-1, 1]
# - scales: {'scale_A': [out, rank], 'scale_B': [rank, in]} or full [out, in]
# - lut: [lut_size] values
# - bias, in_features, out_features, etc.

Exporting quantized model representation...
  [export] model.layers.0.self_attn.q_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.0.self_attn.k_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.self_attn.v_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.self_attn.o_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.0.mlp.gate_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.mlp.up_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.mlp.down_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.1.self_attn.q_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.1.self_attn.k_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.1.self_attn.v_proj: idx=1024.0KB, 

In [51]:
# ============================================================
# SNAP WEIGHTS TO FULL DEQUANT AND TEST
# ============================================================
# Snap weights = LUT[idx] * scale, then disable fake_quant for direct use

print('Snapping weights to full dequantized values (LUT[idx] * scale)...')
indices = snap_all_weights(model, store_lut_values=False, verbose=True)

# Disable fake quantization - use snapped weights directly
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinear':
        module.enable_fake_quant = False

print('\nTesting inference with snapped weights...')
model.eval()

# Quick test
response = run_inference(model, tokenizer, 'What is 2+2?', max_new_tokens=1024)
print(f'Prompt: What is 2+2?')
print(f'Response: {response}')

Snapping weights to full dequantized values (LUT[idx] * scale)...
  [snapped] model.layers.0.self_attn.q_proj: rel_error=0.000000, range=[-0.562, 0.625]
  [snapped] model.layers.0.self_attn.k_proj: rel_error=0.000000, range=[-0.389, 0.404]
  [snapped] model.layers.0.self_attn.v_proj: rel_error=0.000000, range=[-0.150, 0.169]
  [snapped] model.layers.0.self_attn.o_proj: rel_error=0.000000, range=[-0.402, 0.268]
  [snapped] model.layers.0.mlp.gate_proj: rel_error=0.000000, range=[-0.406, 0.305]
  [snapped] model.layers.0.mlp.up_proj: rel_error=0.000000, range=[-0.377, 0.326]
  [snapped] model.layers.0.mlp.down_proj: rel_error=0.000000, range=[-0.434, 0.371]
  [snapped] model.layers.1.self_attn.q_proj: rel_error=0.000000, range=[-0.490, 0.660]
  [snapped] model.layers.1.self_attn.k_proj: rel_error=0.000000, range=[-0.391, 0.250]
  [snapped] model.layers.1.self_attn.v_proj: rel_error=0.000000, range=[-0.156, 0.209]
  [snapped] model.layers.1.self_attn.o_proj: rel_error=0.000000, range=[-0.