In [1]:
from laplace import Laplace

In [2]:
import pandas as pd
import torch

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(43)

# Load data from CSV file using pandas
df = pd.read_csv('data1.csv')

# Split the dataframe into x and y tensors
x = torch.from_numpy(df[['x1', 'x2']].to_numpy()).to(torch.float32)
y = torch.from_numpy(df['y'].to_numpy(dtype=int))

In [3]:
X = torch.tensor(x.T).float().T

# Convert y to a tensor of indices and one-hot encode it
y_unique = torch.unique(torch.tensor(y))
y_indices = torch.searchsorted(y_unique, torch.tensor(y))
y_train = nn.functional.one_hot(y_indices, num_classes=len(y_unique)).float()


  X = torch.tensor(x.T).float().T
  y_unique = torch.unique(torch.tensor(y))
  y_indices = torch.searchsorted(y_unique, torch.tensor(y))


In [4]:
data = list(zip(x, y_train))
n_hidden = 3
D = X.shape[1]
out_dim = y_train.shape[1]
model = nn.Sequential(
    nn.Linear(D, n_hidden),
    nn.Sigmoid(),
    nn.Linear(n_hidden, out_dim)
)

In [None]:

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters())
epochs = 200
avg_loss = lambda data: torch.mean(torch.stack([loss_fn(model(x), y) for (x, y) in data]))
show_every = epochs // 10


In [None]:

for epoch in range(1, epochs+1):
    for (x, y) in data:
        opt.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward()
        opt.step()
    if epoch % show_every == 0:
        print("Epoch ", epoch)
        print("Avg Loss: ", avg_loss(data).item())


In [5]:
py_layer1 = torch.tensor([[-2.7541, 0.5183],
                          [3.5415, 0.1424],
                          [0.1192, -4.3529]])

py_layer2 = torch.tensor([-0.9672, 1.5726, -2.0302])

py_layer3 = torch.tensor([[-5.3452, 2.9503, -8.0299],
                          [0.5926, -7.3648, 4.3767],
                          [-5.9890, 0.3879, 3.8754],
                          [3.0135, -6.8771, -7.4405]])

py_layer4 = torch.tensor([1.5361, -1.5164, -1.5870, 1.8060])

model[0].weight.data = py_layer1
model[0].bias.data = py_layer2
model[2].weight.data = py_layer3
model[2].bias.data = py_layer4

In [6]:
from laplace.curvature import AsdlGGN
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='full', backend=AsdlGGN)

In [7]:
la.fit(DataLoader(TensorDataset(X, y_train), batch_size=len(y_train)))

In [8]:
la.optimize_prior_precision()

In [9]:
print(la.backend)

<laplace.curvature.asdl.AsdlGGN object at 0x000001DE4C808AF0>


In [10]:
probit_predictions = la(X, link_approx='probit')

In [11]:
print(probit_predictions)

tensor([[0.7440, 0.0258, 0.1556, 0.0746],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7527, 0.0250, 0.1502, 0.0721],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7556, 0.0246, 0.1490, 0.0708],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0708],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0708],
        [0.7549, 0.0247, 0.1493, 0.0711],
        [0.7551, 0.0247, 0.1491, 0.0711],
        [0.4583, 0.0461, 0.2968, 0.1989],
        [0.7550, 0.0247, 0.1492, 0.0711],
        [0.7556, 0.0246, 0.1491, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7556, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7546, 0.0248, 0.1491, 0.0715],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7303, 0.0274, 0.1618, 0.0805],
        [0.6391, 0.0345, 0.2163, 0.1101],
        [0.7555, 0.0246, 0.1492, 0

In [14]:
print(la.posterior.H)

tensor([[ 1.0761e-03,  3.7821e-03, -1.0456e-05, -2.8949e-05, -8.5513e-07,
         -6.6766e-07,  5.5666e-04, -6.4605e-06,  1.3006e-07, -4.5244e-06,
          1.7044e-04,  7.4636e-06,  2.8183e-05,  3.6739e-05,  1.9637e-05,
          1.0415e-03, -8.3795e-04,  1.0918e-03, -1.0652e-03,  6.3075e-04,
         -1.1189e-03,  1.7800e-04,  6.3745e-05,  2.8971e-04, -5.3148e-04],
        [ 3.7821e-03,  1.4829e-02, -2.8949e-05, -1.5137e-04, -6.6766e-07,
         -6.8189e-07,  2.5613e-03, -2.0894e-05,  9.7518e-08, -7.4117e-05,
          3.6303e-05, -2.4495e-05, -5.1605e-05, -7.7990e-05, -1.3192e-04,
          1.5562e-03, -2.3375e-03,  2.0083e-03, -1.4305e-03,  2.3791e-03,
         -1.8518e-03,  3.0909e-05, -1.3455e-04, -4.6160e-04,  5.6515e-04],
        [-1.0456e-05, -2.8949e-05,  1.0246e-05, -3.9215e-05, -5.2818e-08,
         -3.0151e-08, -6.4605e-06,  8.3477e-06, -1.6097e-08, -2.3514e-06,
          2.1715e-04,  1.5073e-04,  2.1744e-05, -1.8274e-04, -1.6839e-04,
         -1.7099e-05, -2.3722e-05,  