# Anemll V2 FP16 Training Pipeline

Full FP16 training pipeline for V2 QAT - ensures no precision mismatch with ANE.

## Key Differences from BF16 Training:
- Model loaded in **FP16** (not BF16)
- LUT created in **FP16**
- Indices computed in **FP16** (same as ANE)
- Uses **GradScaler** for stable FP16 training (CUDA)

## Benefits:
- **No precision mismatch**: Indices computed in FP16 = same as ANE
- **No snap needed**: Model is already in FP16 format
- **ANE-ready**: Direct export without conversion

## Pipeline:
1. Load model in FP16
2. Replace linears with AnemllQATLinearV2
3. Convert V2 layers to FP16 (ensures LUT is FP16)
4. Freeze Q (indices computed in FP16)
5. Train with use_fp16=True (GradScaler on CUDA)
6. Save - no snap needed!

In [None]:
# ============================================================
# GOOGLE DRIVE PATHS (STANDARD)
# ============================================================

# Checkpoints/runs go here
GD_RUNS = '/content/drive/MyDrive/qwen3_runs'

# KD caches go here
GD_CACHES = '/content/drive/MyDrive/qwen3_caches'

# Local directories (on Colab VM)
LOCAL_RUNS = 'runs'
LOCAL_CACHES = 'caches'

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo if needed
!git clone https://github.com/anemll/qwen3_apple_style_2bit_qat_lora.git || (cd qwen3_apple_style_2bit_qat_lora && git pull)
%cd qwen3_apple_style_2bit_qat_lora
!git fetch && git pull
!git reset --hard HEAD

import sys
[sys.modules.pop(k) for k in list(sys.modules) if k.startswith('qat_lora')]

# Import V2 modules
from qat_lora import (
    AnemllQATLinearV2,
    AnemllQuantConfigV2,
    replace_linear_with_anemll_v2,
    freeze_Q_all,
    freeze_model_for_inference_v2,
    unfreeze_model_for_training_v2,
    convert_model_to_fp16_v2,  # NEW: FP16 conversion
    evaluate_kd_loss,
    train_all_layers,
    train_e2e,
    save_checkpoint,
)

In [None]:
!pip install -q transformers accelerate safetensors

In [None]:
# ============================================================
# LOAD KD CACHE FROM GOOGLE DRIVE
# ============================================================

CACHE_NAME = 'alpaca_chat_think_both_L128_K128_R1024'
CACHE_TGZ = f'{CACHE_NAME}.tgz'

!mkdir -p {LOCAL_CACHES}

import os
cache_local_path = f'{LOCAL_CACHES}/{CACHE_NAME}'
if not os.path.exists(cache_local_path):
    print(f'Extracting {CACHE_TGZ} from Google Drive...')
    !tar -xzf {GD_CACHES}/{CACHE_TGZ} -C {LOCAL_CACHES}/
else:
    print(f'Cache already exists at {cache_local_path}')

!ls -la {cache_local_path}/ | head -10

In [None]:
# ============================================================
# CONFIGURATION - FP16 TRAINING
# ============================================================

import torch

# Model
MODEL_ID = 'Qwen/Qwen3-0.6B'

# Quantization config
LUT_BITS = 4
LUT_SIZE = 2**LUT_BITS
SCALE_RANK = 4

# Attention quantization
ATTN_LUT_BITS = 4
ATTN_LUT_SIZE = 2**ATTN_LUT_BITS
ATTN_SCALE_RANK = 4

# Training
BATCH_SIZE = 32 if torch.cuda.is_available() else 4
GRAD_ACCUM = 1 if torch.cuda.is_available() else 4

# KD / Distillation params
DISTILL_TEMP = 2.0

# Device - CUDA required for GradScaler
if not torch.cuda.is_available():
    raise RuntimeError("FP16 training with GradScaler requires CUDA!")
DEVICE = torch.device('cuda')

# *** FP16 from the start! ***
DTYPE = torch.float16
USE_FP16_TRAINING = True

QUAL = f'q{LUT_BITS}_a{ATTN_LUT_BITS}_fp16'

print(f'=== FP16 Training Pipeline ===')
print(f'Quality: {QUAL}')
print(f'Device: {DEVICE}, dtype: {DTYPE}')
print(f'GradScaler: Enabled (CUDA FP16)')
print(f'Quant config: lut={LUT_SIZE}, rank={SCALE_RANK}')

In [None]:
# ============================================================
# LOAD MODEL IN FP16
# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer

print(f'Loading {MODEL_ID} in FP16...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,  # FP16 from the start!
    trust_remote_code=True,
)
model.to(DEVICE)
model.eval()

print(f'Loaded in FP16. Parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Model dtype: {next(model.parameters()).dtype}')

In [None]:
# ============================================================
# REPLACE LINEARS WITH V2
# ============================================================

mlp_config = AnemllQuantConfigV2(
    lut_size=LUT_SIZE,
    scale_rank=SCALE_RANK,
    learnable_lut=False,
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

attn_config = AnemllQuantConfigV2(
    lut_size=ATTN_LUT_SIZE,
    scale_rank=ATTN_SCALE_RANK,
    learnable_lut=False,
    force_positive_scales=True,
    positive_scale_method="abs",
    magnitude_activation="softplus",
    magnitude_eps=1e-6,
)

print('Replacing linear layers with V2...')
count = replace_linear_with_anemll_v2(
    model,
    mlp_config=mlp_config,
    attn_config=attn_config,
    quantize_attn=True,
    quantize_lm_head=False,
)

print(f'Replaced {count} layers with AnemllQATLinearV2')

In [None]:
# ============================================================
# CONVERT V2 LAYERS TO FP16 (CRITICAL!)
# ============================================================
# This ensures LUT is created in FP16, so indices will be computed in FP16
# This must be done BEFORE freeze_Q_all()

print('Converting V2 layers to FP16 (LUT, scales, weights)...')
convert_model_to_fp16_v2(model, verbose=True)

# Verify FP16 conversion
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        print(f'\nVerifying {name}:')
        print(f'  weight.dtype: {module.weight.dtype}')
        print(f'  scale_A.dtype: {module.scale_A.dtype}')
        print(f'  lut.dtype: {module.lut.dtype}')
        break

In [None]:
# ============================================================
# FREEZE Q IN FP16
# ============================================================
# Now indices will be computed using FP16 arithmetic - same as ANE!

print('Freezing Q (computing indices in FP16)...')
freeze_Q_all(model, verbose=False)
print('Q frozen for all layers in FP16.')
print('Indices computed in FP16 = same precision as ANE inference!')

In [None]:
# ============================================================
# VERIFY FP16 SETUP
# ============================================================

print('=== FP16 Setup Verification ===')

for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        print(f'\nLayer: {name}')
        print(f'  weight.dtype: {module.weight.dtype}')
        print(f'  scale_A.dtype: {module.scale_A.dtype}')
        print(f'  lut.dtype: {module.lut.dtype}')
        print(f'  _Q.dtype: {module._Q.dtype if module._Q is not None else "None"}')
        print(f'  _indices.dtype: {module._indices.dtype if module._indices is not None else "None"}')
        break

# Initial loss
initial_loss = evaluate_kd_loss(model, cache_local_path, DEVICE, num_samples=40, temperature=DISTILL_TEMP)
print(f'\nInitial KD Loss (FP16): {initial_loss:.4f}')

# FP16 Training with GradScaler

Training uses:
- `torch.amp.autocast` for FP16 forward pass
- `torch.cuda.amp.GradScaler` for gradient scaling

This prevents gradient underflow in FP16.

In [None]:
# ============================================================
# FP16 E2E TRAINING - MLP SCALES
# ============================================================

print('=== FP16 Training with GradScaler ===')

e2e_mlp_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=4000,
    batch_size=BATCH_SIZE,
    lr=5e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=True,
    use_fp16=True,  # FP16 with GradScaler!
)

print(f'\nMLP Training Result:')
print(f'  Initial: {e2e_mlp_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_mlp_result["final_loss"]:.4f}')
print(f'  Best: {e2e_mlp_result["best_loss"]:.4f}')

In [None]:
# ============================================================
# FP16 E2E TRAINING - ATTENTION SCALES
# ============================================================

unfreeze_model_for_training_v2(model)

# Freeze MLP scales, only train attention
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        is_attn = any(x in name for x in ['q_proj', 'k_proj', 'v_proj', 'o_proj'])
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = is_attn
            module.scale_B.requires_grad = is_attn
            module.rank_magnitude.requires_grad = is_attn
        module.weight.requires_grad = False

e2e_attn_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=2000,
    batch_size=BATCH_SIZE,
    lr=1e-4,
    use_cosine_schedule=True,
    warmup_steps=100,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,
    use_fp16=True,
)

print(f'\nAttention Training Result:')
print(f'  Initial: {e2e_attn_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_attn_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# FP16 E2E TRAINING - JOINT MLP + ATTENTION
# ============================================================

unfreeze_model_for_training_v2(model)

# Enable ALL scales
for name, module in model.named_modules():
    if type(module).__name__ == 'AnemllQATLinearV2':
        if hasattr(module, 'scale_A') and module.scale_A is not None:
            module.scale_A.requires_grad = True
            module.scale_B.requires_grad = True
            module.rank_magnitude.requires_grad = True
        module.weight.requires_grad = False

e2e_joint_result = train_e2e(
    model=model,
    cache_dir=cache_local_path,
    device=DEVICE,
    max_steps=1000,
    batch_size=32,
    lr=5e-5,
    use_cosine_schedule=True,
    warmup_steps=50,
    min_lr_ratio=0.1,
    temperature=DISTILL_TEMP,
    train_weights=False,
    train_scales=True,
    hard_top1_weight=0.0,
    hard_full_weight=0.0,
    logging_steps=20,
    eval_steps=100,
    verbose=True,
    train_mlp_only=False,
    use_fp16=True,
)

print(f'\nJoint Training Result:')
print(f'  Initial: {e2e_joint_result["initial_loss"]:.4f}')
print(f'  Final: {e2e_joint_result["final_loss"]:.4f}')

In [None]:
# ============================================================
# SAVE FP16 CHECKPOINT (NO SNAP NEEDED!)
# ============================================================
# Model is already in FP16, indices computed in FP16
# No snap_for_ane() needed - model is ANE-ready!

import json
import os

RUN_NAME = f'anemll_v2_{QUAL}_fp16_trained'
SAVE_DIR = f'{LOCAL_RUNS}/{RUN_NAME}'
os.makedirs(SAVE_DIR, exist_ok=True)

# Save state dict
torch.save(model.state_dict(), f'{SAVE_DIR}/model_state_dict.pt')

# Save config
config = {
    'model_id': MODEL_ID,
    'version': 'v2',
    'precision': 'fp16',  # FP16 trained!
    'lut_bits': LUT_BITS,
    'attn_lut_bits': ATTN_LUT_BITS,
    'scale_rank': SCALE_RANK,
    'attn_scale_rank': ATTN_SCALE_RANK,
    'initial_loss': initial_loss,
    'final_loss': e2e_joint_result['final_loss'],
    'training_mode': 'fp16_gradscaler',
}
with open(f'{SAVE_DIR}/config.json', 'w') as f:
    json.dump(config, f, indent=2)

print(f'\n=== FP16 Checkpoint Saved ===')
print(f'Path: {SAVE_DIR}')
print(f'Precision: FP16 (ANE-ready, no snap needed)')
print(f'Final Loss: {e2e_joint_result["final_loss"]:.4f}')

In [None]:
# Upload to Google Drive
!tar -czvf {RUN_NAME}.tgz -C {LOCAL_RUNS} {RUN_NAME}
!cp {RUN_NAME}.tgz {GD_RUNS}/
print(f'Uploaded to {GD_RUNS}/{RUN_NAME}.tgz')

In [None]:
# ============================================================
# TEST FP16 INFERENCE
# ============================================================

# Freeze for fast inference
freeze_model_for_inference_v2(model, verbose=False)

def run_inference(model, tokenizer, prompt, max_new_tokens=256):
    messages = [{'role': 'user', 'content': prompt}]
    text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
    )
    inputs = tokenizer(text, return_tensors='pt').to(DEVICE)
    
    with torch.no_grad():
        output = model.generate(
            **inputs, max_new_tokens=max_new_tokens,
            do_sample=True, temperature=0.6, top_p=0.9,
        )
    return tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False)

model.eval()

prompts = [
    'What is the capital of France?',
    'What is 2+2?',
    'Explain quantum mechanics briefly.',
]

print('=== FP16 Inference Test ===')
for prompt in prompts:
    response = run_inference(model, tokenizer, prompt)
    print(f'\nPrompt: {prompt}')
    print(f'Response: {response}')
    print('-' * 50)

# Summary

## FP16 Training Complete!

The model was trained entirely in FP16:
- LUT created in FP16
- Indices computed in FP16  
- Training with GradScaler for stability
- No snap_for_ane() needed - model is ANE-ready!

## Key Difference from BF16 Training:
- **BF16**: Indices computed in BF16, then snapped to FP16 (potential mismatch)
- **FP16**: Indices computed in FP16 from start (same as ANE)

## Next Steps:
1. Convert to CoreML with ANEMLL converter
2. Deploy to ANE
3. No additional snapping needed!