# MNIST Quantization Benchmark

This notebook trains a simple MLP on MNIST using the Needle framework (CUDA backend), quantizes it to Int8, and benchmarks the inference speed and memory usage.

In [1]:
import os
import sys
import time
import numpy as np

# Set backend to ndarray (CUDA support)
os.environ["NEEDLE_BACKEND"] = "nd"
sys.path.append("./python")

import needle as ndl
from needle import nn, optim
from needle.data import DataLoader
from needle.data.datasets.mnist_dataset import MNISTDataset

# Check device
device = ndl.cuda()
print(f"Using device: {device}")

Using needle backend
Using device: cuda()


In [2]:
# Load Data
# Ensure data is in the 'data' folder
train_dataset = MNISTDataset("./data/train-images-idx3-ubyte.gz", "./data/train-labels-idx1-ubyte.gz")
test_dataset = MNISTDataset("./data/t10k-images-idx3-ubyte.gz", "./data/t10k-labels-idx1-ubyte.gz")

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 60000
Test size: 10000


In [3]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device=None):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim, device=device, dtype="float32")
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, output_dim, device=device, dtype="float32")

    def forward(self, x):
        x = x.reshape((x.shape[0], -1)) # Flatten
        out = self.linear1(x)
        out = self.relu(out)
        out = self.linear2(out)
        return out

model = MLP(784, 256, 10, device=device)
print("Model initialized.")

Model initialized.


In [None]:
# Training
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.SoftmaxLoss()

epochs = 3
print("Starting training...")
model.train()

for epoch in range(epochs):
    total_loss = 0
    count = 0
    for i, (batch_X, batch_y) in enumerate(train_loader):
        # Move data to device
        batch_X = ndl.Tensor(batch_X, device=device)
        batch_y = ndl.Tensor(batch_y, device=device)
        
        optimizer.reset_grad()
        out = model(batch_X)
        loss = loss_fn(out, batch_y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.numpy()
        count += 1
        
    print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {total_loss/count:.4f}")

print("Training complete.")

Starting training...
Epoch 1/3, Avg Loss: 0.3093
Epoch 1/3, Avg Loss: 0.3093
Epoch 2/3, Avg Loss: 0.1339
Epoch 2/3, Avg Loss: 0.1339
Epoch 3/3, Avg Loss: 0.0915
Training complete.


In [None]:
def benchmark_inference(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    # Warmup
    for i, (batch_X, batch_y) in enumerate(loader):
        if i > 2: break
        batch_X = ndl.Tensor(batch_X, device=device)
        _ = model(batch_X)
        
    start_time = time.perf_counter()
    
    for batch_X, batch_y in loader:
        batch_X = ndl.Tensor(batch_X, device=device)
        out = model(batch_X)
        pred = out.numpy().argmax(axis=1)
        correct += (pred == batch_y.numpy()).sum()
        total += batch_y.shape[0]
        
    end_time = time.perf_counter()
    duration = end_time - start_time
    accuracy = correct / total
    
    return {
        "accuracy": accuracy,
        "duration": duration,
        "throughput": total/duration
    }

def estimate_memory(model):
    """
    Estimates memory usage of model parameters.
    For Float32: 4 bytes per element.
    For Int8 Quantized: 
      - Int8 weights (1 byte)
      - Scale (float32, 4 bytes)
      - Zero Point (int8, 1 byte)
      - Bias (float32, 4 bytes)
    """
    total_bytes = 0
    # Use _children() which returns all submodules recursively in Needle
    modules = model._children()
    
    for m in modules:
        if isinstance(m, nn.Linear):
            bias = getattr(m, "bias", None)
            bias_mem = bias.shape[0] * bias.shape[1] * 4 if bias is not None else 0
            
            # Check if quantized
            if m.use_int8 and m._weight_q is not None:
                # Quantized memory
                q_mem = m._weight_q.memory_bytes()
                total_bytes += q_mem + bias_mem
            else:
                # Float32 memory
                w_mem = m.weight.shape[0] * m.weight.shape[1] * 4
                total_bytes += w_mem + bias_mem
                
    return total_bytes / (1024 * 1024) # MB

In [None]:
# Benchmark Float32
print("Benchmarking Float32 Model...")
res_fp32 = benchmark_inference(model, test_loader, device)
print(f"Accuracy: {res_fp32['accuracy']:.4f}")
print(f"Inference Time: {res_fp32['duration']:.4f} s")
print(f"Throughput: {res_fp32['throughput']:.2f} samples/s")

mem_fp32 = estimate_memory(model)
print(f"Estimated Memory: {mem_fp32:.4f} MB")

# Quantize
print("\nQuantizing Model...")
model.eval()
model.linear1.enable_quantization(axis=1)
model.linear2.enable_quantization(axis=1)

# Benchmark Int8
print("\nBenchmarking Int8 Model...")
res_int8 = benchmark_inference(model, test_loader, device)
print(f"Accuracy: {res_int8['accuracy']:.4f}")
print(f"Inference Time: {res_int8['duration']:.4f} s")
print(f"Throughput: {res_int8['throughput']:.2f} samples/s")

mem_int8 = estimate_memory(model)
print(f"Estimated Memory: {mem_int8:.4f} MB")

# Comparison
print("\n=== Comparison ===")
print(f"Memory Reduction: {100 * (1 - mem_int8/mem_fp32):.2f}%")
print(f"Speedup: {res_fp32['duration']/res_int8['duration']:.2f}x")
print(f"Accuracy Drop: {100 * (res_fp32['accuracy'] - res_int8['accuracy']):.2f}%")