```@meta
CurrentModule = LaplaceRedux
```
## Batched Jacobians
As applied to a multi-class classification task.

### Multi-class problem

In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m project at `~/Builds/navimakarov/LaplaceRedux.jl/dev/notebooks/batching`


In [2]:
x, y = Data.toy_data_multi()
X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)

100-element Vector{Vector{Bool}}:
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 ⋮
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]

In [3]:
data = zip(x,y_train)
n_hidden = 3
D = size(X,1)
out_dim = length(unique(y))
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, out_dim)
)  
loss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)

loss (generic function with 1 method)

In [4]:
using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end

[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(2 => 3, σ)    [90m# 9 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "2-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/EHgZm/src/layers/stateless.jl:60[39m


Epoch 10
avg_loss(data) = 0.85723996f0
Epoch 20
avg_loss(data) = 0.5992859f0
Epoch 30
avg_loss(data) = 0.43424362f0
Epoch 40
avg_loss(data) = 0.3196775f0
Epoch 50
avg_loss(data) = 0.23693909f0
Epoch 60
avg_loss(data) = 0.17637047f0
Epoch 70
avg_loss(data) = 0.13184537f0
Epoch 80
avg_loss(data) = 0.09902897f0
Epoch 90
avg_loss(data) = 0.07476296f0
Epoch 100
avg_loss(data) = 0.056746498f0


In [14]:
Y = reduce(hcat, y_train)

4×100 Matrix{Bool}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     1  1  1  1  1  1  1  1  1  1  1  1

In [15]:
X

2×100 Matrix{Float64}:
 3.45212  4.03781  3.85089  3.41911  5.31203  …  -2.64932  -6.52157  -6.3513
 4.11912  1.88961  2.18114  1.31315  2.84395      2.88045   1.02894   4.07211

In [18]:
Y[:, 1], X[:, 1]

(Bool[1, 0, 0, 0], [3.452124036821634, 4.119117743595139])

### Jacobians, non-batched

In [19]:
x_1 = X[:, 1]
x_2 = X[:, 2]

2-element Vector{Float64}:
 4.0378059584920205
 1.8896085919713759

In [20]:
jgrads_1 = jacobian(() -> nn(x_1), Flux.params(nn))

Grads(...)

In [21]:
jgrads_1.grads

IdDict{Any, Any} with 4 entries:
  Float32[2.53876 -0.09505… => Float32[0.000218209 0.109107 … 0.130188 -8.41697…
  Float32[-2.56515, 2.3463… => Float32[1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.…
  Float32[0.507228, -2.065… => Float32[6.32101f-5 0.0316058 -2.04339f-7; -0.000…
  Float32[0.454429 2.25833… => Float32[0.999861 0.0 … 0.0 0.0; 0.0 0.999861 … 0…

In [25]:
sum(length, jgrads_1.params)

25

In [26]:
nn

Chain(
  Dense(2 => 3, σ),                     [90m# 9 parameters[39m
  Dense(3 => 4),                        [90m# 16 parameters[39m
) [90m                  # Total: 4 arrays, [39m25 parameters, 356 bytes.

In [27]:
sum(length, jgrads_1)

100

In [29]:
foreach(jac -> @show(size(jac)), jgrads_1)

size(jac) = (4, 6)
size(jac) = (4, 3)
size(jac) = (4, 12)
size(jac) = (4, 4)


In [24]:
fieldnames(Chain)

(:layers,)

In [30]:
jgrads_2 = jacobian(() -> nn(x_2), Flux.params(nn))

Grads(...)

In [37]:
J_2 = reduce(hcat, jgrads_2)
J_1 = reduce(hcat, jgrads_1)

4×25 Matrix{Float32}:
  0.000218209   0.109107   -7.05404f-7  …  0.0       1.0  0.0  0.0  0.0
 -0.00193867   -0.182956   -8.0478f-5      0.0       0.0  1.0  0.0  0.0
  0.00134268   -0.0728651  -8.56164f-5     0.0       0.0  0.0  1.0  0.0
 -0.00233176   -0.0644499   3.5387f-5      0.999995  0.0  0.0  0.0  1.0

In [32]:
th = jgrads_1.params

Params([Float32[2.5387568 -0.09505186; 0.82812595 0.8368814; -0.120450914 2.838133], Float32[0.5072278, -2.0655606, 0.940312], Float32[0.4544287 2.258334 -0.040812615; -4.0373597 -3.7868693 -4.6562204; 2.7961922 -1.5081825 -4.953516; -4.855989 -1.3340013 2.0473907], Float32[-2.5651526, 2.3463519, -0.6851924, -0.514734]])

In [33]:
jgrads_1[th[1]]

4×6 Matrix{Float32}:
  0.000218209   0.109107   -7.05404f-7   0.00026037   0.130188   -8.41697f-7
 -0.00193867   -0.182956   -8.0478f-5   -0.00231325  -0.218305   -9.60274f-5
  0.00134268   -0.0728651  -8.56164f-5   0.00160211  -0.0869436  -0.000102159
 -0.00233176   -0.0644499   3.5387f-5   -0.00278229  -0.0769024   4.22243f-5

In [36]:
jgrads_1[th[length(th) - 1]]

4×12 Matrix{Float32}:
 0.999861  0.0       0.0       0.0       …  0.0       0.0       0.0
 0.0       0.999861  0.0       0.0          0.999995  0.0       0.0
 0.0       0.0       0.999861  0.0          0.0       0.999995  0.0
 0.0       0.0       0.0       0.999861     0.0       0.0       0.999995

In [35]:
jgrads_1[th[length(th)]]

4×4 Matrix{Float32}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [42]:
J_1

4×25 Matrix{Float32}:
  0.000218209   0.109107   -7.05404f-7  …  0.0       1.0  0.0  0.0  0.0
 -0.00193867   -0.182956   -8.0478f-5      0.0       0.0  1.0  0.0  0.0
  0.00134268   -0.0728651  -8.56164f-5     0.0       0.0  0.0  1.0  0.0
 -0.00233176   -0.0644499   3.5387f-5      0.999995  0.0  0.0  0.0  1.0

In [39]:
J_2

4×25 Matrix{Float32}:
  4.68084f-5    0.46733   -0.000487648  …  0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639  -0.0556346       0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097  -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053   0.0244632       0.997032  0.0  0.0  0.0  1.0

### Jacobians, batched

In [43]:
dataloader_2 = DataLoader((X, Y), batchsize=2)

50-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Bool}}, batchsize=2)
  with first element:
  (2×2 Matrix{Float64}, 4×2 Matrix{Bool},)

In [44]:
x_b, y_b = popfirst!(Iterators.Stateful(dataloader_2))

([3.452124036821634 4.0378059584920205; 4.119117743595139 1.8896085919713759], Bool[1 1; 0 0; 0 0; 0 0])

In [47]:
x_b

2×2 Matrix{Float64}:
 3.45212  4.03781
 4.11912  1.88961

In [48]:
x_1

2-element Vector{Float64}:
 3.452124036821634
 4.119117743595139

In [51]:
jgrads = jgrads_b = jacobian(() -> nn(x_b), Flux.params(nn))

Grads(...)

In [50]:
foreach(jac -> @show(size(jac)), jgrads_b)

size(jac) = (8, 6)
size(jac) = (8, 3)
size(jac) = (8, 12)
size(jac) = (8, 4)


In [52]:
jtmp = reduce(hcat, jgrads)

8×25 Matrix{Float32}:
  0.000218209   0.109107   -7.05404f-7   …  0.0       1.0  0.0  0.0  0.0
 -0.00193867   -0.182956   -8.0478f-5       0.0       0.0  1.0  0.0  0.0
  0.00134268   -0.0728651  -8.56164f-5      0.0       0.0  0.0  1.0  0.0
 -0.00233176   -0.0644499   3.5387f-5       0.999995  0.0  0.0  0.0  1.0
  4.68084f-5    0.46733    -0.000487648     0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639   -0.0556346    …  0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097   -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053    0.0244632       0.997032  0.0  0.0  0.0  1.0

In [59]:
jtmp_1 = @view jtmp[1:4, :]
jtmp_2 = @view jtmp[5:8, :]
# use views to avoid allocations

4×25 view(::Matrix{Float32}, 5:8, :) with eltype Float32:
  4.68084f-5    0.46733   -0.000487648  …  0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639  -0.0556346       0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097  -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053   0.0244632       0.997032  0.0  0.0  0.0  1.0

In [63]:
J = cat(jtmp_1, jtmp_2, dims=3)
# this is expensive, but there is no way around it

4×25×2 Array{Float32, 3}:
[:, :, 1] =
  0.000218209   0.109107   -7.05404f-7  …  0.0       1.0  0.0  0.0  0.0
 -0.00193867   -0.182956   -8.0478f-5      0.0       0.0  1.0  0.0  0.0
  0.00134268   -0.0728651  -8.56164f-5     0.0       0.0  0.0  1.0  0.0
 -0.00233176   -0.0644499   3.5387f-5      0.999995  0.0  0.0  0.0  1.0

[:, :, 2] =
  4.68084f-5    0.46733   -0.000487648  …  0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639  -0.0556346       0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097  -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053   0.0244632       0.997032  0.0  0.0  0.0  1.0

In [72]:
b = 2
outdim = 4
for batch_index in 1:b
    
end

4×25 view(::Matrix{Float32}, 5:8, :) with eltype Float32:
  4.68084f-5    0.46733   -0.000487648  …  0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639  -0.0556346       0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097  -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053   0.0244632       0.997032  0.0  0.0  0.0  1.0

4×25 view(::Matrix{Float32}, 5:8, :) with eltype Float32:
  4.68084f-5    0.46733   -0.000487648  …  0.0       1.0  0.0  0.0  0.0
 -0.000415868  -0.783639  -0.0556346       0.0       0.0  1.0  0.0  0.0
  0.000288022  -0.312097  -0.0591869       0.0       0.0  0.0  1.0  0.0
 -0.000500191  -0.276053   0.0244632       0.997032  0.0  0.0  0.0  1.0

In [64]:
@assert isapprox(J[:, :, 1], J_1, atol=.0005)
@assert isapprox(J[:, :, 2], J_2, atol=.0005)

### Laplace Approximation (not yet supported)

In [5]:
# la = Laplace(nn; likelihood=:classification)
# fit!(la, data)
# optimize_prior!(la; verbose=true, n_steps=1000)

LoadError: AssertionError: Support for multi-class output still lacking, sorry. Currently only regression and binary classification models are supported.

In [None]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1))
#     push!(plt_list, plt)
# end
# plot(plt_list...)

In [None]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin)
#     push!(plt_list, plt)
# end
# plot(plt_list...)