In [1]:
from laplace import Laplace

In [2]:
import pandas as pd
import torch

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(42)

# Load data from CSV file using pandas
df = pd.read_csv('data1.csv')

# Split the dataframe into x and y tensors
x = torch.from_numpy(df[['x1', 'x2']].to_numpy()).to(torch.float32)
y = torch.from_numpy(df['y'].to_numpy(dtype=int))

In [3]:
X = torch.tensor(x.T).float().T

# Convert y to a tensor of indices and one-hot encode it
y_unique = torch.unique(torch.tensor(y))
y_indices = torch.searchsorted(y_unique, torch.tensor(y))
y_train = nn.functional.one_hot(y_indices, num_classes=len(y_unique)).float()


  X = torch.tensor(x.T).float().T
  y_unique = torch.unique(torch.tensor(y))
  y_indices = torch.searchsorted(y_unique, torch.tensor(y))


In [4]:
data = list(zip(x, y_train))
n_hidden = 3
D = X.shape[1]
out_dim = y_train.shape[1]
model = nn.Sequential(
    nn.Linear(D, n_hidden),
    nn.Sigmoid(),
    nn.Linear(n_hidden, out_dim)
)

In [5]:
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters())
epochs = 100
avg_loss = lambda data: torch.mean(torch.stack([loss_fn(model(x), y) for (x, y) in data]))
show_every = epochs // 10

In [6]:
for epoch in range(1, epochs+1):
    for (x, y) in data:
        opt.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward()
        opt.step()
    if epoch % show_every == 0:
        print("Epoch ", epoch)
        print("Avg Loss: ", avg_loss(data).item())

Epoch  10
Avg Loss:  1.1412208080291748
Epoch  20
Avg Loss:  0.782179057598114
Epoch  30
Avg Loss:  0.511333703994751
Epoch  40
Avg Loss:  0.3474577069282532
Epoch  50
Avg Loss:  0.24560989439487457
Epoch  60
Avg Loss:  0.17748203873634338
Epoch  70
Avg Loss:  0.12997011840343475
Epoch  80
Avg Loss:  0.09598128497600555
Epoch  90
Avg Loss:  0.07127176225185394
Epoch  100
Avg Loss:  0.05311905965209007


In [7]:
la = Laplace(model, 'classification',
             subset_of_weights='all',
             hessian_structure='full')

In [8]:
la.fit(DataLoader(TensorDataset(X, y_train)))

In [9]:
la.optimize_prior_precision()

In [10]:
probit_predictions = la(X, link_approx='probit')

In [11]:
predictions_probit_df = pd.DataFrame(probit_predictions.numpy(), columns=['class1', 'class2', 'class3', 'class4'])

In [12]:
predictions_probit_df.to_csv('predictions1-Python.csv', index=False)

In [13]:
print(probit_predictions)

tensor([[0.6695, 0.0289, 0.1736, 0.1279],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6693, 0.0289, 0.1735, 0.1282],
        [0.6699, 0.0287, 0.1740, 0.1274],
        [0.6699, 0.0287, 0.1739, 0.1275],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6693, 0.0288, 0.1744, 0.1275],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6695, 0.0288, 0.1743, 0.1274],
        [0.6699, 0.0287, 0.1738, 0.1276],
        [0.6698, 0.0288, 0.1738, 0.1277],
        [0.6296, 0.0412, 0.1714, 0.1578],
        [0.6698, 0.0288, 0.1738, 0.1276],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6674, 0.0293, 0.1742, 0.1290],
        [0.6700, 0.0287, 0.1740, 0.1274],
        [0.6698, 0.0287, 0.1742, 0.1273],
        [0.6699, 0.0287, 0.1741, 0.1273],
        [0.6608, 0.0313, 0.1718, 0.1361],
        [0.6510, 0.0342, 0.1717, 0.1432],
        [0.6700, 0.0287, 0.1740, 0

In [14]:
torch.set_printoptions(sci_mode=False)
print(torch.softmax(model(X), dim=1))

tensor([[    0.9486,     0.0002,     0.0351,     0.0161],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9485,     0.0002,     0.0351,     0.0162],
        [    0.9487,     0.0002,     0.0353,     0.0158],
        [    0.9488,     0.0002,     0.0352,     0.0159],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9483,     0.0002,     0.0356,     0.0159],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9485,     0.0002,     0.0355,     0.0159],
        [    0.9487,     0.0002,     0.0352,     0.0159],
        [    0.9487,     0.0002,     0.0352,     0.0160],
        [    0.9451,     0.0002,     0.0341,     0.0205],
        [    0.9487,     0.0002,     0.0352,     0.0160],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0.9488,     0.0002,     0.0352,     0.0158],
        [    0