# ðŸ§ª BitNet 1.58b From Scratch: Ternary LLMs (2025)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/bitnet_demo.ipynb)

## ðŸ“– The Theory: The Era of 1-bit LLMs

BitNet 1.58b is a landmark architecture where every parameter is limited to three possible values: `{-1, 0, 1}`. This is often called **1.58-bit** because $\log_2(3) \approx 1.58$.

### Why Ternary?
- **Addition-only Math**: Multiplying by `1` or `-1` is just adding or subtracting. Multiplying by `0` is doing nothing. This eliminates the "Multiplication" part of the Matrix-Multiply-Accumulate (MAC) operation, which is the most power-hungry part of modern chips.
- **Hardware Efficiency**: DRAM access and computation units become significantly simpler and faster.

### The Quantization Logic
To convert a weight $W$ to ternary, we use $abs\_max$ normalization:

$$\gamma = \max(|W|)$$
$$W_{quant} = \text{round}(\text{clip}(\frac{W}{\gamma + \epsilon}, -1, 1))$$

---

In [None]:
import torch
import torch.nn as nn

def bitnet_quantize(W):
    """
    Manual BitNet 1.58b quantization loop.
    Scales and maps weights to {-1, 0, 1}.
    """
    # 1. Calculate Gamma: mean absolute value is often used for better outlier resilience
    # but the paper uses abs_max for strict ternary range.
    gamma = torch.max(torch.abs(W))
    
    # 2. Scale and Round to ternary
    # Note: we add a epsilon to avoid div by zero
    W_scaled = W / (gamma + 1e-7)
    W_quant = torch.round(torch.clamp(W_scaled, -1, 1)).to(torch.int8)
    
    return W_quant, gamma

def bitnet_matmul(X, W_quant, gamma):
    """
    Simulated BitNet inference.
    X is assumed to be 8-bit quantized activations.
    """
    # In hardware, this is an addition-only matrix multiplication
    # Here, we represent it using integer matmul
    result = torch.matmul(X.to(torch.float32), W_quant.to(torch.float32))
    
    # Rescale back to FP range
    return result * gamma

# Demonstration
W = torch.randn(512, 1024)
W_q, gamma = bitnet_quantize(W)

print(f"Original weight sample: {W[0, :3].tolist()}")
print(f"Ternary weight sample:  {W_q[0, :3].tolist()}")
print(f"Unique values in W_q: {torch.unique(W_q).tolist()}")

X = torch.randn(1, 512)
out = bitnet_matmul(X, W_q, gamma)
print(f"\nOutput Scale (Gamma): {gamma:.4f}")

## ðŸ”¢ Worked Example with Numbers

Before the full implementation, letâ€™s trace through the math with a tiny, hand-traceable example.

In [None]:
# Tiny example: BitNet 1.58b ternary quantization on a small vector
import torch

w = torch.tensor([0.8, -0.3, 0.05, -0.7, 0.2, -0.95])
print(f"Original weights: {[round(v,2) for v in w.tolist()]}")

# Step 1 â€“ Gamma = abs_max
gamma = w.abs().max()
print(f"\nStep 1  Gamma (abs_max) = {gamma:.4f}")

# Step 2 â€“ Scale and clamp to [-1, 1], then round to {-1, 0, 1}
w_scaled = w / (gamma + 1e-7)
w_ternary = w_scaled.clamp(-1, 1).round().to(torch.int8)
print("\nStep 2  Scale â†’ round â†’ ternary:")
print(f"  {'orig':>7}  {'Ã·gamma':>8}  {'round':>6}")
for orig, sc, t in zip(w.tolist(), w_scaled.tolist(), w_ternary.tolist()):
    print(f"  {orig:+7.4f}  {sc:+8.4f}  {t:+6d}")

# Step 3 â€“ Dequantize (just multiply by gamma)
w_recon = w_ternary.float() * gamma
print(f"\nStep 3  Reconstruct (ternary Ã— gamma={gamma:.4f}):")
print(f"  {'orig':>7}  {'ternary':>8}  {'recon':>8}  {'err':>8}")
for orig, t, r in zip(w.tolist(), w_ternary.tolist(), w_recon.tolist()):
    print(f"  {orig:+7.4f}  {t:+8d}  {r:+8.4f}  {abs(orig-r):8.4f}")

print(f"\nMSE: {((w - w_recon)**2).mean():.6f}")
print("\nHardware insight: multiply-by-ternary = add / subtract / skip")
print("  +1 â†’ ADD     the weight row")
print("   0 â†’ SKIP    (no operation)")
print("  -1 â†’ SUBTRACT the weight row")


## ðŸ§ª GPT-2 Evaluation

Apply the method to all 2D weight matrices of GPT-2 and compare perplexity before and after quantization.

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch, copy

model_id = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id).eval()

text = "The quick brown fox jumps over the lazy dog. Transformers are powerful sequence models."
inputs = tokenizer(text, return_tensors="pt")

def perplexity(mdl, inputs):
    with torch.no_grad():
        loss = mdl(**inputs, labels=inputs["input_ids"]).loss
    return torch.exp(loss).item()

baseline_ppl = perplexity(model, inputs)
print(f"Baseline GPT-2 Perplexity:      {baseline_ppl:.2f}")

model_q = copy.deepcopy(model)
for name, param in model_q.named_parameters():
    if param.dim() == 2:
        w_q, gamma = bitnet_quantize(param.data)
        param.data = w_q.float() * gamma

quant_ppl = perplexity(model_q, inputs)
print(f"BitNet 1.58b GPT-2 Perplexity:  {quant_ppl:.2f}")
print(f"Delta:                          {quant_ppl - baseline_ppl:+.2f}")
