# Encoding a neural network generated with Julia into a JSON file

- Step 1: generate a non-linear dataset
- Step 2: Write data to a CSV file
- Step 3: Build the neural network
- Step 4: Train the neural network
- Step 5: Encode NN to JSON file

In [1]:
# imports
using Pkg; Pkg.activate(".")

using Flux, Plots, Random, Statistics, LaplaceRedux
using Flux.Optimise: update!, Adam
theme(:lime)

using CSV
using DataFrames

using JSON
using Serialization
using Tullio

using LinearAlgebra
using Zygote

Random.seed!(42)

[32m[1m  Activating[22m[39m new project at `c:\Users\adeli\OneDrive\Desktop\facultate\2nd year\Q4 - Software Project\LaplaceRedux.jl\dev\notebooks\nn_encoding`


TaskLocalRNG()

### Step 1: generate a non-linear dataset

In [10]:
xs, ys = LaplaceRedux.Data.toy_data_multi(200)
y = ys
X = hcat(xs...) # bring into tabular format
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)
data = zip(xs,y_train)

zip([[3.7950666156682473, 2.97611820240477], [4.862485498045742, 3.101107590056898], [2.5053720102137795, 2.3587018391829364], [1.315703869831591, 3.5442202681951054], [4.565778297827177, 1.7510607028519065], [1.7589719206288321, 1.3576727445892125], [3.943013774519966, 3.741286429325216], [3.1122879091765596, 2.5529812305142228], [1.253312408704216, 4.74084188997857], [4.691959234426529, 1.1545636415108023]  …  [-2.890000398522491, 4.87299122704623], [-6.146643464113058, 3.2877910187493464], [-4.8087486639006425, 2.0009040246959], [-4.475500383963048, 3.7263022893829714], [-3.3266416974577453, 2.22424443653751], [-4.065381866029225, 3.938421457822125], [-5.9580911539916395, 2.754015408541254], [-6.907438373060961, 1.6580922894001633], [-3.5873416907149442, 3.946450375745781], [-6.3978130361325825, 3.121697603169221]], Vector{Bool}[[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]  …  [0, 0, 0, 1],

### Step 2: Write data to a CSV file

In [11]:
# Open the CSV file for writing
file = "data_multi.csv"
csv_file = open(file, "w")

# Write the header
write(csv_file, "x1,x2,y\n")

# Write the data
for ((x1, x2), y) in zip(xs, y)
    write(csv_file, "$x1,$x2,$y\n")
end

# Close the CSV file
close(csv_file)

### Step 3: Build the neural network

In [12]:
n_hidden = 3
D = size(X,1)
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, length(unique(ys)))
)
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) 

loss (generic function with 1 method)

### Step 4: Training the neural network

In [13]:
opt = Adam(1e-3)
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end

Epoch 10


avg_loss(data) = 0.41262257f0
Epoch 20
avg_loss(data) = 0.27379435f0


Epoch 30
avg_loss(data) = 0.19174951f0


Epoch 40
avg_loss(data) = 0.13720119f0
Epoch 50
avg_loss(data) = 0.098459475f0


Epoch 60
avg_loss(data) = 0.07051084f0
Epoch 70
avg_loss(data) = 0.050339155f0


Epoch 80
avg_loss(data) = 0.035825156f0
Epoch 90
avg_loss(data) = 0.025427673f0


Epoch 100
avg_loss(data) = 0.018010197f0


In [14]:
foreach(display, Flux.params(nn))

3×2 Matrix{Float32}:
 -4.06603    0.038519
  0.296283   1.45003
  0.36257   -2.25655

3-element Vector{Float32}:
 -2.0739872
  1.6585876
 -0.7064312

4×3 Matrix{Float32}:
 -8.1377    2.48922  -7.281
  6.03308  -8.89387  -0.258845
 -9.18194  -5.51285   4.0174
  5.56673  -0.4937   -8.3651

4-element Vector{Float32}:
  1.3787575
 -2.5149424
  0.47881675
 -2.0565474

### Step 5: Encode NN to JSON file

In [15]:
serialize_json_nn(nn::Chain)::String = JSON.json([Dict(:weight => nn.layers[i].weight, :bias => nn.layers[i].bias) for i in range(1, length(nn.layers))])
# Export as JSON
write("nn_multi.json", serialize_json_nn(nn))
serialize("nn-binary_multi.jlb", nn)

In [16]:
la = Laplace(nn; likelihood=:classification, hessian_structure=:full, subset_of_weights=:all, backend=:EmpiricalFisher)
fit!(la, data)
# Set display options
ENV["LINES"] = 10^6
ENV["COLUMNS"] = 10^6

# Print the matrix without truncation
@show la.H
#show(stdout, "text/plain", la.H)

la.H = [3.656652628079067e-9 -5.403992571221671e-9 6.516155091326693e-10 5.762908428850302e-9 -5.310900988153527e-9 6.099719431693359e-10 2.941773409872113e-9 -3.0787643194904303e-9 3.5573419374418943e-10 1.2040479717925205e-8 -5.129513615243977e-9 5.1134333261043e-9 -1.2024401883362966e-8 -2.3490973134165885e-7 4.604588714640018e-9 2.0229579186084173e-8 2.1007801023451872e-7 3.543143173292551e-11 1.677338321155334e-8 -2.2431423234845463e-8 5.622573485961441e-9 -2.344106088354596e-7 2.1308623449486085e-8 -2.1042864132946475e-9 2.1520868195547848e-7; -5.403992571221671e-9 4.5026972697556424e-5 -6.092930577291811e-6 -5.3109007448680335e-9 -1.6634741525838413e-5 -1.2336998295148686e-6 -3.078764215647539e-9 1.8458724333458552e-6 -7.460336561351442e-7 5.039699632437161e-6 9.728969358822497e-6 -1.951123231762401e-6 -1.2817436189922726e-5 1.1516181410645772e-5 5.404992512563883e-6 -2.754671424842552e-6 -1.4166439199158597e-5 3.5242740503661435e-6 2.661593208406441e-6 -3.683136685587436e-6 -2.

25×25 Matrix{Float64}:
  3.65665e-9   -5.40399e-9   6.51616e-10   5.76291e-9   -5.3109e-9    6.09972e-10   2.94177e-9   -3.07876e-9   3.55734e-10   1.20405e-8   -5.12951e-9   5.11343e-9   -1.20244e-8  -2.3491e-7     4.60459e-9   2.02296e-8    2.10078e-7    3.54314e-11   1.67734e-8  -2.24314e-8    5.62257e-9   -2.34411e-7    2.13086e-8  -2.10429e-9    2.15209e-7
 -5.40399e-9    4.5027e-5   -6.09293e-6   -5.3109e-9    -1.66347e-5  -1.2337e-6    -3.07876e-9    1.84587e-6  -7.46034e-7    5.0397e-6     9.72897e-6  -1.95112e-6   -1.28174e-5   1.15162e-5    5.40499e-6  -2.75467e-6   -1.41664e-5    3.52427e-6    2.66159e-6  -3.68314e-6   -2.50281e-6    1.39665e-5    8.30896e-6  -4.84429e-6   -1.74311e-5
  6.51616e-10  -6.09293e-6   4.18382e-5    6.09972e-10  -1.2337e-6    1.08178e-5    3.55734e-10  -7.46034e-7   9.07269e-6   -4.23106e-7   -1.70539e-6   5.79037e-7    1.54944e-6  -2.29018e-5   -4.01591e-7   7.67822e-6    1.56249e-5   -3.22113e-6   -1.30572e-6   1.81142e-6    2.71539e-6   -2.3194

In [21]:
function gen_mapping_sq(params)::Array{Tuple{Int64, Int64}}
    mapping_lin = gen_mapping(params)
    length_theta = sum(length, params)
    mapping_sq = Array{Tuple{Int64, Int64}}(undef, length_theta, length_theta)
    for (i, i_) in enumerate(mapping_lin)
        for (j, j_) in enumerate(mapping_lin)
            mapping_sq[i, j] = (i_, j_)
        end
    end
    return mapping_sq
end

import Base: getindex

function getindex(r::Matrix{Float64}, I::Matrix{Tuple{Int64, Int64}})
    l = Matrix{Float64}(undef, size(I))
    for (i, j) in Iterators.product(1:size(I, 1), 1:size(I, 2))
        # Unpack 2d index
        x, y = I[i, j]
        l[i, j] = r[x, y]
    end
    return l
end

function gen_mapping(params)
    theta_length = sum(length, params)
    offset = 0
    mapping = []
    for param in params
        indices = collect(1:length(param))
        indices_updated = vec(reshape(offset .+ indices, size(param))')
        append!(mapping, indices_updated)
        offset += length(param)
    end
    mapping
end

hessian = la.H

gen_mapping_sqt(params) = gen_mapping_sq(map(transpose, params))
#to_col_order(hessian, nn) = hessian[gen_mapping_sqt(Flux.params(nn))]
to_row_order(hessian, nn) = hessian[gen_mapping_sq(Flux.params(nn))]

to_row_order(hessian, nn)

25×25 Matrix{Float64}:
  3.65665e-9    5.76291e-9   -5.40399e-9  -5.3109e-9    6.51616e-10   6.09972e-10   2.94177e-9   -3.07876e-9   3.55734e-10   1.20405e-8   -2.3491e-7     3.54314e-11  -5.12951e-9   4.60459e-9   1.67734e-8   5.11343e-9    2.02296e-8   -2.24314e-8   -1.20244e-8   2.10078e-7    5.62257e-9   -2.34411e-7    2.13086e-8  -2.10429e-9    2.15209e-7
  5.76291e-9    3.5473e-8    -5.3109e-9   -1.94341e-8   6.09972e-10   2.33569e-10   4.45899e-9   -2.23384e-9   1.68785e-10  -1.68941e-8   -6.11917e-7   -6.9455e-9    -2.10078e-8  -6.21891e-9  -1.3514e-7    8.28043e-9    4.78423e-8    1.33744e-7    2.96229e-8   5.70302e-7    8.34397e-9   -6.18698e-7   -1.40769e-7   1.80938e-7    5.7854e-7
 -5.40399e-9   -5.3109e-9     4.5027e-5   -1.66347e-5  -6.09293e-6   -1.2337e-6    -3.07876e-9    1.84587e-6  -7.46034e-7    5.0397e-6     1.15162e-5    3.52427e-6    9.72897e-6   5.40499e-6   2.66159e-6  -1.95112e-6   -2.75467e-6   -3.68314e-6   -1.28174e-5  -1.41664e-5   -2.50281e-6    1.39665

In [None]:
la = Laplace(nn; likelihood=:classification, hessian_structure=:full, subset_of_weights=:all, backend=:EmpiricalFisher)
fit!(la, data)
# Set display options
ENV["LINES"] = 10^6
ENV["COLUMNS"] = 10^6

# Print the matrix without truncation
@show la.H
#show(stdout, "text/plain", la.H)

In [17]:
predictions_plugin = reduce(hcat, predict(la, X, link_approx=:plugin))
predictions_plugin = reshape(predictions_plugin, 4, 200)

4×200 Matrix{Float32}:
 0.998231     0.998233     0.998218     0.998216     0.998195     0.99812      0.998234     0.998225     0.998214     0.997648     0.998235     0.998231     0.998224     0.998235     0.998195     0.998235     0.998235     0.997881     0.998233     0.998227     0.998211     0.998232     0.998232     0.998231     0.998204     0.998235     0.998235     0.998204     0.998228     0.998235     0.998093     0.998234     0.998235     0.998166     0.998235     0.998105     0.998235     0.998192     0.998233     0.998235     0.998234     0.998235    0.998059     0.998194     0.998232     0.998235     0.998234     0.998167     0.99823      0.998235    3.46446f-8   3.21119f-8   3.07626f-8   3.44787f-8   3.07067f-8   3.07239f-8   4.79678f-8   3.08539f-8   3.24057f-8   3.49924f-8   3.07535f-8   3.19737f-8   3.07526f-8   3.28229f-8   4.79722f-8   6.31313f-8   3.30742f-8   3.10638f-8   3.20518f-8   3.28384f-8   3.45735f-8   3.2665f-8    3.1187f-8    3.43496f-8   3.12035f-8   3.0