# Encoding a neural network generated with Julia into a JSON file (multi-class classification)

- Step 1: generate a multi-class dataset
- Step 2: Write data to a CSV file
- Step 3: Build the neural network
- Step 4: Train the neural network
- Step 5: Encode NN to JSON file
- Step 6: Laplace Approximation and Show Hessian
- Step 7: Convert Hessian to PyTorch format

In [1]:
# imports
using Pkg; Pkg.activate(".")

using Flux, Plots, Random, Statistics, LaplaceRedux
using Flux.Optimise: update!, Adam
theme(:lime)

using CSV
using DataFrames

using JSON
using Serialization
using Tullio

using LinearAlgebra
using Zygote

Random.seed!(42)

[32m[1m  Activating[22m[39m new project at `c:\Users\adeli\OneDrive\Desktop\facultate\2nd year\Q4 - Software Project\LaplaceRedux.jl\dev\notebooks\nn_encoding`






TaskLocalRNG()

### Step 1: generate a multi-class dataset

In [2]:
xs, ys = LaplaceRedux.Data.toy_data_multi(200)
y = ys
X = hcat(xs...) # bring into tabular format
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)
data = zip(xs,y_train)

zip([[1.7810855910752834, 1.453755621187046], [2.4474772862101295, 2.444795783744807], [2.1636346147891956, 5.3628177740880245], [1.7489738983854946, 1.0580423579013771], [3.371567531990057, 4.3997622868544735], [3.1735996163580333, 2.560337078130556], [2.7579852508892126, 2.618623127686306], [4.612431480759879, 2.1708445365124773], [5.413409450634119, 5.119188930199143], [1.4249415185713727, 5.083347540628576]  …  [-5.87184114898721, 5.105441383957622], [-4.82690755957905, 1.4430727186165218], [-4.968292355001326, 3.282447703989634], [-5.742254048942101, 2.5556485379415905], [-5.0160641261110595, 1.9538519714978864], [-2.8216229870174887, 4.499115295840324], [-2.676817480886136, 3.1732572835074833], [-3.064760674081976, 4.34544106009164], [-6.697040240569821, 1.5950690324757328], [-4.5212172786289955, 1.6608140597795131]], Vector{Bool}[[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]  …  [0, 0, 0

### Step 2: Write data to a CSV file

In [3]:
# Open the CSV file for writing
file = "data_multi.csv"
csv_file = open(file, "w")

# Write the header
write(csv_file, "x1,x2,y\n")

# Write the data
for ((x1, x2), y) in zip(xs, y)
    write(csv_file, "$x1,$x2,$y\n")
end

# Close the CSV file
close(csv_file)

### Step 3: Build the neural network

In [4]:
n_hidden = 3
D = size(X,1)
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, length(unique(ys)))
)
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) 

loss (generic function with 1 method)

### Step 4: Training the neural network

In [5]:
opt = Adam(1e-3)
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end

│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(2 => 3, σ)
│   summary(x) = 2-element Vector{Float64}
└ @ Flux C:\Users\adeli\.julia\packages\Flux\FWgS0\src\layers\stateless.jl:50


Epoch 10


avg_loss(data) = 0.37087834f0
Epoch 20
avg_loss(data) = 0.24804346f0


Epoch 30
avg_loss(data) = 0.17368573f0


Epoch 40
avg_loss(data) = 0.123093165f0
Epoch 50
avg_loss(data) = 0.08742809f0


Epoch 60
avg_loss(data) = 0.062047515f0
Epoch 70
avg_loss(data) = 0.043952193f0


Epoch 80
avg_loss(data) = 0.031074619f0
Epoch 90
avg_loss(data) = 0.021936228f0


Epoch 100
avg_loss(data) = 0.015468704f0


In [6]:
foreach(display, Flux.params(nn))

3×2 Matrix{Float32}:
  4.59859    0.0204697
 -0.335302   1.85503
 -0.328612  -1.97061

3-element Vector{Float32}:
  2.2643619
  1.4519827
 -0.82282156

4×3 Matrix{Float32}:
  6.11972  -0.582848  -9.10659
 -8.30581  -6.45552    3.18972
  6.42595  -9.15029   -0.395065
 -8.3736    3.33583   -6.32245

4-element Vector{Float32}:
 -2.3268585
  0.75011677
 -2.5603814
  0.76309216

### Step 5: Encode NN to JSON file

In [7]:
serialize_json_nn(nn::Chain)::String = JSON.json([Dict(:weight => nn.layers[i].weight, :bias => nn.layers[i].bias) for i in range(1, length(nn.layers))])
# Export as JSON
write("nn_multi.json", serialize_json_nn(nn))
serialize("nn-binary_multi.jlb", nn)

### Step 6: Laplace Approximation and show the resulting hessian

In [8]:
la = Laplace(nn; likelihood=:classification, hessian_structure=:full, subset_of_weights=:all, backend=:EmpiricalFisher)
fit!(la, data)
# Set display options
ENV["LINES"] = 10^6
ENV["COLUMNS"] = 10^6

# Print the matrix without truncation
@show la.posterior.H
#show(stdout, "text/plain", la.posterior.H)

la.posterior.H = [1.2414799343406667e-10 -8.70478359308363e-10 -5.365167638103371e-10 9.896436775358515e-11 6.278836517773332e-10 -3.0977435932427816e-10 7.719825047406326e-11 -1.1233912098870789e-10 6.068413229272477e-11 1.3711046776144547e-8 -6.127632074022059e-9 4.340313025721895e-9 -1.192517327752897e-8 1.0226922995218704e-8 -3.197365162565264e-10 -4.59389656917394e-9 -5.313715195042519e-9 -2.7559679349697222e-9 2.9000479053192684e-9 6.458977089818387e-10 -7.908298447775048e-10 7.555581616945707e-9 2.4130789881012524e-9 -3.8089987379800135e-9 -6.160937836830458e-9; -8.70478359308363e-10 1.818238501348184e-5 -1.5747558733751787e-6 6.278836533752697e-10 6.499349427525691e-6 -2.2511386343248472e-7 -1.1233911570432882e-10 -2.4021626456213656e-7 -1.1042195640961331e-7 4.036858836780817e-6 2.7468476593134115e-7 -2.7814720151130414e-6 -1.5298253656003157e-6 3.9561189819345995e-6 6.595409221673422e-7 -2.3104807409521766e-6 -2.3049112835753117e-6 2.2426243748430427e-7 -1.576981127417165e-7 




25×25 Matrix{Float64}:
  1.24148e-10  -8.70478e-10  -5.36517e-10   9.89644e-11   6.27884e-10  -3.09774e-10   7.71983e-11  -1.12339e-10   6.06841e-11   1.3711e-8   -6.12763e-9    4.34031e-9  -1.19252e-8    1.02269e-8  -3.19737e-10  -4.5939e-9   -5.31372e-9   -2.75597e-9   2.90005e-9    6.45898e-10  -7.9083e-10    7.55558e-9   2.41308e-9  -3.809e-9    -6.16094e-9
 -8.70478e-10   1.81824e-5   -1.57476e-6    6.27884e-10   6.49935e-6   -2.25114e-7   -1.12339e-10  -2.40216e-7   -1.10422e-7    4.03686e-6   2.74685e-7   -2.78147e-6  -1.52983e-6    3.95612e-6   6.59541e-7   -2.31048e-6  -2.30491e-6    2.24262e-7  -1.57698e-7    1.16675e-6   -1.23287e-6    4.55228e-6  -3.64859e-8  -1.17728e-6  -3.33787e-6
 -5.36517e-10  -1.57476e-6    2.78834e-5   -3.09774e-10  -2.25114e-7   -5.4547e-6     6.06841e-11  -1.10422e-7   -3.86858e-6   -1.48501e-6  -2.27813e-7    9.97732e-7   7.15044e-7   -3.71618e-8  -2.57895e-6    5.67701e-7   2.04811e-6   -6.54957e-7  -1.30085e-6    7.10652e-7    1.24505e-6   -2.99

### Step 7: Convert Hessian to PyTorch format

In [9]:
function gen_mapping_sq(params)::Array{Tuple{Int64, Int64}}
    mapping_lin = gen_mapping(params)
    length_theta = sum(length, params)
    mapping_sq = Array{Tuple{Int64, Int64}}(undef, length_theta, length_theta)
    for (i, i_) in enumerate(mapping_lin)
        for (j, j_) in enumerate(mapping_lin)
            mapping_sq[i, j] = (i_, j_)
        end
    end
    return mapping_sq
end

import Base: getindex

function getindex(r::Matrix{Float64}, I::Matrix{Tuple{Int64, Int64}})
    l = Matrix{Float64}(undef, size(I))
    for (i, j) in Iterators.product(1:size(I, 1), 1:size(I, 2))
        # Unpack 2d index
        x, y = I[i, j]
        l[i, j] = r[x, y]
    end
    return l
end

function gen_mapping(params)
    theta_length = sum(length, params)
    offset = 0
    mapping = []
    for param in params
        indices = collect(1:length(param))
        indices_updated = vec(reshape(offset .+ indices, size(param))')
        append!(mapping, indices_updated)
        offset += length(param)
    end
    mapping
end

hessian = la.posterior.H

gen_mapping_sqt(params) = gen_mapping_sq(map(transpose, params))
#to_col_order(hessian, nn) = hessian[gen_mapping_sqt(Flux.params(nn))]
to_row_order(hessian, nn) = hessian[gen_mapping_sq(Flux.params(nn))]

to_row_order(hessian, nn)

25×25 Matrix{Float64}:
  1.24148e-10   9.89644e-11  -8.70478e-10   6.27884e-10  -5.36517e-10  -3.09774e-10   7.71983e-11  -1.12339e-10   6.06841e-11   1.3711e-8    1.02269e-8  -2.75597e-9  -6.12763e-9   -3.19737e-10   2.90005e-9    4.34031e-9  -4.5939e-9    6.45898e-10  -1.19252e-8   -5.31372e-9   -7.9083e-10    7.55558e-9   2.41308e-9  -3.809e-9    -6.16094e-9
  9.89644e-11   1.06517e-9    6.27884e-10  -2.98016e-9   -3.09774e-10  -3.41035e-10   7.16e-11      1.00405e-9   -1.49852e-10   5.68021e-8   5.74841e-8   7.90777e-9   1.82174e-8    9.01142e-10   3.23436e-8   -3.92745e-8  -1.35088e-8  -3.89647e-8   -3.57446e-8   -4.48785e-8   -1.28446e-9    6.53726e-8   3.3024e-8   -5.22846e-8  -4.61119e-8
 -8.70478e-10   6.27884e-10   1.81824e-5    6.49935e-6   -1.57476e-6   -2.25114e-7   -1.12339e-10  -2.40216e-7   -1.10422e-7    4.03686e-6   3.95612e-6   2.24262e-7   2.74685e-7    6.59541e-7   -1.57698e-7   -2.78147e-6  -2.31048e-6   1.16675e-6   -1.52983e-6   -2.30491e-6   -1.23287e-6    4.55

In [10]:
la = Laplace(nn; likelihood=:classification, hessian_structure=:kron, subset_of_weights=:all, backend=:GGN)
fit!(la, data)

@show predict(la, X)

predict(la, X) = [0.9904355164477457 0.9914082404850619 0.9914622037217969 0.9877270726648183 0.9914629897796067 0.9914119073108244 0.9914300049799578 0.9905677767929202 0.9914629982583352 0.9914557660503315 0.990557536973172 0.9909760070251171 0.9914616382266314 0.9914613846203649 0.9914624770537133 0.9914626615454909 0.9914330291664273 0.9914542125084412 0.9532635205457702 0.991447009040373 0.9892025452197587 0.991464196414552 0.9914631966725679 0.9914629321527905 0.9867860576215675 0.9914626983768582 0.9914629345413632 0.9914509728199524 0.991462916239084 0.9914612385184631 0.9914459500401883 0.9847235233686035 0.9914583500706569 0.9885874235909862 0.989476098722508 0.9905809686024607 0.991442022964726 0.9914097343250348 0.991462638290327 0.9912487551008279 0.9822435136053109 0.9914599880949913 0.9914628970349257 0.9914603274797066 0.9914551918179799 0.991433707303987 0.9914632630047706 0.9914334202255172 0.9913403589060291 0.988233804267558 8.688410704156406e-6 1.7006750081112994e-




4×200 Matrix{Float64}:
 0.990436    0.991408    0.991462    0.987727    0.991463    0.991412    0.99143     0.990568    0.991463    0.991456    0.990558    0.990976    0.991462    0.991461    0.991462    0.991463    0.991433    0.991454    0.953264    0.991447    0.989203    0.991464    0.991463    0.991463    0.986786    0.991463    0.991463    0.991451    0.991463    0.991461    0.991446    0.984724    0.991458    0.988587    0.989476    0.990581    0.991442    0.99141     0.991463    0.991249    0.982244    0.99146     0.991463    0.99146     0.991455    0.991434    0.991463    0.991433    0.99134     0.988234    8.68841e-6   1.70068e-5  1.04196e-5  3.52724e-5  3.2494e-5   2.12676e-5  1.01243e-5  8.63676e-6   1.26918e-5  1.2776e-5   8.64464e-6  8.6368e-6    8.6981e-6    8.6365e-6    9.68577e-6   8.63931e-6  8.70057e-6  8.86067e-6   5.67494e-5  1.42385e-5  8.67462e-6   1.76601e-5  8.6372e-6    1.56463e-5  8.9853e-6    8.66916e-6   8.85418e-6   8.86582e-5  8.63771e-6   8.83109e-6   1.