# DistilBERT Compression Sweep (CPU) – Pruning + Quantization

This notebook evaluates the performance of a fine-tuned DistilBERT model on the SST-2 sentiment classification task under structured pruning and 8-bit dynamic quantization.

We compare:
- Pruned FP32 inference
- Pruned + Quantized (INT8) inference

Across pruning levels:
- 30%
- 40%
- 50%

Reported metrics:
- Accuracy
- Latency (ms/sample)
- Total RAM usage (MB)

All inference is done on CPU using PyTorch and Hugging Face Transformers.

In [None]:
!pip install transformers datasets torch fsspec==2023.6.0 psutil

import torch.nn.utils.prune as prune
import copy
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import torch
import time
import psutil
import pandas as pd

## 1. Load Fine-Tuned Model and Tokenizer

In [None]:
model_id = "distilbert-base-uncased-finetuned-sst-2-english"

tokenizer = AutoTokenizer.from_pretrained(model_id)
base_model = AutoModelForSequenceClassification.from_pretrained(model_id)

## 2. Load and Preprocess SST-2 Validation Set

In [4]:
dataset = load_dataset("glue", "sst2")
texts = dataset["validation"]["sentence"][:]
labels = dataset["validation"]["label"]

encodings = tokenizer(texts, truncation=True, padding=True, max_length=128, return_tensors="pt")

## 3. Define Evaluation Dataset and DataLoader

In [5]:
class SST2Dataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "label": torch.tensor(self.labels[idx]),
        }

val_dataset = SST2Dataset(encodings, labels)
val_loader = DataLoader(val_dataset, batch_size=16)

## 4. Prune and Quantize the Model

This function:
- Deepcopies the original fine-tuned model
- Applies **unstructured L1 pruning** to all `nn.Linear` layers
- Removes pruning masks (making pruning permanent)
- Applies **8-bit dynamic quantization** via PyTorch

Returns the quantized pruned model.

In [6]:
def prune_and_quantize(base_model, pruning_amount):
    model = copy.deepcopy(base_model)

    # Apply pruning
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_amount)

    # Remove pruning masks
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, 'weight')
            except:
                pass

    # Apply dynamic quantization
    model.cpu()
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    )

    return quantized_model

## 5. Define Evaluation Function

In [7]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    start = time.time()

    with torch.no_grad():
        for batch in dataloader:
            inputs = {
                "input_ids": batch["input_ids"],
                "attention_mask": batch["attention_mask"]
            }
            outputs = model(**inputs)
            predictions = outputs.logits.argmax(dim=-1)
            correct += (predictions == batch["label"]).sum().item()
            total += batch["label"].size(0)

    end = time.time()
    latency = end - start
    accuracy = correct / total * 100
    memory_mb = psutil.Process().memory_info().rss / (1024 * 1024)

    return accuracy, latency, memory_mb


## 6. Run Compression Sweep

We apply pruning at 30%, 40%, and 50%, followed by quantization and evaluation.

In [None]:
results = []
for pruning_level in [0.3, 0.4, 0.5]:
    print(f"Running test for pruning level: {int(pruning_level * 100)}%")
    quantized_model = prune_and_quantize(base_model, pruning_level)
    acc, time_taken, mem = evaluate_model(quantized_model, val_loader)
    results.append({
        "Pruning": f"{int(pruning_level * 100)}%",
        "Accuracy": acc,
        "Latency (s)": time_taken,
        "Memory (MB)": mem
    })

In [9]:
pd.DataFrame(results)

Unnamed: 0,Pruning,Accuracy,Latency (s),Memory (MB)
0,30%,90.481651,60.211333,2330.03125
1,40%,88.53211,56.272866,2493.65625
2,50%,87.155963,47.522224,2697.054688


## 7. Summary

Structured pruning combined with 8-bit dynamic quantization results in:

- A **gradual drop in accuracy** as sparsity increases (expected)
- A **consistent drop in latency** with higher pruning (fewer ops)
- An **increase in RAM usage**, likely due to additional overhead from quantization structures or PyTorch internals

These results confirm that pruning can offer latency improvements even on CPU, though memory usage does not drop unless a sparse-aware backend is used (e.g., ONNX Runtime or DeepSparse).
