In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

# --- 1. Define the BNN Model Components ---

class BinaryActivation(torch.autograd.Function):
    """
    Binary Activation function (Sign function) with Straight-Through Estimator (STE).
    For the forward pass, the input is binarized to -1 or 1.
    For the backward pass, the gradient is passed through unchanged (STE).
    """
    @staticmethod
    def forward(ctx, input):
        # Save the input tensor for use in the backward pass.
        ctx.save_for_backward(input)
        # Binarize the input: if input >= 0, output is 1; otherwise, -1.
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the input tensor saved during the forward pass.
        input, = ctx.saved_tensors
        # Straight-Through Estimator: The gradient from the subsequent layer
        # is passed directly back to the previous layer, effectively treating
        # the sign function as an identity for backpropagation.
        grad_input = grad_output.clone()
        return grad_input

class BinarizeWeights(torch.autograd.Function):
    """
    Binarizes weights to -1 or 1 using Straight-Through Estimator (STE).
    For the forward pass, weights are binarized.
    For the backward pass, gradients are computed with respect to the full-precision weights.
    """
    @staticmethod
    def forward(ctx, input):
        # Save the full-precision input (weights) for the backward pass.
        ctx.save_for_backward(input)
        # Binarize the input (weights) to -1 or 1.
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the full-precision weights saved during the forward pass.
        input, = ctx.saved_tensors
        # Straight-Through Estimator: The gradient is passed directly through
        # to the full-precision weights. This allows the optimizer to update
        # the full-precision weights that are then binarized in the next forward pass.
        grad_input = grad_output.clone()
        return grad_input

class BNNLinear(nn.Linear):
    """
    Custom Linear layer for BNNs that binarizes its weights during the forward pass.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(BNNLinear, self).__init__(in_features, out_features, bias)
        # Create an instance of our custom BinarizeWeights function.
        self.binarize = BinarizeWeights.apply

    def forward(self, input):
        # Binarize the layer's weights before performing the matrix multiplication.
        binarized_weight = self.binarize(self.weight)
        # Perform the standard linear operation (matrix multiplication + bias).
        output = F.linear(input, binarized_weight, self.bias)
        return output

class FullyConnectedBNN(nn.Module):
    """
    Fully Connected Binary Neural Network for MNIST classification.
    Uses custom BNNLinear layers, BatchNorm1d, and BinaryActivation.
    """
    def __init__(self, input_size, num_classes):
        super(FullyConnectedBNN, self).__init__()
        self.input_size = input_size
        self.num_classes = num_classes

        # Define the layers of the BNN.
        # Each hidden layer consists of a BNNLinear, BatchNorm1d, and BinaryActivation.

        # First hidden layer
        self.fc1 = BNNLinear(input_size, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.act1 = BinaryActivation.apply # Apply sign activation after BN

        # Second hidden layer
        self.fc2 = BNNLinear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.act2 = BinaryActivation.apply # Apply sign activation after BN

        # Output layer (no sign activation for the final classification output)
        self.fc3 = BNNLinear(256, num_classes)

    def forward(self, x):
        # Flatten the input image from (batch_size, 1, 28, 28) to (batch_size, 784).
        x = x.view(-1, self.input_size)

        # Pass through the first hidden layer
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act1(x) # Apply sign activation

        # Pass through the second hidden layer
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.act2(x) # Apply sign activation

        # Pass through the output layer. The output of this layer will be
        # used by the CrossEntropyLoss function, which handles softmax internally.
        x = self.fc3(x)
        return x

# --- 2. Data Loading and Preprocessing ---

# Define transformations for the MNIST dataset:
# 1. Convert PIL Image to PyTorch Tensor.
# 2. Normalize the tensor with mean and standard deviation specific to MNIST.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Mean and Std for MNIST dataset
])

# Load the MNIST training dataset.
train_dataset = torchvision.datasets.MNIST(
    root='./data',       # Directory where data will be downloaded
    train=True,          # Specify this is the training set
    download=True,       # Download the dataset if not already present
    transform=transform  # Apply the defined transformations
)

# Load the MNIST testing dataset.
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,         # Specify this is the test set
    download=True,
    transform=transform
)

# Define batch size for data loaders.
batch_size = 64

# Create data loaders for training and testing.
# Shuffling the training data helps with generalization.
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

# --- 3. Model Initialization, Loss Function, and Optimizer ---

# Calculate the input size (28x28 pixels for MNIST images).
input_size = 28 * 28
# Number of classes for MNIST (digits 0-9).
num_classes = 10

# Initialize the Fully Connected BNN model.
model = FullyConnectedBNN(input_size, num_classes)

# Determine the device to use (GPU if available, otherwise CPU).
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move the model to the selected device.

# Define the loss function: CrossEntropyLoss is suitable for multi-class classification.
criterion = nn.CrossEntropyLoss()
# Define the optimizer: Adam is a good choice for training BNNs due to its adaptive learning rates.
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --- 4. Training and Evaluation Functions ---

def train(model, device, train_loader, optimizer, epoch):
    """
    Trains the BNN model for one epoch.
    """
    model.train() # Set the model to training mode.
    running_loss = 0.0 # Initialize running loss for the epoch.

    # Iterate over batches in the training data loader.
    for batch_idx, (data, target) in enumerate(train_loader):
        # Move data and targets to the specified device (CPU/GPU).
        data, target = data.to(device), target.to(device)

        # Zero the gradients of the optimizer.
        optimizer.zero_grad()
        # Perform a forward pass to get model output.
        output = model(data)
        # Calculate the loss.
        loss = criterion(output, target)
        # Perform a backward pass to compute gradients.
        loss.backward()
        # Update model parameters using the optimizer.
        optimizer.step()

        running_loss += loss.item() # Accumulate the loss.
        # Print training progress periodically.
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    print(f"Epoch {epoch} Training Loss: {running_loss / len(train_loader):.4f}")

def test(model, device, test_loader):
    """
    Evaluates the BNN model on the test dataset.
    """
    model.eval() # Set the model to evaluation mode.
    test_loss = 0
    correct = 0
    # Disable gradient computation during evaluation for efficiency.
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # Sum up batch loss.
            test_loss += criterion(output, target).item()
            # Get the index of the predicted class with the highest probability.
            pred = output.argmax(dim=1, keepdim=True)
            # Count correct predictions.
            correct += pred.eq(target.view_as(pred)).sum().item()

    # Calculate average test loss per sample.
    test_loss /= len(test_loader.dataset)

    # Print evaluation results.
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n')
    return accuracy

# --- 5. Main Training Loop ---

num_epochs = 10 # Number of training epochs. This can be adjusted.
best_accuracy = 0.0 # To keep track of the best accuracy achieved.

# Run the training and testing loop for the specified number of epochs.
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, epoch) # Train for one epoch.
    accuracy = test(model, device, test_loader)      # Evaluate after each epoch.
    # Check if the current accuracy is better than the best recorded accuracy.
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        # Save the model's state_dict (all learnable parameters, including weights and biases for each layer).
        # This includes weights from fc1, fc2, fc3, and parameters from bn1, bn2.
        # The filename includes the accuracy for easy identification.
        torch.save(model.state_dict(), f"best_bnn_mnist_accuracy_{best_accuracy:.2f}%.pth")
        print(f"Saved new best model with accuracy: {best_accuracy:.2f}%")

print(f"Training finished. Best Test Accuracy: {best_accuracy:.2f}%")


Epoch 1 Training Loss: 3.3229

Test set: Average loss: 0.0481, Accuracy: 8604/10000 (86%)

Saved new best model with accuracy: 86.04%
Epoch 2 Training Loss: 2.8407

Test set: Average loss: 0.0397, Accuracy: 8688/10000 (87%)

Saved new best model with accuracy: 86.88%
Epoch 3 Training Loss: 2.6802

Test set: Average loss: 0.0341, Accuracy: 8865/10000 (89%)

Saved new best model with accuracy: 88.65%
Epoch 4 Training Loss: 2.5578

Test set: Average loss: 0.0350, Accuracy: 8908/10000 (89%)

Saved new best model with accuracy: 89.08%
Epoch 5 Training Loss: 2.5386

Test set: Average loss: 0.0382, Accuracy: 8641/10000 (86%)

Epoch 6 Training Loss: 2.4336

Test set: Average loss: 0.0477, Accuracy: 8580/10000 (86%)

Epoch 7 Training Loss: 2.4345

Test set: Average loss: 0.0400, Accuracy: 8588/10000 (86%)

Epoch 8 Training Loss: 2.3230

Test set: Average loss: 0.0313, Accuracy: 8945/10000 (89%)

Saved new best model with accuracy: 89.45%
Epoch 9 Training Loss: 2.3439

Test set: Average loss: 0.