# Anemll-Style Layer-by-Layer QAT

This notebook implements layer-by-layer QAT training using `AnemllQATLinear` with:
- Groupwise LUT quantization
- Low-rank scale factors (A @ B)
- KD cache for distillation

## Pipeline:
1. Load model and replace linears with AnemllQATLinear
2. Layer-by-layer QAT (freeze all but current layer)
3. End-to-end refinement
4. (Optional) LoRA recovery

In [None]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
# to allow updates
!git fetch
!git pull
# Install dependencies

In [None]:
# Install dependencies
!pip install -q transformers accelerate safetensors

In [None]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

CACHE_NAME = 'alpaca_chat_think_both_L128_K32_R256'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

# Check if cache exists locally
import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config (4-bit with groupwise LUT)
LUT_SIZE = 16        # 4-bit = 16 levels
GROUP_SIZE = 32      # Group size for scales
SCALE_RANK = 4       # Low-rank for A @ B scales

# Attention quantization (same params)
ATTN_LUT_SIZE = 16
ATTN_GROUP_SIZE = 32
ATTN_SCALE_RANK = 8

# Training
BATCH_SIZE = 4
GRAD_ACCUM = 4
LR = 2e-5
EPOCHS_PER_LAYER = 1

# KD params
DISTILL_TEMP = 2.0

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.bfloat16

print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'Quant config: lut={LUT_SIZE}, group={GROUP_SIZE}, rank={SCALE_RANK}')

In [None]:
# ============================================================
# LOAD MODEL
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()
print(f'Loaded. Parameters: {sum(p.numel() for p in model.parameters()):,}')

In [None]:
# ============================================================
# REPLACE LINEARS WITH AnemllQATLinear
# ============================================================

import sys
sys.path.insert(0, '.')

from qat_lora.ane_qat_linear import AnemllQATLinear, AnemllQuantConfig
import re
import torch.nn as nn

def replace_linear_with_anemll(
    model: nn.Module,
    mlp_config: AnemllQuantConfig,
    attn_config: AnemllQuantConfig = None,
    quantize_attn: bool = True,
):
    """Replace MLP and optionally attention linears with AnemllQATLinear."""
    mlp_pattern = re.compile(r'\.mlp\.(gate_proj|up_proj|down_proj)$')
    attn_pattern = re.compile(r'\.self_attn\.(q_proj|k_proj|v_proj|o_proj)$')
    
    replacements = []
    
    for name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue
        if isinstance(module, AnemllQATLinear):
            continue
        
        # Check pattern
        is_mlp = mlp_pattern.search(name)
        is_attn = attn_pattern.search(name)
        
        if is_mlp:
            cfg = mlp_config
        elif is_attn and quantize_attn and attn_config:
            cfg = attn_config
        else:
            continue
        
        # Create replacement
        new_module = AnemllQATLinear.from_linear(module, config=cfg)
        
        # Find parent
        parts = name.rsplit('.', 1)
        if len(parts) == 2:
            parent_name, attr = parts
            parent = dict(model.named_modules())[parent_name]
        else:
            parent = model
            attr = name
        
        replacements.append((parent, attr, new_module, name))
    
    # Apply
    for parent, attr, new_module, name in replacements:
        setattr(parent, attr, new_module)
        print(f'  [replaced] {name}')
    
    return len(replacements)

# Create configs
mlp_config = AnemllQuantConfig(
    lut_size=LUT_SIZE,
    group_size=GROUP_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
)

attn_config = AnemllQuantConfig(
    lut_size=ATTN_LUT_SIZE,
    group_size=ATTN_GROUP_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
)

print('Replacing linear layers...')
count = replace_linear_with_anemll(model, mlp_config, attn_config, quantize_attn=True)
print(f'\nReplaced {count} layers')

In [None]:
# ============================================================
# INITIAL KD LOSS (before training)
# ============================================================

import torch.nn.functional as F
from pathlib import Path

def compute_kd_loss_batch(model, batch, device, temperature=2.0):
    """Compute KD loss for a batch using memory-efficient approach."""
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch.get('attention_mask')
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)
    
    topk_idx = batch['topk_idx'].to(device).long()
    topk_logits = batch['topk_logits'].to(device).float()
    
    # Get hidden states (not full logits)
    with torch.no_grad():
        out = model.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=False,
            return_dict=True,
        )
    hidden = out.last_hidden_state[:, :-1, :]  # [B, S, H]
    B, S, H = hidden.shape
    
    # Only compute logits for top-k
    K = topk_idx.size(-1)
    seq_len = min(S, topk_idx.size(1))
    
    h = hidden[:, :seq_len, :].reshape(B * seq_len, H)
    idx = topk_idx[:, :seq_len, :].reshape(B * seq_len, K)
    
    w = model.lm_head.weight[idx]  # [N, K, H]
    student_topk = torch.einsum('nh,nkh->nk', h, w).view(B, seq_len, K)
    
    # KL divergence with temperature
    t_logits = topk_logits[:, :seq_len, :]
    teacher_probs = F.softmax(t_logits / temperature, dim=-1)
    student_log_probs = F.log_softmax(student_topk / temperature, dim=-1)
    kl = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    
    return kl

def evaluate_kd_loss(model, cache_dir, device, num_samples=40, temperature=2.0):
    """Evaluate KD loss on cache samples."""
    cache_path = Path(cache_dir)
    files = sorted(cache_path.glob('*.pt'))[:num_samples]
    
    total_loss = 0.0
    count = 0
    
    model.eval()
    with torch.no_grad():
        for f in files:
            data = torch.load(f, map_location='cpu', weights_only=True)
            
            # Make batch
            batch = {
                'input_ids': data['input_ids'].unsqueeze(0) if data['input_ids'].dim() == 1 else data['input_ids'],
                'attention_mask': data.get('attention_mask'),
                'topk_idx': data['topk_idx'].unsqueeze(0) if data['topk_idx'].dim() == 2 else data['topk_idx'],
                'topk_logits': data['topk_logits'].unsqueeze(0) if data['topk_logits'].dim() == 2 else data['topk_logits'],
            }
            if batch['attention_mask'] is not None and batch['attention_mask'].dim() == 1:
                batch['attention_mask'] = batch['attention_mask'].unsqueeze(0)
            
            loss = compute_kd_loss_batch(model, batch, device, temperature)
            total_loss += loss.item()
            count += 1
    
    return total_loss / max(1, count)

print('Computing initial KD loss...')
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')

In [None]:
# ============================================================
# LAYER-BY-LAYER QAT TRAINING
# ============================================================

from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
import time

class KDCacheDataset(Dataset):
    """Dataset that loads KD cache files."""
    def __init__(self, cache_dir):
        self.files = sorted(Path(cache_dir).glob('*.pt'))
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        data = torch.load(self.files[idx], map_location='cpu', weights_only=True)
        return {
            'input_ids': data['input_ids'],
            'attention_mask': data.get('attention_mask', torch.ones_like(data['input_ids'])),
            'topk_idx': data['topk_idx'],
            'topk_logits': data['topk_logits'],
        }

def collate_fn(batch):
    return {
        'input_ids': torch.stack([b['input_ids'] for b in batch]),
        'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
        'topk_idx': torch.stack([b['topk_idx'] for b in batch]),
        'topk_logits': torch.stack([b['topk_logits'] for b in batch]),
    }

def get_layer_modules(model, layer_idx):
    """Get all AnemllQATLinear modules in a specific layer."""
    layer = model.model.layers[layer_idx]
    modules = []
    for name, m in layer.named_modules():
        if isinstance(m, AnemllQATLinear):
            modules.append((f'layers.{layer_idx}.{name}', m))
    return modules

def freeze_all_except_layer(model, layer_idx):
    """Freeze all parameters except the specified layer's AnemllQATLinear weights."""
    # Freeze everything
    for p in model.parameters():
        p.requires_grad = False
    
    # Unfreeze layer's quantized weights (NOT scales yet - freeze A, B)
    layer_modules = get_layer_modules(model, layer_idx)
    trainable = 0
    for name, m in layer_modules:
        # Only train the main weight, keep scales frozen for now
        m.weight.requires_grad = True
        trainable += m.weight.numel()
        # Keep scale_A, scale_B, lut frozen
        if m.scale_A is not None:
            m.scale_A.requires_grad = False
        if m.scale_B is not None:
            m.scale_B.requires_grad = False
    
    return trainable

def train_layer(model, layer_idx, dataloader, device, lr=2e-5, epochs=1, grad_accum=4):
    """Train a single layer."""
    trainable = freeze_all_except_layer(model, layer_idx)
    print(f'\n=== Layer {layer_idx} === ({trainable:,} trainable params)')
    
    # Get trainable params
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(params, lr=lr)
    
    model.train()
    total_loss = 0.0
    steps = 0
    
    for epoch in range(epochs):
        for i, batch in enumerate(dataloader):
            loss = compute_kd_loss_batch(model, batch, device, DISTILL_TEMP)
            loss = loss / grad_accum
            loss.backward()
            
            if (i + 1) % grad_accum == 0:
                optimizer.step()
                optimizer.zero_grad()
                steps += 1
                total_loss += loss.item() * grad_accum
                
                if steps % 10 == 0:
                    avg = total_loss / steps
                    print(f'  Step {steps}, Loss: {avg:.4f}')
    
    # Final eval
    model.eval()
    eval_loss = evaluate_kd_loss(model, cache_local_path, device, num_samples=20)
    print(f'  Layer {layer_idx} done. Eval Loss: {eval_loss:.4f}')
    
    return eval_loss

In [None]:
# ============================================================
# RUN LAYER-BY-LAYER TRAINING
# ============================================================

# Create dataloader
dataset = KDCacheDataset(cache_local_path)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

print(f'Dataset: {len(dataset)} samples')
print(f'Batches per epoch: {len(dataloader)}')

# Get number of layers
num_layers = len(model.model.layers)
print(f'Number of layers: {num_layers}')

# Train layer by layer
layer_losses = []
t0 = time.time()

for layer_idx in range(num_layers):
    loss = train_layer(
        model, layer_idx, dataloader, DEVICE,
        lr=LR, epochs=EPOCHS_PER_LAYER, grad_accum=GRAD_ACCUM
    )
    layer_losses.append(loss)

print(f'\nLayer-by-layer training complete in {time.time() - t0:.1f}s')
print(f'Final losses: {[f"{l:.4f}" for l in layer_losses]}')

In [None]:
# ============================================================
# EVALUATE AFTER LAYER-BY-LAYER
# ============================================================

model.eval()
post_layer_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40)
print(f'Initial KD Loss: {initial_loss:.4f}')
print(f'After Layer-by-Layer: {post_layer_loss:.4f}')
print(f'Improvement: {initial_loss - post_layer_loss:.4f}')

In [None]:
# ============================================================
# SAVE CHECKPOINT
# ============================================================

import os

RUN_NAME = 'anemll_q4_layer_by_layer_v1'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'

os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
import json
config = {
    'model_id': MODEL_ID,
    'lut_size': LUT_SIZE,
    'group_size': GROUP_SIZE,
    'scale_rank': SCALE_RANK,
    'attn_lut_size': ATTN_LUT_SIZE,
    'attn_group_size': ATTN_GROUP_SIZE,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_kd_loss': initial_loss,
    'post_layer_loss': post_layer_loss,
    'layer_losses': layer_losses,
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'Saved to {SAVE_DIR}')

In [None]:
# ============================================================
# UPLOAD TO GOOGLE DRIVE
# ============================================================

!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

In [None]:
# ============================================================
# TEST INFERENCE
# ============================================================

def run_inference(model, tokenizer, prompt, max_new_tokens=128):
    messages = [
        {'role': 'system', 'content': 'You are a helpful assistant.'},
        {'role': 'user', 'content': prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    
    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# Test
model.eval()
prompt = 'What is the capital of France?'
response = run_inference(model, tokenizer, prompt)
print(f'Prompt: {prompt}')
print(f'Response: {response}')

## Next Steps

After layer-by-layer training, you can:

1. **End-to-end refinement** - Unfreeze all layers and train together
2. **Train scales (A, B)** - Unfreeze scale_A, scale_B parameters
3. **LoRA recovery** - Add LoRA adapters to recover quality