# üß™ NF4 From Scratch: NormalFloat 4-bit Quantization (2023)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/nf4_demo.ipynb)

## üìñ Theory: Information-Theoretic Optimality

NF4 (NormalFloat 4) is the data type introduced in **QLoRA** (Dettmers et al., 2023).
It achieves **information-theoretically optimal** quantization for data drawn from a
normal distribution -- which is precisely what pre-trained neural-network weights follow.

### Why Not INT4?

Standard INT4 places its 16 levels **uniformly** across $[x_{\min}, x_{\max}]$.
For a Gaussian weight distribution, most probability mass sits near zero,
so many uniformly-spaced levels fall in rarely-visited tails.
NF4 fixes this by placing levels at the **quantiles** of $\mathcal{N}(0,1)$:

$$q_i = \Phi^{-1}\!\left(\frac{i + 0.5}{16}\right), \quad i = 0,\dots,15$$

where $\Phi^{-1}$ is the inverse standard-normal CDF.
Each bucket covers an equal $1/16$ probability slice -- no level is wasted on rare values.

### Per-Block Absmax Normalization

Each weight block is normalised by its absolute maximum before lookup:

$$\hat{w} = \frac{w}{\max|W|} \in [-1,\,1]$$

then mapped to the nearest NF4 level. The scale $\max|W|$ is stored alongside the
4-bit index and multiplied back at dequantization time.

### Double Quantization

QLoRA further compresses per-block scales from 32-bit floats to **8-bit floats**,
saving an additional $\approx 0.37$ bits per parameter at negligible accuracy loss.

### NF4 vs INT4 at a Glance

| Property | INT4 | NF4 |
|---|---|---|
| Level placement | Uniform | Quantile-based |
| Optimal for | Uniform distributions | Normal distributions |
| Typical weight MSE | Higher | ~30% lower |
| Used in | General PTQ | QLoRA fine-tuning |

### Limitations
* Assumes weights are **normally distributed** -- fails for bimodal or heavy-tailed layers.
* Requires storing a per-block FP8 scale, adding a small memory overhead.
* Dequantization introduces a lookup step that can be slower than pure INT arithmetic.

---

In [None]:
import torch
from scipy.stats import norm

def create_nf4_map():
    """Manual creation of the NF4 16-level lookup table"""
    # Standard normal distribution quantiles
    # We need 16 values. QLoRA specifically uses a zero-centered asymmetric map.
    offset = 1.0 / (2 * 16)
    p_values = torch.linspace(offset, 1 - offset, 16)
    
    # Correct for NF4 specifics: it uses zero as one level and is symmetric at certain points
    # This is a simplified version of the official NF4 constant list
    nf4_values = norm.ppf(p_values)
    nf4_values = torch.from_numpy(nf4_values).float()
    
    # Normalize to [-1, 1]
    nf4_values = nf4_values / nf4_values.max()
    return nf4_values.sort()[0]

nf4_map = create_nf4_map()
print(f"NF4 Lookup Table (16 levels):\n{nf4_map}")

## üõ†Ô∏è Implementation: Manual NF4 Mapping

Let's implement the mapping from FP32 to the closest NF4 level.

In [None]:
def quantize_nf4(w, nf4_map):
    """
    Quantize a weight matrix to the closest NF4 value.
    w: Tensor in the range [-1, 1]
    """
    # 1. Normalize weight to unit range if it isn't already
    abs_max = torch.max(torch.abs(w))
    w_norm = w / abs_max
    
    # 2. Find closest values in map
    # This can be done efficiently with searchsorted or absolute difference
    # For clarity, we use the difference method here
    w_flat = w_norm.view(-1, 1)
    diff = torch.abs(w_flat - nf4_map.view(1, -1))
    indices = torch.argmin(diff, dim=1)
    
    # 3. Simulate Dequantization
    q_w = nf4_map[indices].view(w.shape)
    return q_w * abs_max, indices

# Test with Normal data
w_raw = torch.randn(1024, 1024)
w_nf4, w_indices = quantize_nf4(w_raw, nf4_map)

error = (w_raw - w_nf4).pow(2).mean()
print(f"Mean Squared Error: {error:.6f}")
print(f"Compression: 32-bit to 4-bit indices (8x smaller storage)")

## üî¢ Worked Example with Numbers

Before the full implementation, let‚Äôs trace through the math with a tiny, hand-traceable example.

In [None]:
# Tiny example: quantize 6 weights to NF4 step-by-step
# (nf4_map and quantize_nf4 are defined in the cell above)
import torch

w = torch.tensor([0.9, -0.3, 0.05, -0.85, 0.4, -0.1])
print(f"Original weights: {[round(v,2) for v in w.tolist()]}")

# Step 1 ‚Äì Normalize to [-1, 1]
abs_max = w.abs().max()
w_norm  = w / abs_max
print(f"\nStep 1  Normalize (√∑{abs_max:.2f}):")
print(f"  {[round(v,4) for v in w_norm.tolist()]}")

# Step 2 ‚Äì Show the 16 NF4 levels
print("\nStep 2  NF4 lookup table (16 levels):")
for i, v in enumerate(nf4_map.tolist()):
    print(f"  [{i:2d}]  {v:+.4f}")

# Step 3 ‚Äì Map each normalized value to the nearest NF4 level
diff    = (w_norm.view(-1, 1) - nf4_map.view(1, -1)).abs()
indices = diff.argmin(dim=1)
w_q     = nf4_map[indices]
print("\nStep 3  Nearest NF4 level + dequantize (√óabs_max):")
for orig, nrm, idx, q in zip(w.tolist(), w_norm.tolist(), indices.tolist(), w_q.tolist()):
    recon = q * abs_max.item()
    err   = abs(orig - recon)
    print(f"  {orig:+.4f}  ‚Üí  norm {nrm:+.4f}  ‚Üí  NF4[{idx:2d}]={q:+.4f}  ‚Üí  recon {recon:+.4f}  (err {err:.4f})")

print(f"\nMean absolute error : {(w - w_q*abs_max).abs().mean():.6f}")
print(f"Storage: 4-bit index (0-15) instead of 32-bit float  ‚Üí  8√ó smaller")


## üß™ GPT-2 Evaluation

Apply the method to all 2D weight matrices of GPT-2 and compare perplexity before and after quantization.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch, copy

model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id).eval()

text = "The quick brown fox jumps over the lazy dog. Transformers are powerful sequence models."
inputs = tokenizer(text, return_tensors="pt")

def perplexity(mdl, inputs):
    with torch.no_grad():
        loss = mdl(**inputs, labels=inputs["input_ids"]).loss
    return torch.exp(loss).item()

baseline_ppl = perplexity(model, inputs)
print(f"Baseline GPT-2 Perplexity: {baseline_ppl:.2f}")

model_q = copy.deepcopy(model)
for name, param in model_q.named_parameters():
    if param.dim() == 2:
        q_w, _ = quantize_nf4(param.data, nf4_map)
        param.data = q_w

quant_ppl = perplexity(model_q, inputs)
print(f"NF4 GPT-2 Perplexity:      {quant_ppl:.2f}")
print(f"Delta:                     {quant_ppl - baseline_ppl:+.2f}")


## üìö References

1. **Dettmers, T., Pagnoni, A., Holtzman, A., & Zettlemoyer, L.** (2023).  
   *QLoRA: Efficient Finetuning of Quantized LLMs.* NeurIPS 2023.  
   [arXiv:2305.14314](https://arxiv.org/abs/2305.14314)

2. **Dettmers, T., Lewis, M., Belkada, Y., & Zettlemoyer, L.** (2022).  
   *The case for 4-bit precision: k-bit Inference Scaling Laws.* ICML 2023.  
   [arXiv:2212.09720](https://arxiv.org/abs/2212.09720)

3. **Hu, E., Shen, Y., Wallis, P., et al.** (2021).  
   *LoRA: Low-Rank Adaptation of Large Language Models.* ICLR 2022.  
   [arXiv:2106.09685](https://arxiv.org/abs/2106.09685)
