# Setup

In [1]:
using BenchmarkTools
using Flux
using Plots

In [2]:
repo_root = normpath(@__DIR__, "..") # just repo path
include(joinpath(repo_root, "VariationalMLP", "src", "VariationalMLP.jl"))

using .VariationalMLP

# Quick Demo

Based on : https://juliaci.github.io/BenchmarkTools.jl/stable/

In [3]:
data = rand(10)

10-element Vector{Float64}:
 0.07549001527315768
 0.12188035963654287
 0.7008360714296678
 0.4655713087716429
 0.4268167561015853
 0.7500596980461736
 0.5373464383141158
 0.42586681069202026
 0.7685771176491065
 0.11001477371649049

In [4]:
sort(data)

10-element Vector{Float64}:
 0.07549001527315768
 0.11001477371649049
 0.12188035963654287
 0.42586681069202026
 0.4268167561015853
 0.4655713087716429
 0.5373464383141158
 0.7008360714296678
 0.7500596980461736
 0.7685771176491065

Let's see how quickly this works for this sort algorithm

In [6]:
@benchmark sort(data) setup=(data=rand(10))

BenchmarkTools.Trial: 10000 samples with 993 evaluations per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m28.490 ns[22m[39m … [35m 3.313 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 98.40%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m40.660 ns              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m45.364 ns[22m[39m ± [32m55.291 ns[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m9.26% ±  9.51%

  [39m [39m▅[39m█[34m▇[39m[32m▄[39m[39m▂[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[39m█[39m█[34m

In [7]:
@btime sin(x) setup=(x=rand())

  3.625 ns (0 allocations: 0 bytes)


0.09945107896698284

In [8]:
A = rand(3,3)
A_inv = inv(A)
println(A)
println(A_inv)

[0.16960491226192365 0.21345033388990253 0.8214833343078708; 0.44049007481230285 0.07189574966377632 0.8905580783032689; 0.7500428033315432 0.2601684624573597 0.39359645907360974]
[-1.682132105978903 1.0727316803409799 1.0836317694275912; 4.090278049320793 -4.543570284887986 1.7434556524173581; 0.501807688240894 0.9591016627264686 -0.6767397938475697]


In [9]:
A * A_inv

3×3 Matrix{Float64}:
  1.0           0.0           0.0
 -5.55112e-17   1.0          -1.11022e-16
  3.33067e-16  -5.55112e-17   1.0

In [10]:
@btime inv($A)

  308.845 ns (7 allocations: 1.88 KiB)


3×3 Matrix{Float64}:
 -1.68213    1.07273    1.08363
  4.09028   -4.54357    1.74346
  0.501808   0.959102  -0.67674

# Benchmark loops

# Variational Dropout (Molchanov)

In [11]:
x = rand(Float32, 300, 32)  # input: 300 features, batch of 32

300×32 Matrix{Float32}:
 0.958314   0.933941    0.401141   …  0.152534    0.281412    0.894579
 0.120783   0.00399035  0.489561      0.172023    0.875931    0.996966
 0.105867   0.0485231   0.242761      0.00636274  0.120786    0.247303
 0.939628   0.0534546   0.73097       0.690218    0.642835    0.892169
 0.61906    0.358696    0.803592      0.176367    0.0289826   0.734983
 0.125696   0.769063    0.0430561  …  0.355827    0.519506    0.609749
 0.704685   0.734138    0.452008      0.220216    0.00714821  0.45026
 0.430373   0.278624    0.33018       0.766232    0.952734    0.869474
 0.349086   0.367057    0.45821       0.938272    0.386168    0.0920134
 0.0979801  0.519031    0.395636      0.904562    0.560142    0.383232
 0.230415   0.258155    0.60156    …  0.670963    0.0975702   0.303938
 0.021451   0.110649    0.102099      0.275709    0.973505    0.347845
 0.205095   0.44207     0.620219      0.584542    0.148073    0.734525
 ⋮                                 ⋱              ⋮  

In [12]:
dense = Dense(300, 100, relu)
var_layer_1 = make_variational(300, 100; activation=relu, parameterisation=:molchanov)
var_layer_2 = make_variational(300, 100; activation=relu, parameterisation=:kingma)
var_layer_3 = make_variational(300, 100; activation=relu, parameterisation=:graves)

Main.VariationalMLP.VariationalDense{typeof(relu)}(Float32[0.48671097 0.28783625 … 0.7951985 -0.7215633; 0.5177368 0.42106435 … -0.7710353 0.8546423; … ; -1.3780183 -0.6903357 … 2.2115 0.13292938; 0.47789004 -0.12540379 … -0.20387489 -0.13186419], Float32[0.74457884 -2.4324222 … 2.0892978 -1.2033985; 0.4253627 -0.7595672 … 0.018620761 -0.61241436; … ; -2.560398 -1.1910667 … -1.4692762 -0.17510343; -0.6946142 -1.7486149 … 0.61208826 0.61407274], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], NNlib.relu)

In [13]:
dense

Dense(300 => 100, relu)  [90m# 30_100 parameters[39m

In [14]:
var_layer_1

VariationalDropoutMolchanov{typeof(relu)}(Float32[-0.43685755 0.08489554 … -1.6618814 -0.29524508; -1.5998878 1.0415661 … 2.0147724 -1.3684102; … ; 1.3930212 -0.6209397 … -0.74008626 0.54539466; -1.6403211 -0.107296005 … -2.3517299 -1.0712726], Float32[0.232926 0.32656965 … 1.6909105 0.22774501; 0.4568992 -0.4828619 … -0.6694798 -0.6123085; … ; -0.07900804 0.26763338 … 1.2399093 0.21933515; 0.6857367 1.1743313 … 0.3457506 -1.0528866], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], NNlib.relu)

Let's try to apply them to the input

In [15]:
dense(x)

100×32 Matrix{Float32}:
 0.445163   0.0         0.115413    …  0.0        0.107702   0.033223
 0.0        0.0         0.0            0.0        0.0        0.0
 0.0        0.0         0.0            0.0        0.0        0.0
 0.0        0.109663    0.0            0.0        0.235564   0.0
 0.0384172  0.0         0.0            0.0        0.0        0.288226
 0.0        0.0         0.0         …  0.0        0.0        0.0
 0.500724   0.419961    0.236422       0.268847   0.202376   0.211306
 0.0        0.0         0.0            0.0        0.0        0.0
 0.0        0.0         0.0            0.0        0.0        0.183608
 0.0        0.0         0.0            0.0        0.0        0.0
 0.833865   0.984048    0.828721    …  1.13949    0.550333   0.85089
 0.0        0.0         0.0            0.0        0.0        0.0
 0.160715   0.600113    0.34477        0.681765   0.791276   0.0
 ⋮                                  ⋱             ⋮          
 0.809335   0.453948    1.41536        0.4029

In [16]:
var_layer_1(x)

100×32 Matrix{Float32}:
  6.25843    4.82337   0.0      8.76104   …   0.0       2.26321   3.36543
  0.0        0.0       0.0      0.0           0.0       0.0       0.0
  0.0        0.0       0.0      0.0           0.0       0.0       0.0
  0.0        0.0       0.0      0.0           0.0       0.0       0.0
 10.5237    16.4919   11.7713  21.5509       21.164    15.1271   17.2656
 21.878     25.9373   21.0121  19.9754    …  12.7315   15.2481   15.9507
  0.144467  20.4795    0.0      0.254445      0.0       5.77656  13.1105
  0.0        0.0       0.0      0.0           0.0       0.0       0.0
 13.767     15.8598   10.033   21.1814       15.7677    0.0      14.8222
  1.60674    0.0       0.0      0.0           0.0       0.0       0.0
 26.0615    37.0587   39.5522  29.0457    …  42.0399   18.5582   22.1077
  0.0        0.0       0.0      0.0           0.0       0.0       0.0
 15.9796    23.9304   14.1293  21.5131       28.7888    1.21887  17.8526
  ⋮                                       ⋱ 

In [17]:
var_layer_2(x)

100×32 Matrix{Float32}:
 10.6547   19.201     10.4253     …  18.6034    20.7754    16.7124
  0.0       4.88031    0.0513558      5.98526    0.0        0.0
  0.0       8.40891    0.0            0.0        0.0        0.0
  0.0       0.0        0.0            0.0        0.0        0.0
 27.7688   26.6627    35.969         35.7167     7.0453    33.4402
 26.062    17.9019    10.2154     …  35.3665    28.5088    17.0703
  0.0       0.0        0.0            0.0        0.0        0.0
  3.17546   4.04305   11.3343         5.55978    5.70891    8.64492
 16.4642   15.5575    12.482         21.083     11.1843    16.4729
  0.0       0.0        0.0            0.0        6.52704    0.0
 12.3954    0.305874  14.1105     …   1.52634    0.0        8.83331
  0.0       2.57792    0.0            2.50675    0.0        0.0
  9.8828    0.0        0.0           10.5361     5.29919    7.56782
  ⋮                               ⋱              ⋮         
 11.7035    0.0        3.08482       14.5257    18.3814    3

In [18]:
var_layer_3(x)

100×32 Matrix{Float32}:
  0.0       0.0       5.8415    0.0      …   0.0       0.0       7.71049
  0.0       0.0       0.0       0.0          0.0       0.0       0.0
  0.0       0.68419   0.0       0.0          0.0       0.0       0.0
  0.0       0.0       0.0       0.0          0.0       0.0       0.0
 10.7764   29.7994   18.7487   36.7623      16.4782   17.0477   10.4503
 12.8516   16.1304    0.0      19.7109   …  12.6342    0.0      11.522
 10.342     0.0      15.7623   12.2853      14.4564    0.0       5.00547
 17.5177    0.0       1.26094  14.8939       1.61457  14.0908   28.0417
 12.6938    8.06788  14.7526   15.2005       8.17956  12.1103   21.9166
  3.73391   0.0       1.42143   9.4565      17.0895    2.64702   5.93608
  0.0       0.0       0.0       0.0      …   0.0       0.0       0.0
  0.0       7.61388   7.88679   0.0          9.98885  14.3312    0.0
 24.1744   12.4572    5.61219  18.2081       3.08661   0.0       7.90194
  ⋮                                      ⋱          

## One layer

In [19]:
# 1. Benchmark forward pass
println("Benchmark: Dense layer forward")
@btime $dense($x)

println("Benchmark: Molchanov layer forward")
@btime $var_layer_1($x)


println("Benchmark: Kigma layer forward")
@btime $var_layer_2($x)


println("Benchmark: Graves layer forward")
@btime $var_layer_3($x)

Benchmark: Dense layer forward
  33.166 μs (6 allocations: 25.16 KiB)
Benchmark: Molchanov layer forward
  296.584 μs (12 allocations: 259.81 KiB)
Benchmark: Kigma layer forward
  337.167 μs (12 allocations: 259.81 KiB)
Benchmark: Graves layer forward
  297.459 μs (12 allocations: 259.81 KiB)


100×32 Matrix{Float32}:
 12.9601    4.55049  12.7292    0.0      …  11.4549     0.0       16.0513
 10.1546   11.6482   11.2421   12.3546       0.0       15.2563     0.0
  0.0      10.6981    0.0       0.0          0.0        0.545844   1.4399
  0.0       0.0      10.3312    0.0          0.0       19.3328    18.0895
 21.5884   32.9366   35.8767   33.2585      44.6518    15.397     14.2991
 27.3598   30.9148   17.6174   25.68     …  24.0545     2.16785   29.2188
  0.0       0.0       0.0       0.0          0.0        1.11744    0.0
  0.0       2.94395   0.0       0.0          0.0        0.0        0.0
  0.0       1.71854   4.12568   0.0          2.7017     7.73407    8.77951
  0.0       0.0       0.0       0.0          0.0        0.0        0.0
  0.0       0.0       0.0       3.03009  …   0.0        0.575575   3.9425
 33.8649   35.0778   41.8807   20.2876      45.5214    30.8685    19.1409
  0.0       0.0       0.0       5.79657      0.0        0.0        0.0
  ⋮                         

In [20]:
bench_dense = @benchmark $dense($x)

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m33.458 μs[22m[39m … [35m 14.505 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 99.39%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m35.250 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m37.915 μs[22m[39m ± [32m152.853 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m5.09% ±  1.40%

  [39m [39m [39m▁[39m▃[39m█[39m▆[34m▅[39m[39m▄[39m▁[39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▆[39m

In [21]:
function pretty_hist(bench; filter_thresh = 100.0)
    # Convert nanoseconds to microseconds
    times_us = bench.times ./ 1_000

    # Trim values above a reasonable threshold
    trimmed = filter(t -> t < filter_thresh, times_us)

    p = histogram(trimmed;
        bins=100,
        xlabel="Execution Time (μs)",
        ylabel="Frequency",
        legend=false,
        linecolor=:black,
        fillcolor=:cornflowerblue,
        alpha=0.8,
        framestyle=:box,
    )


    return p
end



pretty_hist (generic function with 1 method)

In [22]:
p1  = pretty_hist(bench_dense)
savefig(p1, "Dense_FW_time")


"/Users/andreibleahu/Documents/MSc Data Science - St Andrews/Artifacts_CS5999_AIB/Tutorials/Dense_FW_time.png"

In [23]:
bench_molchanov = @benchmark $var_layer_1($x)

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m300.459 μs[22m[39m … [35m 3.397 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 90.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m312.292 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m325.329 μs[22m[39m ± [32m97.932 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.55% ±  8.64%

  [39m▇[39m█[34m█[39m[39m▅[32m▂[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m

In [24]:
p2 = pretty_hist(bench_molchanov; filter_thresh= 400)
savefig(p2, "Molchanov_FW_time")

"/Users/andreibleahu/Documents/MSc Data Science - St Andrews/Artifacts_CS5999_AIB/Tutorials/Molchanov_FW_time.png"

In [25]:
bench_kingma = @benchmark $var_layer_2($x)

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m341.250 μs[22m[39m … [35m  6.890 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 94.21%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m353.334 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m366.525 μs[22m[39m ± [32m110.343 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.17% ±  8.17%

  [39m▇[39m▇[34m█[39m[39m▆[32m▃[39m[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[

In [26]:
p3 = pretty_hist(bench_kingma; filter_thresh= 500)
savefig(p3, "Kingma_FW_time")

"/Users/andreibleahu/Documents/MSc Data Science - St Andrews/Artifacts_CS5999_AIB/Tutorials/Kingma_FW_time.png"

In [27]:
bench_graves = @benchmark $var_layer_3($x)

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m303.916 μs[22m[39m … [35m 3.406 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 90.13%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m315.792 μs              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m328.290 μs[22m[39m ± [32m95.738 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.41% ±  8.47%

  [39m▇[39m▇[34m█[39m[39m▅[32m▂[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[34m

In [28]:
p4 = pretty_hist(bench_graves; filter_thresh= 400)
savefig(p4, "Graves_FW_time")

"/Users/andreibleahu/Documents/MSc Data Science - St Andrews/Artifacts_CS5999_AIB/Tutorials/Graves_FW_time.png"

# Training 

In [29]:
# 3. Tiny training loop timing
function train_dense_epoch(model, data)
    opt = Flux.setup(Adam(), model)
    for (x, y) in data
        loss, grads = Flux.withgradient(model) do m
            Flux.logitcrossentropy(model(x), y)
        end
        Flux.update!(opt, model, grads[1])
    end
end

train_dense_epoch (generic function with 1 method)

In [30]:
function train_variational_epoch(model, data; N=320)
    opt = Flux.setup(Adam(), model)
    for (x, y) in data
        loss, grads = Flux.withgradient(model) do m
            energy_loss(m, x, y, N; task_type=:classification)
        end
        Flux.update!(opt, model, grads[1])
    end
end

train_variational_epoch (generic function with 1 method)

In [31]:
data_loader = [(rand(Float32, 300, 32), Flux.onehotbatch(rand(1:10, 32), 1:10)) for _ in 1:10]


10-element Vector{Tuple{Matrix{Float32}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}}:
 ([0.39160913 0.18003982 … 0.35694957 0.7631394; 0.8190077 0.61038536 … 0.5345167 0.801996; … ; 0.9872733 0.13055027 … 0.6169957 0.20513117; 0.31079632 0.16451049 … 0.9997054 0.136599], [0 0 … 0 0; 0 0 … 1 0; … ; 0 0 … 0 0; 0 1 … 0 0])
 ([0.6636929 0.49097973 … 0.663972 0.79873574; 0.19095194 0.06527573 … 0.74971825 0.45436847; … ; 0.8961866 0.19018757 … 0.5380682 0.38584286; 0.6985424 0.44699335 … 0.26459897 0.15281671], [1 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])
 ([0.3495115 0.18851721 … 0.9999215 0.7795014; 0.6919953 0.8691436 … 0.4219218 0.73622715; … ; 0.66776294 0.8726363 … 0.35489386 0.37834877; 0.4895143 0.50491124 … 0.848169 0.8801688], [1 0 … 0 0; 0 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 0 0])
 ([0.93426543 0.1787582 … 0.5586577 0.13671368; 0.28404105 0.5410484 … 0.5570679 0.46198702; … ; 0.202057 0.4057253 … 0.21942729 0.109384894; 0.51172554 0.95929885 … 0.55424845 0.37142193], [0 0 

In [32]:
dense_model = Chain(Dense(300, 100, relu), Dense(100, 10))
variational_model_1 = make_model([300, 100, 10]; variant=:molchanov, activations=[relu], final_activation=identity)
variational_model_2 = make_model([300, 100, 10]; variant=:graves, activations=[relu], final_activation=identity)
variational_model_3 = make_model([300, 100, 10]; variant=:kingma, activations=[relu], final_activation=identity)

VarChain{Vector{Main.VariationalMLP.AbstractVariationalLayer}}(Main.VariationalMLP.AbstractVariationalLayer[VariationalDropoutKingma{typeof(relu)}(Float32[0.79698175 1.7667935 … 0.8886948 0.836783; 1.2104137 0.91788214 … -1.2020686 1.501433; … ; -0.42749378 -0.9193341 … -0.9006701 -0.025668351; 0.6532287 0.8128992 … 0.6692754 0.8686027], Float32[0.2615443 1.0059106 … -0.81860894 0.6270579; 0.5952025 -1.1872762 … 1.5421598 -0.40772814; … ; 1.661227 -0.11368676 … -0.28369877 -0.32568318; 0.40808854 1.1286542 … 0.04657433 -1.3461018], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], NNlib.relu), VariationalDropoutKingma{typeof(identity)}(Float32[-1.4180003 0.08133791 … -0.1769691 0.039955594; -0.18554993 0.5168965 … -0.41056782 -0.036238007; … ; 0.86575174 1.6516117 … -0.27139163 -0.9617676; -0.06687585 -2.1209407 … 0.27028772 0.1208957], Float32[1.6278579 0.40476662 … 0.2731237 1.5446397; -1.4909778 1.0354555 … -0.16032845 1.

In [33]:
bench_dense = @benchmark train_dense_epoch($dense_model, $data_loader)

BenchmarkTools.Trial: 2779 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.482 ms[22m[39m … [35m  8.445 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 78.91%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.642 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.798 ms[22m[39m ± [32m447.298 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m7.15% ± 11.81%

  [39m▅[39m▆[39m▇[39m█[34m█[39m[39m▇[39m▄[39m▂[32m▁[39m[39m [39m [39m [39m [39m▂[39m▃[39m▃[39m▂[39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁
  [39m█[39m█[39m█[39m█[3

In [34]:
bench_molchanov = @benchmark train_variational_epoch($variational_model_1, $data_loader)

BenchmarkTools.Trial: 232 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m17.761 ms[22m[39m … [35m167.262 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 7.10% … 88.90%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m20.208 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 9.43%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m21.587 ms[22m[39m ± [32m 10.026 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.62% ±  7.03%

  [39m [39m [39m▂[39m [39m▂[39m▃[39m▆[39m▄[39m█[34m▃[39m[39m [39m [39m▃[39m▄[39m▁[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▅[39m▃[39

In [35]:
bench_graves = @benchmark train_variational_epoch($variational_model_2, $data_loader)

BenchmarkTools.Trial: 338 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m11.973 ms[22m[39m … [35m140.433 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 4.80% … 90.64%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m14.029 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m10.33%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m14.813 ms[22m[39m ± [32m  7.208 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.88% ±  8.03%

  [39m [39m [39m [39m▄[39m▅[39m█[34m█[39m[39m▄[39m▄[32m▂[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▆[39

In [36]:
bench_kingma = @benchmark train_variational_epoch($variational_model_3, $data_loader)

BenchmarkTools.Trial: 311 samples with 1 evaluation per sample.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m13.440 ms[22m[39m … [35m144.176 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 4.77% … 89.56%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m15.238 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m 9.89%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m16.084 ms[22m[39m ± [32m  7.569 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m12.39% ±  7.91%

  [39m [39m [39m [39m [39m▃[39m▂[39m▃[39m█[34m [39m[39m▄[39m [39m▃[32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▆[39

# Test over optimizers

In [37]:
function train_variational_epoch!(model, data, opt_ctor; N=320) # accepts custom 
    opt = Flux.setup(opt_ctor(), model)
    for (x, y) in data
        loss, grads = Flux.withgradient(model) do m
            energy_loss(m, x, y, N; task_type=:classification)
        end
        Flux.update!(opt, model, grads[1])
    end
end


train_variational_epoch! (generic function with 1 method)

I will define a wrapper function that benchmarks based on optimizer and returns a list of results.

In [38]:
function benchmark_optimizer(opt_ctor, model, data; N=320)
    result = @benchmark train_variational_epoch!($model, $data, $opt_ctor; N=$N) setup=(GC.gc())
    return (
        optimizer = string(opt_ctor),
        time_ns = median(result).time,
        memory = median(result).memory,
        allocs = median(result).allocs,
        gc_time_ns = median(result).gctime
    )
end


benchmark_optimizer (generic function with 1 method)

Let's test some optimizers for the same learning rate:

In [39]:
const LR = 0.01  # fixed learning rate
optimizers = [
    () -> Adam(LR),
    () -> RMSProp(LR),
    () -> Descent(LR),
    () -> AdaGrad(LR)
]


4-element Vector{Function}:
 #15 (generic function with 1 method)
 #16 (generic function with 1 method)
 #17 (generic function with 1 method)
 #18 (generic function with 1 method)

Let's name them 

In [40]:
opt_names = [Adam, RMSProp, Descent, AdaGrad]

4-element Vector{DataType}:
 Adam
 RMSProp
 Descent
 AdaGrad

In [41]:
function experiment_opts(model; optimizers = optimizers, opt_names = opt_names)

    results = [
        merge(benchmark_optimizer(opt_ctor, deepcopy(model), data_loader), (; optimizer = name))

        for (opt_ctor, name) in zip(optimizers, opt_names)
    ]
    df = DataFrame(results)
    df.time_ms = df.time_ns ./ 1e6
    df.gc_ms = df.gc_time_ns ./ 1e6
    select!(df, [:optimizer, :time_ms, :memory, :allocs, :gc_ms])

    return df 

        
end 


experiment_opts (generic function with 1 method)

### Molchanov 

In [42]:
using DataFrames 

In [43]:
df_1 = experiment_opts(variational_model_1)

Row,optimizer,time_ms,memory,allocs,gc_ms
Unnamed: 0_level_1,DataType,Float64,Int64,Int64,Float64
1,Adam,18.9168,63220720,19083,0.0
2,RMSProp,18.1409,62974320,19223,0.0
3,Descent,16.4232,62700464,18751,0.0
4,AdaGrad,18.2282,62953968,18767,0.0


## Graves

In [44]:
df_2 = experiment_opts(variational_model_2)

Row,optimizer,time_ms,memory,allocs,gc_ms
Unnamed: 0_level_1,DataType,Float64,Int64,Int64,Float64
1,Adam,14.2445,45918800,18593,0.0
2,RMSProp,13.5628,45672400,18733,0.0
3,Descent,12.8325,45398544,18261,0.0
4,AdaGrad,13.6418,45652048,18277,0.0


## Kingma

In [45]:
df_3 = experiment_opts(variational_model_3)

Row,optimizer,time_ms,memory,allocs,gc_ms
Unnamed: 0_level_1,DataType,Float64,Int64,Int64,Float64
1,Adam,15.7484,48300880,18293,0.0
2,RMSProp,14.9609,48054480,18433,0.0
3,Descent,14.2896,47780624,17961,0.0
4,AdaGrad,15.0419,48034128,17977,0.0
