In [1]:
using Pkg
Pkg.activate(".")
using LaplaceRedux
using LaplaceRedux.Curvature
using LaplaceRedux.Data
using Flux
using Flux.Optimise: update!, Adam
using Plots
using Statistics
using MLUtils
using Zygote
using Printf
using NNlib
using BenchmarkTools
using Tullio

[32m[1m  Activating[22m[39m project at `C:\Users\Lenovo\Desktop\SP\LaplaceRedux.jl\dev\notebooks\KFAC`
[33m[1m│ [22m[39mFor performance reasons, it is recommended to upgrade to a driver that supports CUDA 11.2 or higher.
[33m[1m└ [22m[39m[90m@ CUDA C:\Users\Lenovo\.julia\packages\CUDA\pCcGc\src\initialization.jl:76[39m


In [2]:
n = 2000
data_dict = Dict()
bsize = 10

x, y = LaplaceRedux.Data.toy_data_regression(n)
xs = [[x] for x in x]
X, Y = reduce(hcat, x), reduce(hcat, y)

dataloader = DataLoader((X, Y), batchsize=bsize)
data = zip(xs, y)
data_dict[:regression] = Dict(
    :data => data,
    :X => X,
    :y => y,
    :outdim => 1,
    :loss_fun => :mse,
    :likelihood => :regression,
)

Dict{Symbol, Any} with 6 entries:
  :loss_fun   => :mse
  :y          => [-0.386202, 0.570752, 0.625995, 0.707377, -0.62453, 0.264788, …
  :likelihood => :regression
  :X          => [5.88109 6.67184 … 4.65069 2.39512]
  :outdim     => 1
  :data       => zip([[5.88109], [6.67184], [2.47541], [7.34438], [5.66666], [6…

In [3]:
# Train a NN model

val = data_dict[:regression]

# Unpack:
data = val[:data]
X = val[:X]
y = val[:y]
outdim = val[:outdim]
loss_fun = val[:loss_fun]
likelihood = val[:likelihood]

# Neural network:
n_hidden = 32
D = size(X, 1)
nn = Chain(Dense(D, n_hidden, σ), Dense(n_hidden, outdim))
λ = 0.01
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1 / 2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = getfield(Flux.Losses, loss_fun)(nn(x), y) + weight_regularization()


opt = Adam()
epochs = 200
avg_loss(data) = mean(map(d -> loss(d[1], d[2]), data))
show_every = epochs / 10

for epoch in 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end

H_facs = nothing

[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(1 => 32, σ)   [90m# 64 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "1-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux C:\Users\Lenovo\.julia\packages\Flux\EHgZm\src\layers\stateless.jl:60[39m


Epoch 20
avg_loss(data) = 0.14175186915526503
Epoch 40
avg_loss(data) = 0.10254603506120666
Epoch 60
avg_loss(data) = 0.10143080699237239
Epoch 80
avg_loss(data) = 0.10076707222884074
Epoch 100
avg_loss(data) = 0.10027311312150973
Epoch 120
avg_loss(data) = 0.09986135727059184
Epoch 140
avg_loss(data) = 0.09950496421290969
Epoch 160
avg_loss(data) = 0.09919309187483936
Epoch 180
avg_loss(data) = 0.0989243709846722
Epoch 200
avg_loss(data) = 0.09869832505084136


In [4]:
function get_array_sizes(array)
    if isa(array, AbstractArray)
        sizes = size(array)
        println("Array size: $sizes")
        
        for element in array
            get_array_sizes(element)
        end
    end
end

get_array_sizes (generic function with 1 method)

In [5]:
function print_arrays(array)
    if isa(array, AbstractArray)
        if ndims(array) == 2  # Check if it's a matrix
            println(array)
        elseif ndims(array) > 2  # Check if it's an array with more than two dimensions
            println(array)
        end

        for element in array
            print_arrays(element)
        end
    end
end


print_arrays (generic function with 1 method)

In [6]:
mutable struct Kron
    kfacs :: Union{AbstractArray,Nothing}
end

mutable struct KronDecomposed
    eigenvectors :: Union{AbstractArray,Nothing}
    eigenvalues :: Union{AbstractArray,Nothing}
    damping :: Bool
end


In [7]:
function init(model)
    kfacs = []

    for p in Flux.params(model)
        if ndims(p) == 1  # bias
            P = size(p, 1)
            push!(kfacs, [zeros(P, P)])
        elseif 4 >= ndims(p) >= 2  # fully connected or conv
            if ndims(p) == 2  # fully connected
                P_in, P_out = size(p)
            elseif ndims(p) > 2
                P_in, P_out = size(p, 1), prod(size(p)[2:end])
            end
            
            push!(kfacs, [
                zeros(P_in, P_in),
                zeros(P_out, P_out)
            ])
        else
            error("Invalid parameter shape in network.")
        end
    end
    @show kfacs
    @show get_array_sizes(kfacs)
    @show print_arrays(kfacs)
    return Kron(kfacs)
end

init (generic function with 1 method)

In [8]:
function fitBeta(la::Laplace, data; batched::Bool=false, batchsize::Int, override::Bool=true)
    if override
        H = init(la.model)          
        loss = 0.0
        n_data = 0
    end

    # Training:
    for d in data
        x, y = d
        loss_batch, H_batch =_curv_closure(la.curvature, x, y, length(data))
        loss += loss_batch
        @show(H_batch)
        @show(H)
        H = [[Fi + Fj for (Fi, Fj) in zip(Fi_row, Fj_row)] for (Fi_row, Fj_row) in zip(H.kfacs, H_batch.kfacs)]
        n_data += batchsize
    end

    # Store output:
    la.loss = loss                                                           # Loss
    la.H = H                                                                 # Hessian
    la.P = LaplaceRedux.posterior_precision(la)                              # posterior precision
    la.Σ = LaplaceRedux.posterior_covariance(la)                             # posterior covariance
    return la.n_data = n_data                                                # number of observations
end

fitBeta (generic function with 1 method)

In [9]:
function fitAux(la, train_loader, override=true, damping=false)
    if override
        H_facs = nothing
    end

    if !isnothing(H_facs)
        n_data_old = la.n_data
        n_data_new = length(train_loader)
        la.H = init(la.model) # re-init H non-decomposed
        # discount previous Kronecker factors to sum up properly together with new ones
        H_facs = _rescale_factors(H_facs, n_data_old / (n_data_old + n_data_new))
    end

    fitBeta(la, train_loader, batched=false, batchsize=1, override=override)

    if isnothing(H_facs)
        H_facs =la.H
    else
        # discount new factors that were computed assuming N = n_data_new
        la.H = _rescale_factors(la.H, n_data_new / (n_data_new + n_data_old))
        H_facs += la.H
    end
    # Decompose to self.H for all required quantities but keep H_facs for further inference
    la.H = decompose(la.H_facs, damping=damping)
end

fitAux (generic function with 3 methods)

In [10]:
function _curv_closure(curvature, x, y, N)
    return kron(curvature, x, y, N)#la.backend.kron(X, y, N=N)
end

_curv_closure (generic function with 1 method)

In [11]:
function kron(curvature, x, y, N)
    #context = ifelse(self.stochastic, KFAC, KFLR)
    loss = curvature.factor * curvature.loss_fun(x, y)
    𝐠 = gradient(() -> curvature.loss_fun(x, y), Flux.params(curvature.model))
    𝐠 = reduce(vcat, [vec(𝐠[i]') for i in curvature.params])  
    # backpack(context()) do
    #     backward(loss)
    # end
    kron = Kron(𝐠)
    @show(kron)
    kron = _rescale_kron_factors(kron, length(y), N)
    return curvature.factor * loss, kron#curvature.factor * detach(loss), curvature.factor * kron
end

kron (generic function with 1 method)

In [12]:
function decompose(kron,damping=false)
    """
    Eigendecompose Kronecker factors and turn into `KronDecomposed`.
    
    Parameters
    ----------
    damping : bool
        use damping

    Returns
    -------
    kron_decomposed : KronDecomposed
    """
    eigvecs = []
    eigvals = []
    for F in kron.kfacs
        Qs = []
        ls = []
        for Hi in F
            l, Q = symeig(Hi)
            push!(Qs, Q)
            push!(ls, l)
        push!(eigvecs, Qs)
        push!(eigvals, ls)
        end
    end
    return KronDecomposed(eigvecs, eigvals, damping=damping)
end

decompose (generic function with 2 methods)

In [13]:
function _rescale_factors(kron, factor)
    for F in kron.kfacs
        if length(F) == 2
            F[1] *= factor
        end
    end
    return kron
end

_rescale_factors (generic function with 1 method)

In [14]:
function _rescale_kron_factors(kron, M, N)
    # Renormalize Kronecker factor to sum up correctly over N data points with batches of M
    # for M=N (full-batch) just M/N=1
    for F in kron.kfacs
        if length(F) == 2
            F[1] *= M/N
        end
    end
    return kron
end

_rescale_kron_factors (generic function with 1 method)

In [20]:
function _get_kron_factors(la)
    return Kron([p.kfac for p in la.model.parameters])
end

_get_kron_factors (generic function with 1 method)

In [21]:
dataloader = DataLoader((X, Y), batchsize=10)

200-element DataLoader(::Tuple{Matrix{Float64}, Matrix{Float64}}, batchsize=10)
  with first element:
  (1×10 Matrix{Float64}, 1×10 Matrix{Float64},)

In [22]:
function fit_la(nn, dataloader, X, y)
    la_b = Laplace(nn; likelihood=:regression, λ=λ, subset_of_weights=:all)
    fitAux(la_b, dataloader)
    plot(la_b, X, y )
end

fit_la (generic function with 1 method)

In [23]:
fit_la(nn, dataloader, X, y)

kfacs = Any[[[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

LoadError: MethodError: no method matching +(::Matrix{Float64}, ::Float32)
For element-wise addition, use broadcasting with dot syntax: array .+ scalar
[0mClosest candidates are:
[0m  +(::Any, ::Any, [91m::Any[39m, [91m::Any...[39m) at operators.jl:591
[0m  +([91m::T[39m, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:383
[0m  +([91m::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}[39m, ::Any) at C:\Users\Lenovo\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:154
[0m  ...