# Needle Quantization Demo

This notebook demonstrates the `int8` quantization support in Needle.
We will cover:
1. Basic quantization and dequantization of tensors.
2. Quantized matrix multiplication.
3. Dynamic and Static quantization of a Neural Network.

In [None]:
import sys
sys.path.insert(0, './python')
import importlib
import needle
import needle.backend_ndarray.ndarray
import needle.backend_ndarray
import needle.backend_selection
import needle.autograd

importlib.reload(needle.backend_ndarray.ndarray)
importlib.reload(needle.backend_ndarray)
importlib.reload(needle.backend_selection)
importlib.reload(needle.autograd)

import needle as ndl
print(f"Needle file: {ndl.__file__}")
import numpy as np
import needle.nn as nn
from needle import backend_ndarray as nd

# Set random seed
np.random.seed(42)

In [None]:
from needle.backend_ndarray.ndarray import NDArray as NDArrayOrig
from needle.autograd import NDArray as NDArrayAuto
print(f"NDArrayOrig: {NDArrayOrig}")
print(f"NDArrayAuto: {NDArrayAuto}")
print(f"Same? {NDArrayOrig is NDArrayAuto}")

# Check Tensor init logic
t = ndl.Tensor([1, 2, 3], dtype="float32")
arr = t.realize_cached_data()
print(f"Array type: {type(arr)}")
print(f"Array dtype: {arr.dtype}")
print(f"Is instance of NDArrayAuto? {isinstance(arr, NDArrayAuto)}")

# Check quantize
q = t.quantize_int8(1.0, 0)
q_arr = q.realize_cached_data()
print(f"Quantized Array type: {type(q_arr)}")
print(f"Quantized Array params: {q_arr._quant_params}")

## 1. Basic Quantization

We can quantize a `float32` tensor to `int8` using `quantize_int8`.
This stores the data as `int8` and attaches scale and zero_point metadata.

In [None]:
# Create a float32 tensor
x = ndl.Tensor(np.random.randn(5, 5).astype(np.float32))
print("Original Tensor (first row):", x.numpy()[0])

# Quantize to int8
# We need to compute scale and zero_point first.
# We can use the helper function from needle.quantization
from needle.quantization import compute_scale_zero_point

min_val, max_val = x.numpy().min(), x.numpy().max()
scale, zp = compute_scale_zero_point(min_val, max_val)
print(f"Scale: {scale}, Zero Point: {zp}")

x_quant = x.quantize_int8(scale, zp)
print("Quantized Tensor (int8 values):", x_quant.realize_cached_data().numpy()[0])
print("Quantized Tensor dtype:", x_quant.dtype)

# Dequantize back to float32
x_dequant = x_quant.dequantize()
print("Dequantized Tensor:", x_dequant.numpy()[0])

# Check error
error = np.abs(x.numpy() - x_dequant.numpy()).max()
print(f"Max Quantization Error: {error}")

## 2. Quantized Matrix Multiplication

We can perform matrix multiplication directly on `int8` tensors.
The result is a `float32` tensor (dequantized output).

In [None]:
# Create two random matrices
A = ndl.Tensor(np.random.randn(128, 128).astype(np.float32))
B = ndl.Tensor(np.random.randn(128, 128).astype(np.float32))

# Quantize them
min_a, max_a = A.numpy().min(), A.numpy().max()
scale_a, zp_a = compute_scale_zero_point(min_a, max_a)
A_quant = A.quantize_int8(scale_a, zp_a)

min_b, max_b = B.numpy().min(), B.numpy().max()
scale_b, zp_b = compute_scale_zero_point(min_b, max_b)
B_quant = B.quantize_int8(scale_b, zp_b)

# Perform quantized matmul
C_quant_out = A_quant @ B_quant

# Perform standard float32 matmul
C_ref = A @ B

# Compare results
print("Reference Output (first 5):", C_ref.numpy().flatten()[:5])
print("Quantized Output (first 5):", C_quant_out.numpy().flatten()[:5])

error = np.abs(C_ref.numpy() - C_quant_out.numpy()).mean()
print(f"Mean Absolute Error: {error}")

## 3. Neural Network Quantization

We will train a simple MLP on synthetic data, then apply Dynamic and Static quantization.

In [None]:
# Generate synthetic data
N = 1000
input_dim = 32
hidden_dim = 128
output_dim = 10

X = np.random.randn(N, input_dim).astype(np.float32)
# Random weights for ground truth
W1 = np.random.randn(input_dim, hidden_dim).astype(np.float32) / np.sqrt(input_dim)
W2 = np.random.randn(hidden_dim, output_dim).astype(np.float32) / np.sqrt(hidden_dim)
Y = np.maximum(0, X @ W1) @ W2 # ReLU activation

# Create Dataset and DataLoader
class SimpleDataset(ndl.data.Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

dataset = SimpleDataset(X, Y)
dataloader = ndl.data.DataLoader(dataset, batch_size=32)

# Define Model
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

model = MLP(input_dim, hidden_dim, output_dim)

# Train (briefly)
opt = ndl.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.SoftmaxLoss() # Just using SoftmaxLoss as a dummy loss for regression-ish task (or MSE if available)
# Wait, SoftmaxLoss expects integer labels usually?
# Let's use MSE if available, or implement it.
# nn_basic.py doesn't have MSELoss.
# I'll just use a simple L2 loss manually.

def l2_loss(y_pred, y_true):
    return ((y_pred - y_true)**2).sum() / y_pred.shape[0]

print("Training...")
for epoch in range(5):
    total_loss = 0
    for batch in dataloader:
        x_batch, y_batch = batch
        opt.reset_grad()
        out = model(x_batch)
        loss = l2_loss(out, y_batch)
        loss.backward()
        opt.step()
        total_loss += loss.numpy()
    print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader)}")

print("Training complete.")

### Dynamic Quantization

We quantize the weights, but activations are quantized dynamically at runtime.

In [None]:
import time

# Quantize weights
model.linear1.quantize_weights()
model.linear2.quantize_weights()

print("Weights quantized.")

# Measure inference time and error
start_time = time.time()
total_error = 0
for batch in dataloader:
    x_batch, y_batch = batch
    out = model(x_batch)
    # Compare with float32 forward (we need to disable quantization to compare? 
    # But we overwrote weights? No, quantize_weights stores quantized_weight separately.
    # But forward uses quantized_weight if present.
    # So we can't easily run float forward on the same model instance now without clearing quantized_weight.
    # Let's just measure output against ground truth Y (which is what we trained on).
    loss = l2_loss(out, y_batch)
    total_error += loss.numpy()
    
end_time = time.time()
print(f"Dynamic Quantization Inference Time: {end_time - start_time:.4f}s")
print(f"Total L2 Error: {total_error / len(dataloader)}")

### Static Quantization

We calibrate the model to find optimal scale/zero_point for activations, then run with static quantization.

In [None]:
from needle.quantization import calibrate

# Calibrate
print("Calibrating...")
calibrate(model, dataloader)
print("Calibration complete.")

# Measure inference time and error
start_time = time.time()
total_error = 0
for batch in dataloader:
    x_batch, y_batch = batch
    out = model(x_batch)
    loss = l2_loss(out, y_batch)
    total_error += loss.numpy()
    
end_time = time.time()
print(f"Static Quantization Inference Time: {end_time - start_time:.4f}s")
print(f"Total L2 Error: {total_error / len(dataloader)}")