```@meta
CurrentModule = LaplaceRedux
```
## Batched Jacobians
As applied to a multi-class classification task.

### Multi-class problem

In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m project at `~/Builds/navimakarov/LaplaceRedux.jl/dev/notebooks/batching`


In [2]:
x, y = Data.toy_data_multi()
X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)

100-element Vector{Vector{Bool}}:
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 ⋮
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]

In [3]:
data = zip(x,y_train)
n_hidden = 3
D = size(X,1)
out_dim = length(unique(y))
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, out_dim)
)  
loss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)

loss (generic function with 1 method)

In [4]:
using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end

[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(2 => 3, σ)    [90m# 9 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "2-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/EHgZm/src/layers/stateless.jl:60[39m


Epoch 10
avg_loss(data) = 0.9327189f0
Epoch 20
avg_loss(data) = 0.6260527f0
Epoch 30
avg_loss(data) = 0.4192539f0
Epoch 40
avg_loss(data) = 0.29381973f0
Epoch 50
avg_loss(data) = 0.21203795f0
Epoch 60
avg_loss(data) = 0.15556468f0
Epoch 70
avg_loss(data) = 0.11524396f0
Epoch 80
avg_loss(data) = 0.08588406f0
Epoch 90
avg_loss(data) = 0.06424693f0
Epoch 100
avg_loss(data) = 0.048179574f0


In [5]:
Y = reduce(hcat, y_train)

4×100 Matrix{Bool}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     1  1  1  1  1  1  1  1  1  1  1  1

In [6]:
X

2×100 Matrix{Float64}:
 5.10701  1.61325  5.13842  1.42374  4.80034  …  -3.48668  -4.20696  -5.90665
 2.42447  2.48575  2.11091  1.93559  2.32837      1.31256   2.79496   1.81608

In [7]:
Y[:, 1], X[:, 1]

(Bool[1, 0, 0, 0], [5.107007423372803, 2.4244711420805967])

### Jacobians, non-batched

In [8]:
x_1 = X[:, 1]
x_2 = X[:, 2]

2-element Vector{Float64}:
 1.6132470587112007
 2.4857489228284155

In [9]:
jgrads_1 = jacobian(() -> nn(x_1), Flux.params(nn))

Grads(...)

In [10]:
jgrads_1.grads

IdDict{Any, Any} with 4 entries:
  Float32[2.42897 -0.14302… => Float32[3.07355f-5 -0.000129027 … -6.12534f-5 -0…
  Float32[1.74088 -4.06614… => Float32[0.999997 0.0 … 0.0 0.0; 0.0 0.999997 … 0…
  Float32[1.48922, -1.0898… => Float32[1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.…
  Float32[0.514042, -0.832… => Float32[6.01831f-6 -2.52647f-5 -0.00236856; -1.1…

In [11]:
sum(length, jgrads_1.params)

25

In [12]:
nn

Chain(
  Dense(2 => 3, σ),                     [90m# 9 parameters[39m
  Dense(3 => 4),                        [90m# 16 parameters[39m
) [90m                  # Total: 4 arrays, [39m25 parameters, 356 bytes.

In [13]:
sum(length, jgrads_1)

100

In [14]:
foreach(jac -> @show(size(jac)), jgrads_1)

size(jac) = (4, 6)
size(jac) = (4, 3)
size(jac) = (4, 12)
size(jac) = (4, 4)


In [15]:
fieldnames(Chain)

(:layers,)

In [16]:
jgrads_2 = jacobian(() -> nn(x_2), Flux.params(nn))

Grads(...)

In [17]:
J_2 = reduce(hcat, jgrads_2)
J_1 = reduce(hcat, jgrads_1)

4×25 Matrix{Float32}:
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

In [18]:
th = jgrads_1.params

Params([Float32[2.4289665 -0.14302832; -2.2368581 0.110216156; 0.08029391 -2.6969676], Float32[0.5140422, -0.83235115, -1.5891083], Float32[1.7408757 -4.066143 -5.329109; -3.258112 0.77494115 2.0363786; 1.0316626 -3.1093464 1.7720069; -3.1875958 1.940807 -4.1519613], Float32[1.489223, -1.0898395, -1.0472586, 0.6869783]])

In [19]:
jgrads_1[th[1]]

4×6 Matrix{Float32}:
  3.07355f-5  -0.000129027  -0.0120963   …  -6.12534f-5  -0.00574251
 -5.75227f-5   2.45904f-5    0.00462227      1.16739f-5   0.00219435
  1.82142f-5  -9.86657f-5    0.00402219     -4.684f-5     0.00190947
 -5.62777f-5   6.15856f-5   -0.00942432      2.92368f-5  -0.00447405

In [20]:
jgrads_1[th[length(th) - 1]]

4×12 Matrix{Float32}:
 0.999997  0.0       0.0       0.0       …  0.0          0.0
 0.0       0.999997  0.0       0.0          0.0          0.0
 0.0       0.0       0.999997  0.0          0.000444655  0.0
 0.0       0.0       0.0       0.999997     0.0          0.000444655

In [21]:
jgrads_1[th[length(th)]]

4×4 Matrix{Float32}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [22]:
J_1

4×25 Matrix{Float32}:
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

In [23]:
J_2

4×25 Matrix{Float32}:
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

### Jacobians, batched

In [24]:
dataloader_2 = DataLoader((X, Y), batchsize=2)

50-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Bool}}, batchsize=2)
  with first element:
  (2×2 Matrix{Float64}, 4×2 Matrix{Bool},)

In [25]:
x_b, y_b = popfirst!(Iterators.Stateful(dataloader_2))

([5.107007423372803 1.6132470587112007; 2.4244711420805967 2.4857489228284155], Bool[1 1; 0 0; 0 0; 0 0])

In [26]:
x_b

2×2 Matrix{Float64}:
 5.10701  1.61325
 2.42447  2.48575

In [27]:
x_1

2-element Vector{Float64}:
 5.107007423372803
 2.4244711420805967

In [28]:
jgrads = jgrads_b = jacobian(() -> nn(x_b), Flux.params(nn))

Grads(...)

In [29]:
foreach(jac -> @show(size(jac)), jgrads_b)

size(jac) = (8, 6)
size(jac) = (8, 3)
size(jac) = (8, 12)
size(jac) = (8, 4)


In [30]:
jtmp = reduce(hcat, jgrads)

8×25 Matrix{Float32}:
  3.07355f-5  -0.000129027  -0.0120963    …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227      0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219      0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432      0.000444655  0.0  0.0  0.0  1.0
  0.0460502   -0.0985883    -0.00244748      0.0          1.0  0.0  0.0  0.0
 -0.0861846    0.0187893     0.000935241  …  0.0          0.0  1.0  0.0  0.0
  0.0272899   -0.0753897     0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193    0.0470571    -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [38]:
jtmp_1 = @view jtmp[1:4, :]
jtmp_2 = @view jtmp[5:8, :]
display(jtmp_1)
display(jtmp_2)
# use views to avoid allocations

4×25 view(::Matrix{Float32}, 1:4, :) with eltype Float32:
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

4×25 view(::Matrix{Float32}, 5:8, :) with eltype Float32:
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [32]:
J = cat(jtmp_1, jtmp_2, dims=3)
# this is expensive, but there is no way around it

4×25×2 Array{Float32, 3}:
[:, :, 1] =
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

[:, :, 2] =
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [68]:
stack([jtmp_1, jtmp_2])

4×25×2 Array{Float32, 3}:
[:, :, 1] =
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

[:, :, 2] =
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [41]:
batch_size = 2
outdim = 4
for batch_index in 1:batch_size
    i = (batch_index - 1) * batch_size + 1
    j = i + outdim - 1
    @show i
    @show j
end

i = 1
j = 4
i = 3
j = 6


In [45]:
for i in 1:outdim:(batch_size * outdim)
    @show i
    @show i + outdim - 1
end

i = 1
(i + outdim) - 1 = 4
i = 5
(i + outdim) - 1 = 8


In [69]:
#jviews = [(batch_start, batch_start + outdim - 1) for batch_start in 1 : outdim : batch_size * outdim]
jviews = [jtmp[batch_start : (batch_start + outdim - 1), :] for batch_start in 1 : outdim : batch_size * outdim]
#display(jviews)
# jviews = [@view J[batch_start:batch_start + outdim - 1, :] for batch_start in 1:outdim:batch_size]

2-element Vector{Matrix{Float32}}:
 [3.073554f-5 -0.00012902677 … 0.0 0.0; -5.7522677f-5 2.4590418f-5 … 0.0 0.0; 1.8214227f-5 -9.866572f-5 … 1.0 0.0; -5.62777f-5 6.158565f-5 … 0.0 1.0]
 [0.046050195 -0.09858835 … 0.0 0.0; -0.086184606 0.018789345 … 0.0 0.0; 0.027289864 -0.075389706 … 1.0 0.0; -0.084319286 0.047057115 … 0.0 1.0]

In [70]:
display(jviews[1])
display(jviews[2])

4×25 Matrix{Float32}:
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

4×25 Matrix{Float32}:
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [56]:
cat(jviews, dims=3)

2×1×1 Array{Matrix{Float32}, 3}:
[:, :, 1] =
 [3.073554f-5 -0.00012902677 … 0.0 0.0; -5.7522677f-5 2.4590418f-5 … 0.0 0.0; 1.8214227f-5 -9.866572f-5 … 1.0 0.0; -5.62777f-5 6.158565f-5 … 0.0 1.0]
 [0.046050195 -0.09858835 … 0.0 0.0; -0.086184606 0.018789345 … 0.0 0.0; 0.027289864 -0.075389706 … 1.0 0.0; -0.084319286 0.047057115 … 0.0 1.0]

In [58]:
Jt = reduce((a, b) -> cat(a, b, dims=3), jviews)

4×25×2 Array{Float32, 3}:
[:, :, 1] =
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

[:, :, 2] =
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [61]:
reduce(cat(dims=3), jviews)

LoadError: MethodError: objects of type Vector{Any} are not callable
Use square brackets [] for indexing an Array.

In [65]:
Js = stack(jviews)

4×25×2 Array{Float32, 3}:
[:, :, 1] =
  3.07355f-5  -0.000129027  -0.0120963   …  0.0          1.0  0.0  0.0  0.0
 -5.75227f-5   2.45904f-5    0.00462227     0.0          0.0  1.0  0.0  0.0
  1.82142f-5  -9.86657f-5    0.00402219     0.0          0.0  0.0  1.0  0.0
 -5.62777f-5   6.15856f-5   -0.00942432     0.000444655  0.0  0.0  0.0  1.0

[:, :, 2] =
  0.0460502  -0.0985883  -0.00244748   …  0.0          1.0  0.0  0.0  0.0
 -0.0861846   0.0187893   0.000935241     0.0          0.0  1.0  0.0  0.0
  0.0272899  -0.0753897   0.000813824     0.0          0.0  0.0  1.0  0.0
 -0.0843193   0.0470571  -0.00190686      0.000284766  0.0  0.0  0.0  1.0

In [67]:
Js == J

true

In [59]:
J == Jt

true

In [34]:
@assert isapprox(J[:, :, 1], J_1, atol=.0005)
@assert isapprox(J[:, :, 2], J_2, atol=.0005)

### Laplace Approximation (not yet supported)

In [35]:
# la = Laplace(nn; likelihood=:classification)
# fit!(la, data)
# optimize_prior!(la; verbose=true, n_steps=1000)

In [36]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1))
#     push!(plt_list, plt)
# end
# plot(plt_list...)

In [37]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin)
#     push!(plt_list, plt)
# end
# plot(plt_list...)