# BNN KFAC Toy example

### 1 Import packages

In [33]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn, optim
from torch.optim import Optimizer
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.nn.functional as F


import torchvision
from torchvision import datasets, transforms

import time
import warnings
warnings.filterwarnings("ignore")

### 2 Define a Convolutional Neural Network

In [19]:
class Net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.conv1 = nn.Conv2d(1, 5, 5) # bs x 1 x 28 x 28 -> bs x 5 x 24 x 24
        self.pool = nn.MaxPool2d(2, 2) # bs x 5 x 24 x 24 -> bs x 5 x 12 x 12
        self.conv2 = nn.Conv2d(5, 10, 5) # bs x 10 x 8 x 8
        self.fc1 = nn.Linear(10 * 4 * 4, output_dim) 
              
        self.one = None
        self.a2 = None
        self.h2 = None
        self.a1 = None
        self.h1 = None
        self.a0 = None
        
    def forward(self, x):
        self.one = x.new(x.shape[0], 1).fill_(1)
        a0 = x.view(-1, self.input_dim)
        self.a0 = torch.cat((a0.data, self.one), dim=1)
        
        h1 = self.conv1(a0)
        self.h1 = h1.data
        
        a1 = self.pool(F.relu(h1))
        self.a1 = torch.cat((a1.data, self.one), dim=1)
        
        h2 = self.conv2(a1)
        self.h2 = h2.data  
        
        a2 = self.pool(F.relu(h2))
        self.a2 = torch.cat((a2.data, self.one), dim=1)
        
        h3 = self.fc1(a2) 
        return h3
    
    def sample_predict(self, x, Nsamples, Qinv1, HHinv1, MAP1, Qinv2, HHinv2, MAP2, Qinv3, HHinv3, MAP3):
        # Just copies type from x, initializes new vector
        predictions = x.data.new(Nsamples, x.shape[0], self.output_dim)
        x = x.view(-1, self.input_dim)
        for i in range(Nsamples):         
            w1, b1 = sample_K_laplace_MN(MAP1, Qinv1, HHinv1)
            a = torch.matmul(x, torch.t(w1)) + b1.unsqueeze(0)
            a = self.act(a)
            
            w2, b2 = sample_K_laplace_MN(MAP2, Qinv2, HHinv2)
            a = torch.matmul(a, torch.t(w2)) + b2.unsqueeze(0)
            a = self.act(a)
            
            w3, b3 = sample_K_laplace_MN(MAP3, Qinv3, HHinv3)
            y = torch.matmul(a, torch.t(w3)) + b3.unsqueeze(0)
            predictions[i] = y
        return predictions

### 3 Define KFAC functions

In [35]:
def to_variable(var=(), cuda=True, volatile=False):
    out = []
    for v in var:
        if isinstance(v, np.ndarray):
            v = torch.from_numpy(v).type(torch.FloatTensor)

        if not v.is_cuda and cuda:
            v = v.cuda()

        if not isinstance(v, Variable):
            v = Variable(v, volatile=volatile)

        out.append(v)
    return out

def softmax_CE_preact_hessian(last_layer_acts):
    side = last_layer_acts.shape[1]
    I = torch.eye(side).type(torch.ByteTensor)
    # for i != j    H = -ai * aj -- Note that these are activations not pre-activations
    Hl = - last_layer_acts.unsqueeze(1) * last_layer_acts.unsqueeze(2)
    # for i == j    H = ai * (1 - ai)
    Hl[:, I] = last_layer_acts*(1-last_layer_acts)
    return Hl

def layer_act_hessian_recurse(prev_hessian, prev_weights, layer_pre_acts):   
    newside = layer_pre_acts.shape[1]
    batch_size = layer_pre_acts.shape[0]
    I = torch.eye(newside).type(torch.ByteTensor) # .unsqueeze(0).expand([batch_size, -1, -1])
    
    B = prev_weights.data.new(batch_size, newside, newside).fill_(0)
    B[:, I] = (layer_pre_acts > 0).type(B.type()) # d_act(layer_pre_acts)
    D = prev_weights.data.new(batch_size, newside, newside).fill_(0) # is just 0 for a piecewise linear

    Hl = torch.bmm(torch.t(prev_weights).unsqueeze(0).expand([batch_size, -1, -1]), prev_hessian)    
    Hl = torch.bmm(Hl, prev_weights.unsqueeze(0).expand([batch_size, -1, -1]))
    Hl = torch.bmm(B, Hl)
    Hl = torch.matmul(Hl, B)
    Hl = Hl + D   
    return Hl

def chol_scale_invert_kron_factor(factor, prior_scale, data_scale, upper=False):
    
    scaled_factor = data_scale * factor + prior_scale * torch.eye(factor.shape[0]).type(factor.type())
    inv_factor = torch.inverse(scaled_factor)
    chol_inv_factor = torch.cholesky(inv_factor, upper=upper)
    return chol_inv_factor

def sample_K_laplace_MN(MAP, upper_Qinv, lower_HHinv):
    # H = Qi (kron) HHi
    # sample isotropic unit variance mtrix normal
    Z = MAP.data.new(MAP.size()).normal_(mean=0, std=1)
    # AAT = HHi
    # A = torch.cholesky(HHinv, upper=False)
    # BTB = Qi
    # B = torch.cholesky(Qinv, upper=True)
    all_mtx_sample = MAP + torch.matmul(torch.matmul(lower_HHinv, Z), upper_Qinv)
    
    weight_mtx_sample = all_mtx_sample[:, :-1]
    bias_mtx_sample = all_mtx_sample[:, -1]
    
    return weight_mtx_sample, bias_mtx_sample

### 4 Network wrapper

In [25]:
class BaseNet(object):
    def __init__(self):
        pass
    def get_nb_parameters(self):
        return np.sum(p.numel() for p in self.model.parameters())

    def set_mode_train(self, train=True):
        if train:
            self.model.train()
        else:
            self.model.eval()

    def update_lr(self, epoch, gamma=0.99):
        self.epoch += 1
        if self.schedule is not None:
            if len(self.schedule) == 0 or epoch in self.schedule:
                self.lr *= gamma
                print('learning rate: %f  (%d)\n' % self.lr, epoch)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr

    def save(self, filename):
        torch.save({
            'epoch': self.epoch,
            'lr': self.lr,
            'model': self.model,
            'optimizer': self.optimizer}, filename)
        
    def load(self, filename):
        state_dict = torch.load(filename)
        self.epoch = state_dict['epoch']
        self.lr = state_dict['lr']
        self.model = state_dict['model']
        self.optimizer = state_dict['optimizer']
        print('restoring epoch: %d, lr: %f' % (self.epoch, self.lr))
        return self.epoch
    
class KBayes_Net(BaseNet):
    eps = 1e-6

    def __init__(self, lr=1e-3, channels_in=1, side_in=28, cuda=False, classes=10, batch_size=128, prior_sig=0):
        super(KBayes_Net, self).__init__()
        print('Net created.')
        self.lr = lr
        self.schedule = None  # [] #[50,200,400,600]
        self.cuda = cuda
        self.channels_in = channels_in
        self.prior_sig = prior_sig
        self.classes = classes
        self.batch_size = batch_size
        self.side_in = side_in
        self.create_net()
        self.create_opt()
        self.epoch = 0
        self.test = False
        
    def create_net(self):
        torch.manual_seed(42)
        if self.cuda:
            torch.cuda.manual_seed(42)
        self.model = Net(input_dim=self.channels_in*self.side_in*self.side_in, output_dim=self.classes)
        if self.cuda:
            self.model.cuda()
        print('Total params: %.2fK' % (self.get_nb_parameters() / 1000.0))
    
    def create_opt(self):
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.5, weight_decay=1/self.prior_sig**2)

    def fit(self, x, y):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)
        self.optimizer.zero_grad()
        out = self.model(x)
        loss = F.cross_entropy(out, y, reduction='sum')           
        loss.backward()
        self.optimizer.step()

        # out: (batch_size, out_channels, out_caps_dims)
        pred = out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()
        return loss.data, err
    
    def eval(self, x, y, train=False):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)
        out = self.model(x)
        loss = F.cross_entropy(out, y, reduction='sum')
        probs = F.softmax(out, dim=1).data.cpu()
        pred = out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()
        return loss.data, err, probs
    
    def get_K_laplace_params(self, trainloader):
        self.model.eval()
        it_counter = 0
        cum_HH1 = self.model.fc1.weight.data.new(self.model.n_hid, self.model.n_hid).fill_(0)
        cum_HH2 = self.model.fc1.weight.data.new(self.model.n_hid, self.model.n_hid).fill_(0)
        cum_HH3 = self.model.fc1.weight.data.new(self.model.output_dim, self.model.output_dim).fill_(0)

        cum_Q1 = self.model.fc1.weight.data.new(self.model.input_dim+1, self.model.input_dim+1).fill_(0)
        cum_Q2 = self.model.fc1.weight.data.new(self.model.n_hid+1, self.model.n_hid+1).fill_(0)
        cum_Q3 = self.model.fc1.weight.data.new(self.model.n_hid+1, self.model.n_hid+1).fill_(0)
        
        # Forward pass
        for x, y in trainloader:

            x, y = to_variable(var=(x, y.long()), cuda=use_cuda)
            self.optimizer.zero_grad()
            out = self.model(x)
            out_act = F.softmax(out, dim=1)
            loss = F.cross_entropy(out, y, reduction='sum')
            loss.backward()
            
            HH3 = softmax_CE_preact_hessian(out_act.data)
            cum_HH3 += HH3.sum(dim=0)
        
            Q3 = torch.bmm(self.model.a2.data.unsqueeze(2), self.model.a2.data.unsqueeze(1))
            cum_Q3 += Q3.sum(dim=0)
            
            HH2 = layer_act_hessian_recurse(prev_hessian=HH3, prev_weights=self.model.fc3.weight.data,
                                        layer_pre_acts=self.model.h2.data)
            cum_HH2 += HH2.sum(dim=0)
            Q2 = torch.bmm(self.model.a1.data.unsqueeze(2), self.model.a1.data.unsqueeze(1))
            cum_Q2 += Q2.sum(dim=0)
            HH1 = layer_act_hessian_recurse(prev_hessian=HH2, prev_weights=self.model.fc2.weight.data,
                                   layer_pre_acts=self.model.h1.data)
            cum_HH1 += HH1.sum(dim=0)
            Q1 = torch.bmm(self.model.a0.data.unsqueeze(2), self.model.a0.data.unsqueeze(1))
            cum_Q1 += Q1.sum(dim=0)
            
            it_counter += x.shape[0]
            print(it_counter)

        EHH3 = cum_HH3 / it_counter
        EHH2 = cum_HH2 / it_counter
        EHH1 = cum_HH1 / it_counter

        EQ3 = cum_Q3 / it_counter
        EQ2 = cum_Q2 / it_counter
        EQ1 = cum_Q1 / it_counter

        MAP3 = torch.cat((self.model.fc3.weight.data, self.model.fc3.bias.data.unsqueeze(1)), dim=1)
        MAP2 = torch.cat((self.model.fc2.weight.data, self.model.fc2.bias.data.unsqueeze(1)), dim=1)
        MAP1 = torch.cat((self.model.fc1.weight.data, self.model.fc1.bias.data.unsqueeze(1)), dim=1)
        
        return EQ1, EHH1, MAP1, EQ2, EHH2, MAP2, EQ3, EHH3, MAP3
    
    def sample_eval(self, x, y, Nsamples, scale_inv_EQ1, scale_inv_EHH1, MAP1, scale_inv_EQ2, scale_inv_EHH2, MAP2, scale_inv_EQ3, scale_inv_EHH3, MAP3, logits=False):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)

        out = self.model.sample_predict(x, Nsamples, scale_inv_EQ1, scale_inv_EHH1, MAP1, scale_inv_EQ2, scale_inv_EHH2, MAP2, scale_inv_EQ3, scale_inv_EHH3, MAP3)
        
        if logits:
            mean_out = out.mean(dim=0, keepdim=False)
            loss = F.cross_entropy(mean_out, y, reduction='sum')
            probs = F.softmax(mean_out, dim=1).data.cpu()
            
        else:
            mean_out =  F.softmax(out, dim=2).mean(dim=0, keepdim=False)
            probs = mean_out.data.cpu()
            
            log_mean_probs_out = torch.log(mean_out)
            loss = F.nll_loss(log_mean_probs_out, y, reduction='sum')

        pred = mean_out.data.max(dim=1, keepdim=False)[1]  # get the index of the max log-probability
        err = pred.ne(y.data).sum()

        return loss.data, err, probs
    
    def all_sample_eval(self, x, y, Nsamples, scale_inv_EQ1, scale_inv_EHH1, MAP1, scale_inv_EQ2, scale_inv_EHH2, MAP2, scale_inv_EQ3, scale_inv_EHH3, MAP3):
        x, y = to_variable(var=(x, y.long()), cuda=self.cuda)
        out = self.model.sample_predict(x, Nsamples, scale_inv_EQ1, scale_inv_EHH1, MAP1, scale_inv_EQ2, scale_inv_EHH2, MAP2, scale_inv_EQ3, scale_inv_EHH3, MAP3)       
        prob_out =  F.softmax(out, dim=2)
        prob_out = prob_out.data
        return prob_out
    
    def get_weight_samples(self, Nsamples, scale_inv_EQ1, scale_inv_EHH1, MAP1, scale_inv_EQ2, scale_inv_EHH2, MAP2, scale_inv_EQ3, scale_inv_EHH3, MAP3):
        weight_vec = []      
        for i in range(Nsamples):         
            w1, b1 = sample_K_laplace_MN(MAP1, scale_inv_EQ1, scale_inv_EHH1)
            w2, b2 = sample_K_laplace_MN(MAP2, scale_inv_EQ2, scale_inv_EHH2)
            w3, b3 = sample_K_laplace_MN(MAP3, scale_inv_EQ3, scale_inv_EHH3)
        
            for weight in w1.cpu().numpy().flatten():
                weight_vec.append(weight)
            for weight in w2.cpu().numpy().flatten():
                weight_vec.append(weight)
            for weight in w3.cpu().numpy().flatten():
                weight_vec.append(weight)
            
        return np.array(weight_vec)

### 5 Load and normalize MNIST

In [26]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:7" if use_cuda else "cpu")

transform = torchvision.transforms.ToTensor()
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in datasets.MNIST.resources
]
dataset = datasets.MNIST(
   "./data", train=True, download=True, transform=transform
)

### 6 Split in train, validate and test sets

In [27]:
trainset, valset, testset = torch.utils.data.random_split(dataset,[5000,1000,54000])
#Dataloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8,shuffle=True, num_workers=3)
valloader = DataLoader(valset, batch_size=1)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,shuffle=False, num_workers=3)

In [28]:
lr = 1e-3
prior_sig = 10000
batch_size = 100
net = KBayes_Net(lr=lr, channels_in=1, side_in=28, cuda=use_cuda, classes=10, batch_size=batch_size, prior_sig=prior_sig)

Net created.
Total params: 3.00K


### 7 Train the network

In [None]:
nb_epochs = 10 
nb_its_dev = 1

pred_cost_train = np.zeros(nb_epochs)
err_train = np.zeros(nb_epochs)

cost_dev = np.zeros(nb_epochs)
err_dev = np.zeros(nb_epochs)
best_err = np.inf

tic0 = time.time()
epoch = 0
for i in range(epoch, nb_epochs):    
    net.set_mode_train(True)
    tic = time.time()
    nb_samples = 0
    for x, y in trainloader:
        cost_pred, err = net.fit(x, y)

        err_train[i] += err
        pred_cost_train[i] += cost_pred
        nb_samples += len(x)

    pred_cost_train[i] /= nb_samples
    err_train[i] /= nb_samples

    toc = time.time()
    net.epoch = i
    # ---- print
    print("it %d/%d, Jtr_pred = %f, err = %f, " % (i, nb_epochs, pred_cost_train[i], err_train[i]), end="")
    print("time: %f seconds\n"% (toc - tic))
    # ---- dev
    if i % nb_its_dev == 0:
        net.set_mode_train(False)
        nb_samples = 0
        for j, (x, y) in enumerate(valloader):

            cost, err, probs = net.eval(x, y)

            cost_dev[i] += cost
            err_dev[i] += err
            nb_samples += len(x)

        cost_dev[i] /= nb_samples
        err_dev[i] /= nb_samples
        print('Jdev = %f, err = %f\n' % (cost_dev[i], err_dev[i]))

        if err_dev[i] < best_err:
            best_err = err_dev[i]
            net.save(models_dir+'/theta_best.dat')

toc0 = time.time()
runtime_per_it = (toc0 - tic0) / float(nb_epochs)
print('average time: %f seconds\n' % runtime_per_it)
# results
print('\nRESULTS:')
nb_parameters = net.get_nb_parameters()
best_cost_dev = np.min(cost_dev)
best_cost_train = np.min(pred_cost_train)
err_dev_min = err_dev[::nb_its_dev].min()

print('  cost_dev: %f (cost_train %f)' % (best_cost_dev, best_cost_train))
print('  err_dev: %f' % (err_dev_min))
print('  nb_parameters: %d (%s)' % (nb_parameters, humansize(nb_parameters)))
print('  time_per_it: %fs\n' % (runtime_per_it))




### Get Kron hessian approx

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, pin_memory=True, num_workers=3)
EQ1, EHH1, MAP1, EQ2, EHH2, MAP2, EQ3, EHH3, MAP3 = net.get_K_laplace_params(trainloader)

def save_object(obj, filename):
    with open(filename, 'wb') as output:  # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)
        

h_params = [EQ1, EHH1, MAP1, EQ2, EHH2, MAP2, EQ3, EHH3, MAP3]
save_object(h_params, '/block_hessian_params.pkl')


### Load up net and hessian params

In [None]:
# load data

# data augmentation
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])

use_cuda = torch.cuda.is_available()

trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform_train)
valset = datasets.MNIST(root='../data', train=False, download=True, transform=transform_test)

if use_cuda:
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=3)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=3)

else:
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=False,
                                              num_workers=3)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=False,
                                            num_workers=3)

lr = 1e-3
prior_sig = 2
net = KBayes_Net(lr=lr, channels_in=1, side_in=28, cuda=use_cuda, classes=10, batch_size=batch_size, prior_sig=prior_sig)
net.load(models_dir+'/theta_best.dat')

with open(models_dir+'/block_hessian_params.pkl', 'rb') as input:
    [EQ1, EHH1, MAP1, EQ2, EHH2, MAP2, EQ3, EHH3, MAP3] = pickle.load(input)

### Do scaling and get inverse

In [None]:
data_scale = np.sqrt(60000)

prior_sig = 0.15

prior_prec = 1/prior_sig**2
prior_scale = np.sqrt(prior_prec)

# upper_Qinv, lower_HHinv

scale_inv_EQ1 = chol_scale_invert_kron_factor(EQ1, prior_scale, data_scale, upper=True)
scale_inv_EHH1 = chol_scale_invert_kron_factor(EHH1, prior_scale, data_scale, upper=False)

scale_inv_EQ2 = chol_scale_invert_kron_factor(EQ2, prior_scale, data_scale, upper=True)
scale_inv_EHH2 = chol_scale_invert_kron_factor(EHH2, prior_scale, data_scale, upper=False)

scale_inv_EQ3 = chol_scale_invert_kron_factor(EQ3, prior_scale, data_scale, upper=True)
scale_inv_EHH3 = chol_scale_invert_kron_factor(EHH3, prior_scale, data_scale, upper=False)


### MAP inference on test set

In [None]:
batch_size = 100

if use_cuda:
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)

else:
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=False,
                                            num_workers=4)

nb_samples = 0    

test_cost = 0  # Note that these are per sample
test_err = 0
test_predictions = np.zeros((10000, 10))

net.set_mode_train(False)

for j, (x, y) in enumerate(valloader):
    cost, err, probs = net.eval(x, y) # 
    

    test_cost += cost
    test_err += err.cpu().numpy()
    test_predictions[nb_samples:nb_samples+len(x), :] = probs.numpy()
    nb_samples += len(x)

test_err /= nb_samples
print('Loglike = %5.6f, err = %1.6f\n' % (-test_cost, test_err))
