# Decoding a neural network from a JSON file to Python, then encoding the hessian from laplace.py to a JSON

In [1]:
# imports
from laplace import Laplace

import numpy as np
import pandas as pd
import torch

import json

import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(43)
torch.set_printoptions(sci_mode=False)

In [72]:
# Import data from csv

# Load data from CSV file using pandas
df = pd.read_csv('data.csv')

# Split the dataframe into x and y tensors
x = torch.from_numpy(df[['x1', 'x2']].to_numpy()).to(torch.float32)
y = torch.from_numpy(df['y'].to_numpy(dtype=int))

X = x.T

y_indices = y - 1
y_indices = y_indices.long()
y_train = nn.functional.one_hot(y_indices, num_classes=len(torch.unique(y))).float()
y_train

tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [0.,

In [73]:
# Init model

data = list(zip(x, y_train))
n_hidden = 15
D = X.shape[0]  # == 2
out_dim = y_train.shape[1]  # == 4
model = nn.Sequential(
    nn.Linear(D, n_hidden),
    nn.Sigmoid(),
    nn.Linear(n_hidden, out_dim)
)

loss_fn = nn.CrossEntropyLoss()

In [74]:
with open('nn.json') as fin:
    nn_json_str = fin.read()
    nn_json = json.loads(nn_json_str)

nn_json

[{'bias': [3.3988118,
   3.5933194,
   2.9422352,
   3.3322582,
   3.3399665,
   3.6317368,
   3.5383172,
   3.3447003,
   3.5671725,
   3.1051874,
   3.601897,
   3.2274942,
   3.4626565,
   3.2945428,
   3.2137966],
  'weight': [[-0.4472613,
    -0.36524177,
    -0.49843046,
    -0.4671523,
    -0.45563155,
    -0.4199196,
    -0.37798968,
    -0.46011707,
    -0.38260448,
    -0.5449862,
    -0.3152714,
    -0.50319517,
    -0.462592,
    -0.42082742,
    -0.47164235],
   [-0.42735013,
    -0.34834513,
    -0.47763515,
    -0.44645718,
    -0.4354668,
    -0.40082324,
    -0.36067995,
    -0.4397303,
    -0.36509258,
    -0.5209345,
    -0.29996133,
    -0.48096442,
    -0.44180417,
    -0.40227932,
    -0.45106304]]},
 {'bias': [3.843687, 3.8287387, 3.8443809, 3.8220103],
  'weight': [[3.599201, 3.2941632, 3.9565911, 3.8795183],
   [3.7613833, 4.164982, 4.4525414, 4.509027],
   [3.7834215, 3.9831934, 3.702734, 3.7160733],
   [4.5212874, 4.1562057, 3.5352004, 4.1731434],
   [4.11941

In [75]:
assert len(model.state_dict()) == 2 * len(nn_json)
iter_states = iter(model.state_dict())

# for layer in model.state_dict():
#     print(layer)
for layer_read in nn_json:
    state_w = next(iter_states)
    state_b = next(iter_states)
    tensor_w = torch.tensor(layer_read['weight']).T
    tensor_b = torch.tensor(layer_read['bias']).T
    model.state_dict()[state_w].data.copy_(tensor_w)
    model.state_dict()[state_b].data.copy_(tensor_b)

In [76]:
y_hat = torch.argmax(torch.softmax(model.forward(x), dim=1), dim=1) + 1

y_hat == y

tensor([ True,  True, False,  True,  True, False,  True,  True,  True,  True,
         True,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False])