## Zygote self-autodifferentiability

In [1]:
using Zygote

In [53]:
function fj(x, bias)
    jacc = Zygote.jacobian(x->x.^2, x)
    jac = jacc[1]
    # @info jacc
    @info "jac" jac
    return jac * x .+ bias
end

fj (generic function with 1 method)

In [52]:
function fg(x, bias)
    grad = Zygote.gradient(x->sum(x.^2), x)[1]
    @info "grad" grad
    return grad .* x .+ bias
end

fg (generic function with 1 method)

In [54]:
x = [1., 2., 3.]
b = [0., 0., 0.]

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mjac
[36m[1m│ [22m[39m  jac =
[36m[1m│ [22m[39m   3×3 Matrix{Float64}:
[36m[1m│ [22m[39m    2.0  0.0  0.0
[36m[1m│ [22m[39m    0.0  4.0  0.0
[36m[1m└ [22m[39m    0.0  0.0  6.0


3-element Vector{Float64}:
  2.0
  8.0
 18.0

In [99]:
function altjacobian(f, arg, outdim)
   (reduce(vcat, map(outidx -> gradient(x -> f(x)[outidx], arg)[1]', range(1, outdim))), ) 
end

altjacobian (generic function with 2 methods)

In [None]:
for i in range(1,3)
    @info Zygote.gradient()
end

In [85]:
function f(x)
    @info "f run"
    [x[1]^2, x[2] + x[3]]
end

f (generic function with 2 methods)

In [68]:
f([2,3,4])

2-element Vector{Int64}:
 4
 7

In [105]:
jacobian(f, [2.,3.,4.])

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run


([4.0 0.0 0.0; 0.0 1.0 1.0],)

In [104]:
altjacobian(f, [2.,3.,4.], 2)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run


([4.0 0.0 0.0; 0.0 1.0 1.0],)

In [87]:
g = gradient(x -> f(x)[1], [2,3,4])[1]

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run


3-element Vector{Float64}:
 4.0
 0.0
 0.0

In [88]:
arg = [2,3,4]


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mf run


2×3 Matrix{Float64}:
 4.0  0.0  0.0
 0.0  1.0  1.0

In [96]:
xx = reduce(hcat, [[1, 2], [3, 4]])

2×2 Matrix{Int64}:
 1  3
 2  4

In [98]:
jacobian(x -> xx, 0)[1]

4-element Vector{Int64}:
 0
 0
 0
 0

In [63]:
altjacobian(1,1)

(1,)

In [None]:
fj(x, b)

In [55]:
fg(x, [0, 0, 0])

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mgrad
[36m[1m│ [22m[39m  grad =
[36m[1m│ [22m[39m   3-element Vector{Float64}:
[36m[1m│ [22m[39m    2.0
[36m[1m│ [22m[39m    4.0
[36m[1m└ [22m[39m    6.0


3-element Vector{Float64}:
  2.0
  8.0
 18.0

In [56]:
Zygote.gradient(b -> sum(fj(x,b)), rand(3))

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m([2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0],)
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m[2.0 0.0 0.0; 0.0 4.0 0.0; 0.0 0.0 6.0]


LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations


In [57]:
Zygote.gradient(b -> sum(fg(x,b)), rand(3))

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mgrad
[36m[1m│ [22m[39m  grad =
[36m[1m│ [22m[39m   3-element Vector{Float64}:
[36m[1m│ [22m[39m    2.0
[36m[1m│ [22m[39m    4.0
[36m[1m└ [22m[39m    6.0


(Fill(1.0, 3),)

In [60]:
Zygote.gradient(x -> sum(fg(x,b)), rand(3))

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mgrad
[36m[1m│ [22m[39m  grad =
[36m[1m│ [22m[39m   3-element Vector{Float64}:
[36m[1m│ [22m[39m    0.8113012203912522
[36m[1m│ [22m[39m    0.8916525522283958
[36m[1m└ [22m[39m    0.25571020915909703


([1.6226024407825044, 1.7833051044567916, 0.5114204183181941],)

In [61]:
Zygote.gradient(x -> sum(fj(x,b)), rand(3))

[36m[1m┌ [22m[39m[36m[1mInfo: [22m[39mjac
[36m[1m│ [22m[39m  jac =
[36m[1m│ [22m[39m   3×3 Matrix{Float64}:
[36m[1m│ [22m[39m    1.29859  0.0       0.0
[36m[1m│ [22m[39m    0.0      0.470737  0.0
[36m[1m└ [22m[39m    0.0      0.0       1.56187


LoadError: Mutating arrays is not supported -- called copyto!(SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations


## Zygote diff-ability wrt array elements

In [41]:
x = [1 2 3; 4 5 6]; y = [2, 3]; z = [1, 10, 100];
# x .* y .* z'
ps = Params([x[2,3]])
g = gradient(ps) do
    sum(x .* y .* z')
end
x[2,3]

Grads(...)

In [34]:
g[x]

2×3 Matrix{Float64}:
 2.0  20.0  200.0
 3.0  30.0  300.0

In [42]:
g[x[2,3]]

LoadError: Only reference types can be differentiated with `Params`.

## Testing DataLoader with existing code

In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote

[32m[1m  Activating[22m[39m project at `~/Builds/navimakarov/LaplaceRedux.jl/dev/notebooks/batching`


In [12]:
# SETUP
n = 100
data_dict = Dict()

# Regression batched:
x, y = LaplaceRedux.Data.toy_data_regression(n)
xs = [[x] for x in x]
X, Y = reduce(hcat, x), reduce(hcat, y)

data = DataLoader((X, Y), batchsize=10)
data = (xs, y)
data_dict[:regression_batched] = Dict(
    :data => data,
    :X => X,
    :y => y,
    :outdim => 1,
    :loss_fun => :mse,
    :likelihood => :regression,
)

Dict{Symbol, Any} with 6 entries:
  :loss_fun   => :mse
  :y          => [0.604458, -0.276429, 0.185929, 0.269064, 1.23586, -0.626247, …
  :likelihood => :regression
  :X          => [0.590667 3.4422 … 2.14491 1.3144]
  :outdim     => 1
  :data       => ([[0.590667], [3.4422], [0.205944], [0.0737931], [1.99856], [3…

In [5]:
X

1×100 Matrix{Float64}:
 6.2507  4.97587  4.48094  1.67  …  0.495325  4.22701  3.12035  7.23925

In [4]:
Y

1×100 Matrix{Float64}:
 0.436515  -0.625754  -1.29171  0.439809  …  -0.60571  -0.0255798  0.998537

In [7]:
data

10-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}}, batchsize=10)
  with first element:
  (1×10 Matrix{Float64}, 1×10 Matrix{Float64},)

In [145]:
reduce(hcat, y)

1×10 Matrix{Float64}:
 -0.675062  -1.38496  -0.84902  -1.02599  …  -0.210813  1.18098  0.0332229

In [140]:
typeof(nn(x))

Matrix{Float32}[90m (alias for [39m[90mArray{Float32, 2}[39m[90m)[39m

In [135]:
nnx = nn(x)

1×10 Matrix{Float32}:
 -0.496463  -0.531433  -0.493528  -0.51279  …  -0.469021  -0.40483  -0.254351

In [137]:
size(nnx)

(1, 10)

In [18]:
for d in data
    x, y = d
    # # Assert on batching
    # @assert size(x)[end] == 10
    # @assert size(x)[end] == 10
    # @sho
    @show nn(x)
    @show size(nn(x))
    @show y
    @show size(y)
    break
end

nn(x) = Float32[-1.5544128]
size(nn(x)) = (1,)
y = [3.442203160906069]
size(y) = (1,)


In [136]:
# Here we have "zip" code copied from the laplace.jl tests

# Regression:
x, y = LaplaceRedux.Data.toy_data_regression(n)
xs = [[x] for x in x]
X = hcat(xs...)
Y = reduce(hcat, y)
data = zip(xs, y)
data_dict[:regression] = Dict(
    :data => data,
    :X => X,
    :y => y,
    :outdim => 1,
    :loss_fun => :mse,
    :likelihood => :regression,
)

val = data_dict[:regression]

# Unpack:
data = val[:data]
X = val[:X]
y = val[:y]
outdim = val[:outdim]
loss_fun = val[:loss_fun]
likelihood = val[:likelihood]

# Neural network:
n_hidden = 32
D = size(X, 1)
nn = Chain(Dense(D, n_hidden, σ), Dense(n_hidden, outdim))
λ = 0.01
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1 / 2 * λ^2 * sum(sqnorm, Flux.params(nn))
# loss(x, y) = getfield(Flux.Losses, loss_fun)(nn(x), y) + weight_regularization()
# loss(x, y; kwargs...) = getfield(Flux.Losses, loss_fun)(nn(x), y; kwargs...) + weight_regularization()
# loss(x, y) = getfield(Flux.Losses, loss_fun)(nn(x), y, agg=sum) + weight_regularization()
loss(x, y) = Flux.Losses.mse(nn(x), y, agg=sum) + weight_regularization()

opt = Adam()
epochs = 200
avg_loss(data) = mean(map(d -> loss(d[1], d[2]), data))
show_every = epochs / 10

for epoch in 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end


Epoch 20
avg_loss(data) = 0.5183304577653728
Epoch 40
avg_loss(data) = 0.49096486112199406
Epoch 60
avg_loss(data) = 0.47714176549877707
Epoch 80
avg_loss(data) = 0.46774219217521074
Epoch 100
avg_loss(data) = 0.4591247429787631
Epoch 120
avg_loss(data) = 0.45029180242439104
Epoch 140
avg_loss(data) = 0.4409879624804595
Epoch 160
avg_loss(data) = 0.43108512387503894
Epoch 180
avg_loss(data) = 0.42045859958327
Epoch 200
avg_loss(data) = 0.40896995681616233


In [161]:
# Fit unbatched LA
if outdim == 1
    la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)
    fit!(la, data)
    # optimize_prior!(la; verbose=true)
    plot(la, X, y, title="batchsize=N/A")                              # standard
    savefig(@sprintf("fig-01-%02d.png", 0))
    # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5))  # lims
    # plot(la, X, y; link_approx=:plugin)         # plugin approximation
end

"/home/vd/Builds/navimakarov/LaplaceRedux.jl/dev/notebooks/batching/fig-01-00.png"

In [162]:
size(X)

(1, 100)

In [42]:
using Printf

In [131]:
loss(x, y) = Flux.Losses.mse(nn(x), y, agg=sum) + weight_regularization()

loss (generic function with 1 method)

In [134]:
# The loss function should be summative
# The loss of the batch is the sum of losses of its part
x_1 = [1.5261694931931657]
y_1 = [1.5517179074064962]
x_2 = [5.795661372040879]
y_2 = [-0.6585143352250986]
x_b = hcat(x_1, x_2)
y_b = hcat(y_1, y_2)
@show loss(x_1, y_1)
@show loss(x_2, y_2)
@show loss(x_1, y_1) + loss(x_2, y_2)
@show loss(x_b, y_b)

loss(x_1, y_1) = 1.2346902412711558
loss(x_2, y_2) = 0.3298519739905986
loss(x_1, y_1) + loss(x_2, y_2) = 1.5645422152617545
loss(x_b, y_b) = 1.5630184623774577


1.5630184623774577

In [173]:
# Now batch the same data, repeat the procedure

# X = reduce(hcat, x)
Y = reduce(hcat, y)

b = 2
data = DataLoader((X, Y), batchsize=b)
# Fit LA
if outdim == 1
    la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)
    fit!(la, data)
    # optimize_prior!(la; verbose=false)
    plot(la, X, y, title="batchsize=$b")                              # standard
    # savefig(@sprintf("fig-%02d.png", batchsize))
    # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5))  # lims
    # plot(la, X, y; link_approx=:plugin)         # plugin approximation
end

d = ([1.2715212289774493 2.0476179114531954], [0.9257148467376903 1.2434332632351088])
𝐠 = Float32[-1.5193033, -1.5851523, -2.04893, -1.5471814, -1.5022987, -1.2056206, -1.4960097, -1.6127374, -1.528202, -1.5929381, -1.4918164, -1.5876769, -1.599696, -1.5008733, -1.5858833, -1.7561449, -1.033882, -1.5522914, -1.425071, -1.5456303, -1.1466767, -1.6216621, -1.0841335, -1.6034269, -1.4396887, -2.1658235, -1.8749877, -1.4993583, -1.8684943, -1.8874115, -1.5460135, -1.5593024, -3.1463819]
size(𝐠) = (33,)

d = ([1.2715212289774493], [0.9257148467376903])
𝐠 = Float32[-0.52817833, -0.43119496, -0.5210522, -0.4708877, -0.45764622, -0.42256445, -0.45563537, -0.43949524, -0.53112304, -0.43344456, -0.4039374, -0.43194032, -0.43552247, -0.4571392, -0.43146086, -0.4209185, -0.36306658, -0.42193952, -0.49689957, -0.47045025, -0.40221792, -0.44249928, -0.38054764, -0.43663415, -0.4390404, -0.56177676, -0.46146542, -0.45663112, -0.45890695, -0.4663252, -0.47067362, -0.423838, -0.9058165]
size(𝐠) = (33,

LoadError: AssertionError: isapprox(H_1 + H_2, H_batch, atol = 0.05)

In [167]:
g_1 = Float32[-0.52817833, -0.43119496, -0.5210522, -0.4708877, -0.45764622, -0.42256445, -0.45563537, -0.43949524, -0.53112304, -0.43344456, -0.4039374, -0.43194032, -0.43552247, -0.4571392, -0.43146086, -0.4209185, -0.36306658, -0.42193952, -0.49689957, -0.47045025, -0.40221792, -0.44249928, -0.38054764, -0.43663415, -0.4390404, -0.56177676, -0.46146542, -0.45663112, -0.45890695, -0.4663252, -0.47067362, -0.423838, -0.9058165]
g_2 = Float32[-0.9911245, -1.1539569, -1.5278772, -1.0762932, -1.044652, -0.78305584, -1.0403739, -1.1732417, -0.99707854, -1.1594931, -1.0878786, -1.1557361, -1.1641731, -1.0437337, -1.1544219, -1.3352259, -0.67081517, -1.1303514, -0.92817104, -1.0751797, -0.7444584, -1.1791624, -0.7035855, -1.1667923, -1.0006479, -1.6040461, -1.4135218, -1.0427268, -1.4095868, -1.4210858, -1.0753393, -1.1354641, -2.2405646]
g_batch = Float32[-1.5193033, -1.5851523, -2.04893, -1.5471814, -1.5022987, -1.2056206, -1.4960097, -1.6127374, -1.528202, -1.5929381, -1.4918164, -1.5876769, -1.599696, -1.5008733, -1.5858833, -1.7561449, -1.033882, -1.5522914, -1.425071, -1.5456303, -1.1466767, -1.6216621, -1.0841335, -1.6034269, -1.4396887, -2.1658235, -1.8749877, -1.4993583, -1.8684943, -1.8874115, -1.5460135, -1.5593024, -3.1463819]

33-element Vector{Float32}:
 -1.5193033
 -1.5851523
 -2.04893
 -1.5471814
 -1.5022987
 -1.2056206
 -1.4960097
 -1.6127374
 -1.528202
 -1.5929381
 -1.4918164
 -1.5876769
 -1.599696
  ⋮
 -1.6216621
 -1.0841335
 -1.6034269
 -1.4396887
 -2.1658235
 -1.8749877
 -1.4993583
 -1.8684943
 -1.8874115
 -1.5460135
 -1.5593024
 -3.1463819

In [174]:
g_1 * g_2'

33×33 Matrix{Float32}:
 0.52349   0.609495  0.806992  0.568475  …  0.567971  0.599728  1.18342
 0.427368  0.49758   0.658813  0.464092     0.463681  0.489606  0.96612
 0.516428  0.601272  0.796104  0.560805     0.560308  0.591636  1.16745
 0.466708  0.543384  0.719459  0.506813     0.506364  0.534676  1.05505
 0.453584  0.528104  0.699227  0.492562     0.492125  0.519641  1.02539
 0.418814  0.487621  0.645627  0.454803  …  0.4544    0.479807  0.946783
 0.451591  0.525784  0.696155  0.490397     0.489963  0.517358  1.02088
 0.435594  0.507159  0.671495  0.473026     0.472607  0.499031  0.984717
 0.526409  0.612893  0.811491  0.571644     0.571137  0.603071  1.19002
 0.429598  0.500176  0.66225   0.466513     0.4661    0.492161  0.971161
 0.400352  0.466126  0.617167  0.434755  …  0.43437   0.458656  0.905048
 0.428107  0.498441  0.659952  0.464894     0.464482  0.490453  0.96779
 0.431657  0.502574  0.665425  0.46875      0.468334  0.49452   0.975816
 ⋮                                  

In [176]:
(g_2 * g_1')'

33×33 adjoint(::Matrix{Float32}) with eltype Float32:
 0.52349   0.609495  0.806992  0.568475  …  0.567971  0.599728  1.18342
 0.427368  0.49758   0.658813  0.464092     0.463681  0.489606  0.96612
 0.516428  0.601272  0.796104  0.560805     0.560308  0.591636  1.16745
 0.466708  0.543384  0.719459  0.506813     0.506364  0.534676  1.05505
 0.453584  0.528104  0.699227  0.492562     0.492125  0.519641  1.02539
 0.418814  0.487621  0.645627  0.454803  …  0.4544    0.479807  0.946783
 0.451591  0.525784  0.696155  0.490397     0.489963  0.517358  1.02088
 0.435594  0.507159  0.671495  0.473026     0.472607  0.499031  0.984717
 0.526409  0.612893  0.811491  0.571644     0.571137  0.603071  1.19002
 0.429598  0.500176  0.66225   0.466513     0.4661    0.492161  0.971161
 0.400352  0.466126  0.617167  0.434755  …  0.43437   0.458656  0.905048
 0.428107  0.498441  0.659952  0.464894     0.464482  0.490453  0.96779
 0.431657  0.502574  0.665425  0.46875      0.468334  0.49452   0.975816
 ⋮   

In [172]:
isapprox(g_1 + g_2, g_batch, atol=.00005)

true

In [59]:
for d in data
    x, y = d
    @show x
    @show [x[1]]
    @show y
    break
end

x = [5.175778702715506 7.783548482579036]
[x[1]] = [5.175778702715506]
y = [-1.2076412794280516 0.7338655322191958]


In [158]:
# Plotting confidence range progression for increasing batchsizes
# Save to a series of PNG images, which can be merged to a gif via imagemagick
for batchsize in range(1, 1)
    # Now batch the same data, repeat the procedure
    Y = reduce(hcat, y)
    data = DataLoader((X, Y), batchsize=batchsize)
    # Fit LA
    if outdim == 1
        la = Laplace(nn; likelihood=likelihood, λ=λ, subset_of_weights=:last_layer)
        fit!(la, data)
        # optimize_prior!(la; verbose=true, show_every=10_000)
        plot(la, X, y, title="batchsize=$batchsize")                              # standard
        savefig(@sprintf("fig-01-%02d.png", batchsize))
        # plot(la, X, y; xlims=(-5, 5), ylims=(-5, 5))  # lims
        # plot(la, X, y; link_approx=:plugin)         # plugin approximation
    end
end