# Anemll-Style Layer-by-Layer QAT

This notebook implements layer-by-layer QAT training using `AnemllQATLinear` with:
- Groupwise LUT quantization
- Low-rank scale factors (A @ B)
- KD cache for distillation
- **Hard label loss** for improved convergence

## Pipeline:
1. Load model and replace linears with AnemllQATLinear
2. Layer-by-layer scale optimization (weights frozen)
3. Layer-by-layer weight training (with hard label loss)
4. End-to-end refinement
5. (Optional) LoRA recovery

## Distillation Options:
- `temperature`: KL divergence temperature (default: 2.0)
- `hard_top1_weight`: Hard label top-1 loss weight (recommended: 0.1 for weights, 0.0 for scales)
- `hard_full_weight`: Hard label full vocab loss weight (optional)

In [1]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### GITUB

In [14]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
!git reset --hard HEAD
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

from qat_lora import *

Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 437, done.[K
remote: Counting objects: 100% (127/127), done.[K
remote: Compressing objects: 100% (97/97), done.[K
remote: Total 437 (delta 89), reused 58 (delta 30), pack-reused 310 (from 1)[K
Receiving objects: 100% (437/437), 573.78 KiB | 6.52 MiB/s, done.
Resolving deltas: 100% (280/280), done.
/content/qwen3_apple_style_2bit_qat_lora/qwen3_apple_style_2bit_qat_lora/qwen3_apple_style_2bit_qat_lora
Already up to date.
HEAD is now at 998f33c Enhance AnemllQATLinear and training functions for improved memory management and learning rate scheduling


In [4]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [18]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

#CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
#CACHE_NAME = 'alpaca_chat_think_both_L128_K64_R512'
CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'


CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

Extracting alpaca_chat_think_both_L128_K128_R1024.tgz from Google Drive...
total 17157672
drwx------ 2 root root      4096 Dec 26 02:45 .
drwxr-xr-x 3 root root      4096 Dec 27 00:09 ..
-rw------- 1 root root       423 Dec 26 02:45 meta.json
-rw------- 1 root root 899550149 Dec 26 02:46 shard_00000.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00001.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00002.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00003.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00004.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00005.pt


In [15]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_BITS = 2
LUT_SIZE = 2**LUT_BITS
GROUP_SIZE = 16      # Group size for scales
SCALE_RANK = 32       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_GROUP_SIZE = 16
ATTN_SCALE_RANK = 8

# Training
BATCH_SIZE = 4
GRAD_ACCUM = 4

if torch.cuda.is_available():
    BATCH_SIZE=32
    GRAD_ACCUM=1

LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD / Distillation params
DISTILL_TEMP = 2.0
HARD_TOP1_WEIGHT = 0.2    # Hard label top-1 loss (helps convergence)
HARD_FULL_WEIGHT = 0.00005    # Hard label full vocab loss (optional)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16


QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}'

print(f'Quality: {QUAL}')

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')
print(f'Distillation: temp={DISTILL_TEMP}, hard_top1={HARD_TOP1_WEIGHT}, hard_full={HARD_FULL_WEIGHT}')

Quality: q2_a4
Device: cuda, dtype: torch.bfloat16
Quant config: lut=4, group=16, rank=32
Distillation: temp=2.0, hard_top1=0.2, hard_full=5e-05


In [8]:
# ============================================================
# Extracting LOCAL CACHE
# ============================================================

import os
from pathlib import Path

# Verify drive is mounted and cache exists
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

if not os.path.exists(cache_local_path):
    print(f'Cache not found at {cache_local_path}')
    print(f'Extracting from Google Drive...')
    os.makedirs(LOCAL_CACHES, exist_ok=True)
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/

# Verify cache exists now
assert os.path.exists(cache_local_path), f'Cache still not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'Cache ready: {len(cache_files)} files in {cache_local_path}')

Cache ready: 20 files in caches/alpaca_chat_think_both_L128_K128_R1024


In [5]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

Loading Qwen/Qwen3-0.6B...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loaded. Parameters: 596,049,920


In [8]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinear
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear as ane_module
importlib.reload(ane_module)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfig, replace_linear_with_anemll

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create configs
mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    group_size=GROUP_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
)

attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    group_size=ATTN_GROUP_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
)

print('\nReplacing linear layers...')
count = replace_linear_with_anemll(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked
from qat_lora import AnemllQATLinear
qat_count = sum(1 for _, m in model.named_modules() if isinstance(m, AnemllQATLinear))
print(f"\nVerification: {qat_count} AnemllQATLinear modules in model")

Checking model structure...
  Found Linear: model.layers.0.self_attn.q_proj
  Found Linear: model.layers.0.self_attn.k_proj
  Found Linear: model.layers.0.self_attn.v_proj
  Found Linear: model.layers.0.self_attn.o_proj
  Found Linear: model.layers.0.mlp.gate_proj
Total Linear modules: 197

Replacing linear layers...
  [replaced] model.layers.0.self_attn.q_proj
  [replaced] model.layers.0.self_attn.k_proj
  [replaced] model.layers.0.self_attn.v_proj
  [replaced] model.layers.0.self_attn.o_proj
  [replaced] model.layers.0.mlp.gate_proj
  [replaced] model.layers.0.mlp.up_proj
  [replaced] model.layers.0.mlp.down_proj
  [replaced] model.layers.1.self_attn.q_proj
  [replaced] model.layers.1.self_attn.k_proj
  [replaced] model.layers.1.self_attn.v_proj
  [replaced] model.layers.1.self_attn.o_proj
  [replaced] model.layers.1.mlp.gate_proj
  [replaced] model.layers.1.mlp.up_proj
  [replaced] model.layers.1.mlp.down_proj
  [replaced] model.layers.2.self_attn.q_proj
  [replaced] model.layers.2.

In [16]:
# ============================================================
# IMPORT LAYER-BY-LAYER QAT UTILITIES & VERIFY GRADIENTS
# ============================================================

from qat_lora import (
    evaluate_kd_loss,
    train_all_layers,
    AnemllQATLinear,
)

print('Layer QAT utilities imported from qat_lora')

# Verify gradient flow works
print('\nVerifying gradient flow...')
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if isinstance(m, AnemllQATLinear):
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinear modules found! Replacement failed.")
else:
    # Test gradient flow
    test_module.weight.requires_grad = True
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.weight.grad is not None:
            print(f"  Gradient OK: weight.grad.shape = {test_module.weight.grad.shape}")
            test_module.weight.grad = None  # Clear for actual training
        else:
            print("  ERROR: weight.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

Layer QAT utilities imported from qat_lora

Verifying gradient flow...
ERROR: No AnemllQATLinear modules found! Replacement failed.

Computing initial KD loss...
Initial KD Loss: 0.0000


# **SCALE OPTIMIZATION** (Weights Frozen)

After layer-by-layer QAT on weights, optimize the per-weight scales (A @ B) to further reduce quantization error.

- Weights are **frozen**
- Only `scale_A` and `scale_B` are trained
- Much fewer parameters → can use higher learning rate

In [11]:
# ============================================================
# LAYER-BY-LAYER SCALE OPTIMIZATION
# ============================================================
# Freeze weights, only train scale_A and scale_B tensors
# Higher LR since fewer parameters
# Note: Hard label loss not needed for scale optimization

SCALE_LR = 1e-3  # Higher LR for scales (fewer params)
SCALE_EPOCHS = 2  # More epochs since scales have less capacity


print('Starting scale-only layer-by-layer optimization...')
print(f'LR: {SCALE_LR}, Epochs per layer: {SCALE_EPOCHS}')

# Get loss before scale optimization
pre_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'KD Loss before scale optimization: {pre_scale_loss:.4f}')

# Train scales layer-by-layer (no hard label needed for scales)
scale_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=SCALE_LR,
    epochs_per_layer=SCALE_EPOCHS,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=False,  # Freeze weights
    train_scales=True,    # Train scales only
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=0.0,  # Not needed for scale optimization
    hard_full_weight=0.0,
    verbose=True,
    steps_per_layer=100,
)

# Evaluate after scale optimization
post_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'\n=== Scale Optimization Results ===')
print(f'Before: {pre_scale_loss:.4f}')
print(f'After:  {post_scale_loss:.4f}')
print(f'Improvement: {pre_scale_loss - post_scale_loss:.4f}')

Starting scale-only layer-by-layer optimization...
LR: 0.001, Epochs per layer: 2
KD Loss before scale optimization: 1.0885
Training 28 layers (mode=scales only)...
Cache: caches/alpaca_chat_think_both_L128_K32_R256
Batch size: 64, Grad accum: 1
LR: 0.001, Steps per layer: 100

[Initial Global KD Loss]: 1.0885

=== Layer 0 === (131,072 trainable params, mode=scales only)
  [Global KD Loss BEFORE]: 1.0345
  step 10: local=0.0439 global=0.8089 (3.9s)
  step 20: local=0.0710 global=0.7905 (7.2s)
  step 30: local=0.0724 global=0.7238 (10.4s)
  step 40: local=0.0671 global=0.6626 (13.7s)
  step 50: local=0.0641 global=0.6703 (17.0s)
  step 60: local=0.0619 global=0.6684 (20.3s)
  step 70: local=0.0622 global=0.6874 (23.5s)
  step 80: local=0.0609 global=0.6786 (26.8s)
  step 90: local=0.0638 global=0.6774 (30.1s)
  step 100: local=0.0646 global=0.6453 (33.3s)
  ---
  [Local Loss]:   0.0283 -> 0.0646 (Δ=-0.0364)
  [Global Loss]:  1.0063 -> 0.6453 (Δ=0.3609)
  [Eval KD]:      1.0345 -> 0.6752

# **RUN** LAYER-BY-LAYER TRAINING

In [33]:
# ============================================================
# LAYER-BY-LAYER WEIGHT TRAINING
# ============================================================
# Train weights with hard label loss for better convergence

print('Starting layer-by-layer weight training...')
print(f'LR: {LR}, Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train all layers using the imported function
layer_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs_per_layer=EPOCHS_PER_LAYER,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=True,   # Train weights
    train_scales=False,   # Keep scales frozen for now
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=HARD_TOP1_WEIGHT,  # Helps convergence
    hard_full_weight=HARD_FULL_WEIGHT,
    verbose=True,
    steps_per_layer=100,
)

Starting layer-by-layer weight training...
LR: 2e-05, Hard label: top1=0.2, full=5e-05
Training 28 layers (mode=weights)...
Cache: caches/alpaca_chat_think_both_L128_K128_R1024
Batch size: 32, Grad accum: 1
LR: 2e-05, Steps per layer: 100
Hard label: top1=0.2, full=5e-05

[Initial Global KD Loss]: 0.5927

=== Layer 0 === (15,728,640 trainable params, mode=weights)
  Hard label: top1=0.2, full=5e-05
  [Global KD Loss BEFORE]: 0.6361
  step 10: local=0.0000 global=0.7199 (2.4s)
  step 20: local=0.0000 global=0.7089 (4.1s)
  step 30: local=0.0000 global=0.8131 (5.8s)
  step 40: local=0.0000 global=0.7799 (8.0s)
  step 50: local=0.0000 global=0.7745 (9.7s)
  step 60: local=0.0000 global=0.6737 (11.4s)
  step 70: local=0.0000 global=0.7553 (13.6s)
  step 80: local=0.0000 global=0.8080 (15.4s)
  step 90: local=0.0000 global=0.7113 (17.1s)
  step 100: local=0.0000 global=0.7533 (19.2s)
  ---
  [Local Loss]:   0.0000 -> 0.0000 (Δ=-0.0000)
  [Global Loss]:  0.8142 -> 0.7533 (Δ=0.0609)
  [Eval K

KeyboardInterrupt: 

In [13]:
# ============================================================
# EVALUATE AFTER LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

Initial KD Loss: 1.0885
After Layer-by-Layer: 0.2810
Improvement: 0.8074


In [14]:
# ============================================================
# SAVE CHECKPOINT
# ============================================================

import os

RUN_NAME = f'anemll_{QUAL}_layer_by_layer_v1'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'layer_losses': layer_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved to {SAVE_DIR}')

Saved to runs/anemll_q4_a4_layer_by_layer_v1


In [15]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

anemll_q4_a4_layer_by_layer_v1/
anemll_q4_a4_layer_by_layer_v1/config.json
anemll_q4_a4_layer_by_layer_v1/model_state_dict.pt
Uploaded to /content/drive/MyDrive/qwen3_runs/anemll_q4_a4_layer_by_layer_v1.tgz


# **END-TO-END KD-QAT REFINEMENT**

After layer-by-layer training, refine the model with all layers unfrozen.

Two modes:
1. **Train weights** (scales frozen) - Fine-tune weights globally with hard label loss
2. **Train scales** (weights frozen) - Optimize scales for better quantization

## Distillation Options

| Parameter | Weight Training | Scale Training |
|-----------|----------------|----------------|
| `temperature` | 2.0 | 2.0 |
| `hard_top1_weight` | 0.1 (recommended) | 0.0 |
| `hard_full_weight` | 0.0 | 0.0 |

Hard label loss helps prevent divergence during weight training.

In [12]:
# ============================================================
# END-TO-END KD-QAT: TRAIN SCALES (WEIGHTS FROZEN)
# ============================================================
# Scale training doesn't need hard label loss

# Train scales (weights frozen) - higher LR since fewer params
e2e_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=5e-4,  # Higher LR for scales
    use_cosine_schedule=True,
    warmup_steps=100,          # Linear warmup
    min_lr_ratio=0.1,      # End at 5e-5
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,  # Not needed for scale training
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
)

=== End-to-End KD-QAT ===
Mode: scales
Trainable params: 13,303,808
Steps: 2000, LR: 0.0005, Batch: 64

Initial KD Loss: 10.1585
[20/2000] loss=6.8576 (13s, ETA 1330s)
[40/2000] loss=5.0527 (21s, ETA 1042s)
[60/2000] loss=3.2987 (29s, ETA 941s)
[80/2000] loss=2.3769 (37s, ETA 886s)
[100/2000] loss=1.9908 (45s, ETA 860s)
  [Eval] KD Loss: 1.8185 (best: 10.1585)
[120/2000] loss=1.7580 (59s, ETA 919s)
[140/2000] loss=1.6288 (66s, ETA 883s)
[160/2000] loss=1.5216 (74s, ETA 854s)
[180/2000] loss=1.4983 (82s, ETA 833s)
[200/2000] loss=1.3948 (90s, ETA 812s)
  [Eval] KD Loss: 1.3068 (best: 1.8185)
[220/2000] loss=1.3684 (104s, ETA 842s)
[240/2000] loss=1.3460 (112s, ETA 820s)
[260/2000] loss=1.2889 (120s, ETA 801s)
[280/2000] loss=1.2834 (127s, ETA 783s)
[300/2000] loss=1.2119 (136s, ETA 769s)
  [Eval] KD Loss: 1.1732 (best: 1.3068)
[320/2000] loss=1.2199 (149s, ETA 780s)
[340/2000] loss=1.2266 (156s, ETA 764s)
[360/2000] loss=1.1805 (164s, ETA 748s)
[380/2000] loss=1.1151 (172s, ETA 735s)
[4

In [28]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# ============================================================
# FINE-TUNE END-TO-END KD-QAT: TRAIN WEIGHTS (SCALES FROZEN)
# ============================================================
# Use hard label loss for stable weight training

from qat_lora import train_e2e, save_checkpoint, load_checkpoint, unfreeze_model_for_training

# Unfreeze for training (clear any cached weights)
unfreeze_model_for_training(model)

print('E2E weight training with hard label loss...')
print(f'Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train weights (scales frozen)
e2e_weights_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,  # 5x lower for polish
    use_cosine_schedule=True,
    warmup_steps=0,        # Skip - LR already gentle
    min_lr_ratio=0.1,      # End at 1e-5
    temperature=DISTILL_TEMP,
    train_weights=True,
    train_scales=False,
    hard_top1_weight=0.2,  # Helps prevent divergence
    hard_full_weight=HARD_FULL_WEIGHT,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
)

E2E weight training with hard label loss...
Hard label: top1=0.2, full=5e-05
=== End-to-End KD-QAT ===
Mode: weights
Trainable params: 440,401,920
Steps: 1000, LR: 0.0001, Batch: 64
Hard label: top1=0.2, full=5e-05

Initial KD Loss: 0.6396


In [29]:
# ============================================================
# END-TO-END KD-QAT: TRAIN WEIGHTS (SCALES FROZEN)
# ============================================================
# Use hard label loss for stable weight training

from qat_lora import train_e2e, save_checkpoint, load_checkpoint, unfreeze_model_for_training

# Unfreeze for training (clear any cached weights)
unfreeze_model_for_training(model)

print('E2E weight training with hard label loss...')
print(f'Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train weights (scales frozen)
e2e_weights_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=5e-7,
    temperature=DISTILL_TEMP,
    train_weights=True,
    train_scales=False,
    hard_top1_weight=0.2,  # Helps prevent divergence
    hard_full_weight=HARD_FULL_WEIGHT,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
)

E2E weight training with hard label loss...
Hard label: top1=0.2, full=5e-05
=== End-to-End KD-QAT ===
Mode: weights
Trainable params: 440,401,920
Steps: 2000, LR: 1e-06, Batch: 64
Hard label: top1=0.2, full=5e-05

Initial KD Loss: 0.6547
[50/2000] loss=0.8903 (27s, ETA 1066s)
[100/2000] loss=0.8553 (49s, ETA 938s)
[150/2000] loss=0.8465 (71s, ETA 881s)
[200/2000] loss=0.8545 (93s, ETA 839s)
[250/2000] loss=0.8477 (116s, ETA 811s)
[300/2000] loss=0.8571 (138s, ETA 781s)
[350/2000] loss=0.8432 (160s, ETA 753s)
[400/2000] loss=0.8444 (182s, ETA 726s)
[450/2000] loss=0.8565 (204s, ETA 703s)
[500/2000] loss=0.8385 (226s, ETA 678s)
  [Eval] KD Loss: 0.6324 (best: 0.6547)
[550/2000] loss=0.8398 (254s, ETA 669s)
[600/2000] loss=0.8162 (276s, ETA 644s)
[650/2000] loss=0.8234 (298s, ETA 619s)
[700/2000] loss=0.8082 (320s, ETA 594s)
[750/2000] loss=0.8313 (342s, ETA 570s)
[800/2000] loss=0.8517 (364s, ETA 546s)
[850/2000] loss=0.8297 (387s, ETA 523s)
[900/2000] loss=0.7969 (409s, ETA 500s)
[950/

In [19]:
# ============================================================
# SAVE FINAL CHECKPOINT
# ============================================================

unfreeze_model_for_training(model)
E2E_RUN_NAME = f'anemll_{QUAL}_e2e_v2_scales_only'
E2E_SAVE_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}'

# Save with config
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    #'e2e_weights_result': e2e_weights_result,
    'e2e_scales_result': e2e_scales_result,
}

save_checkpoint(model, E2E_SAVE_DIR, config=config)

# Upload to Google Drive
!tar -czvf {E2E_RUN_NAME}.tgz -C {LOCAL_RUNS} {E2E_RUN_NAME}
!cp {E2E_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{E2E_RUN_NAME}.tgz')

Saved checkpoint to runs/anemll_q2_a4_e2e_v2_scales_only/
  - model_state_dict.pt
  - indices.pt (196 layers, 420.0 MB)
  - config.json
anemll_q2_a4_e2e_v2_scales_only/
anemll_q2_a4_e2e_v2_scales_only/config.json
anemll_q2_a4_e2e_v2_scales_only/indices.pt
anemll_q2_a4_e2e_v2_scales_only/model_state_dict.pt

Uploaded to /content/drive/MyDrive/qwen3_runs/anemll_q2_a4_e2e_v2_scales_only.tgz


# **INFERENCE OPTIMIZATION**

Before running inference, freeze all layers to precompute quantized weights.
This avoids recomputing `LUT[indices] * (scale_A @ scale_B)` on every forward pass.

In [30]:
# ============================================================
# FREEZE MODEL FOR FAST INFERENCE
# ============================================================
# Precompute quantized weights once for all layers
# This caches LUT[idx] * scale to avoid recomputation per token

from qat_lora import freeze_model_for_inference, unfreeze_model_for_training


print('Freezing model for inference...')
num_frozen = freeze_model_for_inference(model, verbose=False)
print(f'Frozen {num_frozen} layers')

# To resume training later:
# unfreeze_model_for_training(model)

Freezing model for inference...


KeyError: "attribute '_cached_weight_q' already exists"

In [31]:
import torch

# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval() # Set model to evaluation mode once

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt,max_new_tokens=1024)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50) # Separator for readability


Prompt: What is the capital of France?
Response: <think>
<think>
</think>

The capital of France is **Paris**.
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
<think>
</think>

Apple Neural Engine is a system of neural processing technology developed by Apple. It is designed to enable the development of neural networks and to support the development of artificial intelligence. The system is used to create and train neural networks for various applications, including machine learning, natural language processing, and other AI systems.
--------------------------------------------------
Prompt: Explain quantum mechanics
Response: <think>
<think>
Okay, I'm a helpful assistant. Let me explain quantum mechanics in a simple way. It's the study of how particles behave at the quantum level. In the early days, people thought that particles are particles, but they also found that they can behave like waves. This is called quantum mechanics

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality

# **EXPORT FOR ANEMLL CONVERTER**

Snap weights to quantized values and export for external tools.

Two export modes:
- `store_lut_values=True`: weights = LUT[idx] (normalized in [-1,1]), scales separate
- `store_lut_values=False`: weights = LUT[idx] * scale (full dequant)

In [49]:
# ============================================================
# SNAP WEIGHTS AND EXPORT
# ============================================================
# Snap weights to LUT[idx] values for ANEMLL converter

from qat_lora import snap_all_weights, export_quantized_model, unfreeze_model_for_training

# First unfreeze to clear cached weights
unfreeze_model_for_training(model)

# Export quantized representation BEFORE snapping (keeps original weights)
print('Exporting quantized model representation...')
export_dict = export_quantized_model(model, verbose=True)

# Save export for ANEMLL converter
EXPORT_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}_export'
os.makedirs(EXPORT_DIR, exist_ok=True)
torch.save(export_dict, f'{EXPORT_DIR}/quantized_model.pt')
print(f'\nSaved export to {EXPORT_DIR}/quantized_model.pt')

# Each layer in export_dict contains:
# - indices: [out, in] uint8 LUT indices
# - quantized_weights: [out, in] LUT[idx] values in [-1, 1]
# - scales: {'scale_A': [out, rank], 'scale_B': [rank, in]} or full [out, in]
# - lut: [lut_size] values
# - bias, in_features, out_features, etc.

Exporting quantized model representation...
  [export] model.layers.0.self_attn.q_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.0.self_attn.k_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.self_attn.v_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.self_attn.o_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.0.mlp.gate_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.mlp.up_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.0.mlp.down_proj: idx=3072.0KB, qw=6144.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.1.self_attn.q_proj: idx=2048.0KB, qw=4096.0KB [-1.000, 1.000], scales=48.0KB
  [export] model.layers.1.self_attn.k_proj: idx=1024.0KB, qw=2048.0KB [-1.000, 1.000], scales=32.0KB
  [export] model.layers.1.self_attn.v_proj: idx=1024.0KB, 

In [32]:
# ============================================================
# SNAP WEIGHTS TO FULL DEQUANT AND TEST
# ============================================================
# Snap weights = LUT[idx] * scale, then disable fake_quant for direct use

print('Snapping weights to full dequantized values (LUT[idx] * scale)...')
indices = snap_all_weights(model, store_lut_values=False, verbose=True)

# Disable fake quantization - use snapped weights directly
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinear':
        module.enable_fake_quant = False

print('\nTesting inference with snapped weights...')
model.eval()

# Quick test
response = run_inference(model, tokenizer, 'What is 2+2?', max_new_tokens=1024)
print(f'Prompt: What is 2+2?')
print(f'Response: {response}')

Snapping weights to full dequantized values (LUT[idx] * scale)...
  [snapped] model.layers.0.self_attn.q_proj: rel_error=0.114258, range=[-0.586, 0.570]
  [snapped] model.layers.0.self_attn.k_proj: rel_error=0.110840, range=[-0.426, 0.395]
  [snapped] model.layers.0.self_attn.v_proj: rel_error=0.101562, range=[-0.128, 0.169]
  [snapped] model.layers.0.self_attn.o_proj: rel_error=0.127930, range=[-0.439, 0.249]
  [snapped] model.layers.0.mlp.gate_proj: rel_error=0.425781, range=[-0.385, 0.381]
  [snapped] model.layers.0.mlp.up_proj: rel_error=0.433594, range=[-0.355, 0.334]
  [snapped] model.layers.0.mlp.down_proj: rel_error=0.449219, range=[-0.365, 0.359]
  [snapped] model.layers.1.self_attn.q_proj: rel_error=0.112305, range=[-0.494, 0.641]
  [snapped] model.layers.1.self_attn.k_proj: rel_error=0.111328, range=[-0.340, 0.328]
  [snapped] model.layers.1.self_attn.v_proj: rel_error=0.103027, range=[-0.154, 0.188]
  [snapped] model.layers.1.self_attn.o_proj: rel_error=0.126953, range=[-0.