In [1]:
from laplace import Laplace

In [2]:
import pandas as pd
import torch

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(43)

# Load data from CSV file using pandas
df = pd.read_csv('data1.csv')

# Split the dataframe into x and y tensors
x = torch.from_numpy(df[['x1', 'x2']].to_numpy()).to(torch.float32)
y = torch.from_numpy(df['y'].to_numpy(dtype=int))

In [3]:
X = torch.tensor(x.T).float().T

# Convert y to a tensor of indices and one-hot encode it
y_unique = torch.unique(torch.tensor(y))
y_indices = torch.searchsorted(y_unique, torch.tensor(y))
y_train = nn.functional.one_hot(y_indices, num_classes=len(y_unique)).float()


  X = torch.tensor(x.T).float().T
  y_unique = torch.unique(torch.tensor(y))
  y_indices = torch.searchsorted(y_unique, torch.tensor(y))


In [4]:
data = list(zip(x, y_train))
n_hidden = 3
D = X.shape[1]
out_dim = y_train.shape[1]
model = nn.Sequential(
    nn.Linear(D, n_hidden),
    nn.Sigmoid(),
    nn.Linear(n_hidden, out_dim)
)

In [5]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters())
epochs = 200
avg_loss = lambda data: torch.mean(torch.stack([loss_fn(model(x), y) for (x, y) in data]))
show_every = epochs // 10

In [6]:
for epoch in range(1, epochs+1):
    for (x, y) in data:
        opt.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward()
        opt.step()
    if epoch % show_every == 0:
        print("Epoch ", epoch)
        print("Avg Loss: ", avg_loss(data).item())

Epoch  20
Avg Loss:  0.661486029624939
Epoch  40
Avg Loss:  0.3299190402030945
Epoch  60
Avg Loss:  0.17765741050243378
Epoch  80
Avg Loss:  0.09842698276042938
Epoch  100
Avg Loss:  0.05515485629439354
Epoch  120
Avg Loss:  0.031066937372088432
Epoch  140
Avg Loss:  0.017541011795401573
Epoch  160
Avg Loss:  0.009913784451782703
Epoch  180
Avg Loss:  0.00560393463820219
Epoch  200
Avg Loss:  0.0031665244605392218


In [7]:
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='full')

In [19]:
la.fit(DataLoader(TensorDataset(X, y_train), batch_size=len(y_train)))

In [16]:
la.optimize_prior_precision()

In [10]:
probit_predictions = la(X, link_approx='probit')

In [11]:
predictions_probit_df = pd.DataFrame(probit_predictions.numpy(), columns=['class1', 'class2', 'class3', 'class4'])

In [12]:
predictions_probit_df.to_csv('predictions1-Python.csv', index=False)

In [13]:
print(probit_predictions)

tensor([[0.7440, 0.0258, 0.1556, 0.0746],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7527, 0.0250, 0.1502, 0.0721],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7556, 0.0246, 0.1490, 0.0708],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0708],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0708],
        [0.7549, 0.0247, 0.1493, 0.0711],
        [0.7551, 0.0247, 0.1491, 0.0711],
        [0.4583, 0.0461, 0.2968, 0.1989],
        [0.7550, 0.0247, 0.1492, 0.0711],
        [0.7556, 0.0246, 0.1491, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7556, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7546, 0.0248, 0.1491, 0.0715],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7555, 0.0246, 0.1492, 0.0707],
        [0.7303, 0.0274, 0.1618, 0.0805],
        [0.6391, 0.0345, 0.2163, 0.1101],
        [0.7555, 0.0246, 0.1492, 0

In [43]:
torch.set_printoptions(sci_mode=False)
print(torch.softmax(model(X), dim=1))

tensor([[    0.9965624213,     0.0000016738,     0.0033581415,     0.0000777742],
        [    0.9965446591,     0.0000015602,     0.0033833627,     0.0000704577],
        [    0.9965562820,     0.0000016364,     0.0033667870,     0.0000753244],
        [    0.9965448976,     0.0000015629,     0.0033830507,     0.0000706247],
        [    0.9965470433,     0.0000015747,     0.0033800169,     0.0000713822],
        [    0.9965445399,     0.0000015598,     0.0033834479,     0.0000704343],
        [    0.9965406060,     0.0000015698,     0.0033868612,     0.0000709172],
        [    0.9965445399,     0.0000015599,     0.0033834300,     0.0000704395],
        [    0.9965425134,     0.0000015673,     0.0033851531,     0.0000708191],
        [    0.9965503216,     0.0000015944,     0.0033755407,     0.0000726426],
        [    0.9965498447,     0.0000015926,     0.0033760718,     0.0000725192],
        [    0.9966515899,     0.0000045085,     0.0030323223,     0.0003116591],
        [    0.9

In [44]:
f_mu, f_var = la._glm_predictive_distribution(X)

In [24]:
print(f_mu)

tensor([[  4.4230,  -8.8740,  -1.2699,  -5.0352],
        [  4.4861,  -8.8812,  -1.1993,  -5.0710],
        [  4.4436,  -8.8760,  -1.2467,  -5.0467],
        [  4.4845,  -8.8810,  -1.2010,  -5.0701],
        [  4.4778,  -8.8802,  -1.2087,  -5.0662],
        [  4.4863,  -8.8812,  -1.1991,  -5.0711],
        [  4.4810,  -8.8800,  -1.2033,  -5.0695],
        [  4.4862,  -8.8812,  -1.1992,  -5.0710],
        [  4.4823,  -8.8804,  -1.2026,  -5.0696],
        [  4.4666,  -8.8789,  -1.2212,  -5.0599],
        [  4.4677,  -8.8790,  -1.2199,  -5.0605],
        [  3.5402,  -8.7660,  -2.2549,  -4.5300],
        [  4.4663,  -8.8789,  -1.2215,  -5.0597],
        [  4.4844,  -8.8810,  -1.2012,  -5.0700],
        [  4.4861,  -8.8812,  -1.1993,  -5.0710],
        [  4.4855,  -8.8811,  -1.2000,  -5.0706],
        [  4.4857,  -8.8811,  -1.1998,  -5.0707],
        [  4.4506,  -8.8739,  -1.2341,  -5.0530],
        [  4.4862,  -8.8812,  -1.1992,  -5.0710],
        [  4.4852,  -8.8809,  -1.1999,  -5.0709],


In [25]:
print(f_var.diagonal(dim1=1, dim2=2))

tensor([[ 31.0196,  41.0548,  31.7070,  52.5846],
        [ 28.6754,  41.1042,  28.9481,  52.8043],
        [ 29.1422,  41.0515,  29.4433,  52.4307],
        [ 28.6683,  41.1017,  28.9372,  52.7885],
        [ 28.6439,  41.0930,  28.8842,  52.6709],
        [ 28.6777,  41.1045,  28.9518,  52.8084],
        [ 28.6612,  41.0956,  28.9257,  52.7675],
        [ 28.6773,  41.1045,  28.9512,  52.8077],
        [ 28.6633,  41.0978,  28.9292,  52.7743],
        [ 28.7576,  41.0801,  28.9946,  52.5394],
        [ 28.7016,  41.0795,  28.9364,  52.5712],
        [416.1187,  45.3382, 513.8153, 164.9300],
        [ 28.7258,  41.0782,  28.9615,  52.5534],
        [ 28.6601,  41.1019,  28.9235,  52.7733],
        [ 28.6761,  41.1042,  28.9493,  52.8054],
        [ 28.6707,  41.1033,  28.9409,  52.7953],
        [ 28.6728,  41.1036,  28.9442,  52.7992],
        [ 28.7053,  41.0407,  28.9585,  52.5625],
        [ 28.6768,  41.1044,  28.9504,  52.8067],
        [ 28.6732,  41.1027,  28.9450,  52.7996],


In [42]:
torch.set_printoptions(sci_mode=False)
print(la.posterior_covariance)

tensor([[    33.2744750977,     -2.7266263962,      0.0084387623,
              0.0185401570,     -0.0011283623,     -0.0002128702,
             -0.3882472217,      0.0052650245,     -0.0002229449,
              0.0123387901,      0.0913084522,     -0.1025983989,
              0.1065993235,     -0.0530657545,      0.0336275436,
             -0.3958779871,      0.3983125091,     -0.3104789257,
              0.2769282758,     -0.4365876615,      0.3794456422,
              0.0944162831,      0.0544639640,     -0.0592246763,
             -0.0896923766],
        [    -2.7266263962,     23.2198352814,      0.0193300936,
              0.1178108230,     -0.0019358058,     -0.0011457719,
             -1.8955211639,      0.0141997524,     -0.0006601219,
             -0.0162161049,      0.4701569676,     -0.2017673552,
             -0.1011129394,      0.1467449516,      0.2128175646,
             -0.4680263996,      0.9485751390,     -0.6647963524,
              0.5853176117,     -1.5656356812, 

In [41]:
torch.set_printoptions(sci_mode=False)
print(la.H)

tensor([[     0.0010755209,      0.0037801524,     -0.0000104518,
             -0.0000289333,     -0.0000008550,     -0.0000006675,
              0.0005564168,     -0.0000064578,      0.0000001300,
             -0.0000044865,      0.0001706312,      0.0000074638,
              0.0000281542,      0.0000367419,      0.0000196151,
              0.0010411981,     -0.0008379242,      0.0010915000,
             -0.0010648549,      0.0006306022,     -0.0011185764,
              0.0001781956,      0.0000637196,      0.0002894249,
             -0.0005312876],
        [     0.0037801513,      0.0148211205,     -0.0000289333,
             -0.0001513119,     -0.0000006675,     -0.0000006818,
              0.0025601147,     -0.0000208819,      0.0000000975,
             -0.0000739473,      0.0000371799,     -0.0000244968,
             -0.0000516416,     -0.0000780119,     -0.0001319733,
              0.0015557257,     -0.0023374027,      0.0020077359,
             -0.0014300867,      0.0023784733, 

In [28]:
print(torch.diag(la.prior_precision_diag))

tensor([[0.0294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0294, 0.0000, 0.0000, 0.0000, 0.0000,
       

In [40]:
torch.set_printoptions(sci_mode=False, threshold=10_000)
Js, f_mu = la.backend.jacobians(X)
Js = Js.permute(2, 0, 1).contiguous().view(25, -1)
print(Js)

tensor([[    -0.1411598027,      0.0156491511,     -0.1581617594,
              0.0795832351,     -0.0009964401,      0.0001104666,
             -0.0011164560,      0.0005617741,     -0.0893566012,
              0.0099061858,     -0.1001191437,      0.0503775664,
             -0.0051915809,      0.0005755451,     -0.0058168797,
              0.0029269152,     -0.0247540437,      0.0027442644,
             -0.0277355369,      0.0139558604,     -0.0001150907,
              0.0000127591,     -0.0001289528,      0.0000648860,
             -0.0107975937,      0.0011970347,     -0.0120981066,
              0.0060874792,     -0.0003120851,      0.0000345981,
             -0.0003496741,      0.0001759477,     -0.0090285754,
              0.0010009191,     -0.0101160184,      0.0050901398,
             -0.0514490008,      0.0057037012,     -0.0576457642,
              0.0290059745,     -0.0458919220,      0.0050876359,
             -0.0514193587,      0.0258729979,     -0.9222305417,
          