# Anemll-Style Layer-by-Layer QAT

This notebook implements layer-by-layer QAT training using `AnemllQATLinear` with:
- Groupwise LUT quantization
- Low-rank scale factors (A @ B)
- KD cache for distillation

## Pipeline:
1. Load model and replace linears with AnemllQATLinear
2. Layer-by-layer QAT (freeze all but current layer)
3. End-to-end refinement
4. (Optional) LoRA recovery

In [2]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [3]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### GITUB

In [4]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
# Install dependencies

Cloning into 'qwen3_apple_style_2bit_qat_lora'...
remote: Enumerating objects: 346, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 346 (delta 20), reused 28 (delta 14), pack-reused 310 (from 1)[K
Receiving objects: 100% (346/346), 438.22 KiB | 950.00 KiB/s, done.
Resolving deltas: 100% (211/211), done.
/content/qwen3_apple_style_2bit_qat_lora
Already up to date.


In [5]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [6]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

Extracting alpaca_chat_think_both_L128_K32_R256.tgz from Google Drive...
total 4298960
drwx------ 2 root root      4096 Dec 18 00:00 .
drwxr-xr-x 3 root root      4096 Dec 26 05:59 ..
-rw------- 1 root root       421 Dec 18 00:15 meta.json
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00000.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00001.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00002.pt
-rw------- 1 root root 112692165 Dec 18 00:15 shard_00003.pt
-rw------- 1 root root 112692165 Dec 18 00:16 shard_00004.pt
-rw------- 1 root root 112692165 Dec 18 00:16 shard_00005.pt


In [9]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_SIZE = 16        # 4-bit = 16 levels
GROUP_SIZE = 32      # Group size for scales
SCALE_RANK = 4       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_SIZE = 16
ATTN_GROUP_SIZE = 32
ATTN_SCALE_RANK = 8

# Training
BATCH_SIZE = 64
if torch.cuda.is_available():
    BATCH_SIZE=64
GRAD_ACCUM = 4
LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD params
DISTILL_TEMP = 2.0

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')

Device: cuda, dtype: torch.bfloat16
Quant config: lut=16, group=32, rank=4


In [10]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

Loading Qwen/Qwen3-0.6B...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/1.50G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Loaded. Parameters: 596,049,920


In [11]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinear
# ============================================================

import sys
sys.path.insert(0, '.')

# Force reimport to get latest code
import importlib
import qat_lora
importlib.reload(qat_lora)
import qat_lora.ane_qat_linear as ane_module
importlib.reload(ane_module)
import qat_lora.layer_qat as layer_module
importlib.reload(layer_module)

from qat_lora import AnemllQuantConfig, replace_linear_with_anemll

# Debug: Check what modules exist in the model
print("Checking model structure...")
import torch.nn as nn
linear_count = 0
for name, m in model.named_modules():
    if isinstance(m, nn.Linear):
        linear_count += 1
        if linear_count <= 5:
            print(f"  Found Linear: {name}")
print(f"Total Linear modules: {linear_count}")

# Create configs
mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    group_size=GROUP_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
)

attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    group_size=ATTN_GROUP_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
)

print('\nReplacing linear layers...')
count = replace_linear_with_anemll(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

# Verify replacement worked
from qat_lora import AnemllQATLinear
qat_count = sum(1 for _, m in model.named_modules() if isinstance(m, AnemllQATLinear))
print(f"\nVerification: {qat_count} AnemllQATLinear modules in model")

Checking model structure...
  Found Linear: model.layers.0.self_attn.q_proj
  Found Linear: model.layers.0.self_attn.k_proj
  Found Linear: model.layers.0.self_attn.v_proj
  Found Linear: model.layers.0.self_attn.o_proj
  Found Linear: model.layers.0.mlp.gate_proj
Total Linear modules: 197

Replacing linear layers...
  [replaced] model.layers.0.self_attn.q_proj
  [replaced] model.layers.0.self_attn.k_proj
  [replaced] model.layers.0.self_attn.v_proj
  [replaced] model.layers.0.self_attn.o_proj
  [replaced] model.layers.0.mlp.gate_proj
  [replaced] model.layers.0.mlp.up_proj
  [replaced] model.layers.0.mlp.down_proj
  [replaced] model.layers.1.self_attn.q_proj
  [replaced] model.layers.1.self_attn.k_proj
  [replaced] model.layers.1.self_attn.v_proj
  [replaced] model.layers.1.self_attn.o_proj
  [replaced] model.layers.1.mlp.gate_proj
  [replaced] model.layers.1.mlp.up_proj
  [replaced] model.layers.1.mlp.down_proj
  [replaced] model.layers.2.self_attn.q_proj
  [replaced] model.layers.2.

In [14]:
# ============================================================
# IMPORT LAYER-BY-LAYER QAT UTILITIES & VERIFY GRADIENTS
# ============================================================

from qat_lora import (
    evaluate_kd_loss,
    train_all_layers,
    AnemllQATLinear,
)

print('Layer QAT utilities imported from qat_lora')

# Verify gradient flow works
print('\nVerifying gradient flow...')
layer0 = model.model.layers[0]
test_module = None
for name, m in layer0.named_modules():
    if isinstance(m, AnemllQATLinear):
        test_module = m
        break

if test_module is None:
    print("ERROR: No AnemllQATLinear modules found! Replacement failed.")
else:
    # Test gradient flow
    test_module.weight.requires_grad = True
    x = torch.randn(1, 10, test_module.in_features, device=DEVICE, dtype=DTYPE)
    y = test_module(x)
    loss = y.sum()
    try:
        loss.backward()
        if test_module.weight.grad is not None:
            print(f"  Gradient OK: weight.grad.shape = {test_module.weight.grad.shape}")
            test_module.weight.grad = None  # Clear for actual training
        else:
            print("  ERROR: weight.grad is None after backward!")
    except Exception as e:
        print(f"  ERROR during backward: {e}")

# Compute initial KD loss
print('\nComputing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'Initial KD Loss: {initial_loss:.4f}')

Layer QAT utilities imported from qat_lora

Verifying gradient flow...
ERROR: No AnemllQATLinear modules found! Replacement failed.

Computing initial KD loss...
Initial KD Loss: 141.2995


In [15]:
# ============================================================
# RUN LAYER-BY-LAYER TRAINING
# ============================================================

import os
from pathlib import Path

# Verify drive is mounted and cache exists
if not os.path.exists('/content/drive/MyDrive'):
    print('Google Drive not mounted! Mounting now...')
    from google.colab import drive
    drive.mount('/content/drive')

if not os.path.exists(cache_local_path):
    print(f'Cache not found at {cache_local_path}')
    print(f'Extracting from Google Drive...')
    os.makedirs(LOCAL_CACHES, exist_ok=True)
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/

# Verify cache exists now
assert os.path.exists(cache_local_path), f'Cache still not found at {cache_local_path}'
cache_files = list(Path(cache_local_path).glob('*.pt'))
print(f'Cache ready: {len(cache_files)} files in {cache_local_path}')

Cache ready: 40 files in caches/alpaca_chat_think_both_L128_K32_R256


# **RUN** LAYER-BY-LAYER TRAINING

In [None]:
# Train all layers using the imported function
layer_losses = train_all_layers(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    lr=LR,
    epochs_per_layer=EPOCHS_PER_LAYER,
    grad_accum=GRAD_ACCUM,
    temperature=DISTILL_TEMP,
    train_scales=False,  # Keep scales frozen for now
    verbose=True,
)

Training 28 layers...
Cache: caches/alpaca_chat_think_both_L128_K32_R256
Batch size: 64, Grad accum: 4
LR: 2e-05, Epochs per layer: 1

[Initial Global KD Loss]: 141.2995

=== Layer 0 === (15,728,640 trainable params)
  [Global KD Loss BEFORE]: 134.4200
  Step 10, Train Loss: 113.8174 (13.5s)
  Step 20, Train Loss: 106.2101 (26.6s)
  Step 30, Train Loss: 101.3711 (39.6s)
  Step 40, Train Loss: 98.3777 (52.7s)
  Step 50, Train Loss: 96.0273 (65.8s)


In [None]:
# ============================================================
# EVALUATE AFTER LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

In [None]:
# ============================================================
# SAVE CHECKPOINT
# ============================================================

import os

RUN_NAME = 'anemll_q4_layer_by_layer_v1'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'layer_losses': layer_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved to {SAVE_DIR}')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# Test
model.eval()
prompt = 'What is the capital of France?'
response = run_inference(model, tokenizer, prompt)
print(f'Prompt: {prompt}')
print(f'Response: {response}')

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality