# LeNet-5 Experiment Verification with NAS

This notebook verifies the LeNet-5 experiment using Neural Architecture Search (NAS) to find optimal pruning configurations, followed by quantization.

## Overview
1. **Baseline Model:** Train or load the baseline LeNet-5 model.
2. **NAS (Neural Architecture Search):** 
    - Generate random pruning configurations.
    - Train an `Estimator` (MLP) to predict accuracy based on configuration.
    - Use brute-force search with the estimator to find the best configuration under constraints (e.g., >99% relative accuracy).
3. **Pruning & Quantization:** Apply the optimal configuration found by NAS to the model and verify performance.

**Verification:** Prune, Quantize, and Retrain (QAT) the model using the found configuration to verify the results (Table I).

In [None]:
import sys
import os
import itertools

try:
    import matplotlib.pyplot as plt

    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms

except ImportError:
    import matplotlib.pyplot as plt

    import torch
    from torch import nn, optim
    from torch.utils import data
    from torchvision import datasets, transforms
    

In [None]:
# Add project root to path
sys.path.append("../../")

try:
    from development import (
        Sequential,
        Conv2d,
        BatchNorm2d,
        ReLU,
        MaxPool2d,
        Flatten,
        Linear,
        EarlyStopper,
        QuantizationGranularity,
        QuantizationScheme,
        Estimator
    )
except ImportError:
    print("Error: Could not import 'development' module. Check path setup.")


### Importing the necessary libraries

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

INPUT_SHAPE = (1, 28, 28)
DATASET_DIR = "../../Datasets"
LENET5_FILE = "lenet5_state_dict.pth"
NAS_PARAMS_FILE = "nas_parameters.pth"
LUCKY_NUMBER = 25

# Reproducibility
torch.manual_seed(LUCKY_NUMBER)
if DEVICE == "cuda":
    torch.cuda.manual_seed(LUCKY_NUMBER)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Getting the Dataset

In [None]:
def get_data_loaders():
    print("Loading MNIST dataset...")
    data_transform = transforms.Compose([
        transforms.RandomCrop((24, 24)),
        transforms.Resize(INPUT_SHAPE[1:]),
        transforms.ToTensor(),
    ])
    
    train_set = datasets.MNIST(DATASET_DIR, train=True, download=True, transform=data_transform)
    test_set = datasets.MNIST(DATASET_DIR, train=False, download=True, transform=data_transform)
    
    train_loader = data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=os.cpu_count())
    test_loader = data.DataLoader(test_set, batch_size=32, shuffle=False, num_workers=os.cpu_count())
    
    return train_loader, test_loader

def accuracy_fun(y_pred, y_true):
    return (y_pred.argmax(dim=1) == y_true).to(torch.float).mean().item() * 100

mnist_train_loader, mnist_test_loader = get_data_loaders()

### Defining and Training the baseline model

In [None]:
def get_lenet5_model():
    return Sequential(
        Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, pad=(2, 2, 2, 2), bias=True),
        BatchNorm2d(num_features=6),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0, bias=True),
        BatchNorm2d(num_features=16),
        ReLU(),
        MaxPool2d(kernel_size=2, stride=2, padding=0),
        Flatten(),
        Linear(in_features=16*5*5, out_features=84, bias=True),
        ReLU(),
        Linear(in_features=84, out_features=10, bias=True)
    ).to(DEVICE)

lenet5_model = get_lenet5_model()

# Train or Load Baseline
if os.path.exists(LENET5_FILE):
    print(f"Loading baseline from {LENET5_FILE}...")
    lenet5_model.load_state_dict(torch.load(LENET5_FILE, map_location=DEVICE))
else:
    print("Training baseline model...")
    early_stopper = EarlyStopper(
        monitor_metric="validation_loss", delta=1e-7, mode="min", patience=4, restore_best_state_dict=True
    )
    criterion_fun = nn.CrossEntropyLoss()
    optimizer_fun = optim.Adam(lenet5_model.parameters(), lr=1.e-3)
    
    lenet5_model.fit(
        mnist_train_loader, 20, criterion_fun, optimizer_fun,
        validation_dataloader=mnist_test_loader, metrics={"acc": accuracy_fun},
        device=DEVICE
    )
    torch.save(lenet5_model.state_dict(), LENET5_FILE)

print(f"Baseline Accuracy: {lenet5_model.evaluate(mnist_test_loader, {'acc': accuracy_fun}, device=DEVICE)['acc']:.2f}%")

### Generating NAS Training Samples

In [None]:
print("\n--- Generatng NAS Data ---")
nas_parameters = []

with_training = True

if os.path.exists(NAS_PARAMS_FILE):
    print(f"Loading NAS parameters from {NAS_PARAMS_FILE}...")
    nas_parameters = torch.load(NAS_PARAMS_FILE)
    if len(nas_parameters) != 1000:
        print("Warning: Expected 1000 params, got", len(nas_parameters))
else:
    print("Running NAS Sampling (this may take time)...")
    # Using get_nas_prune_channel from Sequential (refactored earlier)
    nas_parameters = lenet5_model.get_nas_prune_channel(
        INPUT_SHAPE, mnist_test_loader, accuracy_fun, DEVICE, 
        num_data=1000, train=with_training, train_dataloader=mnist_train_loader, epochs=2, 
        criterion_fun=nn.CrossEntropyLoss(), random_seed=LUCKY_NUMBER
    )
    torch.save(nas_parameters, NAS_PARAMS_FILE)
    
print(f"Generated {len(nas_parameters)} NAS samples.")

### Training the NAS Estimator

In [None]:
print("\n--- Training Accuracy Estimator ---")
estimator = Estimator(nas_parameters, device=DEVICE, hidden_dim=[128, 128, 128], dropout=.75)
estimator_history = estimator.fit(epochs=400)

In [None]:
# Plot Training History
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(estimator_history["validation_abs"], label="val_abs")
axes[0].plot(estimator_history["train_abs"], label="train_abs")
axes[0].legend(); axes[0].set_title("MAE Loss")

axes[1].plot(estimator_history["validation_mse"], label="val_mse")
axes[1].plot(estimator_history["train_mse"], label="train_mse")
axes[1].legend(); axes[1].set_title("MSE Loss")
plt.show()

### Search for the best set of parameters

In [None]:
def brute_force_search_prune_config(
    model,
    estimator,
    input_shape,
    condition=lambda m, s, r, c: True,                 
    objective=lambda m, s, r, c: m,                  
    maximize=True,                   
    verbose=True
):
    """
    Generic brute-force search engine.
    """
    best_value = float("-inf") if maximize else float("inf")
    best_comb = None
    best_result_info = None

    # compute baseline size
    original_size = model.get_size_in_bytes()
    # Dummy prediction to get baseline metric scaling if needed
    original_metric = estimator.predict(torch.Tensor([0]*len(model.get_prune_channel_possible_hyperparameters())).unsqueeze(0))

    def get_all_combinations(flat_dict):
        keys = list(flat_dict.keys())
        vals = list(flat_dict.values())
        for combo in itertools.product(*vals):
            yield {k: v for k, v in zip(keys, combo)}

    print("Starting Brute Force Search...")
    # iterate search space
    for comb in get_all_combinations(model.get_prune_channel_possible_hyperparameters()):

        # predict metric (Accuracy)
        # Note: Estimator input order must match NAS generation order
        # Normalized metric (relative to original)
        metric = estimator.predict(
            torch.Tensor(list(comb.values())).unsqueeze(0)
        ) / original_metric
        
        # create compressed model to measure size/ram
        compressed = model.init_compress({
            "prune_channel": {"sparsity": comb, "metric": "l2"}
        }, input_shape)

        size = compressed.get_size_in_bytes() / original_size
        # -------- HARD FILTERS --------
        if not condition(metric, size, comb):
            continue

        # -------- OBJECTIVE VALUE --------
        obj = objective(metric, size, comb)

        if (maximize and obj > best_value) or (not maximize and obj < best_value):
            best_value = obj
            best_comb = comb
            best_result_info = [metric, size]
            if verbose:
                print(f"âœ” New best: obj={obj:.4f}, metric={metric:.2f}, size_ratio={size:.4f}, comb={comb}")

    return best_comb, best_result_info

In [None]:
print("Searching for smallest config with Accuracy > 98.0% ...")
# Note: condition depends on estimator scale. Assuming estimator predicts 0-100 accuracy.
best_comb, best_result_info = brute_force_search_prune_config(
    lenet5_model,
    estimator,
    INPUT_SHAPE,
    condition=lambda metric, size, comb: metric > .989, 
    objective= lambda metric, size, comb: size, 
    maximize=False, # Minimize Size
    verbose=True
)

In [None]:
print(f"\nOptimal Configuration Found: {best_comb}")
