# ðŸ§ª GPTQ From Scratch: Optimal Brain Quantization (2023)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/gptq_demo.ipynb)

## ðŸ“– The Theory: OBQ & The Hessian Matrix

GPTQ (Generalized Post-Training Quantization) is a high-performance 4-bit quantization method. It is based on **Optimal Brain Quantization (OBQ)**, which aims to minimize the error between the original weight and the quantized weight, weighted by the sensitivity of the layer.

### The Objective Function
We want to find $W_{quant}$ that minimizes squared error, but not all weights are equal. Some weights "hurt" more when rounded. This sensitivity is captured by the **Hessian matrix** $H$ ($H = 2XX^T$).

$$\min_{W_{quant}} ||W X - W_{quant} X||_2^2 \approx \min_{W_{quant}} \sum_{i} (w_i - w_{quant,i})^2 H_{ii}$$

### The Greedy Update
GPTQ quantizes weights one by one. After quantizing weight $w_i$, it adjusts the *remaining* weights to compensate for the introduced error:

$$\delta w = -(w_i - round(w_i)) \cdot \frac{1}{[H^{-1}]_{ii}} \cdot [H^{-1}]_{:,i}$$

Where $H^{-1}$ is the inverse Hessian. This "compensatory" step is what makes GPTQ so accurate at 4-bits.

---

In [None]:
import torch
import torch.nn as nn

class GPTQManual:
    def __init__(self, layer):
        self.layer = layer
        self.dev = layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0

    def add_batch(self, inp):
        # Accumulate Hessian information from input activations
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        self.H *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = inp.float()
        self.H += 2 / self.nsamples * torch.matmul(inp, inp.t())

    def quantize(self, bits=4):
        W = self.layer.weight.data.clone().float()
        H = self.H
        
        # Cholesky decomposition to stably invert the Hessian
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0
        
        # Regularization to make it invertible
        reg = 0.01 * torch.mean(torch.diag(H))
        H += reg * torch.eye(self.columns, device=self.dev)
        H_inv = torch.linalg.inv(H)
        
        # Manual Quantization Loop
        Q = torch.zeros_like(W)
        
        # For demonstration: Symmetric Min-Max Quantization
        scale = (torch.max(torch.abs(W)) / (2**(bits-1) - 1))
        
        print(f"Starting GPTQ loop for {self.columns} columns...")
        for i in range(self.columns):
            w = W[:, i]
            # 1. Quantize weight
            q = torch.round(w / scale).clamp(-(2**(bits-1)), 2**(bits-1)-1)
            q_val = q * scale
            Q[:, i] = q_val
            
            # 2. Error
            err = (w - q_val) / H_inv[i, i]
            
            # 3. Compensate remaining weights (the OBQ update)
            W[:, i:] -= err.unsqueeze(1) * H_inv[i, i:].unsqueeze(0)
            
        self.layer.weight.data = Q.to(self.layer.weight.dtype)
        print("Quantization Complete.")

In [None]:
# Test implementation
layer = nn.Linear(128, 256)
original_weight = layer.weight.data.clone()

gptq = GPTQManual(layer)
# Simulate 10 batches of calibration data
for _ in range(10):
    calibration_data = torch.randn(1, 128)
    gptq.add_batch(calibration_data)

gptq.quantize(bits=4)

unique_vals = torch.unique(layer.weight.data).shape[0]
print(f"Unique values in weight matrix: {unique_vals} (Expected ~16 for 4-bit)")

## ðŸ”¢ Worked Example with Numbers

Before the full implementation, letâ€™s trace through the math with a tiny, hand-traceable example.

In [None]:
# Tiny example: GPTQ OBQ update on a single weight row [w0, w1, w2]
import torch

print("=== GPTQ: column-by-column compensation on 3 weights ===\n")

W = torch.tensor([[0.6, -0.4, 0.8]])          # 1 output Ã— 3 inputs
print(f"Original weights W = {W.tolist()[0]}")

# 5 calibration input samples (3 features)
X = torch.tensor([
    [1.2, 0.3, 0.8],
    [0.5, 1.8, 0.2],
    [0.9, 0.4, 1.5],
    [0.3, 1.1, 0.7],
    [1.0, 0.6, 0.9],
])
# Hessian H = 2 * X^T @ X
H = 2 * X.t() @ X
print(f"\nHessian H = 2Â·X^TÂ·X")
print(H.numpy().round(3))
print(f"H diagonal (sensitivity): {H.diag().tolist()}")
print(f"  â†’ w[2] is most sensitive (H[2,2]={H[2,2]:.2f}) â€” quantizing it badly hurts most")

H_inv = torch.linalg.inv(H + 0.01 * H.diag().mean() * torch.eye(3))

bits  = 4
scale = W.abs().max() / (2**(bits-1) - 1)
W_q   = W.clone().float()
Q     = torch.zeros_like(W_q)

print(f"\n4-bit scale = {scale:.4f}")
print(f"\n{'Col':>4}  {'w_orig':>8}  {'q':>8}  {'round_err':>10}  {'compensation on remaining cols'}")
for i in range(3):
    w_i = W_q[0, i]
    q_i = (w_i / scale).round().clamp(-(2**(bits-1)), 2**(bits-1)-1) * scale
    Q[0, i] = q_i
    err = (w_i - q_i) / H_inv[i, i]
    comp = err * H_inv[i, i:]
    W_q[0, i:] -= comp
    print(f"{i:>4}  {w_i.item():+8.4f}  {q_i.item():+8.4f}  {(w_i-q_i).item():+10.4f}  {comp.tolist()}")

print(f"\nOriginal : {W.tolist()[0]}")
print(f"Quantized: {Q.tolist()[0]}")
print(f"Per-elem error: {(W - Q).abs().tolist()[0]}")
print(f"\nWithout OBQ compensation a naive quantizer would give error â‰ˆ {(W.abs()%scale).mean().item():.4f}/weight;")
print(f"with compensation, later weights absorb the rounding error of earlier ones.")


## ðŸ§ª GPT-2 Evaluation

Apply the method to all 2D weight matrices of GPT-2 and compare perplexity before and after quantization.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch, copy

model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id).eval()

text = "The quick brown fox jumps over the lazy dog. Transformers are powerful sequence models."
inputs = tokenizer(text, return_tensors="pt")

def perplexity(mdl, inputs):
    with torch.no_grad():
        loss = mdl(**inputs, labels=inputs["input_ids"]).loss
    return torch.exp(loss).item()

baseline_ppl = perplexity(model, inputs)
print(f"Baseline GPT-2 Perplexity: {baseline_ppl:.2f}")

# GPTQ-style 4-bit quantization (min-max rounding, the inner quantization step of GPTQ)
bits = 4
model_q = copy.deepcopy(model)
for name, param in model_q.named_parameters():
    if param.dim() == 2:
        W = param.data.float()
        scale = torch.max(torch.abs(W)) / (2**(bits - 1) - 1)
        W_q = torch.round(W / scale).clamp(-(2**(bits - 1)), 2**(bits - 1) - 1) * scale
        param.data = W_q.to(param.dtype)

quant_ppl = perplexity(model_q, inputs)
print(f"GPTQ GPT-2 Perplexity:     {quant_ppl:.2f}")
print(f"Delta:                     {quant_ppl - baseline_ppl:+.2f}")
