# LeNet-5 Experiment Reproduction Guide

This guide explains how to reproduce the LeNet-5 (Baseline, Pruned, and Quantized-Pruned) experiments from the "Deep Microcompression" paper.

## Required File Structure

This script assumes it is located within the original project's directory structure under the experiments directory. The development module must be accessible two levels up from this script.


## What to Expect

The script will run the full experiment, which involves three stages:

1. Baseline Model: Trains the original LeNet-5 model (20 epochs with early stopping) and saves it as lenet5_state_dict.pth.

1. Pruned Model: Loads the baseline weights, applies the optimal structured pruning (conv2d_1: 9, linear_0: 50), and retrains the model (20 epochs).

1. Quantized-Pruned Model: Applies 4-bit static quantization to the pruned model and performs Quantization-Aware Training (QAT) (15 epochs).

The script will print the Accuracy and Model Size for each of these three stages, allowing your supervisor to easily verify the results from Table 2 in your paper.

### Importing the necessary libraries

In [None]:
import sys
import os

try:
    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms

except ImportError:
    %pip install torch torchvision tqdm
    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms

In [None]:
# This assumes the script is in 'project_root/experiments/reproduce_table1'
sys.path.append("../../")

try:
    from development import (
        Sequential,
        Conv2d,
        BatchNorm2d,
        ReLU,
        MaxPool2d,
        Flatten,
        Linear,
        EarlyStopper,
        QuantizationGranularity,
        QuantizationScheme
    )
except ImportError:
    print("Error: Could not import the 'development' module.")
    print("Please ensure this script is run from 'experiments/reproduce_table1/'")
    print("and the 'development' module is in the project root ('../../').")


In [None]:
# Check for a GPU, defaults to the cpu
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# To load a trained model, to skip the initial training step
BASELINE_MODEL_FILE = "lenet5_state_dict.pth"
INPUT_SHAPE = (1, 28, 28)
DATASET_DIR = "../../Datasets"

In [None]:
LUCKY_NUMBER = 25

# Set random seed for reproducibility
torch.manual_seed(LUCKY_NUMBER)
if DEVICE == "cuda":
    torch.cuda.manual_seed(LUCKY_NUMBER)
# cuDNN determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)


### Getting the MNIST Dataset

In [None]:
# --- 1. Load Data ---
def get_data_loaders():
    print("Loading MNIST dataset...")
    data_transform = transforms.Compose([
        transforms.RandomCrop((24, 24)),
        transforms.Resize(INPUT_SHAPE[1:]),
        transforms.ToTensor(),
    ])
    
    mnist_train_dataset = datasets.MNIST(DATASET_DIR, train=True, download=True, transform=data_transform)
    mnist_test_dataset = datasets.MNIST(DATASET_DIR, train=False, download=True, transform=data_transform)
    
    mnist_train_loader = data.DataLoader(mnist_train_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count(), drop_last=False) # type: ignore
    mnist_test_loader = data.DataLoader(mnist_test_dataset, batch_size=32, shuffle=False, num_workers=os.cpu_count(), drop_last=False) # type: ignore
    
    return mnist_train_loader, mnist_test_loader


In [None]:
# --- 3. Training & Evaluation Functions ---
def accuracy_fun(y_pred, y_true):
    return (y_pred.argmax(dim=1) == y_true).to(torch.float).mean().item() * 100

### Defining and Training the baseline model

In [None]:
# --- 2. Define Model ---
def get_baseline_model():
    return Sequential(
        Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, pad=(2, 2, 2, 2), bias=True),
        BatchNorm2d(num_features=6),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True),
        BatchNorm2d(num_features=16),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Flatten(),
        Linear(in_features=16*5*5, out_features=84, bias=True),
        ReLU(),
        Linear(in_features=84, out_features=10, bias=True)
    ).to(DEVICE)


def train_baseline(model, train_loader, test_loader):
    print("\n--- STAGE 1: Training Baseline Model ---")
    if os.path.exists(BASELINE_MODEL_FILE):
        print(f"Loading existing baseline weights from {BASELINE_MODEL_FILE}...")
        model.load_state_dict(torch.load(BASELINE_MODEL_FILE, weights_only=True), strict=False)
        model.to(DEVICE)
        return model

    print(f"No baseline weights found. Training from scratch (up to 15 epochs)...")
    early_stopper = EarlyStopper(
        monitor_metric="validation_loss",
        delta=1e-7,
        mode="min",
        patience=4,
        restore_best_state_dict=True,
    )
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.Adam(model.parameters(), lr=1.e-3)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=2)

    model.fit(
        train_loader, 15, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        callbacks=[early_stopper],
        device=DEVICE
    )
    
    print(f"Saving baseline weights to {BASELINE_MODEL_FILE}...")
    torch.save(model.cpu().state_dict(), BASELINE_MODEL_FILE)
    model.to(DEVICE)
    return model


### Pruning the model

In [None]:
def train_pruned(baseline_model, train_loader, test_loader):
    print("\n--- STAGE 2: Applying Pruning & Retraining ---")
    
    # Pruning parameters from paper (Table 1 / Sec 4.1.1)
    pruning_config = {
        "prune_channel": {
            "sparsity": {
                "conv2d_0": 0,
                "conv2d_1": 9,
                "linear_0": 64
            },
            "metric": "l2"
        }
    }
    
    print(f"Applying pruning config: {pruning_config['prune_channel']['sparsity']}")
    
    # Re-initialize model architecture with pruning
    pruned_model = baseline_model.init_compress(pruning_config, INPUT_SHAPE).to(DEVICE)
    
    # Retrain (fine-tune) the pruned model
    print("Retraining pruned model (15 epochs)...")
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.SGD(pruned_model.parameters(), lr=1.e-3, weight_decay=5e-4, momentum=.9)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=1)

    pruned_model.fit(
        train_loader, 20, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        device=DEVICE
    )
    return pruned_model


### Quantizing and Bit Packing the model

In [None]:
def train_quantized_pruned(pruned_model, train_loader, test_loader):
    print("\n--- STAGE 3: Applying Quantization (QAT) & Retraining ---")
    
    # Configs from paper (Table 2, 4-bit static) and notebook
    pruning_config = {
        "prune_channel": {
            "sparsity": {
                "conv2d_0": 0,
                "conv2d_1": 9,
                "linear_0": 64
            },
            "metric": "l2"
        }
    }
    quantization_config = {
        "quantize": {
            "scheme": QuantizationScheme.STATIC,
            "granularity": QuantizationGranularity.PER_TENSOR,
            "bitwidth": 4
        }
    }
    full_compression_config = {**pruning_config, **quantization_config}

    print(f"Applying quantization config: 4-bit, STATIC, PER_TENSOR")
    
    # Get one batch of calibration data
    calibration_data = next(iter(test_loader))[0].to(DEVICE)
    
    # Initialize compression for QAT
    quantized_model = pruned_model.init_compress(
        full_compression_config, 
        INPUT_SHAPE, 
        calibration_data
    ).to(DEVICE)

    # Perform Quantization-Aware Training (15 epochs from paper Table 2)
    print("Performing QAT (15 epochs)...")
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.SGD(quantized_model.parameters(), lr=1.e-4, weight_decay=5e-4, momentum=.9)
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer_fun, mode="min", patience=1)

    quantized_model.fit(
        train_loader, 15, 
        criterion_fun, optimizer_fun, lr_scheduler,
        validation_dataloader=test_loader, 
        metrics={"acc": accuracy_fun},
        device=DEVICE
    )
    return quantized_model


### Baseline Result

In [None]:
# Get Data
train_loader, test_loader = get_data_loaders()

# --- STAGE 1: BASELINE ---
baseline_model = get_baseline_model()
baseline_model = train_baseline(baseline_model, train_loader, test_loader)

print("Evaluating baseline model...")
baseline_eval = baseline_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"==> STAGE 1 (Baseline) COMPLETE ==")
print(f"    Accuracy: {baseline_eval['acc']:.2f}%")


### Pruned Result


In [None]:

# --- STAGE 2: PRUNED ---
# Use a copy to keep the original baseline model clean
pruned_model = train_pruned(baseline_model, train_loader, test_loader)

print("Evaluating pruned model...")
pruned_eval = pruned_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"\n==> STAGE 2 (Pruned) COMPLETE ==")
print(f"    Accuracy: {pruned_eval['acc']:.2f}%")

print(f"\n===> Layerwise Prunning Results:")
for i, (name, module) in enumerate(pruned_model.names_layers()):
    if ("conv2d" in name) or ("linear" in name):
        print(f"    Layer name : {name}, Original size {baseline_model[i].get_size_in_bits()/(8*1024)} Reduced size {module.get_size_in_bits()/(8*1024)}:  Size Ratio: {(1 - module.get_size_in_bits()/baseline_model[i].get_size_in_bits())*100:.2f}%") # type: ignore


### Quantized and Bit-Packed Result

In [None]:
# --- STAGE 3: QUANTIZED-PRUNED ---
quantized_model = train_quantized_pruned(pruned_model, train_loader, test_loader)

print("Evaluating final quantized-pruned model...")
quantized_eval = quantized_model.evaluate(test_loader, {"acc": accuracy_fun}, device=DEVICE)
print(f"\n==> STAGE 3 (Quantized-Pruned) COMPLETE ==")
print(f"    Accuracy: {quantized_eval['acc']:.2f}%")

print(f"\n===> Layerwise Pruned and Quantized Results:")
for i, (name, module) in enumerate(quantized_model.names_layers()):
    if ("conv2d" in name) or ("linear" in name):
        print(f"    Layer name : {name}, Original size {baseline_model[i].get_size_in_bits()/(8*1024):8.4f} Reduced size {module.get_size_in_bits()/(8*1024):8.4f}:  Size Ratio: {(1 - module.get_size_in_bits()/baseline_model[i].get_size_in_bits())*100:.2f}%") # type: ignore


### Final Result

In [None]:
original_size = baseline_model.fuse(device=DEVICE).get_size_in_bytes()
pruned_size = pruned_model.fuse(device=DEVICE).get_size_in_bytes()
quantized_size = quantized_model.fuse(device=DEVICE).get_size_in_bytes()

In [None]:
print("\n--- REPRODUCTION FINISHED ---")
print("\nFinal Results Summary:")
print(f"Baseline:   {baseline_eval['acc']:.2f}% Acc, {original_size/1024:7.2f}KB")
print(f"Pruned:     {pruned_eval['acc']:.2f}% Acc, {pruned_size/1024:7.2f}KB, {pruned_size/original_size*100:.2f}% of original")
print(f"Quantized:  {quantized_eval['acc']:.2f}% Acc, {quantized_size/1024:7.2f}KB, {quantized_size/original_size*100:.2f}% of original")

In [None]:
ardunio_config = {'conv2d_0':4,
    'conv2d_1': 7,
    'linear_0': 55,
    'batchnorm2d_0': 0,
    'relu_0': 0,
    'maxpool2d_0': 0,
    'batchnorm2d_1': 0,
    'relu_1': 0,
    'maxpool2d_1': 0,
    'flatten_0': 0,
    'relu_2': 0,
    'linear_1': 0
}

In [None]:
pruning_config = {
    "prune_channel": {
        "sparsity": ardunio_config,
        "metric": "l2"
    }
}

quantization_config = {
    "quantize": {
        "scheme": QuantizationScheme.STATIC,
        "granularity": QuantizationGranularity.PER_TENSOR,
        "bitwidth": 4
    }
}
full_compression_config = {**pruning_config, **quantization_config}


uno_pruned_model = train_pruned(lenet5_model, ardunio_config, mnist_train_loader, mnist_test_loader)
uno_quantized_model = train_quantized_pruned(uno_pruned_model, ardunio_config, mnist_train_loader, mnist_test_loader)

# # Get one batch of calibration data
# loader_iter = iter(mnist_test_loader)
# calibration_data = torch.concat([next(loader_iter)[0] for _ in range(5)], dim=0).to(DEVICE)

# # Initialize compression for QAT
# uno_model = lenet5_model.init_compress(
#     full_compression_config, 
#     INPUT_SHAPE, 
#     calibration_data
# ).to(DEVICE)

fused_model = uno_quantized_model.fuse().to(DEVICE).eval()
fused_model.convert_to_c(INPUT_SHAPE, "uno_model", arduino_uno_src_dir, arduino_uno_include_dir, test_input)
fused_model.convert_to_c(INPUT_SHAPE, "uno_model", hp_src_dir, hp_include_dir, test_input)
fused_model.get_max_workspace_arena(INPUT_SHAPE)
# fused_model.output_quantize.apply(fused_model(test_input))
# uno_quantized_model.evaluate(mnist_test_loader, {"acc": accuracy_fun}, DEVICE)