# DistilBERT on CPU with ONNX Runtime – FP32, Pruned, and Quantized Inference

This notebook benchmarks the inference performance of a fine-tuned DistilBERT model on the SST-2 sentiment classification task using ONNX Runtime on CPU.

We evaluate:

- Full-precision (FP32) inference using ONNX
- Structured-pruned DistilBERT models exported to ONNX (30%, 40%, 50% sparsity)
- Structured-pruned + 8-bit quantized models using ONNX Runtime’s dynamic quantization tools

All experiments are conducted on a CPU-only environment in Google Colab.  
We report:

- Accuracy (on the full SST-2 validation set)
- Total inference time (seconds)
- Total RAM usage (MB)

In [None]:
# Install dependencies
!pip install -q transformers datasets onnx onnxruntime psutil

# Imports
import torch
import numpy as np
import psutil
import time
import os
import onnx
import onnxruntime as ort
import platform
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# System info
print("Torch version:", torch.__version__)
print("ONNX Runtime version:", ort.__version__)
print("CPU:", platform.processor())
print("Total system RAM (GB):", round(psutil.virtual_memory().total / (1024**3), 2))
print("CUDA available:", torch.cuda.is_available())

## 2. Load DistilBERT and SST-2 Validation Set

We load the fine-tuned DistilBERT model for SST-2 from Hugging Face, along with the full validation split from the GLUE benchmark.  
The dataset is tokenized to a fixed maximum length of 128 tokens and wrapped into a custom PyTorch-compatible dataset for ONNX evaluation.

In [None]:
# Load tokenizer and fine-tuned model
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)

# Load full SST-2 validation set
dataset = load_dataset("glue", "sst2", split="validation")

# Tokenize
def tokenize_function(example):
    return tokenizer(example["sentence"], padding="max_length", truncation=True, max_length=128)

tokenized_dataset = dataset.map(tokenize_function)

# Wrap as a PyTorch-compatible Dataset
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.input_ids = torch.tensor(hf_dataset["input_ids"])
        self.attention_mask = torch.tensor(hf_dataset["attention_mask"])
        self.labels = torch.tensor(hf_dataset["label"])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "label": self.labels[idx]
        }

dataset = SST2Dataset(tokenized_dataset)
print(f"Loaded {len(dataset)} validation samples.")

## 3. Export DistilBERT (FP32) to ONNX Format

We export the full-precision PyTorch model to ONNX format using a representative input from the validation set.  
The exported ONNX model will be used with ONNX Runtime for inference on CPU.

In [None]:
# Set model to evaluation mode and move to CPU
model.eval().cpu()

# Prepare dummy input using the first sample
sample = dataset[0]
inputs_onnx = {
    "input_ids": sample["input_ids"].unsqueeze(0),
    "attention_mask": sample["attention_mask"].unsqueeze(0)
}

# Export path
onnx_fp32_path = "distilbert_fp32.onnx"

# Export model to ONNX
torch.onnx.export(
    model,
    args=(inputs_onnx["input_ids"], inputs_onnx["attention_mask"]),
    f=onnx_fp32_path,
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}},
    opset_version=14 # Note: scaled_dot_product_attention requires opset >= 14

)

print(f"Exported ONNX model to: {onnx_fp32_path}")

## 4. Define ONNX Runtime Evaluation Function

This function runs inference on the full SST-2 validation set using ONNX Runtime.  
It computes classification accuracy, total inference time (in seconds), and peak RAM usage during evaluation.

In [None]:
def evaluate_onnx_model(onnx_path, dataset):
    # Initialize ONNX Runtime session
    session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

    # Get input names
    input_names = {inp.name: inp.name for inp in session.get_inputs()}
    output_name = session.get_outputs()[0].name

    # Memory and timing
    process = psutil.Process(os.getpid())
    start_ram = process.memory_info().rss
    start_time = time.time()

    correct = 0
    total = 0

    for sample in dataset:
        inputs = {
            input_names["input_ids"]: sample["input_ids"].unsqueeze(0).numpy(),
            input_names["attention_mask"]: sample["attention_mask"].unsqueeze(0).numpy()
        }

        outputs = session.run([output_name], inputs)[0]
        pred = int(np.argmax(outputs))
        label = int(sample["label"])

        correct += (pred == label)
        total += 1

    end_time = time.time()
    end_ram = process.memory_info().rss

    accuracy = correct / total
    total_latency = end_time - start_time
    ram_usage_mb = (end_ram - start_ram) / (1024 ** 2)

    return accuracy, total_latency, ram_usage_mb

## 5. Evaluate FP32 ONNX Model on CPU

We now run inference using the full-precision ONNX model on the full SST-2 validation set using ONNX Runtime.  
This serves as our performance baseline for comparison with pruned and quantized models.

In [None]:
# Run evaluation
accuracy_fp32, latency_fp32, ram_fp32 = evaluate_onnx_model("distilbert_fp32.onnx", dataset)

# Report results
print(f"Accuracy (FP32 ONNX): {accuracy_fp32:.2%}")
print(f"Total inference time (FP32 ONNX): {latency_fp32:.2f} seconds")
print(f"RAM usage increase (FP32 ONNX): {ram_fp32:.2f} MB")