In [15]:
import torch
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader, TensorDataset

from collections import OrderedDict

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import os

os.chdir("/home/s2113174/Projects-1")

#np.random.seed(1234)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [16]:
# Deep neural network
class DNN(torch.nn.Module):
    def __init__(self, layers):
        super(DNN, self).__init__()

        # Number of layers
        self.depth = len(layers) - 1
        
        # Activation Function
        self.activation = torch.nn.Tanh
        
        # The following loop organized the layers of the NN         
        layer_list = list()
        for i in range(self.depth - 1): 
            layer_list.append(
                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1])))
            layer_list.append(('activation_%d' % i, self.activation()))
        layer_list.append(
            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
        )
        layerDict = OrderedDict(layer_list)
        
        # Deploy layers
        self.layers = torch.nn.Sequential(layerDict)

        # for param in self.parameters():
        #     if len(param.shape) > 1:
        #         torch.nn.init.xavier_normal_(param)

    def forward(self, x):
        out = self.layers(x)
        return out

In [17]:
def test_set(max_space = 2,obs = 1,param = 1 ,mean = 0,std = 0):

    t= np.linspace(0,max_space,obs)

    sol = (param/ (2*np.pi))*np.sin(2*np.pi*t)

    noise_sol_test = sol + np.random.normal(mean,std, len(t))

    return t,noise_sol_test


def data(max_space = 2,obs = 1,param = 1 ,mean = 0,std = 0.1):

    t= np.linspace(0,max_space,obs)

    sol = (param/ (2*np.pi))*np.sin(2*np.pi*t)

    noise_sol_test = sol + np.random.normal(mean,std, len(t))

    x,y = torch.tensor(t).float().reshape(-1,1),torch.tensor(noise_sol_test).float().reshape(-1,1)
    
    X_u_train = TensorDataset(x,y)

    X_u_train = DataLoader(X_u_train,batch_size=obs)

    return X_u_train

In [18]:
nobs = 100
t, y = test_set(obs = nobs)

layers = [1] + 1*[10] + [1]
model = DNN(layers)
loss = torch.nn.MSELoss(reduction ='mean')

In [19]:
from backpack import backpack, extend
from backpack.extensions import DiagHessian, DiagGGNExact

model_ = extend(model, use_converter=True)
lossfunc_ = extend(loss)

loss_ = lossfunc_(model_(Variable(torch.tensor(t).float().reshape(-1,1),requires_grad=True)), torch.tensor(y).float().reshape(-1,1))

with backpack(DiagHessian(), DiagGGNExact()):
    loss_.backward()

for name, param in model_.named_parameters():
    print(name)
    print(".diag_ggn_exact.shape:   ", param.diag_ggn_exact)


layers.layer_0.weight
.diag_ggn_exact.shape:    tensor([[7.4235e-05],
        [2.6942e-02],
        [1.2357e-03],
        [6.4169e-02],
        [2.6704e-03],
        [7.5335e-05],
        [1.0758e-02],
        [1.4902e-02],
        [7.6811e-02],
        [1.4195e-02]])
layers.layer_0.bias
.diag_ggn_exact.shape:    tensor([0.0002, 0.0211, 0.0013, 0.0479, 0.0090, 0.0002, 0.0116, 0.0116, 0.0718,
        0.0253])
layers.layer_1.weight
.diag_ggn_exact.shape:    tensor([[0.9941, 0.7067, 1.2323, 0.0028, 1.5282, 1.1405, 0.9663, 0.4120, 0.2626,
         0.6643]])
layers.layer_1.bias
.diag_ggn_exact.shape:    tensor([2.])


In [20]:
from laplace import Laplace

la = Laplace(model, 'regression', subset_of_weights='last_layer', hessian_structure='diag')

dta = data(obs = nobs)

la.fit(dta)

print(la.H)

x,y = next(iter(dta))

#la.model.forward_with_features(x)

fm, varl = la(x)

print(varl)

tensor([ 49.7045,  35.3336,  61.6160,   0.1379,  76.4101,  57.0230,  48.3130,
         20.6006,  13.1279,  33.2144, 100.0000])
tensor([[[0.0842]],

        [[0.0830]],

        [[0.0819]],

        [[0.0808]],

        [[0.0797]],

        [[0.0787]],

        [[0.0777]],

        [[0.0767]],

        [[0.0758]],

        [[0.0750]],

        [[0.0742]],

        [[0.0734]],

        [[0.0727]],

        [[0.0720]],

        [[0.0714]],

        [[0.0708]],

        [[0.0702]],

        [[0.0697]],

        [[0.0693]],

        [[0.0689]],

        [[0.0686]],

        [[0.0683]],

        [[0.0680]],

        [[0.0678]],

        [[0.0677]],

        [[0.0676]],

        [[0.0676]],

        [[0.0676]],

        [[0.0677]],

        [[0.0678]],

        [[0.0680]],

        [[0.0682]],

        [[0.0685]],

        [[0.0688]],

        [[0.0692]],

        [[0.0696]],

        [[0.0701]],

        [[0.0706]],

        [[0.0712]],

        [[0.0719]],

        [[0.0726]],

        [[0.

In [21]:
# a dict to store the activations
forw_activation = {}
def forw_getActivation(name):
  # the hook signature
  def hook(model, input, output):
    forw_activation[name] = output.detach()
  return hook

h1 = model.layers[1].register_forward_hook(forw_getActivation('layers.activation_0'))

In [22]:
t = Variable(torch.tensor(t).float().reshape(-1,1),requires_grad=True)
y_ = model(t)

h1.remove()

Loss = loss(y_,torch.tensor(y).float().reshape(-1,1))

df_f = grad(Loss, y_, create_graph=True)[0]

ddf_ff = grad(df_f, y_, torch.ones_like(df_f))[0]

  Loss = loss(y_,torch.tensor(y).float().reshape(-1,1))


In [23]:
wt,bias = model.layers[-1].weight, model.layers[-1].bias

param_MAP = torch.cat((wt,bias.reshape(1,1)),1) 

nparam = param_MAP.reshape(-1).shape

In [24]:
df_theta = torch.cat((forw_activation['layers.activation_0'],torch.ones_like(ddf_ff)),1)

H = (nobs/2)*torch.sum(df_theta*ddf_ff*df_theta,axis=0)

print(H)

tensor([ 49.7045,  35.3336,  61.6160,   0.1379,  76.4101,  57.0230,  48.3130,
         20.6006,  13.1279,  33.2144, 100.0000])


In [25]:
f, phi = y_,forw_activation['layers.activation_0']

bsize = phi.shape[0]
output_size = f.shape[-1]

# calculate Jacobians using the feature vector 'phi'
identity = torch.eye(output_size, device=x.device).unsqueeze(0).tile(bsize, 1, 1)
# Jacobians are batch x output x params
Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1)
Js = torch.cat([Js, identity], dim=2)

In [26]:
def sigma_noise():
    return _sigma_noise

def _H_factor():
    sigma2 = sigma_noise().square()
    return 1 / sigma2 / temperature

def prior_precision_diag(prior_precision,n_params):
    """Obtain the diagonal prior precision \\(p_0\\) constructed from either
    a scalar, layer-wise, or diagonal prior precision.

    Returns
    -------
    prior_precision_diag : torch.Tensor
    """
    if len(prior_precision) == 1:  # scalar
        return prior_precision * torch.ones(n_params, device=device)

    elif len(prior_precision) == n_params:  # diagonal
        return prior_precision

    # elif len(prior_precision) == n_layers:  # per layer
    #     n_params_per_layer = parameters_per_layer(self.model)
    #     return torch.cat([prior * torch.ones(n_params, device=self._device) for prior, n_params
    #                         in zip(self.prior_precision, n_params_per_layer)])

    # else:
    #     raise ValueError('Mismatch of prior and model. Diagonal, scalar, or per-layer prior.')

_sigma_noise=torch.tensor([1])
temperature=torch.tensor([1])
prior_precision=torch.tensor([1])

prior_precision_diag = prior_precision_diag(prior_precision,nparam)

In [27]:
post_presicion = _H_factor() * H + prior_precision_diag


post_variance = 1 / post_presicion


functional_var = torch.einsum('ncp,p,nkp->nck', Js, post_variance, Js)

print(functional_var)

tensor([[[0.0842]],

        [[0.0830]],

        [[0.0819]],

        [[0.0808]],

        [[0.0797]],

        [[0.0787]],

        [[0.0777]],

        [[0.0767]],

        [[0.0758]],

        [[0.0750]],

        [[0.0742]],

        [[0.0734]],

        [[0.0727]],

        [[0.0720]],

        [[0.0714]],

        [[0.0708]],

        [[0.0702]],

        [[0.0697]],

        [[0.0693]],

        [[0.0689]],

        [[0.0686]],

        [[0.0683]],

        [[0.0680]],

        [[0.0678]],

        [[0.0677]],

        [[0.0676]],

        [[0.0676]],

        [[0.0676]],

        [[0.0677]],

        [[0.0678]],

        [[0.0680]],

        [[0.0682]],

        [[0.0685]],

        [[0.0688]],

        [[0.0692]],

        [[0.0696]],

        [[0.0701]],

        [[0.0706]],

        [[0.0712]],

        [[0.0719]],

        [[0.0726]],

        [[0.0733]],

        [[0.0741]],

        [[0.0750]],

        [[0.0759]],

        [[0.0768]],

        [[0.0778]],

        [[0.0