## Batched Gradients

In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m new project at `c:\Users\Andrei\LaplaceRedux.jl\src`


In [2]:
# Setup some dummy data, like in tests

n = 128 * 128
data_dict = Dict()
bsize = 2

x, y = LaplaceRedux.Data.toy_data_regression(n)
xs = [[x] for x in x]
X, Y = reduce(hcat, x), reduce(hcat, y)

dataloader = DataLoader((X, Y), batchsize=bsize)
data = zip(xs, y)
data_dict[:regression] = Dict(
    :data => data,
    :X => X,
    :y => y,
    :outdim => 1,
    :loss_fun => :mse,
    :likelihood => :regression,
)

Dict{Symbol, Any} with 6 entries:
  :loss_fun   => :mse
  :y          => [0.0154858, 0.0650503, -0.466806, 1.06432, 0.80842, 0.369409, …
  :likelihood => :regression
  :X          => [6.89082 0.26044 … 6.6274 6.56329]
  :outdim     => 1
  :data       => zip([[6.89082], [0.26044], [3.53151], [1.10064], [7.44497], [6…

In [3]:
n

16384

In [4]:
# Train a NN model

val = data_dict[:regression]

# Unpack:
data = val[:data]
X = val[:X]
y = val[:y]
outdim = val[:outdim]
loss_fun = val[:loss_fun]
likelihood = val[:likelihood]

# Neural network:
n_hidden = 32
D = size(X, 1)
nn = Chain(Dense(D, n_hidden, σ), Dense(n_hidden, outdim))
λ = 0.01
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1 / 2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = getfield(Flux.Losses, loss_fun)(nn(x), y) + weight_regularization()


opt = Adam()
epochs = 200
avg_loss(data) = mean(map(d -> loss(d[1], d[2]), data))
show_every = epochs / 10

for epoch in 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end


Epoch 20

│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(1 => 32, σ)
│   summary(x) = 1-element Vector{Float64}
└ @ Flux C:\Users\Andrei\.julia\packages\Flux\EHgZm\src\layers\stateless.jl:60



avg_loss(data) = 0.09968076230125669


Epoch 40
avg_loss(data) = 0.09893657241923182


Epoch 60
avg_loss(data) = 0.09873215264137986


Epoch 80
avg_loss(data) = 0.098570424924625


Epoch 100
avg_loss(data) = 0.09843579397679747


Epoch 120
avg_loss(data) = 0.09837301508460453


Epoch 140
avg_loss(data) = 0.0984010883063332


Epoch 160
avg_loss(data) = 0.09845822254666581


Epoch 180
avg_loss(data) = 0.09849405293357148


Epoch 200
avg_loss(data) = 0.09848949786115607


In [5]:
model = nn

Chain(
  Dense(1 => 32, σ),                    [90m# 64 parameters[39m
  Dense(32 => 1),                       [90m# 33 parameters[39m
) [90m                  # Total: 4 arrays, [39m97 parameters, 644 bytes.

In [6]:
# Loss fun for regression
loss_fn(x, ytrue; agg=sum) = Flux.Losses.mse(model(x), ytrue, agg=agg)

loss_fn (generic function with 1 method)

In [7]:
dataloader

8192-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}}, batchsize=2)
  with first element:
  (1×2 Matrix{Float64}, 1×2 Matrix{Float64},)

In [8]:
# Take one datapoint of the dataset
x_1 = xs[1]
y_1 = y[1]

0.01548584004498288

In [9]:
grads_1 = gradient(() -> loss_fn(x_1, y_1), Flux.params(model))
grads = grads_1

Grads(...)

In [10]:
grads.grads

IdDict{Any, Any} with 6 entries:
  Float32[-1.86301; 0.5380… => Float32[-8.32705f-5; 1.67715; … ; -0.880106; -0.…
  Float32[1.07408, -2.6962… => Float32[-1.20843f-5, 0.243389, -0.0682107, 0.219…
  Float32[-1.69082 1.35515… => Float32[7.14703f-6 0.67339 … 0.144992 0.110392]
  :(Main.x_1)               => Float32[0.709215]
  Float32[-0.437689]        => Float32[0.91832]
  :(Main.y_1)               => -0.91832

In [11]:
grads.params

Params([Float32[-1.8630121; 0.53805554; … ; -0.40431502; -0.39164895;;], Float32[1.0740811, -2.696294, 0.56854427, -2.529265, 0.028561326, -1.7896606, -1.3546953, -1.3722371, -1.0168822, -2.1914685  …  -0.8684253, -1.4725244, -2.3014822, -0.34306574, -1.7645233, 0.17289999, 5.812564, -0.34634212, 1.1120374, 0.7083484], Float32[-1.6908232 1.355146 … -1.0460447 -0.84596074], Float32[-0.4376891]])

In [12]:
length(grads.params)

4

In [13]:
sum(length, grads.params) # 97 params total

97

In [14]:
fieldnames(Params)

(:order, :params)

In [15]:
grads.params.order

Zygote.Buffer{Any, Vector{Any}}(Any[Float32[-1.8630121; 0.53805554; … ; -0.40431502; -0.39164895;;], Float32[1.0740811, -2.696294, 0.56854427, -2.529265, 0.028561326, -1.7896606, -1.3546953, -1.3722371, -1.0168822, -2.1914685  …  -0.8684253, -1.4725244, -2.3014822, -0.34306574, -1.7645233, 0.17289999, 5.812564, -0.34634212, 1.1120374, 0.7083484], Float32[-1.6908232 1.355146 … -1.0460447 -0.84596074], Float32[-0.4376891]], false)

In [16]:
th_1 = grads.params[1]

32×1 Matrix{Float32}:
 -1.8630121
  0.53805554
 -0.39032742
  0.5150949
  0.5423486
  0.41597626
  0.35198945
  0.35496658
  0.26668763
  0.46952686
  ⋮
  0.3709763
  1.1276038
 -0.45483503
  0.4125146
 -0.39623737
 -1.611097
 -0.4577712
 -0.40431502
 -0.39164895

In [17]:
# This is the gradient of the loss function value at (x_1, y_1) for th_1
# Size 32
grads.grads[th_1]

32×1 Matrix{Float32}:
 -8.3270534f-5
  1.6771513
 -0.4700276
  1.513821
 -0.02774583
  0.85494155
  0.5385639
  0.55021006
  0.32691005
  1.1954195
  ⋮
  0.6180835
 -0.040892553
 -0.03335479
  0.8333848
 -0.24627012
  0.12060051
 -0.030416934
 -0.8801061
 -0.5661583

In [18]:
# The same for x_2, y_2
x_2 = xs[2]
y_2 = y[2]
grads_2 = gradient(() -> loss_fn(x_2, y_2), Flux.params(model))
grads_2.grads[th_1]

32×1 Matrix{Float32}:
 -0.038889773
  0.009075242
 -0.0184421
  0.009423961
 -0.0049317447
  0.009449902
  0.007710292
  0.007814531
  0.0049118726
  0.009786163
  ⋮
  0.008342939
 -0.016162949
 -0.004309263
  0.009371281
 -0.0145887155
  0.0017233177
 -0.004013892
 -0.020545674
 -0.01935912

In [19]:
# Now take one batch combining the two
x_b, y_b = popfirst!(Iterators.Stateful(dataloader))

([6.890818797403573 0.26043975403126574], [0.01548584004498288 0.0650502542475323])

In [20]:
@show x_1
@show x_2
@show y_1
@show y_2

x_1 = [6.890818797403573]
x_2 = [0.26043975403126574]
y_1 = 0.01548584004498288
y_2 = 0.0650502542475323


0.0650502542475323

In [21]:
# Take the *Jacobian* to obtain the gradients for the batch parts
grads_b = jacobian(() -> loss_fn(x_b, y_b, agg=identity), Flux.params(model))

Grads(...)

In [22]:
grads_b.grads[th_1]

2×32 Matrix{Float64}:
 -8.32704e-5  1.67715     -0.470027   …  -0.0304169  -0.880105   -0.566157
 -0.0388899   0.00907526  -0.0184421     -0.0040139  -0.0205457  -0.0193592

In [23]:
th = grads_b.params

Params([Float32[-1.8630121; 0.53805554; … ; -0.40431502; -0.39164895;;], Float32[1.0740811, -2.696294, 0.56854427, -2.529265, 0.028561326, -1.7896606, -1.3546953, -1.3722371, -1.0168822, -2.1914685  …  -0.8684253, -1.4725244, -2.3014822, -0.34306574, -1.7645233, 0.17289999, 5.812564, -0.34634212, 1.1120374, 0.7083484], Float32[-1.6908232 1.355146 … -1.0460447 -0.84596074], Float32[-0.4376891]])

In [24]:
grads_b.params == grads_1.params == grads_2.params

true

In [25]:
grads_b[th_1]

2×32 Matrix{Float64}:
 -8.32704e-5  1.67715     -0.470027   …  -0.0304169  -0.880105   -0.566157
 -0.0388899   0.00907526  -0.0184421     -0.0040139  -0.0205457  -0.0193592

In [26]:
grads

Grads(...)

In [27]:
grads_1[th_1]'

1×32 adjoint(::Matrix{Float32}) with eltype Float32:
 -8.32705f-5  1.67715  -0.470028  …  -0.0304169  -0.880106  -0.566158

In [28]:
grads_2[th_1]'

1×32 adjoint(::Matrix{Float32}) with eltype Float32:
 -0.0388898  0.00907524  -0.0184421  …  -0.00401389  -0.0205457  -0.0193591

In [29]:
# grads_1 .* transpose.(grads_1)

In [30]:
sum(length, grads_1)

97

In [31]:
sum(length, grads_1.params)

97

In [32]:
sum(length, grads_b)

194

In [33]:
# We would like to obtain g as a 97x2 matrix
# containing as the first column the gradient for loss_fn(x_1, y_1) wrt theta (of size 97)
# and as the second column, one for loss_fn(x_2, y_2)
g = permutedims(reduce(hcat, grads_b))

97×2 Matrix{Float64}:
 -8.32704e-5  -0.0388899
  1.67715      0.00907526
 -0.470027    -0.0184421
  1.51382      0.00942398
 -0.0277458   -0.00493176
  0.85494      0.00944993
  0.538563     0.00771031
  0.550209     0.00781455
  0.326909     0.00491189
  1.19542      0.00978619
  ⋮           
  0.914462     0.0455521
  0.0275188    0.148765
  0.685179     0.0616235
  0.0660494    0.199091
  0.00461032   0.383036
  0.0268981    0.148396
  0.144992     0.281804
  0.110392     0.248993
  0.918319     0.384778

In [34]:
function gradient_helper(model, x, y)
    grads = jacobian(() -> loss_fn(x, y, agg=identity), Flux.params(model))
    g = permutedims(reduce(hcat, grads))
end

gradient_helper (generic function with 1 method)

In [35]:
# sanity check
grads_b[th[1]]

2×32 Matrix{Float64}:
 -8.32704e-5  1.67715     -0.470027   …  -0.0304169  -0.880105   -0.566157
 -0.0388899   0.00907526  -0.0184421     -0.0040139  -0.0205457  -0.0193592

In [36]:
grads_b[th[length(th)-1]]

2×32 Matrix{Float64}:
 7.14702e-6  0.673388   0.0983151  …  0.0268981  0.144992  0.110392
 0.247454    0.0277092  0.236503      0.148396   0.281804  0.248993

In [37]:
grads_b[th[length(th)]] # sanity: checked

2×1 Matrix{Float64}:
 0.9183186292648315
 0.3847779333591461

In [38]:
vg = g * transpose(g)

97×97 Matrix{Float64}:
  0.00151243   -0.000492593  …  -0.0109714  -0.00969248  -0.0150404
 -0.000492593   2.81291          0.24573     0.187403     1.54365
  0.000756352  -0.788472        -0.0733471  -0.0564792   -0.43873
 -0.000492554   2.53898          0.222147    0.16946      1.39379
  0.000194106  -0.0465786       -0.0054127  -0.00429088  -0.0273771
 -0.000438698   1.43395      …   0.126622    0.0967315    0.788744
 -0.000344699   0.90332          0.0802601   0.0613729    0.497539
 -0.000349723   0.922853         0.081978    0.0626845    0.508274
 -0.000218244   0.54832          0.0487834   0.0373112    0.302097
 -0.000480126   2.00498          0.176084    0.134401     1.10154
  ⋮                          ⋱               ⋮           
 -0.00184766    1.5341           0.145426    0.112291     0.857295
 -0.00578773    0.0475031        0.0459125   0.0400791    0.0825123
 -0.00245358    1.14971      …   0.116711    0.090982     0.652924
 -0.00774814    0.112581         0.0656814   0.05

In [39]:
foreach(x -> @show(size(x)), grads)

size(x) = (32, 1)
size(x) = (32,)
size(x) = (1, 32)
size(x) = (1,)


In [40]:
g_1 = reduce(vcat, [vec(grads_1[th]) for th in grads_1.params])
g_2 = reduce(vcat, [vec(grads_2[th]) for th in grads_2.params])
# let v(g) denote gg'
vg_1 = g_1 * transpose(g_1)
vg_2 = g_2 * transpose(g_2)

97×97 Matrix{Float32}:
  0.00151241   -0.000352934   0.000717209  …  -0.00968324  -0.0149639
 -0.000352934   8.236f-5     -0.000167367      0.00225966   0.00349194
  0.000717209  -0.000167367   0.000340111     -0.00459193  -0.0070961
 -0.000366496   8.55247f-5   -0.000173798      0.00234649   0.00362612
  0.000191794  -4.47568f-5    9.09517f-5      -0.00122796  -0.00189762
 -0.000367505   8.57601f-5   -0.000174276  …   0.00235295   0.0036361
 -0.000299851   6.99728f-5   -0.000142194      0.0019198    0.00296674
 -0.000303905   7.09188f-5   -0.000144116      0.00194576   0.00300685
 -0.000191022   4.45764f-5   -9.05852f-5       0.00122302   0.00188998
 -0.000380582   8.88118f-5   -0.000180477      0.00243668   0.00376549
  ⋮                                        ⋱   ⋮           
 -0.00177151    0.000413395  -0.000840074      0.0113421    0.0175274
 -0.00578541    0.00135007   -0.00274353       0.0370411    0.0572411
 -0.00239652    0.000559246  -0.00113646   …   0.0153437    0.0237112


In [41]:
isapprox(vg_1 + vg_2, vg, atol=.0005) 
# This is not it.
# We want a 97x97x2 matrix

true

In [42]:
# Reshape g as 97x1x2 since MLUtils expects 3-d arrays
gm = reshape(g, size(g, 1), 1, size(g, 2))

97×1×2 Array{Float64, 3}:
[:, :, 1] =
 -8.327038813149557e-5
  1.677148461341858
 -0.47002679109573364
  1.513818383216858
 -0.02774578519165516
  0.8549401164054871
  0.5385630130767822
  0.5502091646194458
  0.32690948247909546
  1.1954176425933838
  ⋮
  0.9144623279571533
  0.027518773451447487
  0.6851785778999329
  0.06604937463998795
  0.004610323812812567
  0.026898102834820747
  0.14499185979366302
  0.1103919968008995
  0.9183186292648315

[:, :, 2] =
 -0.038889866322278976
  0.009075264446437359
 -0.018442144617438316
  0.009423984214663506
 -0.0049317567609250546
  0.00944992620497942
  0.007710311561822891
  0.007814550772309303
  0.004911885131150484
  0.009786187671124935
  ⋮
  0.045552100986242294
  0.14876462519168854
  0.06162345036864281
  0.19909138977527618
  0.38303571939468384
  0.1483960598707199
  0.2818041443824768
  0.24899256229400635
  0.3847779333591461

In [43]:
vgm = batched_mul(gm, batched_transpose(gm))

97×97×2 Array{Float64, 3}:
[:, :, 1] =
  6.93396e-9   -0.000139657   3.91393e-5  …  -9.19238e-6   -7.64687e-5
 -0.000139657   2.81283      -0.788305        0.185144      1.54016
  3.91393e-5   -0.788305      0.220925       -0.0518872    -0.431634
 -0.000126056   2.5389       -0.711535        0.167113      1.39017
  2.3104e-6    -0.0465338     0.0130413      -0.00306291   -0.0254795
 -7.11912e-5    1.43386      -0.401845    …   0.0943785     0.785107
 -4.48464e-5    0.90325      -0.253139        0.059453      0.494572
 -4.58161e-5    0.922782     -0.258613        0.0607387     0.505267
 -2.72219e-5    0.548276     -0.153656        0.0360882     0.300207
 -9.95429e-5    2.00489      -0.561878        0.131965      1.09777
  ⋮                                       ⋱   ⋮            
 -7.61476e-5    1.53369      -0.429822        0.100949      0.839768
 -2.2915e-6     0.0461531    -0.0129346       0.00303785    0.025271
 -5.70551e-5    1.14915      -0.322052    …   0.0756382     0.629212
 -5.

In [44]:
@assert isapprox(vgm[:,:,1], vg_1, atol=.0005)
@assert isapprox(vgm[:,:,2], vg_2, atol=.0005)

In [45]:
vg_2

97×97 Matrix{Float32}:
  0.00151241   -0.000352934   0.000717209  …  -0.00968324  -0.0149639
 -0.000352934   8.236f-5     -0.000167367      0.00225966   0.00349194
  0.000717209  -0.000167367   0.000340111     -0.00459193  -0.0070961
 -0.000366496   8.55247f-5   -0.000173798      0.00234649   0.00362612
  0.000191794  -4.47568f-5    9.09517f-5      -0.00122796  -0.00189762
 -0.000367505   8.57601f-5   -0.000174276  …   0.00235295   0.0036361
 -0.000299851   6.99728f-5   -0.000142194      0.0019198    0.00296674
 -0.000303905   7.09188f-5   -0.000144116      0.00194576   0.00300685
 -0.000191022   4.45764f-5   -9.05852f-5       0.00122302   0.00188998
 -0.000380582   8.88118f-5   -0.000180477      0.00243668   0.00376549
  ⋮                                        ⋱   ⋮           
 -0.00177151    0.000413395  -0.000840074      0.0113421    0.0175274
 -0.00578541    0.00135007   -0.00274353       0.0370411    0.0572411
 -0.00239652    0.000559246  -0.00113646   …   0.0153437    0.0237112


In [46]:
g

97×2 Matrix{Float64}:
 -8.32704e-5  -0.0388899
  1.67715      0.00907526
 -0.470027    -0.0184421
  1.51382      0.00942398
 -0.0277458   -0.00493176
  0.85494      0.00944993
  0.538563     0.00771031
  0.550209     0.00781455
  0.326909     0.00491189
  1.19542      0.00978619
  ⋮           
  0.914462     0.0455521
  0.0275188    0.148765
  0.685179     0.0616235
  0.0660494    0.199091
  0.00461032   0.383036
  0.0268981    0.148396
  0.144992     0.281804
  0.110392     0.248993
  0.918319     0.384778

In [47]:
# Now try einsums
using Tullio

In [48]:
@tullio H[i, j, b] := g[i, b] * g[j, b]

97×97×2 Array{Float64, 3}:
[:, :, 1] =
  6.93396e-9   -0.000139657   3.91393e-5  …  -9.19238e-6   -7.64687e-5
 -0.000139657   2.81283      -0.788305        0.185144      1.54016
  3.91393e-5   -0.788305      0.220925       -0.0518872    -0.431634
 -0.000126056   2.5389       -0.711535        0.167113      1.39017
  2.3104e-6    -0.0465338     0.0130413      -0.00306291   -0.0254795
 -7.11912e-5    1.43386      -0.401845    …   0.0943785     0.785107
 -4.48464e-5    0.90325      -0.253139        0.059453      0.494572
 -4.58161e-5    0.922782     -0.258613        0.0607387     0.505267
 -2.72219e-5    0.548276     -0.153656        0.0360882     0.300207
 -9.95429e-5    2.00489      -0.561878        0.131965      1.09777
  ⋮                                       ⋱   ⋮            
 -7.61476e-5    1.53369      -0.429822        0.100949      0.839768
 -2.2915e-6     0.0461531    -0.0129346       0.00303785    0.025271
 -5.70551e-5    1.14915      -0.322052    …   0.0756382     0.629212
 -5.

In [49]:
@assert isapprox(vgm, H, atol=eps())

In [50]:
eps()

2.220446049250313e-16

## Benchmarks for batched computation of H = 𝐠 * 𝐠'

In [51]:
using BenchmarkTools

In [52]:
# Method 1: reshape + MLUtils
function method_1(g)
    gm = reshape(g, size(g, 1), 1, size(g, 2))
    vgm = batched_mul(gm, batched_transpose(gm))
end

method_1 (generic function with 1 method)

In [53]:
# Method 2: Tullio Einstein summation
function method_2(g)
    @tullio H[i, j, b] := g[i, b] * g[j, b]
end

method_2 (generic function with 1 method)

In [54]:
@benchmark method_1($g)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m13.100 μs[22m[39m … [35m  6.498 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 97.79%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m22.400 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m53.724 μs[22m[39m ± [32m172.483 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m16.11% ±  5.39%

  [39m█[39m█[34m▆[39m[39m▄[39m▃[39m▃[39m▂[39m▂[39m▁[39m▁[32m▂[39m[39m▃[39m▃[39m▆[39m▅[39m▄[39m▃[39m▃[39m▃[39m▃[39m▂[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m█[39m[

In [55]:
@benchmark method_2($g)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m11.500 μs[22m[39m … [35m  4.994 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 97.87%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m17.000 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m31.683 μs[22m[39m ± [32m111.664 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m17.02% ±  5.26%

  [39m▆[39m█[34m▇[39m[39m▅[39m▄[39m▄[39m▃[32m▃[39m[39m▃[39m▂[39m▂[39m▁[39m▁[39m [39m▁[39m▁[39m▁[39m▂[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m█[39m[

In [56]:
# More benchmarks:
# batchsize = 128
dataloader_128 = DataLoader((X, Y), batchsize=128)
x_128, y_128 = popfirst!(Iterators.Stateful(dataloader_128))
g_128 = gradient_helper(model, x_128, y_128)

97×128 Matrix{Float64}:
 -8.32704e-5  -0.0388899   -0.00291601  …  -2.4884e-8    -0.0213082
  1.67715      0.00907526   0.124145        0.00171719    0.00496514
 -0.470027    -0.0184421   -0.0708067      -0.000471961  -0.00985723
  1.51382      0.00942398   0.116051        0.00156517    0.00514572
 -0.0277458   -0.00493176  -0.00928818     -2.37233e-5   -0.00261334
  0.85494      0.00944993   0.0742789   …   0.000920082   0.00511361
  0.538563     0.00771031   0.0477155       0.000597577   0.00414917
  0.550209     0.00781455   0.0487987       0.000609484   0.00420624
  0.326909     0.00491189   0.0259904       0.000384619   0.00263024
  1.19542      0.00978619   0.0977518       0.00125914    0.00532181
  ⋮                                     ⋱                
  0.914462     0.0455521    0.10208         0.0010599     0.020767
  0.0275188    0.148765     0.0150914       2.24747e-5    0.0626304
  0.685179     0.0616235    0.0513032   …   0.000851968   0.0269822
  0.0660494    0.199091   

In [57]:
@benchmark method_1($g_128) 

BenchmarkTools.Trial: 1268 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.408 ms[22m[39m … [35m9.986 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 50.55%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.132 ms             [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m3.915 ms[22m[39m ± [32m1.714 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m20.17% ± 22.10%

  [39m [39m▁[39m▆[39m█[39m▆[39m▄[39m▂[34m▁[39m[39m▁[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m█[39m█[39m█[39m█[39m█[39m█[34m█[39m

In [58]:
@benchmark method_2($g_128) 

BenchmarkTools.Trial: 851 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m3.126 ms[22m[39m … [35m15.410 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 54.25%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m4.918 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m5.848 ms[22m[39m ± [32m 2.364 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m18.60% ± 21.46%

  [39m [39m [39m [39m [39m [39m▅[39m▄[39m█[39m▇[39m▇[39m▇[34m▄[39m[39m▃[39m [39m▂[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▄[39m▅[39m█[39m█[39m█[39m█[3

In [59]:
@benchmark method_1($g_128) samples=10_000

BenchmarkTools.Trial: 1284 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.125 ms[22m[39m … [35m10.474 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 62.47%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.128 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m3.866 ms[22m[39m ± [32m 1.737 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m20.06% ± 22.29%

  [39m [39m▄[39m▄[39m▃[39m▄[39m█[39m█[39m▆[39m▅[34m▄[39m[39m▁[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m█[39m█[39m█[39m█[39m█[39m█[

In [60]:
@benchmark method_2($g_128) samples=10_000

BenchmarkTools.Trial: 813 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m3.131 ms[22m[39m … [35m16.548 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 57.13%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m5.115 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m6.127 ms[22m[39m ± [32m 2.667 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m18.36% ± 21.39%

  [39m [39m [39m▇[39m█[39m▄[39m▂[39m▂[39m▁[39m▅[39m▃[39m▃[34m▄[39m[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▁[39m▄[39m█[39m█[39m█[39m█[39m█[3

### Benchmark conclusions
Method 2, with Tullio, is faster, probably since it makes less allocations. The effect is lesser for larger batchsizes -- compare the almost two-fold difference for batchsize 2, and the ~16% difference for batchsize 128.
Tullio would also employ multi-threading for large-enough matrices.

Machine: Thinkpad X1 Gen 6, Intel i7 8th Gen, Intel GPU

## Summary

In [61]:
x_b, y_b = popfirst!(Iterators.Stateful(dataloader))
grads_b = jacobian(() -> loss_fn(x_b, y_b, agg=identity), Flux.params(model))
g = transpose(reduce(hcat, grads_b))  # Q: could using permutedims make the next computation faster, because of BLAS-optimised multiplication, at a cost of an allocation?
@tullio H_b[i, j, b] := g[i, b] * g[j, b]
@tullio H[i, j] := g[i, b] * g[j, b]

97×97 Matrix{Float64}:
  0.00151243   -0.000492593  …  -0.0109714  -0.00969248  -0.0150404
 -0.000492593   2.81291          0.24573     0.187403     1.54365
  0.000756352  -0.788472        -0.0733471  -0.0564792   -0.43873
 -0.000492554   2.53898          0.222147    0.16946      1.39379
  0.000194106  -0.0465786       -0.0054127  -0.00429088  -0.0273771
 -0.000438698   1.43395      …   0.126622    0.0967315    0.788744
 -0.000344699   0.90332          0.0802601   0.0613729    0.497539
 -0.000349723   0.922853         0.081978    0.0626845    0.508274
 -0.000218244   0.54832          0.0487834   0.0373112    0.302097
 -0.000480126   2.00498          0.176084    0.134401     1.10154
  ⋮                          ⋱               ⋮           
 -0.00184766    1.5341           0.145426    0.112291     0.857295
 -0.00578773    0.0475031        0.0459125   0.0400791    0.0825123
 -0.00245358    1.14971      …   0.116711    0.090982     0.652924
 -0.00774814    0.112581         0.0656814   0.05

In [62]:
H_b

97×97×2 Array{Float64, 3}:
[:, :, 1] =
  6.93396e-9   -0.000139657   3.91393e-5  …  -9.19238e-6   -7.64687e-5
 -0.000139657   2.81283      -0.788305        0.185144      1.54016
  3.91393e-5   -0.788305      0.220925       -0.0518872    -0.431634
 -0.000126056   2.5389       -0.711535        0.167113      1.39017
  2.3104e-6    -0.0465338     0.0130413      -0.00306291   -0.0254795
 -7.11912e-5    1.43386      -0.401845    …   0.0943785     0.785107
 -4.48464e-5    0.90325      -0.253139        0.059453      0.494572
 -4.58161e-5    0.922782     -0.258613        0.0607387     0.505267
 -2.72219e-5    0.548276     -0.153656        0.0360882     0.300207
 -9.95429e-5    2.00489      -0.561878        0.131965      1.09777
  ⋮                                       ⋱   ⋮            
 -7.61476e-5    1.53369      -0.429822        0.100949      0.839768
 -2.2915e-6     0.0461531    -0.0129346       0.00303785    0.025271
 -5.70551e-5    1.14915      -0.322052    …   0.0756382     0.629212
 -5.

In [63]:
@assert isapprox(H, H_b[:, :, 1] + H_b[:, :, 2], atol=.05)

## Sanity check

In [64]:
data

zip([[6.890818797403573], [0.26043975403126574], [3.531508829900008], [1.1006431286934388], [7.444971612833461], [6.9055480307075765], [7.096606595920804], [7.49290808338976], [7.886053029307477], [1.9954021528253456]  …  [1.0126553374571472], [3.1642602669955346], [5.1392679687161955], [2.308002086996864], [7.878664260943012], [6.128956950142811], [2.8510610517270507], [2.500786700142399], [6.627403099097526], [6.563289310410638]], [0.01548584004498288, 0.0650502542475323, -0.466806016348434, 1.0643227230458363, 0.8084199374169351, 0.3694085963699799, 0.8605000431231095, 1.217079975762081, 0.8372465451682657, 1.0274790781864027  …  0.616120436749096, 0.3727942823307355, -1.245267664673427, 0.6760783565161077, 1.5992073824265285, -0.9113091145217157, 0.11834121182959989, 0.3930714012847796, 0.2987255761617808, 0.9730906195136584])

In [65]:
function fit_la_unbatched(nn, data, X, y)
    la = Laplace(nn; likelihood=:regression, λ=λ, subset_of_weights=:all)
    fit!(la, data)
    plot(la, X, y)
end

fit_la_unbatched (generic function with 1 method)

In [66]:
dataloader_128 = DataLoader((X, Y), batchsize=128)

128-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}}, batchsize=128)
  with first element:
  (1×128 Matrix{Float64}, 1×128 Matrix{Float64},)

In [67]:
dataloader_1 = DataLoader((X, Y), batchsize=1)

16384-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}})
  with first element:
  (1×1 Matrix{Float64}, 1×1 Matrix{Float64},)

In [68]:
dataloader_2 = DataLoader((X, Y), batchsize=2)

8192-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}}, batchsize=2)
  with first element:
  (1×2 Matrix{Float64}, 1×2 Matrix{Float64},)

In [69]:
function fit_la_batched(nn, dataloader, X, y)
    la_b = Laplace(nn; likelihood=:regression, λ=λ, subset_of_weights=:all)
    fit!(la_b, dataloader)
    plot(la_b, X, y)
end

fit_la_batched (generic function with 1 method)

In [70]:
isapprox(la_b.H, la.H, atol=.005)

UndefVarError: UndefVarError: `la_b` not defined

In [71]:
length(dataloader)

8192

In [72]:
@benchmark fit_la_unbatched($nn, $data, $X, $y)

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m761.406 ms[22m[39m … [35m948.132 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m14.65% … 14.59%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m855.867 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m14.76%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m853.590 ms[22m[39m ± [32m 59.339 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.74% ±  0.36%

  [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [32m▁[39m[34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m▁[39m▁[

In [73]:
@benchmark fit_la_batched($nn, $dataloader_128, $X, $y)

BenchmarkTools.Trial: 137 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m26.382 ms[22m[39m … [35m56.091 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 16.39%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m36.550 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m9.72%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m36.700 ms[22m[39m ± [32m 5.110 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.98% ±  4.73%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▃[39m [39m [39m▁[39m▆[39m [39m▃[39m▃[39m▁[39m█[39m [39m▁[39m [39m [34m▆[39m[39m [39m▁[39m [39m▄[39m▃[39m [39m [39m [39m▄[39m▃[39m [39m▃[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▄[39m▁[39m▁[39m▆[39m▁[39m▆[

In [74]:
@benchmark fit_la_batched($nn, $dataloader_1, $X, $y)

BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m793.336 ms[22m[39m … [35m886.060 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m14.24% … 14.78%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m864.310 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m14.04%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m849.946 ms[22m[39m ± [32m 35.433 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.27% ±  0.45%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [34m█[39m[39m [39m [39m█[39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[

In [75]:
@benchmark fit_la_batched($nn, $dataloader_2, $X, $y)

BenchmarkTools.Trial: 12 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m361.541 ms[22m[39m … [35m504.637 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m16.55% … 13.99%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m416.384 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m14.60%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m430.859 ms[22m[39m ± [32m 51.180 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.38% ±  1.11%

  [39m█[39m [39m█[39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m█[34m█[39m[39m [39m [39m [39m█[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m█[39m█[39m [39m 
  [39m█[39m▁[39m█

In [76]:
# Benchmarking with globals

function run_unbatched()
    la = Laplace(nn; likelihood=:regression, λ=λ, subset_of_weights=:all)
    fit!(la, data)
    plot(la, X, y)
end

function run_batched()
    la_b = Laplace(nn; likelihood=:regression, λ=λ, subset_of_weights=:all)
    fit!(la_b, dataloader_128)
    plot(la_b, X, y)
end

run_batched (generic function with 1 method)

In [80]:
@benchmark run_unbatched()

BenchmarkTools.Trial: 7 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m747.493 ms[22m[39m … [35m885.277 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m13.99% … 14.70%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m828.813 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m14.36%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m821.796 ms[22m[39m ± [32m 53.276 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m14.38% ±  0.27%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[

In [81]:
@benchmark run_batched()

BenchmarkTools.Trial: 130 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m29.107 ms[22m[39m … [35m53.286 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 8.66%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m38.712 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m9.34%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m38.657 ms[22m[39m ± [32m 4.578 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.51% ± 4.33%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m▄[39m▅[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [34m▁[39m[39m [39m▂[39m▄[39m [39m [39m▁[39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▃[39m▁[39m▁[39m▃[39m▁[39m▅[39