In [None]:
using Knet, Plots, Statistics, LinearAlgebra, Base.Iterators, Random, StatsBase
ARRAY = Array{Float64}
LAMBDA=1e-4 # regularization parameter
LR=1e-3     # learning rate
XSIZE=1   # input dimension
YSIZE=1    # output dimension
BATCHSIZE=10   # batch size
DITER=10000 # iterations for diffusion tensor
CITER=10000 # iterations for covariance matrix
CINIT=5000  # throw away this many iterations from trajectory for covariance calc

In [None]:
Range=3.0; # range of the x values for the target Gaussian function
Incr = 0.03; # determines the number of samples from which we'll learn
Noise_std=0.1; # add noise on the Gaussian
HiddenSize=2;
function gen_noisy_gaussian(;range=1.0,noise=0.1)
    x = collect(-Range:Incr:Range)
    y = exp.(-x.^2) + randn(length(x))*noise; # additive gaussian noise
    return (ARRAY(x),ARRAY(y))
end
Random.seed!(4);
(xtrn,ytrn) = gen_noisy_gaussian(range=Range,noise=Noise_std);
pop!(xtrn);pop!(ytrn);
atrn = xtrn # atrn is used in some boxes below
#@show size(xtrn),size(ytrn)

function flat(w) # make a single vector out of all weights
    return vcat(w[1],w[2],w[3],w[4])
end

function unflat(wf)
    return (wf[1:HiddenSize],wf[HiddenSize+1:2*HiddenSize],wf[2*HiddenSize+1:3*HiddenSize],wf[end])
end

function pred(wf,x) # returns a row of predicted values for each sample in x
    w = unflat(wf)
    return w[3]'*tanh.(w[1]*x' .+ w[2]) .+ w[4]
end

function loss(wf,x,y)
    return mean(abs2,y'-pred(wf,x)) + (LAMBDA/2)*sum(abs2,wf) 
    # wf is 7 dimensional, the following assumes it is the unflat version
    # sum(norm(wf[i])^2  for i=1:4)
end
lossgradient = grad(loss)

println.(summary.((xtrn,ytrn)));

# Find minimum

In [None]:
# Find minimum without minibatching
# ~50 iters/sec, converges to .267218 in 3 mins
wminfile = "wmin_cov_test.jld2"
if true # !isfile(wminfile)
#    wmin = Param(ARRAY(zeros(XSIZE*YSIZE)))
    wmin = Param(ARRAY(0.1*randn(7))) # cannot init with zero for multi-layer net
    args = repeat([(wmin,xtrn,ytrn)],10000)
    Knet.gc()
    losses = collect(progress(adam(loss,args)))
    Knet.save(wminfile, "wmin", wmin, "losses", losses)
else
    wmin, losses = Knet.load(wminfile, "wmin", "losses");
end
@show value(wmin)
losses[end-4:end]'

In [None]:
plot(xtrn,[ytrn pred(wmin,xtrn)'])

# Hessian of loss around minimum

In [None]:
function hessian(loss,w,x,y)
    ∇loss = grad(loss)
    ∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
    ∇∇lossi = grad(∇lossi)
    w = value(w)
    n = length(w)
    h = similar(w,n,n)
    for i in progress(1:n)
        h[:,i] .= vec(∇∇lossi(w,x,y,i))
    end
    return h
end

In [None]:
# Compute hessian: ~6 mins, ~4:20 with slower _logp? TODO: reoptimize loss.jl
hessfile = "hess_cov_test.jld2"
if true # !isfile(hessfile)
    Knet.gc()
    hmin = hessian(loss,wmin,atrn,ytrn)
    Knet.save(hessfile,"h",hmin)
else
    hmin = Knet.load(hessfile,"h")
end
println.((summary(hmin),extrema(Array(hmin)),norm(hmin),norm(hmin-hmin')));

In [None]:
heigfile = "heig_cov_test.jld2"
H = Symmetric(Array(0.5*(hmin + hmin')))
if true # !isfile(heigfile)
    @time eigenH = eigen(H) # ~53s
    Knet.save(heigfile,"eigenH",eigenH)
else
    eigenH = Knet.load(heigfile,"eigenH")
end
eigenH.values'

In [None]:
#plot(eigenH.values, yscale=:log10) |> display
#describe(eigenH.values)

# Hessian (numeric check)

In [None]:
# f(w) ≈ f(wmin) + (w-wmin)' g + 1/2 (w-wmin)' H (w-wmin)
# Gradient at wmin is ≈0, so the middle term can be assumed 0
df = @diff loss(wmin,atrn,ytrn)
J = vec(grad(df, wmin)); @show summary(J)
@show norm(J);

In [None]:
# Test approx at ~0.1 distance around wmin
# adding first order term does not make much difference as expected
wrnd = 0.1 * randn!(similar(wmin)) / sqrt(length(wmin))
lossw(w) = loss(w,atrn,ytrn)
@show lossw(wmin)
@show lossw(wmin + wrnd)
@show lossw(wmin) + 0.5 * wrnd' * hmin * wrnd
@show lossw(wmin) + J' * wrnd + 0.5 * wrnd' * hmin * wrnd

# Minibatch data

In [None]:
# Minibatching for SGD-I, i.e. with replacement. Knet.minibatch can't do this, we define new struct
struct MB; x; y; n; end
Base.Iterators.IteratorSize(::MB) = Base.IsInfinite()
#Base.iterate(d::MB, s...)=(r = rand(1:length(d.y),d.n); ((ARRAY(mat(d.x)[:,r]), d.y[r]), true))
Base.iterate(d::MB, s...)=(r = rand(1:length(d.y),d.n); ((d.x[r], d.y[r]), true))
dtrn = MB(xtrn, ytrn, BATCHSIZE)
println.(summary.((xtrn,ytrn,first(dtrn)...)));

# Diffusion Tensor

In [None]:
function diffusiontensor(loss,w,x,y;iters=DITER,lr=LR,batchsize=BATCHSIZE)
    ∇loss = grad(loss)
    grad0 = ∇loss(w, ARRAY(x), y)
    data = MB(x,y,batchsize)
    grads = ( ∇loss(w,x,y) for (x,y) in take(data,iters) )
    prefac = lr^2/(2iters)
    v = ARRAY(zeros(length(w),length(w)))
    for g in progress(grads)
        e=vec(grad0-g)
        axpy!(prefac,e*e',v)
    end
    return Array(v)
end

In [None]:
# compute diffusion tensor ~20 iters/sec, ~1000 iters/min
dtfile = "dt_cov_test.jld2"
if true # !isfile(dtfile)
    Knet.gc()
    D = diffusiontensor(loss,wmin,xtrn,ytrn)
    Knet.save(dtfile,"D",D)
else
    D = Knet.load(dtfile,"D")
end
summarystats(vec(D)) |> dump

In [None]:
#Convergence (with LR=0.1):
#norm(d100),norm(d1000),norm(d100-d1000)
#(9.006901905269366e-5, 7.966566524654371e-5, 4.746245825247884e-5)
#norm(d1000),norm(d2000),norm(d1000-d2000)
#(7.966566524654371e-5, 7.976183314696048e-5, 1.8908596356951573e-5)
#norm(d4000),norm(d2000),norm(d4000-d2000)
#(8.024754500933944e-5, 7.976183314696048e-5, 1.3446098748867312e-5)
#norm(d10000),norm(d4000),norm(d10000-d4000)
#(7.942869188760434e-5, 8.024754500933944e-5, 8.963487531775732e-6)

In [None]:
deigfile = "deig_cov_test.jld2"
if true # !isfile(deigfile)
    @time eigenD = eigen(Symmetric(D)) # ~53s
    Knet.save(deigfile,"eigenD",eigenD)
else
    eigenD = Knet.load(deigfile,"eigenD")
end
eigenD.values'

In [None]:
#plot(eigenD.values .+ 1e-23, yscale=:log10) |> display
#summarystats(eigenD.values) |> dump

# Record trajectory with SGD starting at minimum

In [None]:
# Trajectory of w starting from wmin recorded after each update: 
# ~1000 updates/sec, ~16 secs total
trajfile = "traj_cov_test.jld2"
if true # !isfile(trajfile)
    w = Param(ARRAY(value(wmin)))
    data = MB(xtrn,ytrn,BATCHSIZE)
    d = take(data,CITER-1)
    W = zeros(eltype(w),length(w),1+length(d))
    i = 1; W[:,i] = Array(vec(w))
    f(x,y) = loss(w,x,y)
    Knet.gc()
    for t in progress(sgd(f,d; lr=LR))
        i += 1
        W[:,i] = Array(vec(w))
    end
    Knet.save(trajfile,"W",W)
else
    W = Knet.load(trajfile,"W")
end
summary(W)

In [None]:
# Plot losses on whole dataset, first steps seem transient, ~10 secs
r = 1:10:size(W,2)
@time plot(r, [loss(ARRAY(W[:,i]),atrn,ytrn) for i in r])

In [None]:
# Plot trajectory of two random dimensions
# Seems to converge to a slightly different point?
# Interesting patterns: staircase, globe, H shaped
@show r1,r2 = rand(1:size(W,1)),rand(1:size(W,1))
scatter(W[r1,1:end],W[r2,1:end])
scatter!(W[r1,1:10], W[r2,1:10],mc=:red) # mark beginning with red
scatter!(W[r1,end-9:end],W[r2,end-9:end],mc=:yellow) # mark end with yellow

In [None]:
# Minibatch training seems to converge to a slightly worse spot
w0 = Array(value(wmin))
μ = mean(W[:,CINIT:end],dims=2)
w1 = W[:,end]
@show norm(w0), norm(μ), norm(w0 - μ)
@show extrema(w0), extrema(μ), extrema(w0 - μ)
@show mean(abs.(w0 - μ) .> 0.01)
@show loss(w0,xtrn,ytrn)
@show loss(μ,xtrn,ytrn)
@show loss(w1,xtrn,ytrn)

# Covariance of SGD trajectory around minimum

In [None]:
Wstable = W[:,CINIT:end];  @show summary(Wstable)
μ = mean(Wstable,dims=2); @show summary(μ)
Wzero = Wstable .- μ;     @show summary(Wzero)
Σ = (Wzero * Wzero') / size(Wzero,2); @show summary(Σ)
@show norm(Σ),extrema(Σ);

In [None]:
ceigfile = "ceig07.jld2"
if true # !isfile(ceigfile)
    @time eigenC = eigen(Symmetric(Σ)) # ~53s
    Knet.save(ceigfile,"eigenC",eigenC)
else
    eigenC = Knet.load(ceigfile,"eigenC")
end
eigenC.values'

In [None]:
#plot(eigenC.values .+ 1e-19, yscale=:log10) |> display
#summarystats(eigenC.values) |> dump

# Check equation

In [None]:
summary.((H,D,Σ))

In [None]:
a = H*Σ + Σ*H
b = (2/LR)*D
a ≈ b

In [None]:
norm(a),norm(b),norm(a-b)

# Try fit_mle for covariance: gives the same result

In [None]:
# Σ is not positive definite, MLE fails in mnist because Hessian=inverse(Σ)
# Note that this is not the same as the loss Hessian, it is the distribution Hessian!
using Distributions
fit = fit_mle(MvNormal, Wstable)
C = fit.Σ + zeros(7,7)

In [None]:
a = H*C + C*H
b = (2/LR)*D
a ≈ b

In [None]:
norm(a),norm(b),norm(a-b)

In [None]:
C ≈ Σ