# ðŸ§ª AWQ From Scratch: Activation-aware Weight Quantization (2023)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/awq_demo.ipynb)

## ðŸ“– The Theory: Protecting Salient Weights

AWQ (Activation-aware Weight Quantization) is based on the observation that **not all weights are equally important**. Weights corresponding to large activation values ("salient weights") contribute significantly more to the final error if quantized poorly.

### The Scaling Strategy
Instead of searching for a complex non-linear mapping, AWQ simply **scales up** the most important weights before quantization. By multiplying a weight by $s > 1$, we move it to a higher precision region of the quantization grid. To maintain mathematical equivalence, we must scale down the activations by $1/s$.

$$Y = (X \cdot diag(1/s)) \cdot (diag(s) \cdot W)$$

### Finding the Optimal Scale
AWQ searches for a scale factor $s$ that minimizes the output error. A common heuristic is to use the activation magnitude raised to some power:

$s = s_{X}^\alpha$ where $s_X$ is the activation scale.

---

In [None]:
import torch
import torch.nn as nn

def pseudo_quantize_tensor(w, n_bits, scale, zero):
    """Standard Min-Max Quantization Simulation"""
    w_q = torch.round(w / scale + zero)
    w_q = torch.clamp(w_q, 0, 2**n_bits - 1)
    w_q = (w_q - zero) * scale
    return w_q

def awq_from_scratch(w, x, n_bits=4, n_grid=20):
    """
    Simplified AWQ Logic.
    w: [out_features, in_features] - Weight matrix
    x: [batch, in_features] - Calibration activations
    """
    # 1. Measure Activation Statistics (Scale of each input feature)
    x_max = torch.mean(torch.abs(x), dim=0)
    
    # 2. Search for the best alpha (heuristic power for scaling)
    best_error = float('inf')
    best_s = None
    
    # Baseline weight stats (row-wise)
    w_max = torch.max(torch.abs(w), dim=1, keepdim=True)[0]
    
    org_out = torch.matmul(x, w.t())
    
    print("Searching for optimal AWQ scaling factor...")
    for alpha in np.linspace(0, 1, n_grid):
        # Scale based on activation magnitude
        s = x_max.pow(alpha)
        s = s / torch.sqrt(s.max() * s.min())  # Normalize scale
        
        # Apply scale to weight
        w_scaled = w * s.view(1, -1)
        
        # Quantize the scaled weight
        cur_max = torch.max(torch.abs(w_scaled), dim=1, keepdim=True)[0]
        cur_scale = cur_max / (2**(n_bits-1) - 1)
        w_q = torch.round(w_scaled / cur_scale) * cur_scale
        
        # Reverse scale for inference simulation
        w_q_final = w_q / s.view(1, -1)
        
        # Measure error
        cur_out = torch.matmul(x, w_q_final.t())
        err = (org_out - cur_out).pow(2).mean()
        
        if err < best_error:
            best_error = err
            best_s = s
            
    print(f"Best Error found: {best_error:.6f}")
    return best_s

import numpy as np
# Test implementation
in_features, out_features = 512, 1024
w = torch.randn(out_features, in_features)
x = torch.randn(16, in_features)
x[:, :10] *= 10.0  # Make some features salient

s = awq_from_scratch(w, x)
print(f"Scale factor for salient feature 0: {s[0]:.4f}")
print(f"Scale factor for normal feature 50: {s[50]:.4f}")