# ðŸ§ª BitNet 1.58b From Scratch: Ternary LLMs (2025)

[!["Open In Colab"](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/model-size-reduction/blob/main/chronology/bitnet_demo.ipynb)

## ðŸ“– The Theory: The Era of 1-bit LLMs

BitNet 1.58b is a landmark architecture where every parameter is limited to three possible values: `{-1, 0, 1}`. This is often called **1.58-bit** because $\log_2(3) \approx 1.58$.

### Why Ternary?
- **Addition-only Math**: Multiplying by `1` or `-1` is just adding or subtracting. Multiplying by `0` is doing nothing. This eliminates the "Multiplication" part of the Matrix-Multiply-Accumulate (MAC) operation, which is the most power-hungry part of modern chips.
- **Hardware Efficiency**: DRAM access and computation units become significantly simpler and faster.

### The Quantization Logic
To convert a weight $W$ to ternary, we use $abs\_max$ normalization:

$$\gamma = \max(|W|)$$
$$W_{quant} = \text{round}(\text{clip}(\frac{W}{\gamma + \epsilon}, -1, 1))$$

---

In [None]:
import torch
import torch.nn as nn

def bitnet_quantize(W):
    """
    Manual BitNet 1.58b quantization loop.
    Scales and maps weights to {-1, 0, 1}.
    """
    # 1. Calculate Gamma: mean absolute value is often used for better outlier resilience
    # but the paper uses abs_max for strict ternary range.
    gamma = torch.max(torch.abs(W))
    
    # 2. Scale and Round to ternary
    # Note: we add a epsilon to avoid div by zero
    W_scaled = W / (gamma + 1e-7)
    W_quant = torch.round(torch.clamp(W_scaled, -1, 1)).to(torch.int8)
    
    return W_quant, gamma

def bitnet_matmul(X, W_quant, gamma):
    """
    Simulated BitNet inference.
    X is assumed to be 8-bit quantized activations.
    """
    # In hardware, this is an addition-only matrix multiplication
    # Here, we represent it using integer matmul
    result = torch.matmul(X.to(torch.float32), W_quant.to(torch.float32))
    
    # Rescale back to FP range
    return result * gamma

# Demonstration
W = torch.randn(512, 1024)
W_q, gamma = bitnet_quantize(W)

print(f"Original weight sample: {W[0, :3].tolist()}")
print(f"Ternary weight sample:  {W_q[0, :3].tolist()}")
print(f"Unique values in W_q: {torch.unique(W_q).tolist()}")

X = torch.randn(1, 512)
out = bitnet_matmul(X, W_q, gamma)
print(f"\nOutput Scale (Gamma): {gamma:.4f}")