# Anemll-Style Layer-by-Layer QAT

This notebook implements layer-by-layer QAT training using `AnemllQATLinear` with:
- Groupwise LUT quantization
- Low-rank scale factors (A @ B)
- KD cache for distillation
- **Hard label loss** for improved convergence

## Pipeline:
1. Load model and replace linears with AnemllQATLinear
2. Layer-by-layer scale optimization (weights frozen)
3. Layer-by-layer weight training (with hard label loss)
4. End-to-end refinement
5. (Optional) LoRA recovery

## Distillation Options:
- `temperature`: KL divergence temperature (default: 2.0)
- `hard_top1_weight`: Hard label top-1 loss weight (recommended: 0.1 for weights, 0.0 for scales)
- `hard_full_weight`: Hard label full vocab loss weight (optional)

In [1]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### GITUB

In [3]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
!git reset --hard HEAD
import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

from qat_lora import *

Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 449, done.[K
remote: Counting objects: 100% (139/139), done.[K
remote: Compressing objects: 100% (104/104), done.[K
remote: Total 449 (delta 98), reused 68 (delta 35), pack-reused 310 (from 1)[K
Receiving objects: 100% (449/449), 589.07 KiB | 12.02 MiB/s, done.
Resolving deltas: 100% (289/289), done.
/content/qwen3_apple_style_2bit_qat_lora
Already up to date.
HEAD is now at c577e82 Enhance logging in train_e2e function with formatted time display


In [4]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [5]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

#CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
#CACHE_NAME = 'alpaca_chat_think_both_L128_K64_R512'
CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'


CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

Extracting alpaca_chat_think_both_L128_K128_R1024.tgz from Google Drive...
total 17157672
drwx------ 2 root root      4096 Dec 26 02:45 .
drwxr-xr-x 3 root root      4096 Dec 29 00:46 ..
-rw------- 1 root root       423 Dec 26 02:45 meta.json
-rw------- 1 root root 899550149 Dec 26 02:46 shard_00000.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00001.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00002.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00003.pt
-rw------- 1 root root 899550149 Dec 26 02:44 shard_00004.pt
-rw------- 1 root root 899550149 Dec 26 02:43 shard_00005.pt


In [6]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
GROUP_SIZE = 16      # Group size for scales
SCALE_RANK = 4       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_GROUP_SIZE = 16
ATTN_SCALE_RANK = 4

# Training
BATCH_SIZE = 4
GRAD_ACCUM = 4

if torch.cuda.is_available():
    BATCH_SIZE=32
    GRAD_ACCUM=1

LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD / Distillation params
DISTILL_TEMP = 2.0
HARD_TOP1_WEIGHT = 0.2    # Hard label top-1 loss (helps convergence)
HARD_FULL_WEIGHT = 0.00005    # Hard label full vocab loss (optional)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16


QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}'

print(f'Quality: {QUAL}')

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')
print(f'Distillation: temp={DISTILL_TEMP}, hard_top1={HARD_TOP1_WEIGHT}, hard_full={HARD_FULL_WEIGHT}')

Quality: q4_a4
Device: cuda, dtype: torch.bfloat16
Quant config: lut=16, group=16, rank=4
Distillation: temp=2.0, hard_top1=0.2, hard_full=5e-05


In [7]:
# ============================================================
# Extracting LOCAL CACHE
# ============================================================

import os
from pathlib import Path

# Verify drive is mounted and cache exists
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

if not os.path.exists(cache_local_path):
    print(f'Cache not found at {cache_local_path}')
    print(f'Extracting from Google Drive...')
    os.makedirs(LOCAL_CACHES, exist_ok=True)
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/

# Verify cache exists now
assert os.path.exists(cache_local_path), f'Cache still not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'Cache ready: {len(cache_files)} files in {cache_local_path}')

Cache ready: 20 files in caches/alpaca_chat_think_both_L128_K128_R1024


In [8]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

Loading Qwen/Qwen3-0.6B...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded. Parameters: 596,049,920


In [9]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinear
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear as ane_module
importlib.reload(ane_module)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfig, replace_linear_with_anemll

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create configs
mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    group_size=GROUP_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
)

attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    group_size=ATTN_GROUP_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
)

print('\nReplacing linear layers...')
count = replace_linear_with_anemll(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked
from qat_lora import AnemllQATLinear
qat_count = sum(1 for _, m in model.named_modules() if isinstance(m, AnemllQATLinear))
print(f"\nVerification: {qat_count} AnemllQATLinear modules in model")

Checking model structure...
  Found Linear: model.layers.0.self_attn.q_proj
  Found Linear: model.layers.0.self_attn.k_proj
  Found Linear: model.layers.0.self_attn.v_proj
  Found Linear: model.layers.0.self_attn.o_proj
  Found Linear: model.layers.0.mlp.gate_proj
Total Linear modules: 197

Replacing linear layers...
  [replaced] model.layers.0.self_attn.q_proj
  [replaced] model.layers.0.self_attn.k_proj
  [replaced] model.layers.0.self_attn.v_proj
  [replaced] model.layers.0.self_attn.o_proj
  [replaced] model.layers.0.mlp.gate_proj
  [replaced] model.layers.0.mlp.up_proj
  [replaced] model.layers.0.mlp.down_proj
  [replaced] model.layers.1.self_attn.q_proj
  [replaced] model.layers.1.self_attn.k_proj
  [replaced] model.layers.1.self_attn.v_proj
  [replaced] model.layers.1.self_attn.o_proj
  [replaced] model.layers.1.mlp.gate_proj
  [replaced] model.layers.1.mlp.up_proj
  [replaced] model.layers.1.mlp.down_proj
  [replaced] model.layers.2.self_attn.q_proj
  [replaced] model.layers.2.

In [10]:
# ============================================================
# IMPORT LAYER-BY-LAYER QAT UTILITIES & VERIFY GRADIENTS
# ============================================================

from qat_lora import (
    evaluate_kd_loss,
    train_all_layers,
    AnemllQATLinear,
)

print('Layer QAT utilities imported from qat_lora')

# Verify gradient flow works
print('\nVerifying gradient flow...')
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if isinstance(m, AnemllQATLinear):
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinear modules found! Replacement failed.")
else:
    # Test gradient flow
    test_module.weight.requires_grad = True
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.weight.grad is not None:
            print(f"  Gradient OK: weight.grad.shape = {test_module.weight.grad.shape}")
            test_module.weight.grad = None  # Clear for actual training
        else:
            print("  ERROR: weight.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

Layer QAT utilities imported from qat_lora

Verifying gradient flow...
ERROR: No AnemllQATLinear modules found! Replacement failed.

Computing initial KD loss...
Initial KD Loss: 1.6692


# **SCALE OPTIMIZATION** (Weights Frozen)

After layer-by-layer QAT on weights, optimize the per-weight scales (A @ B) to further reduce quantization error.

- Weights are **frozen**
- Only `scale_A` and `scale_B` are trained
- Much fewer parameters → can use higher learning rate

In [None]:
# ============================================================
# LAYER-BY-LAYER SCALE OPTIMIZATION
# ============================================================
# Freeze weights, only train scale_A and scale_B tensors
# Higher LR since fewer parameters
# Note: Hard label loss not needed for scale optimization

SCALE_LR = 1e-3  # Higher LR for scales (fewer params)
SCALE_EPOCHS = 2  # More epochs since scales have less capacity


print('Starting scale-only layer-by-layer optimization...')
print(f'LR: {SCALE_LR}, Epochs per layer: {SCALE_EPOCHS}')

# Get loss before scale optimization
pre_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'KD Loss before scale optimization: {pre_scale_loss:.4f}')

# Train scales layer-by-layer (no hard label needed for scales)
scale_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=SCALE_LR,
    epochs_per_layer=SCALE_EPOCHS,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=False,  # Freeze weights
    train_scales=True,    # Train scales only
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=0.0,  # Not needed for scale optimization
    hard_full_weight=0.0,
    verbose=True,
    steps_per_layer=100,
)

# Evaluate after scale optimization
post_scale_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'\n=== Scale Optimization Results ===')
print(f'Before: {pre_scale_loss:.4f}')
print(f'After:  {post_scale_loss:.4f}')
print(f'Improvement: {pre_scale_loss - post_scale_loss:.4f}')

# **RUN** LAYER-BY-LAYER TRAINING

In [None]:
# ============================================================
# LAYER-BY-LAYER WEIGHT TRAINING
# ============================================================
# Train weights with hard label loss for better convergence

print('Starting layer-by-layer weight training...')
print(f'LR: {LR}, Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train all layers using the imported function
layer_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs_per_layer=EPOCHS_PER_LAYER,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_weights=True,   # Train weights
    train_scales=False,   # Keep scales frozen for now
    local_weight=0.5,
    global_weight=0.5,
    hard_top1_weight=HARD_TOP1_WEIGHT,  # Helps convergence
    hard_full_weight=HARD_FULL_WEIGHT,
    verbose=True,
    steps_per_layer=100,
)

In [None]:
# ============================================================
# EVALUATE AFTER LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

In [None]:
# ============================================================
# SAVE CHECKPOINT
# ============================================================

import os

RUN_NAME = f'anemll_{QUAL}_layer_by_layer_v1'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'layer_losses': layer_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved to {SAVE_DIR}')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

# **END-TO-END KD-QAT REFINEMENT**

After layer-by-layer training, refine the model with all layers unfrozen.

Two modes:
1. **Train weights** (scales frozen) - Fine-tune weights globally with hard label loss
2. **Train scales** (weights frozen) - Optimize scales for better quantization

## Distillation Options

| Parameter | Weight Training | Scale Training |
|-----------|----------------|----------------|
| `temperature` | 2.0 | 2.0 |
| `hard_top1_weight` | 0.1 (recommended) | 0.0 |
| `hard_full_weight` | 0.0 | 0.0 |

Hard label loss helps prevent divergence during weight training.

In [11]:
# ============================================================
# END-TO-END KD-QAT: TRAIN SCALES (WEIGHTS FROZEN)
# ============================================================
# Scale training doesn't need hard label loss

# Train scales (weights frozen) - higher LR since fewer params
e2e_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=4000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=5e-4,  # Higher LR for scales
    use_cosine_schedule=True,
    warmup_steps=100,          # Linear warmup
    min_lr_ratio=0.1,      # End at 5e-5
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,  # Not needed for scale training
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=True,  # ← Freeze attention (4-bit), train MLP (2-bit) only
)

=== End-to-End KD-QAT ===
Mode: scales (MLP only)
Trainable params: 1,376,256
Frozen attention params: 177,307,648
Steps: 4000, LR: 0.0005, Batch: 64
LR Schedule: warmup=100, cosine→5.00e-05

Initial KD Loss: 1.6692
[20/4000] loss=1.5389 lr=1.00e-04 (0:12, ETA 42:31)
[40/4000] loss=1.0328 lr=2.00e-04 (0:20, ETA 33:18)
[60/4000] loss=0.7492 lr=3.00e-04 (0:27, ETA 30:08)
[80/4000] loss=0.6013 lr=4.00e-04 (0:34, ETA 28:29)
[100/4000] loss=0.5284 lr=5.00e-04 (0:42, ETA 27:48)
  [Eval] KD Loss: 0.5085 (best: 1.6692)
[120/4000] loss=0.4974 lr=5.00e-04 (0:55, ETA 30:06)
[140/4000] loss=0.4700 lr=5.00e-04 (1:03, ETA 29:03)
[160/4000] loss=0.4445 lr=5.00e-04 (1:10, ETA 28:13)
[180/4000] loss=0.4349 lr=5.00e-04 (1:18, ETA 27:45)
[200/4000] loss=0.4200 lr=4.99e-04 (1:25, ETA 27:10)
  [Eval] KD Loss: 0.4008 (best: 0.5085)
[220/4000] loss=0.4154 lr=4.99e-04 (1:38, ETA 28:18)
[240/4000] loss=0.4169 lr=4.99e-04 (1:46, ETA 27:42)
[260/4000] loss=0.4216 lr=4.98e-04 (1:53, ETA 27:10)
[280/4000] loss=0.4

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [12]:
# ============================================================
# FINE-TUNE END-TO-END KD-QAT: TRAIN WEIGHTS (SCALES FROZEN)
# ============================================================
# Use hard label loss for stable weight training

from qat_lora import train_e2e, save_checkpoint, load_checkpoint, unfreeze_model_for_training

# Unfreeze for training (clear any cached weights)
unfreeze_model_for_training(model)

print('E2E weight training with hard label loss...')
print(f'Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train weights (scales frozen)
e2e_weights_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,  # 5x lower for polish
    use_cosine_schedule=True,
    warmup_steps=0,        # Skip - LR already gentle
    min_lr_ratio=0.1,      # End at 1e-5
    temperature=DISTILL_TEMP,
    train_weights=True,
    train_scales=False,
    hard_top1_weight=0.2,  # Helps prevent divergence
    hard_full_weight=HARD_FULL_WEIGHT,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
)

E2E weight training with hard label loss...
Hard label: top1=0.2, full=5e-05
=== End-to-End KD-QAT ===
Mode: weights
Trainable params: 440,401,920
Steps: 1000, LR: 0.0001, Batch: 64
Hard label: top1=0.2, full=5e-05
LR Schedule: cosine→1.00e-05

Initial KD Loss: 0.3142
[50/1000] loss=0.9857 lr=9.94e-05 (0:27, ETA 8:45)
[100/1000] loss=0.5629 lr=9.78e-05 (0:49, ETA 7:28)
[150/1000] loss=0.5132 lr=9.51e-05 (1:11, ETA 6:46)
[200/1000] loss=0.5000 lr=9.14e-05 (1:34, ETA 6:16)
[250/1000] loss=0.4771 lr=8.68e-05 (1:56, ETA 5:50)
[300/1000] loss=0.4587 lr=8.15e-05 (2:19, ETA 5:24)
[350/1000] loss=0.4179 lr=7.54e-05 (2:41, ETA 4:59)
[400/1000] loss=0.3830 lr=6.89e-05 (3:03, ETA 4:35)
[450/1000] loss=0.3623 lr=6.20e-05 (3:25, ETA 4:11)
[500/1000] loss=0.3584 lr=5.50e-05 (3:47, ETA 3:47)
  [Eval] KD Loss: 0.1882 (best: 0.3142)
[550/1000] loss=0.3518 lr=4.80e-05 (4:16, ETA 3:29)
[600/1000] loss=0.3498 lr=4.11e-05 (4:38, ETA 3:05)
[650/1000] loss=0.3628 lr=3.46e-05 (5:00, ETA 2:41)
[700/1000] loss=

In [13]:
# ============================================================
# END-TO-END KD-QAT: TRAIN WEIGHTS (SCALES FROZEN)
# ============================================================
# Use hard label loss for stable weight training

from qat_lora import train_e2e, save_checkpoint, load_checkpoint, unfreeze_model_for_training

# Unfreeze for training (clear any cached weights)
unfreeze_model_for_training(model)

print('E2E weight training with hard label loss...')
print(f'Hard label: top1={HARD_TOP1_WEIGHT}, full={HARD_FULL_WEIGHT}')

# Train weights (scales frozen)
e2e_weights_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-5,
    temperature=DISTILL_TEMP,
    train_weights=True,
    train_scales=False,
    hard_top1_weight=0.2,  # Helps prevent divergence
    hard_full_weight=0.0,
    logging_steps=50,
    eval_steps=500,
    verbose=True,
    train_mlp_only=True,
    use_cosine_schedule=True,
    warmup_steps=100,
)

E2E weight training with hard label loss...
Hard label: top1=0.2, full=5e-05
=== End-to-End KD-QAT ===
Mode: weights (MLP only)
Trainable params: 264,241,152
Frozen attention params: 177,307,648
Steps: 1000, LR: 1e-05, Batch: 64
Hard label: top1=0.2, full=0.0
LR Schedule: warmup=100, cosine→1.00e-06

Initial KD Loss: 0.1596
[50/1000] loss=0.2883 lr=5.00e-06 (0:25, ETA 8:08)
[100/1000] loss=0.2848 lr=1.00e-05 (0:46, ETA 6:57)
[150/1000] loss=0.2875 lr=9.93e-06 (1:07, ETA 6:20)
[200/1000] loss=0.2845 lr=9.73e-06 (1:27, ETA 5:51)
[250/1000] loss=0.2866 lr=9.40e-06 (1:49, ETA 5:27)
[300/1000] loss=0.2947 lr=8.95e-06 (2:09, ETA 5:03)
[350/1000] loss=0.2884 lr=8.39e-06 (2:30, ETA 4:39)
[400/1000] loss=0.2872 lr=7.75e-06 (2:51, ETA 4:16)
[450/1000] loss=0.2792 lr=7.04e-06 (3:12, ETA 3:54)
[500/1000] loss=0.2810 lr=6.28e-06 (3:32, ETA 3:32)
  [Eval] KD Loss: 0.1494 (best: 0.1596)
[550/1000] loss=0.2802 lr=5.50e-06 (3:59, ETA 3:15)
[600/1000] loss=0.2783 lr=4.72e-06 (4:20, ETA 2:53)
[650/1000] 

# END-TO-END -Attention: GENTLE MLP+ATT TRAIN SCALES (WEIGHTS FROZEN)


In [None]:
# ============================================================
# END-TO-END KD-QAT: GENTLE MLP+ATT TRAIN SCALES (WEIGHTS FROZEN)
# ============================================================
# Scale training doesn't need hard label loss
unfreeze_model_for_training(model)

# Freeze MLP scales, only train attention scales
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinear':
        is_attn = any(x in name for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj'])
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = is_attn
            module.scale_B.requires_grad = is_attn
        module.weight.requires_grad = False

# Train scales (weights frozen) - higher LR since fewer params
e2e_scales_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=64 if torch.cuda.is_available() else 32,
    lr=1e-4,  # Higher LR for scales
    use_cosine_schedule=True,
    warmup_steps=100,          # Linear warmup
    min_lr_ratio=0.1,      # End at 5e-5
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,  # Not needed for scale training
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,  # ← Freeze attention (4-bit), train MLP (2-bit) only
)

=== End-to-End KD-QAT ===
Mode: scales
Trainable params: 13,303,808
Steps: 2000, LR: 0.0001, Batch: 64
LR Schedule: warmup=100, cosine→1.00e-05

Initial KD Loss: 0.4702


KeyboardInterrupt: 

In [16]:
# ============================================================
# SAVE FINAL CHECKPOINT
# ============================================================

unfreeze_model_for_training(model)
E2E_RUN_NAME = f'anemll_{QUAL}_e2e_v2_scales_only'
E2E_SAVE_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}'

# Save with config
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    #'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    #'e2e_weights_result': e2e_weights_result,
    'e2e_scales_result': e2e_scales_result,
}

save_checkpoint(model, E2E_SAVE_DIR, config=config)

# Upload to Google Drive
!tar -czvf {E2E_RUN_NAME}.tgz -C {LOCAL_RUNS} {E2E_RUN_NAME}
!cp {E2E_RUN_NAME}.tgz {GD_RUNS}/
print(f'\nUploaded to {GD_RUNS}/{E2E_RUN_NAME}.tgz')

Saved checkpoint to runs/anemll_q4_a4_e2e_v2_scales_only/
  - model_state_dict.pt
  - indices.pt (196 layers, 420.0 MB)
  - config.json
anemll_q4_a4_e2e_v2_scales_only/
anemll_q4_a4_e2e_v2_scales_only/config.json
anemll_q4_a4_e2e_v2_scales_only/indices.pt
anemll_q4_a4_e2e_v2_scales_only/model_state_dict.pt

Uploaded to /content/drive/MyDrive/qwen3_runs/anemll_q4_a4_e2e_v2_scales_only.tgz


# **INFERENCE OPTIMIZATION**

Before running inference, freeze all layers to precompute quantized weights.
This avoids recomputing `LUT[indices] * (scale_A @ scale_B)` on every forward pass.

In [18]:
# ============================================================
# FREEZE MODEL FOR FAST INFERENCE
# ============================================================
# Precompute quantized weights once for all layers
# This caches LUT[idx] * scale to avoid recomputation per token

from qat_lora import freeze_model_for_inference, unfreeze_model_for_training


print('Freezing model for inference...')
num_frozen = freeze_model_for_inference(model, verbose=False)
print(f'Frozen {num_frozen} layers')

# To resume training later:
# unfreeze_model_for_training(model)

Freezing model for inference...
Frozen 196 layers


In [19]:
import torch

# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval() # Set model to evaluation mode once

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt,max_new_tokens=1024)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50) # Separator for readability


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Prompt: What is the capital of France?
Response: <think>
<think>
</think>

The capital of France is **Paris**.
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
<think>
</think>

The **Apple Neural Engine** is a powerful computing platform developed by Apple Inc. It is designed to run on Apple devices and is used for various applications, including AI and machine learning. It is a key component of Apple's ecosystem and is known for its performance and efficiency in handling complex tasks.
--------------------------------------------------
Prompt: Explain quantum mechanics
Response: <think>
<think>
</think>

Quantum mechanics is a fundamental theory of physics that describes the behavior of particles at the smallest scales, such as atoms, molecules, and even subatomic particles. It is a revolutionary theory that challenges our classical understanding of the physical world, which is based on the principles of classical mechanics and

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=512):
    messages = [
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True
    )
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            repetition_penalty=1.1,
        )

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

# List of prompts to test
prompts = [
    'What is the capital of France?',
    'What is Apple Neural Engine?',
    'Explain quantum mechanics',
    'What is speed of light'
]

model.eval()

for prompt in prompts:
    response = run_inference(model, tokenizer, prompt, max_new_tokens=512)
    print(f'Prompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)


Prompt: What is the capital of France?
Response: <think>
<think>
</think>

The capital of France is **Paris**, located in the western part of the country.<|im_end|>
--------------------------------------------------
Prompt: What is Apple Neural Engine?
Response: <think>
<think>
</think>

The **Apple Neural Engine** is a key component of Apple's operating system, specifically the **iOS** operating system. It is developed by Apple and is part of the iOS ecosystem. The neural engine is responsible for handling various tasks such as image processing, speech recognition, and machine learning, which are essential for applications like apps, games, and other services. It plays a crucial role in enabling the development and performance of Apple products.<|im_end|>
--------------------------------------------------
Prompt: Explain quantum mechanics
Response: <think>
<think>
</think>

Quantum Mechanics is a fundamental field of physics that describes the behavior of particles at the smallest pos

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality

# **EXPORT FOR ANEMLL CONVERTER**

Snap weights to quantized values and export for external tools.

Two export modes:
- `store_lut_values=True`: weights = LUT[idx] (normalized in [-1,1]), scales separate
- `store_lut_values=False`: weights = LUT[idx] * scale (full dequant)

In [None]:
# ============================================================
# SNAP WEIGHTS AND EXPORT
# ============================================================
# Snap weights to LUT[idx] values for ANEMLL converter

from qat_lora import snap_all_weights, export_quantized_model, unfreeze_model_for_training

# First unfreeze to clear cached weights
unfreeze_model_for_training(model)

# Export quantized representation BEFORE snapping (keeps original weights)
print('Exporting quantized model representation...')
export_dict = export_quantized_model(model, verbose=True)

# Save export for ANEMLL converter
EXPORT_DIR = f'{LOCAL_RUNS}/{E2E_RUN_NAME}_export'
os.makedirs(EXPORT_DIR, exist_ok=True)
torch.save(export_dict, f'{EXPORT_DIR}/quantized_model.pt')
print(f'\nSaved export to {EXPORT_DIR}/quantized_model.pt')

# Each layer in export_dict contains:
# - indices: [out, in] uint8 LUT indices
# - quantized_weights: [out, in] LUT[idx] values in [-1, 1]
# - scales: {'scale_A': [out, rank], 'scale_B': [rank, in]} or full [out, in]
# - lut: [lut_size] values
# - bias, in_features, out_features, etc.

In [None]:
# ============================================================
# SNAP WEIGHTS TO FULL DEQUANT AND TEST
# ============================================================
# Snap weights = LUT[idx] * scale, then disable fake_quant for direct use

print('Snapping weights to full dequantized values (LUT[idx] * scale)...')
indices = snap_all_weights(model, store_lut_values=False, verbose=True)

# Disable fake quantization - use snapped weights directly
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinear':
        module.enable_fake_quant = False

print('\nTesting inference with snapped weights...')
model.eval()

# Quick test
response = run_inference(model, tokenizer, 'What is 2+2?', max_new_tokens=1024)
print(f'Prompt: What is 2+2?')
print(f'Response: {response}')

In [None]:
torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_0.4613.pt')  # Local, fast

torch.save(model.state_dict(), '/tmp/backup_mlp_e2e_w_0.3824.pt')  # Local, fast

In [15]:
torch.save(model.state_dict(), '/tmp/backup_mlp_e4e_4_4.pt')  # Local, fast

In [None]:
model.load_state_dict(torch.load('/tmp/backup_initial.pt', map_location=DEVICE))

<All keys matched successfully>

In [None]:
model.load_state_dict(torch.load('/tmp/backup_mlp_e2e_w_0.3824.pt', map_location=DEVICE))

<All keys matched successfully>