# LeNet-5 Experiment Reproduction Guide

This guide explains how to reproduce the LeNet-5 (Baseline, Pruned, and Quantized-Pruned) experiments from the "Deep Microcompression" paper.

## Required File Structure

This script assumes it is located within the original project's directory structure under the experiments directory. The development module must be accessible two levels up from this script.


## What to Expect

The script will run the full experiment, which involves three stages:

1. Baseline Model: Trains the original LeNet-5 model (20 epochs with early stopping) and saves it as lenet5_state_dict.pth.

1. Pruned Model: Loads the baseline weights, applies the optimal structured pruning (conv2d_1: 9, linear_0: 50), and retrains the model (20 epochs).

1. Quantized-Pruned Model: Applies 4-bit static quantization to the pruned model and performs Quantization-Aware Training (QAT) (15 epochs).

The script will print the Accuracy and Model Size for each of these three stages, allowing your supervisor to easily verify the results from Table 2 in your paper.

### Importing the necessary libraries

In [None]:
import sys
import os

try:
    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms

except ImportError:
    %pip install torch torchvision tqdm
    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms

In [2]:
# This assumes the script is in 'project_root/experiments/reproduce_table1'
sys.path.append("../../")

try:
    from development import (
        Sequential,
        Conv2d,
        BatchNorm2d,
        ReLU,
        MaxPool2d,
        Flatten,
        Linear,
        EarlyStopper,
        QuantizationGranularity,
        QuantizationScheme
    )
except ImportError:
    print("Error: Could not import the 'development' module.")
    print("Please ensure this script is run from 'experiments/reproduce_table1/'")
    print("and the 'development' module is in the project root ('../../').")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Check for a GPU, defaults to the cpu
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# To load a trained model, to skip the initial training step
BASELINE_MODEL_FILE = "lenet5_state_dict.pth"
INPUT_SHAPE = (1, 28, 28)
DATASET_DIR = "../../Datasets"

Using device: cuda


In [4]:
LUCKY_NUMBER = 25

# Set random seed for reproducibility
torch.manual_seed(LUCKY_NUMBER)
if DEVICE == "cuda":
    torch.cuda.manual_seed(LUCKY_NUMBER)
# cuDNN determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)


### Getting the MNIST Dataset

In [5]:
# --- 1. Load Data ---
def get_data_loaders():
    print("Loading MNIST dataset...")
    data_transform = transforms.Compose([
        transforms.RandomCrop((24, 24)),
        transforms.Resize(INPUT_SHAPE[1:]),
        transforms.ToTensor(),
    ])
    
    mnist_train_dataset = datasets.MNIST(DATASET_DIR, train=True, download=True, transform=data_transform)
    mnist_test_dataset = datasets.MNIST(DATASET_DIR, train=False, download=True, transform=data_transform)
    
    mnist_train_loader = data.DataLoader(mnist_train_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count(), drop_last=False) # type: ignore
    mnist_test_loader = data.DataLoader(mnist_test_dataset, batch_size=32, shuffle=False, num_workers=os.cpu_count(), drop_last=False) # type: ignore
    
    return mnist_train_loader, mnist_test_loader


In [6]:
# --- 3. Training & Evaluation Functions ---
def accuracy_fun(y_pred, y_true):
    return (y_pred.argmax(dim=1) == y_true).to(torch.float).mean().item() * 100

### Defining and Training the baseline model

In [7]:
# --- 2. Define Model ---
def get_baseline_model():
    return Sequential(
        Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, pad=(2, 2, 2, 2), bias=True),
        BatchNorm2d(num_features=6),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True),
        BatchNorm2d(num_features=16),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Flatten(),
        Linear(in_features=16*5*5, out_features=84, bias=True),
        ReLU(),
        Linear(in_features=84, out_features=10, bias=True)
    ).to(DEVICE)


def train_baseline(model, train_loader, test_loader):
    print("\n--- STAGE 1: Training Baseline Model ---")
    if os.path.exists(BASELINE_MODEL_FILE):
        print(f"Loading existing baseline weights from {BASELINE_MODEL_FILE}...")
        model.load_state_dict(torch.load(BASELINE_MODEL_FILE, weights_only=True), strict=False)
        model.to(DEVICE)
        return model

    print(f"No baseline weights found. Training from scratch (up to 15 epochs)...")
    early_stopper = EarlyStopper(
        monitor_metric="validation_loss",
        delta=1e-7,
        mode="min",
        patience=4,
        restore_best_state_dict=True,
    )
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.Adam(model.parameters(), lr=1.e-3)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=2)

    model.fit(
        train_loader, 15, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        callbacks=[early_stopper],
        device=DEVICE
    )
    
    print(f"Saving baseline weights to {BASELINE_MODEL_FILE}...")
    torch.save(model.cpu().state_dict(), BASELINE_MODEL_FILE)
    model.to(DEVICE)
    return model


### Pruning the model

In [8]:
def train_pruned(baseline_model, train_loader, test_loader):
    print("\n--- STAGE 2: Applying Pruning & Retraining ---")
    
    # Pruning parameters from paper (Table 1 / Sec 4.1.1)
    pruning_config = {
        "prune_channel": {
            "sparsity": {
                "conv2d_0": 0,
                "conv2d_1": 9,
                "linear_0": 64
            },
            "metric": "l2"
        }
    }
    
    print(f"Applying pruning config: {pruning_config['prune_channel']['sparsity']}")
    
    # Re-initialize model architecture with pruning
    pruned_model = baseline_model.init_compress(pruning_config, INPUT_SHAPE).to(DEVICE)
    
    # Retrain (fine-tune) the pruned model
    print("Retraining pruned model (15 epochs)...")
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.SGD(pruned_model.parameters(), lr=1.e-3, weight_decay=5e-4, momentum=.9)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=1)

    pruned_model.fit(
        train_loader, 20, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        device=DEVICE
    )
    return pruned_model


### Quantizing and Bit Packing the model

In [9]:
def train_quantized_pruned(pruned_model, train_loader, test_loader):
    print("\n--- STAGE 3: Applying Quantization (QAT) & Retraining ---")
    
    # Configs from paper (Table 2, 4-bit static) and notebook
    pruning_config = {
        "prune_channel": {
            "sparsity": {
                "conv2d_0": 0,
                "conv2d_1": 9,
                "linear_0": 64
            },
            "metric": "l2"
        }
    }
    quantization_config = {
        "quantize": {
            "scheme": QuantizationScheme.STATIC,
            "granularity": QuantizationGranularity.PER_TENSOR,
            "bitwidth": 4
        }
    }
    full_compression_config = {**pruning_config, **quantization_config}

    print(f"Applying quantization config: 4-bit, STATIC, PER_TENSOR")
    
    # Get one batch of calibration data
    calibration_data = next(iter(test_loader))[0].to(DEVICE)
    
    # Initialize compression for QAT
    quantized_model = pruned_model.init_compress(
        full_compression_config, 
        INPUT_SHAPE, 
        calibration_data
    ).to(DEVICE)

    # Perform Quantization-Aware Training (15 epochs from paper Table 2)
    print("Performing QAT (15 epochs)...")
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.SGD(quantized_model.parameters(), lr=1.e-4, weight_decay=5e-4, momentum=.9)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=1)

    quantized_model.fit(
        train_loader, 15, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        device=DEVICE
    )
    return quantized_model


### Baseline Result

In [10]:
# Get Data
train_loader, test_loader = get_data_loaders()

# --- STAGE 1: BASELINE ---
baseline_model = get_baseline_model()
baseline_model = train_baseline(baseline_model, train_loader, test_loader)

print("Evaluating baseline model...")
baseline_eval = baseline_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"==> STAGE 1 (Baseline) COMPLETE ==")
print(f"    Accuracy: {baseline_eval['acc']:.2f}%")


Loading MNIST dataset...

--- STAGE 1: Training Baseline Model ---
Loading existing baseline weights from lenet5_state_dict.pth...
Evaluating baseline model...


                                                              

==> STAGE 1 (Baseline) COMPLETE ==
    Accuracy: 99.35%




### Pruned Result


In [11]:

# --- STAGE 2: PRUNED ---
# Use a copy to keep the original baseline model clean
pruned_model = train_pruned(baseline_model, train_loader, test_loader)

print("Evaluating pruned model...")
pruned_eval = pruned_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"\n==> STAGE 2 (Pruned) COMPLETE ==")
print(f"    Accuracy: {pruned_eval['acc']:.2f}%")

print(f"\n===> Layerwise Prunning Results:")
for i, (name, module) in enumerate(pruned_model.names_layers()):
    if ("conv2d" in name) or ("linear" in name):
        print(f"    Layer name : {name}, Original size {baseline_model[i].get_size_in_bits()/(8*1024)} Reduced size {module.get_size_in_bits()/(8*1024)}:  Size Ratio: {(1 - module.get_size_in_bits()/baseline_model[i].get_size_in_bits())*100:.2f}%") # type: ignore



--- STAGE 2: Applying Pruning & Retraining ---
Applying pruning config: {'conv2d_0': 0, 'conv2d_1': 9, 'linear_0': 64}
Retraining pruned model (15 epochs)...


DMC Training (Epochs 1-20):   5%|▌         | 1/20 [00:15<05:02, 15.93s/it]

epoch    0 | train loss 0.2733 | validation loss 0.1073 | train acc 92.0550 | validation acc 96.8200


DMC Training (Epochs 1-20):  10%|█         | 2/20 [00:31<04:44, 15.82s/it]

epoch    1 | train loss 0.1048 | validation loss 0.0873 | train acc 96.8150 | validation acc 97.4100


DMC Training (Epochs 1-20):  15%|█▌        | 3/20 [00:49<04:46, 16.87s/it]

epoch    2 | train loss 0.0873 | validation loss 0.0689 | train acc 97.3367 | validation acc 97.9000


DMC Training (Epochs 1-20):  20%|██        | 4/20 [01:07<04:32, 17.04s/it]

epoch    3 | train loss 0.0792 | validation loss 0.0678 | train acc 97.6200 | validation acc 97.9200


DMC Training (Epochs 1-20):  25%|██▌       | 5/20 [01:24<04:16, 17.09s/it]

epoch    4 | train loss 0.0721 | validation loss 0.0583 | train acc 97.7833 | validation acc 98.2100


DMC Training (Epochs 1-20):  30%|███       | 6/20 [01:41<03:58, 17.05s/it]

epoch    5 | train loss 0.0677 | validation loss 0.0600 | train acc 97.7983 | validation acc 98.0000


DMC Training (Epochs 1-20):  35%|███▌      | 7/20 [01:58<03:42, 17.08s/it]

epoch    6 | train loss 0.0647 | validation loss 0.0531 | train acc 97.9600 | validation acc 98.3900


DMC Training (Epochs 1-20):  40%|████      | 8/20 [02:15<03:24, 17.07s/it]

epoch    7 | train loss 0.0610 | validation loss 0.0578 | train acc 98.1300 | validation acc 98.0900


DMC Training (Epochs 1-20):  45%|████▌     | 9/20 [02:32<03:09, 17.18s/it]

epoch    8 | train loss 0.0595 | validation loss 0.0477 | train acc 98.1183 | validation acc 98.5200


DMC Training (Epochs 1-20):  50%|█████     | 10/20 [02:49<02:51, 17.16s/it]

epoch    9 | train loss 0.0599 | validation loss 0.0572 | train acc 98.1167 | validation acc 98.3400


DMC Training (Epochs 1-20):  55%|█████▌    | 11/20 [03:06<02:34, 17.12s/it]

epoch   10 | train loss 0.0572 | validation loss 0.0492 | train acc 98.2300 | validation acc 98.5000


DMC Training (Epochs 1-20):  60%|██████    | 12/20 [03:23<02:15, 16.93s/it]

epoch   11 | train loss 0.0531 | validation loss 0.0497 | train acc 98.3617 | validation acc 98.3600


DMC Training (Epochs 1-20):  65%|██████▌   | 13/20 [03:40<01:58, 16.87s/it]

epoch   12 | train loss 0.0531 | validation loss 0.0476 | train acc 98.3900 | validation acc 98.3700


DMC Training (Epochs 1-20):  70%|███████   | 14/20 [03:56<01:40, 16.78s/it]

epoch   13 | train loss 0.0532 | validation loss 0.0461 | train acc 98.3117 | validation acc 98.4900


DMC Training (Epochs 1-20):  75%|███████▌  | 15/20 [04:13<01:23, 16.69s/it]

epoch   14 | train loss 0.0518 | validation loss 0.0484 | train acc 98.3933 | validation acc 98.3900


DMC Training (Epochs 1-20):  80%|████████  | 16/20 [04:29<01:06, 16.68s/it]

epoch   15 | train loss 0.0518 | validation loss 0.0478 | train acc 98.3900 | validation acc 98.5100


DMC Training (Epochs 1-20):  85%|████████▌ | 17/20 [04:46<00:49, 16.63s/it]

epoch   16 | train loss 0.0510 | validation loss 0.0446 | train acc 98.4300 | validation acc 98.6300


DMC Training (Epochs 1-20):  90%|█████████ | 18/20 [05:03<00:33, 16.62s/it]

epoch   17 | train loss 0.0516 | validation loss 0.0502 | train acc 98.4150 | validation acc 98.3700


DMC Training (Epochs 1-20):  95%|█████████▌| 19/20 [05:19<00:16, 16.68s/it]

epoch   18 | train loss 0.0515 | validation loss 0.0482 | train acc 98.3283 | validation acc 98.4500


DMC Training (Epochs 1-20): 100%|██████████| 20/20 [05:36<00:00, 16.83s/it]


epoch   19 | train loss 0.0512 | validation loss 0.0481 | train acc 98.4300 | validation acc 98.5500
Evaluating pruned model...


                                                              


==> STAGE 2 (Pruned) COMPLETE ==
    Accuracy: 98.55%

===> Layerwise Prunning Results:
    Layer name : conv2d_0, Original size 0.609375 Reduced size 0.609375:  Size Ratio: 0.00%
    Layer name : conv2d_1, Original size 9.4375 Reduced size 4.12890625:  Size Ratio: 56.25%
    Layer name : linear_0, Original size 131.578125 Reduced size 13.75:  Size Ratio: 89.55%
    Layer name : linear_1, Original size 3.3203125 Reduced size 0.8203125:  Size Ratio: 75.29%


### Quantized and Bit-Packed Result

In [12]:
# --- STAGE 3: QUANTIZED-PRUNED ---
quantized_model = train_quantized_pruned(pruned_model, train_loader, test_loader)

print("Evaluating final quantized-pruned model...")
quantized_eval = quantized_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"\n==> STAGE 3 (Quantized-Pruned) COMPLETE ==")
print(f"    Accuracy: {quantized_eval['acc']:.2f}%")

print(f"\n===> Layerwise Pruned and Quantized Results:")
for i, (name, module) in enumerate(quantized_model.names_layers()):
    if ("conv2d" in name) or ("linear" in name):
        print(f"    Layer name : {name}, Original size {baseline_model[i].get_size_in_bits()/(8*1024):8.4f} Reduced size {module.get_size_in_bits()/(8*1024):8.4f}:  Size Ratio: {(1 - module.get_size_in_bits()/baseline_model[i].get_size_in_bits())*100:.2f}%") # type: ignore



--- STAGE 3: Applying Quantization (QAT) & Retraining ---
Applying quantization config: 4-bit, STATIC, PER_TENSOR
Performing QAT (15 epochs)...


DMC Training (Epochs 1-15):   7%|▋         | 1/15 [00:35<08:16, 35.48s/it]

epoch    0 | train loss 0.1089 | validation loss 0.0923 | train acc 96.3633 | validation acc 97.0000


DMC Training (Epochs 1-15):  13%|█▎        | 2/15 [01:11<07:42, 35.58s/it]

epoch    1 | train loss 0.0980 | validation loss 0.0879 | train acc 96.6467 | validation acc 96.9600


DMC Training (Epochs 1-15):  20%|██        | 3/15 [01:46<07:06, 35.53s/it]

epoch    2 | train loss 0.0962 | validation loss 0.0838 | train acc 96.6883 | validation acc 97.0600


DMC Training (Epochs 1-15):  27%|██▋       | 4/15 [02:22<06:30, 35.53s/it]

epoch    3 | train loss 0.0948 | validation loss 0.0882 | train acc 96.7250 | validation acc 96.7700


DMC Training (Epochs 1-15):  33%|███▎      | 5/15 [02:57<05:55, 35.54s/it]

epoch    4 | train loss 0.0980 | validation loss 0.0872 | train acc 96.6933 | validation acc 96.8200


DMC Training (Epochs 1-15):  40%|████      | 6/15 [03:31<05:15, 35.09s/it]

epoch    5 | train loss 0.0924 | validation loss 0.0838 | train acc 96.8400 | validation acc 96.8300


DMC Training (Epochs 1-15):  47%|████▋     | 7/15 [04:06<04:38, 34.80s/it]

epoch    6 | train loss 0.0922 | validation loss 0.0841 | train acc 96.8100 | validation acc 97.2500


DMC Training (Epochs 1-15):  53%|█████▎    | 8/15 [04:41<04:04, 34.99s/it]

epoch    7 | train loss 0.0890 | validation loss 0.0798 | train acc 96.9867 | validation acc 97.0800


DMC Training (Epochs 1-15):  60%|██████    | 9/15 [05:15<03:28, 34.77s/it]

epoch    8 | train loss 0.0897 | validation loss 0.0813 | train acc 96.9667 | validation acc 97.1000


DMC Training (Epochs 1-15):  67%|██████▋   | 10/15 [05:50<02:53, 34.65s/it]

epoch    9 | train loss 0.0898 | validation loss 0.0796 | train acc 96.9250 | validation acc 97.1600


DMC Training (Epochs 1-15):  73%|███████▎  | 11/15 [06:23<02:17, 34.33s/it]

epoch   10 | train loss 0.0919 | validation loss 0.0737 | train acc 96.8883 | validation acc 97.3600


DMC Training (Epochs 1-15):  80%|████████  | 12/15 [06:57<01:42, 34.17s/it]

epoch   11 | train loss 0.0942 | validation loss 0.0738 | train acc 96.8867 | validation acc 97.4200


DMC Training (Epochs 1-15):  87%|████████▋ | 13/15 [07:31<01:08, 34.10s/it]

epoch   12 | train loss 0.0933 | validation loss 0.0817 | train acc 96.8633 | validation acc 97.1000


DMC Training (Epochs 1-15):  93%|█████████▎| 14/15 [08:05<00:34, 34.19s/it]

epoch   13 | train loss 0.0904 | validation loss 0.0779 | train acc 96.9100 | validation acc 97.0600


DMC Training (Epochs 1-15): 100%|██████████| 15/15 [08:40<00:00, 34.69s/it]


epoch   14 | train loss 0.0909 | validation loss 0.0772 | train acc 96.8950 | validation acc 97.4600
Evaluating final quantized-pruned model...


                                                             


==> STAGE 3 (Quantized-Pruned) COMPLETE ==
    Accuracy: 97.59%

===> Layerwise Pruned and Quantized Results:
    Layer name : conv2d_0, Original size   0.6094 Reduced size   0.1064:  Size Ratio: 82.53%
    Layer name : conv2d_1, Original size   9.4375 Reduced size   0.5498:  Size Ratio: 94.17%
    Layer name : linear_0, Original size 131.5781 Reduced size   1.7969:  Size Ratio: 98.63%
    Layer name : linear_1, Original size   3.3203 Reduced size   0.1465:  Size Ratio: 95.59%


### Final Result

In [15]:
original_size = baseline_model.fuse(device=DEVICE).get_size_in_bytes()
pruned_size = pruned_model.fuse(device=DEVICE).get_size_in_bytes()
quantized_size = quantized_model.fuse(device=DEVICE).get_size_in_bytes()

In [16]:
print("\n--- REPRODUCTION FINISHED ---")
print("\nFinal Results Summary:")
print(f"Baseline:   {baseline_eval['acc']:.2f}% Acc, {original_size/1024:7.2f}KB")
print(f"Pruned:     {pruned_eval['acc']:.2f}% Acc, {pruned_size/1024:7.2f}KB, {pruned_size/original_size*100:.2f}% of original")
print(f"Quantized:  {quantized_eval['acc']:.2f}% Acc, {quantized_size/1024:7.2f}KB, {quantized_size/original_size*100:.2f}% of original")


--- REPRODUCTION FINISHED ---

Final Results Summary:
Baseline:   99.35% Acc,  144.95KB
Pruned:     98.55% Acc,   19.31KB, 13.32% of original
Quantized:  97.59% Acc,    2.60KB, 1.79% of original
