# Bayesian Neural Network Evaluation on MNIST
This notebook replicates the architecture and setup from the baseline paper and evaluates your custom Bayesian Neural Network implementation.

In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


In [3]:

# Import the BayesianNet from your file
from sbnn import BayesianNet, train, BayesianNetLeNet

# Confirm import worked and structure is correct
# model = BayesianNetLeNet(input_dim=784, hidden_dim=1000, output_dim=10).to(device)



BayesianNetLeNet(
  (conv1): BayesianBinaryConv2d()
  (conv2): BayesianBinaryConv2d()
  (fc1): BayesianBinaryLinear()
  (fc2): BayesianBinaryLinear()
)


In [7]:

# MNIST loaders
batch_size = 1000
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
# Define label transform (integer → one-hot vector)
target_transform = lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float()

train_loader = DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform, target_transform=target_transform),
    batch_size=batch_size, shuffle=True
)

test_loader = DataLoader(
    datasets.MNIST('./data', train=False, transform=transform, target_transform=target_transform),
    batch_size=batch_size, shuffle=False
)


In [10]:

# Initialize model
model = BayesianNet(input_dim=784, hidden_dim=1000, output_dim=10).to(device)
# model = BayesianNetLeNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training function using your code
def train_mnist(model, train_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            # x = x.view(-1, 1, 28, 28)  # ✅ RESHAPE for conv2d
            optimizer.zero_grad()
            out, _, _ = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
        print(f"Epoch {epoch}: Loss={total_loss:.4f}, Accuracy={correct/total:.2%}")

train_mnist(model, train_loader, epochs=50)


Epoch 0: Loss=132.7400, Accuracy=24.58%
Epoch 1: Loss=117.3858, Accuracy=50.24%
Epoch 2: Loss=111.0137, Accuracy=60.94%
Epoch 3: Loss=108.0653, Accuracy=65.94%
Epoch 4: Loss=106.5121, Accuracy=68.51%
Epoch 5: Loss=104.9642, Accuracy=71.08%
Epoch 6: Loss=102.4993, Accuracy=75.19%
Epoch 7: Loss=99.9216, Accuracy=79.50%
Epoch 8: Loss=98.5379, Accuracy=81.82%
Epoch 9: Loss=97.7857, Accuracy=83.10%
Epoch 10: Loss=97.1953, Accuracy=84.05%
Epoch 11: Loss=96.6279, Accuracy=85.02%
Epoch 12: Loss=96.3001, Accuracy=85.57%
Epoch 13: Loss=96.1181, Accuracy=85.86%
Epoch 14: Loss=95.8978, Accuracy=86.23%
Epoch 15: Loss=95.4503, Accuracy=86.99%
Epoch 16: Loss=95.4588, Accuracy=86.96%
Epoch 17: Loss=95.0413, Accuracy=87.65%
Epoch 18: Loss=94.8933, Accuracy=87.92%
Epoch 19: Loss=94.7934, Accuracy=88.09%
Epoch 20: Loss=94.5514, Accuracy=88.51%
Epoch 21: Loss=94.4263, Accuracy=88.70%
Epoch 22: Loss=94.4128, Accuracy=88.73%
Epoch 23: Loss=94.2320, Accuracy=89.02%
Epoch 24: Loss=94.2493, Accuracy=89.00%
Epo

In [11]:

# Evaluate with predictive uncertainty from Bayesian model
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        mean_preds, std_preds = model.predict_multiple(x, n_samples=20)
        preds = mean_preds.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(y.numpy())

from sklearn.metrics import accuracy_score
print("Test Accuracy:", accuracy_score(all_targets, all_preds))


Test Accuracy: 0.9401
