# Sampling-free Laplace Approximation for Bayesian Neural Network
This notebook demonstrates how to compute the KFAC approximations of the Fisher information matrix from PyTorch models, as well as how to perform approximate Bayesian inference.

## 1 Import packages

In [None]:
# choose cuda
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

# standard imports
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# from the repository
from models.wrapper import BaseNet
from models.curvatures import BlockDiagonal, KFAC, EFB, INF
from models.utilities import calibration_curve
from models import plot

## 2 Basic functions

In [None]:
def gradient(y, x, grad_outputs=None):
    '''
    Compute dy/dx @ grad_outputs
    y: output
    x: parameter
    grad_outputs: the “vector” in the Jacobian-vector product
    '''
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs = grad_outputs, create_graph=True, retain_graph=True, allow_unused=True)[0]
    return grad

def jacobian(y, x, device):
    '''
    Compute dy/dx = dy/dx @ grad_outputs; 
    y: output, batch_size 
    x: parameter
    '''
    jac = torch.zeros(y.shape[1], torch.flatten(x).shape[0]).to(device)
    for i in range(y.shape[1]):
        grad_outputs = torch.zeros_like(y)
        grad_outputs[:,i] = 1
        jac[i,:] = torch.flatten(gradient(y, x, grad_outputs))
    return jac


## 3 Load MNIST dataset and split it into two parts

In [None]:
models_dir = 'theta'
results_dir = 'results'
device = "cuda" if torch.cuda.is_available() else "cpu"
# load and normalize MNIST
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
datasets.MNIST.resources = [
    ('/'.join([new_mirror, url.split('/')[-1]]), md5)
    for url, md5 in datasets.MNIST.resources
]

train_set = datasets.MNIST(root="./data",
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)
train_loader = DataLoader(train_set, batch_size=32)

# And some for evaluating/testing
test_set = datasets.MNIST(root="./data",
                                        train=False,
                                        transform=transforms.ToTensor(),
                                        download=True)
test_loader = DataLoader(test_set, batch_size=1)


## 4 Define the network model and train

In [None]:
net = BaseNet(lr=1e-3, epoch=3, batch_size=32, device=device)
criterion = nn.CrossEntropyLoss().to(device)
net.train(train_loader, criterion)
sgd_predictions, sgd_labels = net.eval(test_loader)
print(f"MAP Accuracy: {100 * np.mean(np.argmax(sgd_predictions.cpu().numpy(), axis=1) == sgd_labels.numpy()):.2f}%")

## 5 Compute the Kronecker factored FiM 

In [None]:
kfac = KFAC(net.model)

for images, labels in tqdm(train_loader):
    logits = net.model(images.to(device))
    dist = torch.distributions.Categorical(logits=logits)
    # A rank-1 Kronecker factored FiM approximation.
    labels = dist.sample()
    loss = criterion(logits, labels)
    net.model.zero_grad()
    loss.backward()
    kfac.update(batch_size=images.size(0))

## 6 Calculate the inversion of H and Q

In [None]:
estimator = kfac
add = 1
multiply = 200
estimator.invert(add, multiply)

## 7 Evaluate model performance on testset

In [None]:
targets = torch.Tensor()
kfac_prediction = torch.Tensor().to(device)
kfac_entropy_lst  = []
for images,labels in tqdm(test_loader):
    # prediction mean, equals to the MAP output 
    pred_mean = torch.nn.functional.softmax(net.model(images.to(device)) ,dim=1)        
    # compute prediction variance  
    pred_std = 0
    idx  = np.argmax(pred_mean.cpu().detach().numpy(), axis=1)
    grad_outputs = torch.zeros_like(pred_mean)
    grad_outputs[:,idx] = 1
    for layer in list(estimator.model.modules())[1:]:
        g = []
        if layer in estimator.state:
            if torch.cuda.is_available():
                Q_i = estimator.inv_state[layer][0]
                H_i = estimator.inv_state[layer][1] 
                for p in layer.parameters():    
                    g.append(torch.flatten(gradient(pred_mean, p, grad_outputs=grad_outputs)))
                J_i = torch.cat(g, dim=0).unsqueeze(0) 
                H = torch.kron(Q_i,H_i)
                pred_std += torch.abs(J_i @ H @ J_i.t()).item()
    # uncertainty
    const = 2*np.e*np.pi 
    entropy = 0.5 * np.log2(const * pred_std)
    kfac_entropy_lst.append(entropy) 
    kfac_uncertainty = np.array(kfac_entropy_lst)
    # ground truth
    targets = torch.cat([targets, labels])  
    # prediction, mean value of the gaussian distribution
    kfac_prediction = torch.cat([kfac_prediction, pred_mean]) 
print(f"KFAC Accuracy: {100 * np.mean(np.argmax(kfac_prediction.cpu().detach().numpy(), axis=1) == targets.numpy()):.2f}%")
print(f"Mean KFAC Entropy:{np.mean(kfac_uncertainty)}%")


## 8 Evaluate model performance on Gaussian noise images

In [None]:
res_entropy_lst = []
for i in tqdm(range(10000)):
    noise = torch.randn_like(images)
    pred_mean = torch.nn.functional.softmax(net.model(noise.to(device)) ,dim=1)        
    # compute prediction variance  
    pred_std = 0
    idx  = np.argmax(pred_mean.cpu().detach().numpy(), axis=1)
    grad_outputs = torch.zeros_like(pred_mean)
    grad_outputs[:,idx] = 1
    for layer in list(estimator.model.modules())[1:]:
        g = []
        if layer in estimator.state:
            Q_i = estimator.inv_state[layer][0]
            H_i = estimator.inv_state[layer][1] 
            for p in layer.parameters():    
                g.append(torch.flatten(gradient(pred_mean, p, grad_outputs=grad_outputs)))
            J_i = torch.cat(g, dim=0).unsqueeze(0) 
            H = torch.kron(Q_i,H_i)
            pred_std += torch.abs(J_i @ H @ J_i.t()).item()
    const = 2*np.e*np.pi 
    entropy = 0.5 * np.log2(const * pred_std)
    res_entropy_lst.append(entropy) 
    res_uncertainty = np.array(res_entropy_lst)
print(f"Mean Noise Entropy:{np.mean(res_uncertainty)}%")


## 9 Plot the results

In [None]:
# calibration
ece_nn = calibration_curve(sgd_predictions.cpu().numpy(), sgd_labels.numpy())[0]
ece_bnn = calibration_curve(kfac_prediction.cpu().numpy(), targets.numpy())[0]
print(f"ECE NN: {100 * ece_nn:.2f}%, ECE BNN: {100 * ece_bnn:.2f}%")

fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(12, 6), tight_layout=True)
ax[0].set_title('SGD', fontsize=16)
ax[1].set_title('KFAC-Laplace', fontsize=16)
plot.reliability_diagram(sgd_predictions.cpu().numpy(), sgd_labels.numpy(), axis=ax[0])
plot.reliability_diagram(kfac_prediction.cpu().numpy(), targets.numpy(), axis=ax[1])

fig, ax = plt.subplots(figsize=(12, 7), tight_layout=True)
c1 = next(ax._get_lines.prop_cycler)['color']
c2 = next(ax._get_lines.prop_cycler)['color']
plot.calibration(sgd_predictions.cpu().numpy(), sgd_labels.numpy(), color=c1, label="SGD", axis=ax)
plot.calibration(kfac_prediction.cpu().numpy(), targets.numpy(), color=c2, label="KFAC-Laplace", axis=ax)
