In [5]:
#!/usr/bin/env python3
"""
Complete post-training optimization pipeline for Binary Neural Networks (BNNs)
with TALL evaluation and hardware deployment preparation.

This pipeline includes:
1. Baseline evaluation
2. BatchNorm folding
3. Bias constant clamping
4. TALL parameter optimization
5. Hardware export

Usage:
    python post_training_pipeline.py --model-path bnn_deep_best.pth --model-type deep
"""

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from copy import deepcopy

from BNN_model import (
    BinaryMLP, TALLClassifier, 
    build_cam4_deep, build_cam4_shallow,
    apply_post_training_optimization,
    export_hardware_weights,
    create_hardware_checkpoint
)

from evaluate_mnist import evaluate_model

def get_mnist_test_loader(batch_size=1000, data_dir='./data'):
    """Get MNIST test data loader"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    test_dataset = torchvision.datasets.MNIST(
        root=data_dir, train=False, transform=transform, download=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )
    
    return test_loader



    


In [6]:
# Setup device
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Using device: {device}")

# Load test data
print("Loading MNIST test set...")
test_loader = get_mnist_test_loader(batch_size=1000, data_dir="./data")

model_type = "deep"  # Change to "deep" for deep model
def load_bnn_model(model_type, device, thresholds = [0,0,0]):
    # Create and load model based on model_type
    if model_type == "shallow":
        print(f"Creating Shallow BNN model...")
        model = build_cam4_shallow(num_classes=10)
        model_path = "bnn_shallow_best.pth"
    elif model_type == "deep":
        print(f"Creating Deep BNN model...")
        model = build_cam4_deep(num_classes=10, thresholds=thresholds)
        model_path = "bnn_deep_best.pth"
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

    print(f"Loading model from {model_path}...")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    print(f"Model trained for {checkpoint.get('epoch', 'unknown')} epochs")
    print(f"Best test accuracy(popcount last layer output): {checkpoint.get('best_acc', 'unknown'):.2f}%")
    return model, checkpoint

model, checkpoint = load_bnn_model(model_type, device)

    


Using device: cuda
Loading MNIST test set...
Creating Deep BNN model...
Loading model from bnn_deep_best.pth...
Model trained for 130 epochs
Best test accuracy(popcount last layer output): 97.38%


In [7]:
accuracy, _ = evaluate_model(model, test_loader, device, verbose=False)
print(f"Baseline test accuracy with last layer binary output: {accuracy:.2f}%")

Baseline test accuracy with last layer binary output: 52.17%


# threshold expiriment

In [8]:
model, checkpoint = load_bnn_model(model_type, device, thresholds=[0.1, 0.1, 0.1])
accuracy, _ = evaluate_model(model, test_loader, device, verbose=False)

Creating Deep BNN model...
Loading model from bnn_deep_best.pth...
Model trained for 130 epochs
Best test accuracy(popcount last layer output): 97.38%
