```@meta
CurrentModule = LaplaceRedux
```

# Bayesian MLP


In [None]:
#| echo: false
using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, Random, Statistics, LaplaceRedux
theme(:lime)

This time we use a synthetic dataset containing samples that are not linearly separable:


In [None]:
# Number of points to generate.
xs, ys = LaplaceRedux.Data.toy_data_non_linear(200)
X = hcat(xs...) # bring into tabular format
data = zip(xs,ys)

For the classification task we build a neural network with weight decay composed of a single hidden layer.


In [None]:
n_hidden = 32
D = size(X,1)
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, 1)
)  
λ = 0.01
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization();

The model is trained for 200 epochs before the training loss stagnates.


In [None]:
using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 200
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end

## Laplace Approximation

Laplace approximation can be implemented as follows:


In [None]:
la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)
fit!(la, data)

The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).


In [None]:
#| output: true

# Plot the posterior distribution with a contour plot.
zoom=0
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1), zoom=zoom)
p_laplace = plot(la, X, ys; title="Laplace", clim=(0,1), zoom=zoom)
plot(p_plugin, p_laplace, layout=(1,2), size=(1000,400))

Zooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.


In [None]:
#| output: true

zoom=-50
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1), zoom=zoom)
p_laplace = plot(la, X, ys; title="Laplace", clim=(0,1), zoom=zoom)
plot(p_plugin, p_laplace, layout=(1,2), size=(1000,400))