## Bayesian NN implementation

**NOTE**: This script is designed for use in Kaggle, as it allows the use of GPUs. In order to run the script, download the file and upload it to Kaggle. Then in Kaggle, download the dataset "mnist-in_csv" by Dariel Dato-on.

### Setup explenation

MNIST has 70000 samples.
60000 are used for training, 10000 for testing.
10000 samples from training are used for validation, rest for training.

The bayesian NN has two layers. 
First layer has 100 hidden neurons using a sigmoid function.
Output layer uses softmax. 

Optimization based techniques (SGD and SGD with momentum) uses the validation set to select the optimal regularizer lambda of network weights. (Meaning lambda used for the regularization term in the loss function.)

For sample based techniques (SGLD and SGHMC), we place a weak gamma prior on each layers weight regularizer lamdba. 

Sampling using SGLD and SGHMC are done using minibatches of 500 training samples. (In order to compute the gradient).
Hyperparameters are resampled after an entire pass over the training set (one epoch).
We use a total of 800 iterations/epochs. 
We have a burn-in of 50 samples.

In the Bayesian framework, we are treating lambda as a random variable, and place a prior distribution on it.
Since lambda has to be positive, we use a gamma prior.

In [None]:
import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import json
from torch.optim.optimizer import Optimizer
import math

In [None]:
train_df = pd.read_csv('/kaggle/input/mnist-in-csv/mnist_train.csv')
X_train = train_df.drop('label', axis=1).values
y_train = train_df['label'].values

# Normalize and reshape
X_train = torch.tensor(X_train, dtype=torch.float32).reshape(-1, 1, 28, 28)
y_train = torch.tensor(y_train, dtype=torch.long)

X_train = torch.tensor(X_train, dtype=torch.float32).reshape(-1, 1, 28, 28)
y_train = torch.tensor(y_train, dtype=torch.long)

# Load test data (also labeled)
test_df = pd.read_csv('/kaggle/input/mnist-in-csv/mnist_test.csv')
X_test = test_df.drop('label', axis=1).values
y_test = test_df['label'].values

X_test = torch.tensor(X_test, dtype=torch.float32).reshape(-1, 1, 28, 28)
y_test = torch.tensor(y_test, dtype=torch.long)

train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=2000, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2000, shuffle=False)

In [None]:
class SimpleNN(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=100, output_dim=10):
        super().__init__() #Calling the constructor of nn.Module
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    

    def forward(self, x):
        x = x.view(x.size(0), -1) 
        x = torch.sigmoid(self.fc1(x))
        x = self.fc2(x) # Apparently cross-entropy loss handles the softmax
        return x

In [None]:
class BayesianNN(nn.Module):

    def __init__(self, input_dim=784, hidden_dim=100, output_dim=10):

        super().__init__() #Calling the constructor of nn.Module
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

         # Xavier initialization
        #for layer in self.modules():
        #    if isinstance(layer, nn.Linear):
        #        nn.init.xavier_uniform_(layer.weight)
        #        nn.init.zeros_(layer.bias)

    

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = self.fc2(x)  # logits for softmax
        return x
    

In [None]:
init_epochs = 100
lambda_reg = 1e-04
momentum = 0.9
init_lr = 0.05

def init_training(model, train_loader, test_loader, lr=init_lr, lambda_reg=lambda_reg, momentum=momentum, epochs=init_epochs):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=lambda_reg, momentum=momentum) #L2
    
    test_errors = []
    params = []
    
    for epoch in range(1, epochs + 1):
    
        model.train()
        total_loss = 0
    
        for data, target in train_loader:
            output = model(data)
            loss = criterion(output, target)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
        # Evaluate test error after each epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                pred = output.argmax(dim=1)
                correct += (pred == target).sum().item()
                total += target.size(0)
    
        test_error = 1 - (correct / total)
        test_errors.append(test_error)

        params.append({k: v.clone() for k, v in model.state_dict().items()})
        
        print(f"Epoch {epoch}, Loss: {total_loss:.4f}, Test Error: {test_error:.4f}")


    return model, test_errors, params

In [None]:
epochs = 700
burn_in = 100
eps = 0.05 #Learning rate
C = 0.1 #Friction
a = 1.0
b = 10

def sghmc_training(model, init_params, train_loader, test_loader, eps=eps, C=C, epochs=epochs, burn_in=burn_in, a=a, b=b):
    model.train()
    
    #for param in model.parameters():
    #    nn.init.normal_(param, mean=0.0, std=0.05)

    theta = list(model.parameters())
    r = [torch.zeros_like(p.data) for p in theta]
    
    test_errors = []
    posteriors = list(init_params)
    
    for epoch in range(epochs):
        for x_batch, y_batch in train_loader:
            x_batch = x_batch.view(x_batch.size(0), -1)

            logits = model(x_batch)
            loss = F.cross_entropy(logits, y_batch)

            model.zero_grad()
            loss.backward()

            lambda_val = torch.distributions.Gamma(a, b).sample().item()


            #---------------------------------------------------------------------

            for i, param in enumerate(theta):
                
                grad_U_tilde = param.grad.data + lambda_val * param.data

                #if epoch < burn_in:
                #    r[i] = r[i] - eps * grad_U_tilde
                #    param.data = param.data + eps * r[i]
                #else:
                noise = torch.randn_like(param.data) * torch.sqrt(torch.tensor(2 * C * eps))
                r[i] = r[i] - eps * grad_U_tilde - eps * C * r[i] + noise
                param.data = param.data + eps * r[i]
                
                #if epoch < burn_in:
                #    noise = torch.randn_like(param.data)
                #else:
                #    noise = torch.randn_like(param.data) * torch.sqrt(torch.tensor(2 * C * eps))
                    
                #r[i] = r[i] - eps * grad_U_tilde - eps * C * r[i] + noise
                #param.data = param.data + eps * r[i]
                
        #print(f"Epoch {epoch+1}")
        #for i, param in enumerate(theta):
        #    print(f"Param {i} norm: {param.data.norm():.4f} | r[{i}] norm: {r[i].norm():.4f}")
        #    grad = param.grad.data
        #    print(f"Param {i} grad norm: {grad.norm():.4e}")


            #---------------------------------------------------------------------
            
        # Evaluate on test set after each epoch
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x_test, y_test in test_loader:
                x_test = x_test.view(x_test.size(0), -1)
                probs = predict(model, posteriors, x_test)
                pred = probs.argmax(dim=1)
                correct += (pred == y_test).sum().item()
                total += y_test.size(0)
        test_error = 1 - (correct / total)
        test_errors.append(test_error)

        # Save posterior sample
        posteriors.append({k: v.clone() for k, v in model.state_dict().items()})

        print(f"Epoch {epoch + 1}/{epochs} | Test Error: {test_error:.4f}")
        model.train()

    return posteriors, test_errors

In [None]:
# Predict class of a sample by averaging the output of 10 posterior samples
def predict(model, posteriors, sample, n_posteriors=10):
    
    model.eval()
    sample = sample.view(sample.size(0), -1)
    pred_probs = torch.zeros((sample.size(0), 10))

    chosen_posteriors = random.sample(posteriors, k=n_posteriors)  # Choose 10 sets of posteriors at random

    for posterior in chosen_posteriors:
        model.load_state_dict(posterior)  # Load model with posterior weights

        with torch.no_grad():
            logits = model(sample)
            probs = F.softmax(logits, dim=1)
            pred_probs += probs

    return pred_probs / n_posteriors

In [None]:
init_model = SimpleNN()
init_model, init_errors, init_params = init_training(init_model, train_loader, test_loader)

In [None]:
model = BayesianNN()
model.load_state_dict(init_params[-1])

posteriors, test_errors = sghmc_training(model, init_params, train_loader, test_loader)

# Plot test error
plt.plot(test_errors)
plt.xlabel("Iteration")
plt.ylabel("Test Error")
plt.title("Test Error Over Iterations")
plt.grid(True)
plt.show()

In [None]:
test_errors_final = init_errors + test_errors

In [None]:
df = pd.DataFrame({'epoch': list(range(1, (init_epochs+epochs) + 1)), 'test_error': test_errors_final})
df.to_csv('/kaggle/working/test_error_bayesianNN.csv', index=False)