## Batched Gradients

In [62]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m project at `~/Builds/navimakarov/LaplaceRedux.jl/dev/notebooks/batching`


In [63]:
# Setup some dummy data, like in tests

n = 128
data_dict = Dict()
bsize = 2

x, y = LaplaceRedux.Data.toy_data_regression(n)
xs = [[x] for x in x]
X, Y = reduce(hcat, x), reduce(hcat, y)

dataloader = DataLoader((X, Y), batchsize=bsize)
data = zip(xs, y)
data_dict[:regression] = Dict(
    :data => data,
    :X => X,
    :y => y,
    :outdim => 1,
    :loss_fun => :mse,
    :likelihood => :regression,
)

Dict{Symbol, Any} with 6 entries:
  :loss_fun   => :mse
  :y          => [0.674481, 0.476305, 1.0961, 1.50355, 0.414426, 0.221046, 0.72…
  :likelihood => :regression
  :X          => [2.02611 2.39063 … 6.35966 4.52948]
  :outdim     => 1
  :data       => zip([[2.02611], [2.39063], [1.3071], [7.45034], [2.85453], [2.…

In [64]:
# Train a NN model

val = data_dict[:regression]

# Unpack:
data = val[:data]
X = val[:X]
y = val[:y]
outdim = val[:outdim]
loss_fun = val[:loss_fun]
likelihood = val[:likelihood]

# Neural network:
n_hidden = 32
D = size(X, 1)
nn = Chain(Dense(D, n_hidden, σ), Dense(n_hidden, outdim))
λ = 0.01
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1 / 2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = getfield(Flux.Losses, loss_fun)(nn(x), y) + weight_regularization()


opt = Adam()
epochs = 200
avg_loss(data) = mean(map(d -> loss(d[1], d[2]), data))
show_every = epochs / 10

for epoch in 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end


[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(1 => 32, σ)   [90m# 64 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "1-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/EHgZm/src/layers/stateless.jl:60[39m


Epoch 20
avg_loss(data) = 0.47340725422171465
Epoch 40
avg_loss(data) = 0.4534691405424245
Epoch 60
avg_loss(data) = 0.43944693575348115
Epoch 80
avg_loss(data) = 0.42765904204060134
Epoch 100
avg_loss(data) = 0.4159723891351543
Epoch 120
avg_loss(data) = 0.40304985023726797
Epoch 140
avg_loss(data) = 0.3883432639032817
Epoch 160
avg_loss(data) = 0.36502531939915644
Epoch 180
avg_loss(data) = 0.3273398871608472
Epoch 200
avg_loss(data) = 0.280764489288256


In [65]:
model = nn

Chain(
  Dense(1 => 32, σ),                    [90m# 64 parameters[39m
  Dense(32 => 1),                       [90m# 33 parameters[39m
) [90m                  # Total: 4 arrays, [39m97 parameters, 644 bytes.

In [66]:
# Loss fun for regression
loss_fn(x, ytrue; agg=sum) = Flux.Losses.mse(model(x), ytrue, agg=agg)

loss_fn (generic function with 1 method)

In [7]:
dataloader

LoadError: UndefVarError: `dataloader` not defined

In [67]:
# Take one datapoint of the dataset
x_1 = xs[1]
y_1 = y[1]

0.6744812875199306

In [68]:
loss_fn(x_1, y_1)

0.07596102806142857

In [69]:
grads_1 = gradient(() -> loss_fn(x_1, y_1), Flux.params(model))
grads = grads_1

Grads(...)

In [70]:
grads.grads

IdDict{Any, Any} with 6 entries:
  Float32[0.314995, -0.498… => Float32[0.0788479, -0.0320896, -0.132364, -0.119…
  Float32[0.0238793]        => Float32[-0.551221]
  :(Main.y_1)               => 0.551221
  :(Main.x_1)               => Float32[0.305503]
  Float32[-0.574315 0.2331… => Float32[-0.258767 -0.266765 … -0.377837 -0.39490…
  Float32[-0.215867; 0.214… => Float32[0.159754; -0.065017; … ; 0.0460124; 0.07…

In [71]:
grads.params

Params([Float32[-0.215867; 0.21430783; … ; 0.60598916; 0.5763072;;], Float32[0.31499466, -0.4984213, 1.7169719, 1.4942384, -0.044635363, 1.6780366, -0.655377, -0.09566895, 0.062958434, -1.4049017  …  0.29456976, -0.5352677, -0.3875355, -1.7812886, 0.23457813, -0.38505232, -1.9335868, 0.20516092, -0.44884127, -0.24087419], Float32[-0.57431465 0.23310256 … -0.1910847 -0.34639055], Float32[0.02387928]])

In [72]:
length(grads.params)

4

In [73]:
sum(length, grads.params) # 97 params total

97

In [74]:
fieldnames(Params)

(:order, :params)

In [75]:
grads.params.order

Zygote.Buffer{Any, Vector{Any}}(Any[Float32[-0.215867; 0.21430783; … ; 0.60598916; 0.5763072;;], Float32[0.31499466, -0.4984213, 1.7169719, 1.4942384, -0.044635363, 1.6780366, -0.655377, -0.09566895, 0.062958434, -1.4049017  …  0.29456976, -0.5352677, -0.3875355, -1.7812886, 0.23457813, -0.38505232, -1.9335868, 0.20516092, -0.44884127, -0.24087419], Float32[-0.57431465 0.23310256 … -0.1910847 -0.34639055], Float32[0.02387928]], false)

In [76]:
th_1 = grads.params[1]

32×1 Matrix{Float32}:
 -0.215867
  0.21430783
 -0.8522523
 -0.8058787
 -0.55217594
 -0.8438247
  0.64416796
 -0.58848274
 -0.20825455
  0.7841481
  0.21280028
 -0.21573345
 -0.79292846
  ⋮
  0.19885172
  0.2104885
 -0.2148419
  0.21682474
  0.20893377
  0.86282456
 -0.21216859
  0.21194375
  0.8965222
 -0.2127403
  0.60598916
  0.5763072

In [77]:
# This is the gradient of the loss function value at (x_1, y_1) for th_1
# Size 32
grads.grads[th_1]

32×1 Matrix{Float32}:
  0.15975425
 -0.06501701
 -0.26818433
 -0.24190344
 -0.014645586
 -0.3032601
  0.10062788
  0.0010204703
  0.09501783
  0.078615464
 -0.07067647
  0.15351215
 -0.22601257
  ⋮
 -0.07786055
 -0.16788235
  0.16065204
 -0.06647736
 -0.07539779
  0.05932422
  0.15040717
 -0.13158734
  0.06730304
  0.11090459
  0.046012405
  0.0785949

In [78]:
# The same for x_2, y_2
x_2 = xs[2]
y_2 = y[2]
grads_2 = gradient(() -> loss_fn(x_2, y_2), Flux.params(model))
grads_2.grads[th_1]

32×1 Matrix{Float32}:
  0.19042273
 -0.078068666
 -0.31359237
 -0.27836326
 -0.015698383
 -0.35359997
  0.11078293
  0.001075707
  0.11229756
  0.09004506
 -0.08475374
  0.18294053
 -0.2589492
  ⋮
 -0.09289562
 -0.20045124
  0.19136347
 -0.07992181
 -0.09019053
  0.069791876
  0.17880549
 -0.15734246
  0.08012773
  0.13168693
  0.050335698
  0.08534028

In [79]:
# Now take one batch combining the two
x_b, y_b = popfirst!(Iterators.Stateful(dataloader))

([2.0261055314447827 2.3906266989746072], [0.6744812875199306 0.47630528861687577])

In [80]:
@show x_1
@show x_2
@show y_1
@show y_2

x_1 = [2.0261055314447827]
x_2 = [2.3906266989746072]
y_1 = 0.6744812875199306
y_2 = 0.47630528861687577


0.47630528861687577

In [81]:
# Take the *Jacobian* to obtain the gradients for the batch parts
grads_b = jacobian(() -> loss_fn(x_b, y_b, agg=identity), Flux.params(model))

Grads(...)

In [82]:
grads_b.grads[th_1]

2×32 Matrix{Float64}:
 0.159754  -0.065017   -0.268184  …  0.110905  0.0460124  0.0785949
 0.190423  -0.0780687  -0.313592     0.131687  0.0503357  0.0853403

In [83]:
th = grads_b.params

Params([Float32[-0.215867; 0.21430783; … ; 0.60598916; 0.5763072;;], Float32[0.31499466, -0.4984213, 1.7169719, 1.4942384, -0.044635363, 1.6780366, -0.655377, -0.09566895, 0.062958434, -1.4049017  …  0.29456976, -0.5352677, -0.3875355, -1.7812886, 0.23457813, -0.38505232, -1.9335868, 0.20516092, -0.44884127, -0.24087419], Float32[-0.57431465 0.23310256 … -0.1910847 -0.34639055], Float32[0.02387928]])

In [84]:
grads_b.params == grads_1.params == grads_2.params

true

In [85]:
grads_b[th_1]

2×32 Matrix{Float64}:
 0.159754  -0.065017   -0.268184  …  0.110905  0.0460124  0.0785949
 0.190423  -0.0780687  -0.313592     0.131687  0.0503357  0.0853403

In [86]:
grads

Grads(...)

In [87]:
grads_1[th_1]'

1×32 adjoint(::Matrix{Float32}) with eltype Float32:
 0.159754  -0.065017  -0.268184  -0.241903  …  0.110905  0.0460124  0.0785949

In [88]:
grads_2[th_1]'

1×32 adjoint(::Matrix{Float32}) with eltype Float32:
 0.190423  -0.0780687  -0.313592  …  0.131687  0.0503357  0.0853403

In [30]:
# grads_1 .* transpose.(grads_1)

In [89]:
sum(length, grads_1)

97

In [90]:
sum(length, grads_1.params)

97

In [91]:
sum(length, grads_b)

194

In [92]:
# We would like to obtain g as a 97x2 matrix
# containing as the first column the gradient for loss_fn(x_1, y_1) wrt theta (of size 97)
# and as the second column, one for loss_fn(x_2, y_2)
g = permutedims(reduce(hcat, grads_b))

97×2 Matrix{Float64}:
  0.159754     0.190423
 -0.065017    -0.0780687
 -0.268184    -0.313592
 -0.241903    -0.278363
 -0.0146456   -0.0156984
 -0.30326     -0.3536
  0.100628     0.110783
  0.00102047   0.00107571
  0.0950178    0.112298
  0.0786154    0.090045
 -0.0706765   -0.0847537
  0.153512     0.18294
 -0.226013    -0.258949
  ⋮           
 -0.287043    -0.302533
 -0.25625     -0.249635
 -0.262397    -0.27783
 -0.280541    -0.295868
 -0.271047    -0.319367
 -0.248782    -0.242239
 -0.281723    -0.297219
 -0.259486    -0.309467
 -0.244615    -0.238014
 -0.377837    -0.40967
 -0.394907    -0.424281
 -0.55122     -0.560401

In [117]:
function gradient_helper(model, x, y)
    grads = jacobian(() -> loss_fn(x, y, agg=identity), Flux.params(model))
    g = permutedims(reduce(hcat, grads))
end

gradient_helper (generic function with 1 method)

In [93]:
# sanity check
grads_b[th[1]]

2×32 Matrix{Float64}:
 0.159754  -0.065017   -0.268184  …  0.110905  0.0460124  0.0785949
 0.190423  -0.0780687  -0.313592     0.131687  0.0503357  0.0853403

In [94]:
grads_b[th[length(th)-1]]

2×32 Matrix{Float64}:
 -0.258767  -0.266765  -0.274262  …  -0.244615  -0.377837  -0.394907
 -0.252126  -0.282149  -0.235686     -0.238014  -0.40967   -0.424281

In [95]:
grads_b[th[length(th)]] # sanity: checked

2×1 Matrix{Float64}:
 -0.5512204170227051
 -0.5604010224342346

In [96]:
g

97×2 Matrix{Float64}:
  0.159754     0.190423
 -0.065017    -0.0780687
 -0.268184    -0.313592
 -0.241903    -0.278363
 -0.0146456   -0.0156984
 -0.30326     -0.3536
  0.100628     0.110783
  0.00102047   0.00107571
  0.0950178    0.112298
  0.0786154    0.090045
 -0.0706765   -0.0847537
  0.153512     0.18294
 -0.226013    -0.258949
  ⋮           
 -0.287043    -0.302533
 -0.25625     -0.249635
 -0.262397    -0.27783
 -0.280541    -0.295868
 -0.271047    -0.319367
 -0.248782    -0.242239
 -0.281723    -0.297219
 -0.259486    -0.309467
 -0.244615    -0.238014
 -0.377837    -0.40967
 -0.394907    -0.424281
 -0.55122     -0.560401

In [97]:
vg = g * transpose(g)

97×97 Matrix{Float64}:
  0.0617822    -0.0252528    -0.102559     …  -0.143881     -0.194773
 -0.0252528     0.0103219     0.0419183        0.0587987     0.0795884
 -0.102559      0.0419183     0.170263         0.238959      0.323566
 -0.0916518     0.0374593     0.152167         0.213634      0.289337
 -0.00532902    0.00217776    0.00885061       0.0124442     0.0168703
 -0.115781      0.0473221     0.192216     …   0.269785      0.365321
  0.0371713    -0.0151912    -0.0617275       -0.0867417    -0.117551
  0.000367863  -0.000150327  -0.000611008     -0.000859393  -0.00116533
  0.0365635    -0.0149447    -0.0606979       -0.0851689    -0.115307
  0.0297058    -0.012141     -0.0493209       -0.0692502    -0.0937958
 -0.0274299     0.0112118     0.0455324    …   0.06387       0.0864544
  0.0593602    -0.0242628    -0.0985383       -0.138241     -0.187139
 -0.0854163     0.0349105     0.141817         0.199121      0.269698
  ⋮                                        ⋱   ⋮            


In [98]:
foreach(x -> @show(size(x)), grads)

size(x) = (32, 1)
size(x) = (32,)
size(x) = (1, 32)
size(x) = (1,)


In [99]:
g_1 = reduce(vcat, [vec(grads_1[th]) for th in grads_1.params])
g_2 = reduce(vcat, [vec(grads_2[th]) for th in grads_2.params])
# let v(g) denote gg'
vg_1 = g_1 * transpose(g_1)
vg_2 = g_2 * transpose(g_2)

97×97 Matrix{Float32}:
  0.0362608    -0.014866    -0.0597151    …  -0.0807928    -0.106713
 -0.014866      0.00609472   0.0244817        0.0331231     0.0437498
 -0.0597151     0.0244817    0.0983402        0.133051      0.175738
 -0.0530067     0.0217314    0.0872926        0.118104      0.155995
 -0.00298933    0.00122555   0.00492289       0.00666053    0.00879739
 -0.0673335     0.0276051    0.110886     …   0.150026      0.198158
  0.0210956    -0.00864868  -0.0347407       -0.0470031    -0.0620829
  0.000204839  -8.3979f-5   -0.000337334     -0.000456402  -0.000602827
  0.021384     -0.00876692  -0.0352157       -0.0476458    -0.0629317
  0.0171466    -0.0070297   -0.0282374       -0.0382044    -0.0504613
 -0.016139      0.00661661   0.0265781    …   0.0359594     0.0474961
  0.034836     -0.0142819   -0.0573688       -0.0776182    -0.10252
 -0.0493098     0.0202158    0.0812045        0.109867      0.145115
  ⋮                                       ⋱   ⋮            
 -0.0576091

In [100]:
isapprox(vg_1 + vg_2, vg, atol=.0005) 
# This is not it.
# We want a 97x97x2 matrix

true

In [101]:
# Reshape g as 97x1x2 since MLUtils expects 3-d arrays
gm = reshape(g, size(g, 1), 1, size(g, 2))

97×1×2 Array{Float64, 3}:
[:, :, 1] =
  0.15975421667099
 -0.06501699239015579
 -0.26818427443504333
 -0.2419033646583557
 -0.014645582996308804
 -0.30326002836227417
  0.1006278395652771
  0.0010204700520262122
  0.09501780569553375
  0.0786154493689537
 -0.07067645341157913
  0.153512105345726
 -0.22601251304149628
  ⋮
 -0.2870430648326874
 -0.2562498450279236
 -0.2623968720436096
 -0.28054124116897583
 -0.2710472047328949
 -0.24878238141536713
 -0.2817233204841614
 -0.2594864070415497
 -0.2446153610944748
 -0.3778369426727295
 -0.39490675926208496
 -0.5512204170227051

[:, :, 2] =
  0.19042271375656128
 -0.07806865870952606
 -0.3135923743247986
 -0.2783631980419159
 -0.015698380768299103
 -0.3535999357700348
  0.1107829213142395
  0.0010757070267573
  0.11229754239320755
  0.09004504978656769
 -0.08475372940301895
  0.1829404979944229
 -0.2589491605758667
  ⋮
 -0.3025325536727905
 -0.2496354728937149
 -0.27782997488975525
 -0.2958679497241974
 -0.319366991519928
 -0.2422387450933456

In [102]:
g

97×2 Matrix{Float64}:
  0.159754     0.190423
 -0.065017    -0.0780687
 -0.268184    -0.313592
 -0.241903    -0.278363
 -0.0146456   -0.0156984
 -0.30326     -0.3536
  0.100628     0.110783
  0.00102047   0.00107571
  0.0950178    0.112298
  0.0786154    0.090045
 -0.0706765   -0.0847537
  0.153512     0.18294
 -0.226013    -0.258949
  ⋮           
 -0.287043    -0.302533
 -0.25625     -0.249635
 -0.262397    -0.27783
 -0.280541    -0.295868
 -0.271047    -0.319367
 -0.248782    -0.242239
 -0.281723    -0.297219
 -0.259486    -0.309467
 -0.244615    -0.238014
 -0.377837    -0.40967
 -0.394907    -0.424281
 -0.55122     -0.560401

In [103]:
vgm = batched_mul(gm, batched_transpose(gm))

97×97×2 Array{Float64, 3}:
[:, :, 1] =
  0.0255214    -0.0103867    -0.0428436    …  -0.063088     -0.0880598
 -0.0103867     0.00422721    0.0174365        0.0256756     0.0358387
 -0.0428436     0.0174365     0.0719228        0.105908      0.147829
 -0.0386451     0.0157278     0.0648747        0.0955293     0.133342
 -0.00233969    0.000952212   0.00392772       0.00578364    0.00807294
 -0.0484471     0.0197171     0.0813296    …   0.119759      0.167163
  0.0160757    -0.00654252   -0.0269868       -0.0397386    -0.0554681
  0.000163024  -6.63479e-5   -0.000273674     -0.000402991  -0.000562504
  0.0151795    -0.00617777   -0.0254823       -0.0375232    -0.0523758
  0.0125591    -0.00511134   -0.0210834       -0.0310458    -0.0433344
 -0.0112909     0.00459517    0.0189543    …   0.0279106     0.0389583
  0.0245242    -0.0099809    -0.0411695       -0.060623     -0.084619
 -0.0361065     0.0146947     0.060613         0.0892539     0.124583
  ⋮                                     

In [104]:
@assert isapprox(vgm[:,:,1], vg_1, atol=.0005)
@assert isapprox(vgm[:,:,2], vg_2, atol=.0005)

In [105]:
vg_2

97×97 Matrix{Float32}:
  0.0362608    -0.014866    -0.0597151    …  -0.0807928    -0.106713
 -0.014866      0.00609472   0.0244817        0.0331231     0.0437498
 -0.0597151     0.0244817    0.0983402        0.133051      0.175738
 -0.0530067     0.0217314    0.0872926        0.118104      0.155995
 -0.00298933    0.00122555   0.00492289       0.00666053    0.00879739
 -0.0673335     0.0276051    0.110886     …   0.150026      0.198158
  0.0210956    -0.00864868  -0.0347407       -0.0470031    -0.0620829
  0.000204839  -8.3979f-5   -0.000337334     -0.000456402  -0.000602827
  0.021384     -0.00876692  -0.0352157       -0.0476458    -0.0629317
  0.0171466    -0.0070297   -0.0282374       -0.0382044    -0.0504613
 -0.016139      0.00661661   0.0265781    …   0.0359594     0.0474961
  0.034836     -0.0142819   -0.0573688       -0.0776182    -0.10252
 -0.0493098     0.0202158    0.0812045        0.109867      0.145115
  ⋮                                       ⋱   ⋮            
 -0.0576091

In [106]:
g

97×2 Matrix{Float64}:
  0.159754     0.190423
 -0.065017    -0.0780687
 -0.268184    -0.313592
 -0.241903    -0.278363
 -0.0146456   -0.0156984
 -0.30326     -0.3536
  0.100628     0.110783
  0.00102047   0.00107571
  0.0950178    0.112298
  0.0786154    0.090045
 -0.0706765   -0.0847537
  0.153512     0.18294
 -0.226013    -0.258949
  ⋮           
 -0.287043    -0.302533
 -0.25625     -0.249635
 -0.262397    -0.27783
 -0.280541    -0.295868
 -0.271047    -0.319367
 -0.248782    -0.242239
 -0.281723    -0.297219
 -0.259486    -0.309467
 -0.244615    -0.238014
 -0.377837    -0.40967
 -0.394907    -0.424281
 -0.55122     -0.560401

In [107]:
# Now try einsums
using Tullio

In [108]:
@tullio H[i, j, b] := g[i, b] * g[j, b]

97×97×2 Array{Float64, 3}:
[:, :, 1] =
  0.0255214    -0.0103867    -0.0428436    …  -0.063088     -0.0880598
 -0.0103867     0.00422721    0.0174365        0.0256756     0.0358387
 -0.0428436     0.0174365     0.0719228        0.105908      0.147829
 -0.0386451     0.0157278     0.0648747        0.0955293     0.133342
 -0.00233969    0.000952212   0.00392772       0.00578364    0.00807294
 -0.0484471     0.0197171     0.0813296    …   0.119759      0.167163
  0.0160757    -0.00654252   -0.0269868       -0.0397386    -0.0554681
  0.000163024  -6.63479e-5   -0.000273674     -0.000402991  -0.000562504
  0.0151795    -0.00617777   -0.0254823       -0.0375232    -0.0523758
  0.0125591    -0.00511134   -0.0210834       -0.0310458    -0.0433344
 -0.0112909     0.00459517    0.0189543    …   0.0279106     0.0389583
  0.0245242    -0.0099809    -0.0411695       -0.060623     -0.084619
 -0.0361065     0.0146947     0.060613         0.0892539     0.124583
  ⋮                                     

In [109]:
@assert isapprox(vgm, H, atol=eps())

In [110]:
eps()

2.220446049250313e-16

## Benchmarks for batched computation of H = 𝐠 * 𝐠'

In [111]:
using BenchmarkTools

In [112]:
# Method 1: reshape + MLUtils
function method_1(g)
    gm = reshape(g, size(g, 1), 1, size(g, 2))
    vgm = batched_mul(gm, batched_transpose(gm))
end

method_1 (generic function with 1 method)

In [113]:
# Method 2: Tullio Einstein summation
function method_2(g)
    @tullio H[i, j, b] := g[i, b] * g[j, b]
end

method_2 (generic function with 1 method)

In [114]:
@benchmark method_1($g)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m 49.388 μs[22m[39m … [35m 10.842 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 97.90%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m 59.316 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m125.341 μs[22m[39m ± [32m464.658 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m19.16% ±  5.34%

  [39m▆[34m█[39m[39m▃[39m▄[39m▄[39m▂[39m▁[39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m▂[39m▂[39m▁[39m▁[39m [39m▁[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[34m█[39

In [115]:
@benchmark method_2($g)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m36.671 μs[22m[39m … [35m  7.133 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 93.05%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m43.790 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m68.586 μs[22m[39m ± [32m347.323 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m27.82% ±  5.47%

  [39m▅[34m█[39m[39m▅[39m▄[39m▃[39m▂[32m▁[39m[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[34m█[39m[39m█[

In [119]:
# More benchmarks:
# batchsize = 128
dataloader_128 = DataLoader((X, Y), batchsize=128)
x_128, y_128 = popfirst!(Iterators.Stateful(dataloader_128))
g_128 = gradient_helper(model, x_128, y_128)

97×128 Matrix{Float64}:
  0.159754     0.190423     0.135743    …   0.227085     -0.778728
 -0.065017    -0.0780687   -0.0544587      -0.100478      0.333144
 -0.268184    -0.313592    -0.207638       -0.046609      0.545158
 -0.241903    -0.278363    -0.196131       -0.0452416     0.489152
 -0.0146456   -0.0156984   -0.0148095      -0.00403204    0.0294513
 -0.30326     -0.3536      -0.236733    …  -0.0534575     0.616227
  0.100628     0.110783     0.0937034       0.0248455    -0.206268
  0.00102047   0.00107571   0.00106327      0.000229966  -0.00182176
  0.0950178    0.112298     0.0821739       0.124926     -0.439863
  0.0786154    0.090045     0.0647664       0.0154556    -0.160701
 -0.0706765   -0.0847537   -0.0593582   …  -0.10792       0.359276
  0.153512     0.18294      0.130499        0.217723     -0.747227
 -0.226013    -0.258949    -0.185587       -0.043134      0.456392
  ⋮                                     ⋱                
 -0.287043    -0.302533    -0.349393    …  -

In [123]:
g

97×2 Matrix{Float64}:
  0.159754     0.190423
 -0.065017    -0.0780687
 -0.268184    -0.313592
 -0.241903    -0.278363
 -0.0146456   -0.0156984
 -0.30326     -0.3536
  0.100628     0.110783
  0.00102047   0.00107571
  0.0950178    0.112298
  0.0786154    0.090045
 -0.0706765   -0.0847537
  0.153512     0.18294
 -0.226013    -0.258949
  ⋮           
 -0.287043    -0.302533
 -0.25625     -0.249635
 -0.262397    -0.27783
 -0.280541    -0.295868
 -0.271047    -0.319367
 -0.248782    -0.242239
 -0.281723    -0.297219
 -0.259486    -0.309467
 -0.244615    -0.238014
 -0.377837    -0.40967
 -0.394907    -0.424281
 -0.55122     -0.560401

In [120]:
@benchmark method_1($g_128) 

BenchmarkTools.Trial: 579 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.888 ms[22m[39m … [35m29.510 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 19.11%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.354 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m8.627 ms[22m[39m ± [32m 8.521 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m19.00% ± 22.84%

  [39m█[34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m [39m [39m 
  [39m█[34m█[39m[39m▆[39m▄[39m▁[39m▁[39

In [121]:
@benchmark method_2($g_128) 

BenchmarkTools.Trial: 670 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.314 ms[22m[39m … [35m27.435 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 21.58%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m2.915 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m7.450 ms[22m[39m ± [32m 7.511 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m20.67% ± 23.42%

  [39m█[39m█[34m▃[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m▁[39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m▁[39m [39m 
  [39m█[39m█[34m█[39m[39m▆[39m▁[39m▁[39

In [124]:
@benchmark method_1($g_128) samples=10_000

BenchmarkTools.Trial: 540 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.902 ms[22m[39m … [35m35.119 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 18.32%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.562 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m9.247 ms[22m[39m ± [32m 9.138 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m18.95% ± 22.85%

  [39m█[34m▅[39m[39m▄[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[34m█[39m[39m█[39m█[39m▅[39m▁[39

In [125]:
@benchmark method_2($g_128) samples=10_000

BenchmarkTools.Trial: 639 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.310 ms[22m[39m … [35m31.376 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 0.00% … 19.24%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m3.014 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m7.795 ms[22m[39m ± [32m 7.830 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m20.68% ± 23.47%

  [39m▅[39m█[34m▃[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m 
  [39m█[39m█[34m█[39m[39m▆[39m▄[39m▁[39

### Benchmark conclusions
Method 2, with Tullio, is faster, probably since it makes less allocations. The effect is lesser for larger batchsizes -- compare the almost two-fold difference for batchsize 2, and the ~16% difference for batchsize 128.
Tullio would also employ multi-threading for large-enough matrices.

Machine: Thinkpad X1 Gen 6, Intel i7 8th Gen, Intel GPU