# DistilBERT Inference on CPU with ONNX Runtime – FP32 and INT8 Quantization

This notebook benchmarks the inference performance of the **DistilBERT** model fine-tuned on the **SST-2 sentiment classification task**, using **ONNX Runtime** in a **CPU-only** environment (Google Colab).

We evaluate:
- Full-precision (FP32) inference via ONNX  
- 8-bit quantized inference using ONNX Runtime’s dynamic quantization tools [1]

All evaluations are performed on the full SST-2 validation set [2].  
We report:
- **Accuracy**
- **Total inference time** (in seconds)
- **RAM usage increase** during evaluation (in MB)

> ⚠️ Structured pruning is **not supported** in ONNX exports via PyTorch’s pruning API. See [Section 8](#8.-Limitations:-Structured-Pruning-and-ONNX) for details.

**References:**  
[1] ONNX Runtime Quantization Docs: https://onnxruntime.ai/docs/performance/quantization.html  
[2] SST-2 from the GLUE Benchmark: https://huggingface.co/datasets/glue/viewer/sst2  
[3] Sanh et al. (2019). *DistilBERT: A distilled version of BERT.* https://arxiv.org/abs/1910.01108


In [None]:
# Install dependencies
!pip install -q transformers datasets onnx onnxruntime psutil

# Imports
import torch
import numpy as np
import psutil
import time
import os
import onnx
import onnxruntime as ort
import platform
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# System info
print("Torch version:", torch.__version__)
print("ONNX Runtime version:", ort.__version__)
print("CPU:", platform.processor())
print("Total system RAM (GB):", round(psutil.virtual_memory().total / (1024**3), 2))
print("CUDA available:", torch.cuda.is_available())

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hTorch version: 2.6.0+cu124
ONNX Runtime version: 1.22.1
CPU: x86_64
Total system RAM (GB): 12.67
CUDA available: False


## 1. Load DistilBERT and SST-2 Validation Set

We load the fine-tuned DistilBERT model for SST-2 from Hugging Face, along with the full validation split from the GLUE benchmark.  
The dataset is tokenized to a fixed maximum length of 128 tokens and wrapped into a custom PyTorch-compatible dataset for ONNX evaluation.

In [None]:
# Load tokenizer and fine-tuned model
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)

# Load full SST-2 validation set
dataset = load_dataset("glue", "sst2", split="validation")

# Tokenize
def tokenize_function(example):
    return tokenizer(example["sentence"], padding="max_length", truncation=True, max_length=128)

tokenized_dataset = dataset.map(tokenize_function)

# Wrap as a PyTorch-compatible Dataset
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.input_ids = torch.tensor(hf_dataset["input_ids"])
        self.attention_mask = torch.tensor(hf_dataset["attention_mask"])
        self.labels = torch.tensor(hf_dataset["label"])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "label": self.labels[idx]
        }

dataset = SST2Dataset(tokenized_dataset)
print(f"Loaded {len(dataset)} validation samples.")

## 2. Export DistilBERT (FP32) to ONNX Format

We export the full-precision PyTorch model to ONNX format using a representative input from the validation set.  
The exported ONNX model will be used with ONNX Runtime for inference on CPU.

In [None]:
# Set model to evaluation mode and move to CPU
model.eval().cpu()

# Prepare dummy input using the first sample
sample = dataset[0]
inputs_onnx = {
    "input_ids": sample["input_ids"].unsqueeze(0),
    "attention_mask": sample["attention_mask"].unsqueeze(0)
}

# Export path
onnx_fp32_path = "distilbert_fp32.onnx"

# Export model to ONNX
torch.onnx.export(
    model,
    args=(inputs_onnx["input_ids"], inputs_onnx["attention_mask"]),
    f=onnx_fp32_path,
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
    opset_version=14 # Note: scaled_dot_product_attention requires opset >= 14

)

print(f"Exported ONNX model to: {onnx_fp32_path}")

  inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask


Exported ONNX model to: distilbert_fp32.onnx


## 3. Define ONNX Runtime Evaluation Function

This function runs inference on the full SST-2 validation set using ONNX Runtime.  
It computes classification accuracy, total inference time (in seconds), and peak RAM usage during evaluation.

In [None]:
def evaluate_onnx_model(onnx_path, dataset):
    # Initialize ONNX Runtime session
    session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

    # Get input names
    input_names = {inp.name: inp.name for inp in session.get_inputs()}
    output_name = session.get_outputs()[0].name

    # Memory and timing
    process = psutil.Process(os.getpid())
    start_ram = process.memory_info().rss
    start_time = time.time()

    correct = 0
    total = 0

    for sample in dataset:
        inputs = {
            input_names["input_ids"]: sample["input_ids"].unsqueeze(0).numpy(),
            input_names["attention_mask"]: sample["attention_mask"].unsqueeze(0).numpy()
        }

        outputs = session.run([output_name], inputs)[0]
        pred = int(np.argmax(outputs))
        label = int(sample["label"])

        correct += (pred == label)
        total += 1

    end_time = time.time()
    end_ram = process.memory_info().rss

    accuracy = correct / total
    total_latency = end_time - start_time
    ram_usage_mb = (end_ram - start_ram) / (1024 ** 2)

    return accuracy, total_latency, ram_usage_mb

## 4. Evaluate FP32 ONNX Model on CPU

We now run inference using the full-precision ONNX model on the full SST-2 validation set using ONNX Runtime.  
This serves as our performance baseline for comparison with pruned and quantized models.

In [None]:
# Run evaluation
accuracy_fp32, latency_fp32, ram_fp32 = evaluate_onnx_model("distilbert_fp32.onnx", dataset)

# Report results
print(f"Accuracy (FP32 ONNX): {accuracy_fp32:.2%}")
print(f"Total inference time (FP32 ONNX): {latency_fp32:.2f} seconds")
print(f"RAM usage increase (FP32 ONNX): {ram_fp32:.2f} MB")

Accuracy (FP32 ONNX): 91.06%
Total inference time (FP32 ONNX): 191.68 seconds
RAM usage increase (FP32 ONNX): 0.00 MB


## 5. Quantize the ONNX Model to 8-bit (INT8)

We use ONNX Runtime's post-training dynamic quantization tool to convert the full-precision model to 8-bit integers. This quantizes the model weights, reducing model size and potentially improving inference latency.

In [None]:
from onnxruntime.quantization import quantize_dynamic, QuantType

onnx_int8_path = "distilbert_int8.onnx"

# Quantize model (weights-only dynamic quantization)
quantize_dynamic(
    model_input="distilbert_fp32.onnx",
    model_output=onnx_int8_path,
    weight_type=QuantType.QInt8  # use QInt8 for signed int weights
)

print(f"Quantized ONNX model saved to: {onnx_int8_path}")

## 6. Evaluate the INT8 Quantized ONNX Model
We evaluate the quantized model on the same SST-2 validation set using ONNX Runtime.

In [None]:
accuracy_int8, latency_int8, ram_int8 = evaluate_onnx_model(onnx_int8_path, dataset)

# Report
print(f"Accuracy (INT8 ONNX): {accuracy_int8:.2%}")
print(f"Total inference time (INT8 ONNX): {latency_int8:.2f} seconds")
print(f"RAM usage increase (INT8 ONNX): {ram_int8:.2f} MB")

Accuracy (INT8 ONNX): 90.48%
Total inference time (INT8 ONNX): 121.08 seconds
RAM usage increase (INT8 ONNX): 0.00 MB


## 7. Summary: ONNX FP32 vs INT8

| Model     | Accuracy | Inference Time (s) | RAM Usage Increase (MB) |
|-----------|----------|--------------------|--------------------------|
| FP32 ONNX | 91.06%   | 191.68             | 0.00                     |
| INT8 ONNX | 90.48%   | 121.08             | 0.00                     |

The 8-bit quantized model achieves nearly identical classification accuracy to the original FP32 model, with a ~36.8% reduction in total inference time.

**Notes:**
- RAM usage measurements are based on Python process-level RSS deltas. Due to caching, memory reuse, or measurement granularity, the change may appear as 0.00 MB even when internal buffers are reallocated.
- This benchmark uses dynamic (post-training) quantization. Only model weights are quantized to 8-bit integers; activations remain in floating-point. As a result, accuracy degradation is minimal.

## 8. Limitations: Structured Pruning and ONNX

This notebook focuses on benchmarking ONNX Runtime inference for DistilBERT models in two configurations: full-precision (FP32) and 8-bit dynamically quantized. Structured pruning, although used in our PyTorch CPU experiments, is not included here for the following reasons:

- PyTorch’s `torch.nn.utils.prune` module does not modify layer structures. Instead, it applies binary masks over the original weights using reparameterization during the forward pass.
- When such a pruned model is exported to ONNX, these masks are removed and the exported graph contains the original dense weight tensors — including any zeroed-out values.
- ONNX Runtime executes inference on these dense weights, with no awareness of sparsity. As a result, pruning yields no performance benefit unless the model is manually restructured to remove the pruned dimensions entirely.

These limitations are consistent with PyTorch’s official pruning documentation [1], and they explain why structured-pruned models were excluded from this ONNX evaluation. Instead, pruning results are presented in our PyTorch CPU-based benchmarking notebook (`n2_dbert_quant_prun_cpu.ipynb`).

> [1] https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
