In [1]:
# TODO make a replica of LA as in the julia example

In [1]:
from laplace import Laplace

import numpy as np
import pandas as pd
import torch

import json

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(43)
torch.set_printoptions(sci_mode=False)

In [2]:
# Import data from csv

# Load data from CSV file using pandas
df = pd.read_csv('data1.csv')

# Split the dataframe into x and y tensors
x = torch.from_numpy(df[['x1', 'x2']].to_numpy()).to(torch.float32)
y = torch.from_numpy(df['y'].to_numpy(dtype=int))

X = x.T

y_unique = torch.unique(y)
y_indices = y - 1
y_train = nn.functional.one_hot(y_indices, num_classes=len(y_unique)).float()
y_train

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0.,

In [3]:
# Init model

data = list(zip(x, y_train))
n_hidden = 3
D = X.shape[0]  # == 2
out_dim = y_train.shape[1]  # == 4
model = nn.Sequential(
    nn.Linear(D, n_hidden),
    nn.Sigmoid(),
    nn.Linear(n_hidden, out_dim)
)

loss_fn = nn.CrossEntropyLoss()
# opt = torch.optim.Adam(model.parameters())
# epochs = 200
# avg_loss = lambda data: torch.mean(torch.stack([loss_fn(model(x), y) for (x, y) in data]))
# show_every = epochs // 10

In [4]:
# Import model from json

with open('nn.json') as fin:
    nn_json_str = fin.read()
    nn_json = json.loads(nn_json_str)

In [5]:
nn_json

[{'bias': [-0.14515024, 0.80743283, 1.2083956],
  'weight': [[-0.49617776, -0.29888734, 2.791775],
   [-1.1155951, 2.777497, -0.07422561]]},
 {'bias': [-0.5658274, 0.6530026, -0.5378328, 0.25649628],
  'weight': [[-4.0386014, 1.5438478, 0.19519025, -2.0000153],
   [-0.27907616, -4.3666024, -5.0731287, 1.5551064],
   [1.7367028, -3.2249343, 2.2207005, -4.285545]]}]

In [6]:
# with torch.no_grad():

assert len(model.state_dict()) == 2 * len(nn_json)
iter_states = iter(model.state_dict())

# for layer in model.state_dict():
#     print(layer)
for layer_read in nn_json:
    state_w = next(iter_states)
    state_b = next(iter_states)
    tensor_w = torch.tensor(layer_read['weight']).T
    tensor_b = torch.tensor(layer_read['bias']).T
    model.state_dict()[state_w].data.copy_(tensor_w)
    model.state_dict()[state_b].data.copy_(tensor_b)
    # model.state_dict()[layer].data.fill_(const)
    
# NOTE: DOES NOT WORK
# params = list(model.parameters())
# assert len(params) == 2 * len(nn_json)
# for idx_layer in range(len(nn_json)):
#     layer = nn_json[idx_layer]
#     idx_param_w = idx_layer * 2
#     idx_param_b = idx_param_w + 1
#     print(torch.tensor(layer['weight']).T)
#     print(torch.tensor(layer['bias']).T)
#     params[idx_param_w].data = torch.tensor(layer['weight']).T
#     params[idx_param_b].data = torch.tensor(layer['bias']).T

  tensor_b = torch.tensor(layer_read['bias']).T


In [7]:
[model.state_dict()[layer].data for layer in model.state_dict()]

[tensor([[-0.4962, -1.1156],
         [-0.2989,  2.7775],
         [ 2.7918, -0.0742]]),
 tensor([-0.1452,  0.8074,  1.2084]),
 tensor([[-4.0386, -0.2791,  1.7367],
         [ 1.5438, -4.3666, -3.2249],
         [ 0.1952, -5.0731,  2.2207],
         [-2.0000,  1.5551, -4.2855]]),
 tensor([-0.5658,  0.6530, -0.5378,  0.2565])]

In [8]:
y_hat = torch.argmax(torch.softmax(model.forward(x), dim=1), dim=1) + 1

In [9]:
y_hat == y

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

## Laplace Approximation

In [10]:
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='full')

In [11]:
la.fit(DataLoader(TensorDataset(x, y_train), batch_size=1))
# NOTE: batch size 1 since there is no batching in Julia (yet)

In [16]:
la.posterior.H

tensor([[     7.3767,     -3.7364,     -0.0546,      0.0092,     -0.0039,
              0.0019,     -0.6562,      0.0065,      0.0006,     -0.0466,
              1.0685,     -0.5575,     -0.3250,     -0.9833,      0.2835,
              0.1213,     -0.0023,      0.2356,      0.2502,     -0.0829,
              0.0384,      0.6027,     -0.7324,      0.2051,     -0.0754],
        [    -3.7364,      2.2330,      0.0092,     -0.0060,      0.0019,
             -0.0060,      0.2223,     -0.0026,     -0.0009,      0.1228,
             -0.8440,      0.2922,     -0.0784,      0.4128,     -0.2649,
             -0.0386,      0.0576,     -0.0692,     -0.0058,      0.3736,
              0.0419,     -0.4785,      0.1213,     -0.0320,      0.3892],
        [    -0.0546,      0.0092,      0.0174,      0.0058,     -0.0001,
             -0.0001,      0.0065,      0.0020,     -0.0000,      0.0012,
              0.0069,      0.0062,      0.0204,      0.0144,     -0.0003,
              0.0081,     -0.0179,  

In [13]:
pred = la(x, pred_type='glm', link_approx='probit')