# ðŸ§ª AWQ From Scratch: Activation-aware Weight Quantization (2023)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/awq_demo.ipynb)

## ðŸ“– The Theory: Protecting Salient Weights

AWQ (Activation-aware Weight Quantization) is based on the observation that **not all weights are equally important**. Weights corresponding to large activation values ("salient weights") contribute significantly more to the final error if quantized poorly.

### The Scaling Strategy
Instead of searching for a complex non-linear mapping, AWQ simply **scales up** the most important weights before quantization. By multiplying a weight by $s > 1$, we move it to a higher precision region of the quantization grid. To maintain mathematical equivalence, we must scale down the activations by $1/s$.

$$Y = (X \cdot diag(1/s)) \cdot (diag(s) \cdot W)$$

### Finding the Optimal Scale
AWQ searches for a scale factor $s$ that minimizes the output error. A common heuristic is to use the activation magnitude raised to some power:

$s = s_{X}^\alpha$ where $s_X$ is the activation scale.

---

In [None]:
import torch
import torch.nn as nn
import numpy as np

def pseudo_quantize_tensor(w, n_bits, scale, zero):
    """Standard Min-Max Quantization Simulation"""
    w_q = torch.round(w / scale + zero)
    w_q = torch.clamp(w_q, 0, 2**n_bits - 1)
    w_q = (w_q - zero) * scale
    return w_q

def awq_from_scratch(w, x, n_bits=4, n_grid=20):
    """
    Simplified AWQ Logic.
    w: [out_features, in_features] - Weight matrix
    x: [batch, in_features] - Calibration activations
    """
    # 1. Measure Activation Statistics (Scale of each input feature)
    x_max = torch.mean(torch.abs(x), dim=0)
    
    # 2. Search for the best alpha (heuristic power for scaling)
    best_error = float('inf')
    best_s = None
    
    # Baseline weight stats (row-wise)
    w_max = torch.max(torch.abs(w), dim=1, keepdim=True)[0]
    
    org_out = torch.matmul(x, w.t())
    
    print("Searching for optimal AWQ scaling factor...")
    for alpha in np.linspace(0, 1, n_grid):
        # Scale based on activation magnitude
        s = x_max.pow(alpha)
        s = s / torch.sqrt(s.max() * s.min() + 1e-8)  # Normalize scale
        
        # Apply scale to weight
        w_scaled = w * s.view(1, -1)
        
        # Quantize the scaled weight
        cur_max = torch.max(torch.abs(w_scaled), dim=1, keepdim=True)[0]
        cur_scale = (cur_max / (2**(n_bits-1) - 1)) + 1e-8
        w_q = torch.round(w_scaled / cur_scale) * cur_scale
        
        # Reverse scale for inference simulation
        w_q_final = w_q / s.view(1, -1)
        
        # Measure error
        cur_out = torch.matmul(x, w_q_final.t())
        err = (org_out - cur_out).pow(2).mean()
        
        if err < best_error:
            best_error = err
            best_s = s
            
    print(f"Best Error found: {best_error:.6f}")
    return best_s

# Test implementation
in_features, out_features = 512, 1024
w = torch.randn(out_features, in_features)
x = torch.randn(16, in_features)
x[:, :10] *= 10.0  # Make some features salient

s = awq_from_scratch(w, x)
print(f"Scale factor for salient feature 0: {s[0]:.4f}")
print(f"Scale factor for normal feature 50: {s[50]:.4f}")

## ðŸ”¢ Worked Example with Numbers

Before the full implementation, letâ€™s trace through the math with a tiny, hand-traceable example.

In [None]:
# Tiny example: AWQ protects a salient input channel
import torch

print("=== AWQ: scaling a salient channel before quantization ===\n")

# 2 outputs, 4 inputs â€” channel 2 is salient
W     = torch.tensor([[ 0.5, -0.3,  0.8, -0.2],
                       [-0.4,  0.7, -0.6,  0.3]])
x_max = torch.tensor([0.10, 0.15, 8.00, 0.12])   # activation scales

print(f"Weight matrix W (2Ã—4):\n{W.numpy()}")
print(f"\nActivation magnitude per input channel: {x_max.tolist()}")
print(f"  â†’ Channel 2 is SALIENT (8.0 vs ~0.1 for the others)")

def quant4(w):
    """Row-wise 4-bit symmetric quantization."""
    q_max = 7
    scale = w.abs().max(dim=1, keepdim=True)[0] / q_max + 1e-8
    return (w / scale).round().clamp(-q_max, q_max) * scale

# Baseline â€” naive 4-bit, no scaling
W_q_naive = quant4(W)
mse_naive  = (W - W_q_naive).pow(2).mean()
print(f"\n--- Naive 4-bit (no AWQ) ---")
print(f"Quantized W:\n{W_q_naive.numpy().round(4)}")
print(f"MSE: {mse_naive:.6f}")

# AWQ â€” scale salient channel UP before quantizing
alpha = 0.5
s     = x_max.pow(alpha)
s     = s / (s.max() * s.min()).sqrt()       # normalize scale

print(f"\n--- AWQ (alpha={alpha}) ---")
print(f"Per-channel scale s = x_max^0.5 (normalized): {s.tolist()}")
print(f"  â†’ Channel 2 scaled by {s[2]:.2f}Ã— before quantization (more precision allocated)")

W_scaled  = W * s.view(1, -1)              # scale up salient weights
W_q_scaled = quant4(W_scaled)
W_q_awq   = W_q_scaled / s.view(1, -1)    # undo scale for inference
mse_awq   = (W - W_q_awq).pow(2).mean()

print(f"Quantized W (after unscaling):\n{W_q_awq.numpy().round(4)}")
print(f"MSE: {mse_awq:.6f}")
print(f"\nAWQ improvement: {(1 - mse_awq/mse_naive)*100:.1f}% lower MSE for the same 4-bit budget")


## ðŸ§ª GPT-2 Evaluation

Apply the method to all 2D weight matrices of GPT-2 and compare perplexity before and after quantization.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch, copy

model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id).eval()

text = "The quick brown fox jumps over the lazy dog. Transformers are powerful sequence models."
inputs = tokenizer(text, return_tensors="pt")

def perplexity(mdl, inputs):
    with torch.no_grad():
        loss = mdl(**inputs, labels=inputs["input_ids"]).loss
    return torch.exp(loss).item()

baseline_ppl = perplexity(model, inputs)
print(f"Baseline GPT-2 Perplexity: {baseline_ppl:.2f}")

# Collect per-layer input activation stats via hooks
act_stats = {}
def make_stat_hook(name):
    def hook(module, inp, out):
        x = inp[0].detach().float().reshape(-1, inp[0].shape[-1])
        act_stats[name] = torch.mean(torch.abs(x), dim=0)
    return hook

hooks = []
for name, module in model.named_modules():
    if hasattr(module, "weight") and module.weight is not None and module.weight.dim() == 2:
        hooks.append(module.register_forward_hook(make_stat_hook(name)))
with torch.no_grad():
    model(**inputs)
for h in hooks:
    h.remove()

# Apply AWQ (alpha=0.5) + 4-bit quantization to all 2D weights
n_bits, q_max = 4, 7
model_q = copy.deepcopy(model)
for name, module in model_q.named_modules():
    if hasattr(module, "weight") and module.weight is not None and module.weight.dim() == 2:
        w = module.weight.data.float()
        x_max = act_stats.get(name)
        if x_max is not None and x_max.shape[0] == w.shape[1]:
            s = x_max.pow(0.5)
            s = s / (s.max() * s.min()).sqrt().clamp(min=1e-8)
            w_scaled = w * s.view(1, -1)
        else:
            w_scaled, s = w, None
        scale = torch.max(torch.abs(w_scaled)) / q_max + 1e-8
        w_q = torch.round(w_scaled / scale).clamp(-q_max, q_max) * scale
        if s is not None:
            w_q = w_q / s.view(1, -1)
        module.weight.data = w_q.to(module.weight.dtype)

quant_ppl = perplexity(model_q, inputs)
print(f"AWQ GPT-2 Perplexity:      {quant_ppl:.2f}")
print(f"Delta:                     {quant_ppl - baseline_ppl:+.2f}")
