```@meta
CurrentModule = LaplaceRedux
```
## Batched Jacobians
As applied to a multi-class classification task.

### Multi-class problem

In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m new project at `c:\Users\Andrei\LaplaceRedux.jl\src`


In [2]:
x, y = Data.toy_data_multi(2000)
X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)

2000-element Vector{Vector{Bool}}:
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 [1, 0, 0, 0]
 ‚ãÆ
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]
 [0, 0, 0, 1]

In [3]:
data = zip(x,y_train)
n_hidden = 3
D = size(X,1)
out_dim = length(unique(y))
nn = Chain(
    Dense(D, n_hidden, œÉ),
    Dense(n_hidden, out_dim)
)  
loss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)

loss (generic function with 1 method)

In [4]:
using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end

Epoch 10

‚îÇ   The input will be converted, but any earlier layers may be very slow.
‚îÇ   layer = Dense(2 => 3, œÉ)
‚îÇ   summary(x) = 2-element Vector{Float64}
‚îî @ Flux C:\Users\Andrei\.julia\packages\Flux\EHgZm\src\layers\stateless.jl:60



avg_loss(data) = 0.009490322f0


Epoch 20
avg_loss(data) = 8.6731576f-5


Epoch 30
avg_loss(data) = 7.357594f-7


Epoch 40
avg_loss(data) = 9.536743f-10


Epoch 50
avg_loss(data) = 1.1324881f-9


Epoch 60
avg_loss(data) = 2.2649762f-9


Epoch 70
avg_loss(data) = 1.3709067f-9


Epoch 80
avg_loss(data) = 8.9406954f-10


Epoch 90
avg_loss(data) = 7.152557f-10


Epoch 100
avg_loss(data) = 7.7486034f-10


In [5]:
Y = reduce(hcat, y_train)

4√ó2000 Matrix{Bool}:
 1  1  1  1  1  1  1  1  1  1  1  1  1  ‚Ä¶  0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     1  1  1  1  1  1  1  1  1  1  1  1

In [6]:
X

2√ó2000 Matrix{Float64}:
 2.15895  3.57961  4.66783  1.60082  3.00085  ‚Ä¶  -4.66207  -2.61483  -6.6308
 1.5807   5.39027  4.58806  2.82622  3.89678      5.12095   2.62973   4.58499

In [7]:
Y[:, 1], X[:, 1]

(Bool[1, 0, 0, 0], [2.1589545269914048, 1.5807007219814284])

### Jacobians, non-batched

In [8]:
x_1 = X[:, 1]
x_2 = X[:, 2]

2-element Vector{Float64}:
 3.579606460511134
 5.390268461925427

In [9]:
jgrads_1 = jacobian(() -> nn(x_1), Flux.params(nn))

Grads(...)

In [10]:
jgrads_1.grads

IdDict{Any, Any} with 4 entries:
  Float32[-38.628, -26.060‚Ä¶ => Float32[1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.‚Ä¶
  Float32[-22.4144 -36.416‚Ä¶ => Float32[0.997183 0.0 ‚Ä¶ 0.0 0.0; 0.0 0.997183 ‚Ä¶ 0‚Ä¶
  Float32[3.19916, -2.5561‚Ä¶ => Float32[-0.0629575 -0.000201889 -0.0; -0.0974679‚Ä¶
  Float32[-0.0361987 1.738‚Ä¶ => Float32[-0.135922 -0.000435868 ‚Ä¶ -0.000319126 -0‚Ä¶

In [11]:
sum(length, jgrads_1.params)

25

In [12]:
nn

Chain(
  Dense(2 => 3, œÉ),                     [90m# 9 parameters[39m
  Dense(3 => 4),                        [90m# 16 parameters[39m
) [90m                  # Total: 4 arrays, [39m25 parameters, 356 bytes.

In [13]:
sum(length, jgrads_1)

100

In [14]:
foreach(jac -> @show(size(jac)), jgrads_1)

size(jac) = (4, 6)
size(jac) = (4, 3)
size(jac) = (4, 12)
size(jac) = (4, 4)


In [15]:
fieldnames(Chain)

(:layers,)

In [16]:
jgrads_2 = jacobian(() -> nn(x_2), Flux.params(nn))

Grads(...)

In [17]:
J_2 = reduce(hcat, jgrads_2)
J_1 = reduce(hcat, jgrads_1)

4√ó25 Matrix{Float32}:
 -0.135922   -0.000435868  -0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  -0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  -0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  -0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [18]:
th = jgrads_1.params

Params([Float32[-0.03619867 1.7386854; -0.84552133 -4.8846498; 6.613449 -0.03254935], Float32[3.1991556, -2.5561907, 3.662587], Float32[-22.41437 -36.416393 -2.334856; -34.700882 -8.658449 -41.25319; -42.456375 -16.425999 -2.4271576; -15.21544 -29.683495 -40.88694], Float32[-38.62803, -26.060854, -37.017788, -26.423927]])

In [19]:
jgrads_1[th[1]]

4√ó6 Matrix{Float32}:
 -0.135922   -0.000435868  -0.0  -0.099517   -0.000319126  -0.0
 -0.210429   -0.000103633  -0.0  -0.154068   -7.58761f-5   -0.0
 -0.257459   -0.000196603  -0.0  -0.188501   -0.000143945  -0.0
 -0.0922676  -0.000355282  -0.0  -0.0675547  -0.000260124  -0.0

In [20]:
jgrads_1[th[length(th) - 1]]

4√ó12 Matrix{Float32}:
 0.997183  0.0       0.0       0.0       ‚Ä¶  0.0         1.0  0.0  0.0  0.0
 0.0       0.997183  0.0       0.0          0.0         0.0  1.0  0.0  0.0
 0.0       0.0       0.997183  0.0          0.0         0.0  0.0  1.0  0.0
 0.0       0.0       0.0       0.997183     5.54392f-6  0.0  0.0  0.0  1.0

In [21]:
jgrads_1[th[length(th)]]

4√ó4 Matrix{Float32}:
 1.0  0.0  0.0  0.0
 0.0  1.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  0.0  1.0

In [22]:
J_1

4√ó25 Matrix{Float32}:
 -0.135922   -0.000435868  -0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  -0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  -0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  -0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [23]:
J_2

4√ó25 Matrix{Float32}:
 -0.000315634  -1.802f-12    -0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  -0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   -0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  -0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

### Jacobians, batched

In [24]:
dataloader_2 = DataLoader((X, Y), batchsize=2)

1000-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Bool}}, batchsize=2)
  with first element:
  (2√ó2 Matrix{Float64}, 4√ó2 Matrix{Bool},)

In [25]:
x_b, y_b = popfirst!(Iterators.Stateful(dataloader_2))

([2.1589545269914048 3.579606460511134; 1.5807007219814284 5.390268461925427], Bool[1 1; 0 0; 0 0; 0 0])

In [26]:
x_b

2√ó2 Matrix{Float64}:
 2.15895  3.57961
 1.5807   5.39027

In [27]:
x_1

2-element Vector{Float64}:
 2.1589545269914048
 1.5807007219814284

In [28]:
jgrads = jgrads_b = jacobian(() -> nn(x_b), Flux.params(nn))

Grads(...)

In [29]:
foreach(jac -> @show(size(jac)), jgrads_b)

size(jac) = (8, 6)
size(jac) = (8, 3)
size(jac) = (8, 12)
size(jac) = (8, 4)


In [30]:
jtmp = reduce(hcat, jgrads)

8√ó25 Matrix{Float32}:
 -0.135922     -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429     -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459     -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676    -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0
 -0.000315634  -1.802f-12    0.0     0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0  ‚Ä¶  1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [31]:
jtmp_1 = @view jtmp[1:4, :]
jtmp_2 = @view jtmp[5:8, :]
display(jtmp_1)
display(jtmp_2)
# use views to avoid allocations

4√ó25 view(::Matrix{Float32}, 1:4, :) with eltype Float32:
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

4√ó25 view(::Matrix{Float32}, 5:8, :) with eltype Float32:
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [32]:
J = cat(jtmp_1, jtmp_2, dims=3)
# this is expensive, but there is no way around it

4√ó25√ó2 Array{Float32, 3}:
[:, :, 1] =
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 2] =
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [33]:
stack([jtmp_1, jtmp_2])

4√ó25√ó2 Array{Float32, 3}:
[:, :, 1] =
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 2] =
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [34]:
batch_size = 2
outdim = 4
for batch_index in 1:batch_size
    i = (batch_index - 1) * batch_size + 1
    j = i + outdim - 1
    @show i
    @show j
end

i = 1
j = 4
i = 3
j = 6


In [35]:
for i in 1:outdim:(batch_size * outdim)
    @show i
    @show i + outdim - 1
end

i = 1
(i + outdim) - 1 = 4
i = 5
(i + outdim) - 1 = 8


In [36]:
#jviews = [(batch_start, batch_start + outdim - 1) for batch_start in 1 : outdim : batch_size * outdim]
jviews = [jtmp[batch_start : (batch_start + outdim - 1), :] for batch_start in 1 : outdim : batch_size * outdim]
#display(jviews)
# jviews = [@view J[batch_start:batch_start + outdim - 1, :] for batch_start in 1:outdim:batch_size]

2-element Vector{Matrix{Float32}}:
 [-0.13592248 -0.0004358684 ‚Ä¶ 0.0 0.0; -0.21042885 -0.000103633116 ‚Ä¶ 0.0 0.0; -0.25745878 -0.00019660302 ‚Ä¶ 1.0 0.0; -0.09226761 -0.00035528222 ‚Ä¶ 0.0 1.0]
 [-0.00031563427 -1.8019975f-12 ‚Ä¶ 0.0 0.0; -0.0004886503 -4.284472f-13 ‚Ä¶ 0.0 0.0; -0.0005978615 -8.1280994f-13 ‚Ä¶ 1.0 0.0; -0.00021426051 -1.4688323f-12 ‚Ä¶ 0.0 1.0]

In [37]:
display(jviews[1])
display(jviews[2])

4√ó25 Matrix{Float32}:
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

4√ó25 Matrix{Float32}:
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [38]:
cat(jviews, dims=3)

2√ó1√ó1 Array{Matrix{Float32}, 3}:
[:, :, 1] =
 [-0.13592248 -0.0004358684 ‚Ä¶ 0.0 0.0; -0.21042885 -0.000103633116 ‚Ä¶ 0.0 0.0; -0.25745878 -0.00019660302 ‚Ä¶ 1.0 0.0; -0.09226761 -0.00035528222 ‚Ä¶ 0.0 1.0]
 [-0.00031563427 -1.8019975f-12 ‚Ä¶ 0.0 0.0; -0.0004886503 -4.284472f-13 ‚Ä¶ 0.0 0.0; -0.0005978615 -8.1280994f-13 ‚Ä¶ 1.0 0.0; -0.00021426051 -1.4688323f-12 ‚Ä¶ 0.0 1.0]

In [39]:
Jt = reduce((a, b) -> cat(a, b, dims=3), jviews)

4√ó25√ó2 Array{Float32, 3}:
[:, :, 1] =
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 2] =
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [40]:
reduce(cat(dims=3), jviews)

MethodError: MethodError: objects of type Vector{Any} are not callable
Use square brackets [] for indexing an Array.

In [41]:
Js = stack(jviews)

4√ó25√ó2 Array{Float32, 3}:
[:, :, 1] =
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 2] =
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

In [42]:
Js == J

true

In [43]:
J == Jt

true

In [44]:
@assert isapprox(J[:, :, 1], J_1, atol=.0005)
@assert isapprox(J[:, :, 2], J_2, atol=.0005)

### Laplace Approximation (not yet supported)

In [45]:
# la = Laplace(nn; likelihood=:classification)
# fit!(la, data)
# optimize_prior!(la; verbose=true, n_steps=1000)

In [46]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1))
#     push!(plt_list, plt)
# end
# plot(plt_list...)

In [47]:
#| output: true

# _labels = sort(unique(y))
# plt_list = []
# for target in _labels
#     plt = plot(la, X, y; target=target, clim=(0,1), link_approx=:plugin)
#     push!(plt_list, plt)
# end
# plot(plt_list...)

# Checking the implementation of Jacobians batched

In [48]:
function jacobians_batched(nn, X::AbstractArray)
    # Output:
    yÃÇ = nn(X)
    batch_size = size(X)[end]
    K = get_outdim(nn)
    # Jacobian:
    grads = jacobian(() -> nn(X), Flux.params(nn))
    grads_joint = reduce(hcat, grads)
    views = [grads_joint[batch_start : (batch_start + K - 1), :] for batch_start in 1 : K : batch_size * K]
    ùêâ = stack(views)
    return ùêâ, yÃÇ
end

 function get_outdim(model::Chain)
     return [size(p) for p in Flux.params(model)][end][1]
 end

function jacobians(nn, X::AbstractArray)
    # Output:
    yÃÇ = nn(X)
    # Jacobian:
    ùêâ = jacobian(() -> nn(X), Flux.params(nn))
    ùêâ = permutedims(reduce(hcat, [ùêâ[Œ∏] for Œ∏ in Flux.params(nn)]))
    return ùêâ, yÃÇ
end

jacobians (generic function with 1 method)

In [49]:
x[1]

2-element Vector{Float64}:
 2.1589545269914048
 1.5807007219814284

In [50]:
J_1, yhat_1 = jacobians(nn, x[1])
display(J_1)
display(yhat_1)

25√ó4 Matrix{Float32}:
 -0.135922     -0.210429     -0.257459     -0.0922676
 -0.000435868  -0.000103633  -0.000196603  -0.000355282
 -0.0          -0.0          -0.0          -0.0
 -0.099517     -0.154068     -0.188501     -0.0675547
 -0.000319126  -7.58761f-5   -0.000143945  -0.000260124
 -0.0          -0.0          -0.0          -0.0
 -0.0629575    -0.0974679    -0.119252     -0.0427372
 -0.000201889  -4.80015f-5   -9.1064f-5    -0.000164562
 -0.0          -0.0          -0.0          -0.0
  0.997183      0.0           0.0           0.0
  ‚ãÆ                                        
  0.0           0.0           0.0           5.54392f-6
  1.0           0.0           0.0           0.0
  0.0           1.0           0.0           0.0
  0.0           0.0           1.0           0.0
  0.0           0.0           0.0           1.0
  1.0           0.0           0.0           0.0
  0.0           1.0           0.0           0.0
  0.0           0.0           1.0           0.0
  0.0           0.

4-element Vector{Float32}:
  -63.314323
 -101.91723
  -81.78183
  -82.48361

In [51]:
J_b, yhat_b = jacobians_batched(nn, X)
display(J_b)
display(yhat_b)

4√ó25√ó2000 Array{Float32, 3}:
[:, :, 1] =
 -0.135922   -0.000435868  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.210429   -0.000103633  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.257459   -0.000196603  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.0922676  -0.000355282  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 2] =
 -0.000315634  -1.802f-12    0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00048865   -4.28447f-13  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.000597861  -8.1281f-13   0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.000214261  -1.46883f-12  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

[:, :, 3] =
 -0.00173364  -4.71217f-11  0.0  ‚Ä¶  0.0  0.0  0.0  1.0  0.0  0.0  0.0
 -0.00268394  -1.12038f-11  0.0     1.0  0.0  0.0  0.0  1.0  0.0  0.0
 -0.00328379  -2.12547f-11  0.0     0.0  1.0  0.0  0.0  0.0  1.0  0.0
 -0.00117684  -3.84095f-11  0.0     0.0  0.0  1.0  0.0  0.0  0.0  1.0

;;; ‚Ä¶ 

[:, :, 1998] =
 0.000485823  9.29418f-9  1.46142f-11  ‚Ä¶  0.0          1.

4√ó2000 Matrix{Float32}:
  -63.3143   -63.3772   -63.3769   -63.3701  ‚Ä¶  -61.0423  -61.0339  -61.0422
 -101.917   -102.015   -102.014   -102.004      -60.7616  -60.7485  -60.7614
  -81.7818   -81.9012   -81.9006   -81.8878     -79.474   -79.4579  -79.4737
  -82.4836   -82.5262   -82.5261   -82.5215     -41.6393  -41.6336  -41.6392

In [52]:
function method_batched()
    jacobians_batched(nn, X)
end

function method_unbatched()
    for x in eachcol(X)
        jacobians(nn, x)
    end
end

method_unbatched (generic function with 1 method)

In [53]:
@benchmark method_batched()

BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m ‚Ä¶ [35mmax[39m[90m):  [39m[36m[1m504.703 ms[22m[39m ‚Ä¶ [35m564.012 ms[39m  [90m‚îä[39m GC [90m([39mmin ‚Ä¶ max[90m): [39m6.69% ‚Ä¶ 7.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m530.400 ms               [22m[39m[90m‚îä[39m GC [90m([39mmedian[90m):    [39m7.01%
 Time  [90m([39m[32m[1mmean[22m[39m ¬± [32mœÉ[39m[90m):   [39m[32m[1m529.960 ms[22m[39m ¬± [32m 19.483 ms[39m  [90m‚îä[39m GC [90m([39mmean ¬± œÉ[90m):  [39m6.93% ¬± 0.38%

  [39m‚ñà[39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÅ[34m [39m[39m‚ñÅ[39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÅ[39m 

In [54]:
@benchmark method_unbatched()

BenchmarkTools.Trial: 78 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m ‚Ä¶ [35mmax[39m[90m):  [39m[36m[1m57.296 ms[22m[39m ‚Ä¶ [35m79.891 ms[39m  [90m‚îä[39m GC [90m([39mmin ‚Ä¶ max[90m): [39m0.00% ‚Ä¶ 6.25%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m65.176 ms              [22m[39m[90m‚îä[39m GC [90m([39mmedian[90m):    [39m5.87%
 Time  [90m([39m[32m[1mmean[22m[39m ¬± [32mœÉ[39m[90m):   [39m[32m[1m64.810 ms[22m[39m ¬± [32m 3.934 ms[39m  [90m‚îä[39m GC [90m([39mmean ¬± œÉ[90m):  [39m5.84% ¬± 1.54%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÉ[39m [39m [39m‚ñÉ[39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[34m [39m[39m [39m [39m [39m‚ñà[39m [39m [39m [39m‚ñÉ[39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m‚ñá[

In [55]:
function jacobianStacking_Loop()
    function jacobians_batched_aux(nn, X::AbstractArray)
        # Output:
        #yÃÇ = nn(X)
        batch_size = size(X)[end]
        K = get_outdim(nn)
        # Jacobian:
        grads = jacobian(() -> nn(X), Flux.params(nn))
        #grads_joint = reduce(hcat, grads)
        #views = [grads_joint[batch_start : (batch_start + K - 1), :] for batch_start in 1 : K : batch_size * K]
        #ùêâ = stack(views)
        return grads
    end

    for x in eachcol(X)
        jacobians_batched_aux(nn,x)
    end
    
end

jacobianStacking_Loop (generic function with 1 method)

In [56]:
function jacobianStacking()
    function jacobians_batched_aux(nn, X::AbstractArray)
        # Output:
        #yÃÇ = nn(X)
        batch_size = size(X)[end]
        K = get_outdim(nn)
        # Jacobian:
        grads = jacobian(() -> nn(X), Flux.params(nn))
        #grads_joint = reduce(hcat, grads)
        #views = [grads_joint[batch_start : (batch_start + K - 1), :] for batch_start in 1 : K : batch_size * K]
        #ùêâ = stack(views)
        return grads
    end
    
    jacobians_batched_aux(nn,X)
    
end

jacobianStacking (generic function with 1 method)

In [57]:
@benchmark jacobianStacking_Loop()

BenchmarkTools.Trial: 75 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m ‚Ä¶ [35mmax[39m[90m):  [39m[36m[1m59.999 ms[22m[39m ‚Ä¶ [35m86.614 ms[39m  [90m‚îä[39m GC [90m([39mmin ‚Ä¶ max[90m): [39m0.00% ‚Ä¶ 4.59%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m66.474 ms              [22m[39m[90m‚îä[39m GC [90m([39mmedian[90m):    [39m6.08%
 Time  [90m([39m[32m[1mmean[22m[39m ¬± [32mœÉ[39m[90m):   [39m[32m[1m67.433 ms[22m[39m ¬± [32m 4.543 ms[39m  [90m‚îä[39m GC [90m([39mmean ¬± œÉ[90m):  [39m5.73% ¬± 2.05%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m‚ñà[39m [39m [39m [39m [39m [39m [39m‚ñÉ[39m‚ñÉ[39m [34m [39m[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m‚ñÅ[39m [39m‚ñÅ[39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m‚ñÅ[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m‚

In [58]:
@benchmark jacobianStacking()

BenchmarkTools.Trial: 10 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m ‚Ä¶ [35mmax[39m[90m):  [39m[36m[1m506.045 ms[22m[39m ‚Ä¶ [35m553.887 ms[39m  [90m‚îä[39m GC [90m([39mmin ‚Ä¶ max[90m): [39m6.81% ‚Ä¶ 7.78%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m540.004 ms               [22m[39m[90m‚îä[39m GC [90m([39mmedian[90m):    [39m7.18%
 Time  [90m([39m[32m[1mmean[22m[39m ¬± [32mœÉ[39m[90m):   [39m[32m[1m534.986 ms[22m[39m ¬± [32m 15.579 ms[39m  [90m‚îä[39m GC [90m([39mmean ¬± œÉ[90m):  [39m7.25% ¬± 0.40%

  [39m‚ñà[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñà[39m [39m [39m [39m‚ñà[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m‚ñà[39m [39m [32m [39m[39m [39m‚ñà[34m [39m[39m [39m [39m [39m [39m [39m‚ñà[39m [39m‚ñà[39m [39m [39m‚ñà[39m [39m‚ñà[39m [39m [39m [39m [39m [39m [39m‚ñ