In [None]:
using Knet, Plots, Statistics, LinearAlgebra, Base.Iterators, Random
ENV["COLUMNS"]=40
ARRAY = KnetArray{Float64}

# Load MNIST data

In [None]:
include(Knet.dir("data/mnist.jl"))
xtrn,ytrn,xtst,ytst = mnist()
gtrn,gtst = ARRAY(xtrn), ARRAY(xtst) # GPU copies for batch training
println.(summary.((xtrn,ytrn,xtst,ytst,gtrn,gtst)));

# Define regularized linear model with softmax loss

In [None]:
LAMBDA=1e-4

In [None]:
# Use simple linear model with regularization (otherwise w blows up)
struct Linear; w; λ; end    # new type
Linear(w;λ=LAMBDA) = Linear(w,λ)  # constructor
Linear(i::Int,o::Int;λ=LAMBDA) = Linear(param(o,i,atype=ARRAY),λ)  # constructor
(f::Linear)(x) = reshape(f.w,10,:) * mat(x)   # predict
(f::Linear)(x,y) = nll(f(x),y) + (f.λ/2) * sum(abs2,f.w)   # loss

# Find minimum

In [None]:
# Find minimum without minibatching ~50 iters/sec
if !isfile("fmin03.jld2")
    fmin = Linear(784,10)
    data = repeat([(gtrn,ytrn)],10000)
    Knet.gc()
    losses = collect(progress(adam(fmin,data)))
    Knet.save("fmin03.jld2", "fmin", fmin, "losses", losses)
else
    fmin, losses = Knet.load("fmin03.jld2", "fmin", "losses");
end
@show summary(fmin.w)
losses[end-10:end]'

In [None]:
plot(losses, ylim=(.26,.28))

# Minibatch data

In [None]:
# Minibatching for SGD-II, shuffle=true corresponds to SGD-II, i.e. without replacement
# dtrn1 = minibatch(xtrn,ytrn,100;xtype=ARRAY,shuffle=true)
# x1,y1 = first(dtrn1)
# println.(summary.((dtrn1,x1,y1)));

In [None]:
# Minibatching for SGD-I, i.e. with replacement. Knet.minibatch can't do this, we define new struct
struct MB; x; y; n; end
Base.Iterators.IteratorSize(::MB) = Base.IsInfinite()
Base.iterate(d::MB, s...)=(r = rand(1:length(d.y),d.n); ((ARRAY(d.x[:,r]), d.y[r]), true))
dtrn = MB(mat(xtrn), ytrn, 100)
println.(summary.((dtrn,first(dtrn)...)));

# Record trajectory with SGD starting at minimum

In [None]:
LR = 0.001

In [None]:
# Trajectory of w starting from wmin recorded after each update: ~1000 updates/sec
f = deepcopy(fmin)
f.w.opt = nothing  # We do not want to use Adam! => TODO: warn about this.
d = take(dtrn,9999)
W = zeros(eltype(f.w),length(f.w),1+length(d))
i = 1
W[:,i] = vec(Array(f.w))
Knet.gc()
for t in progress(sgd(f,d; lr=LR))
    i += 1
    W[:,i] = vec(Array(f.w))
end
@show summary(W);
@show f.w.opt;

In [None]:
# Plot losses on whole dataset, first steps seem transient
r = 1:10:size(W,2)
@time plot(r, [Linear(ARRAY(W[:,i]))(gtrn,ytrn) for i in r])

In [None]:
# Plot trajectory of two random dimensions
# Seems to converge to a slightly different point?
# Interesting patterns: staircase, globe, H shaped
@show r1,r2 = rand(1:size(W,1)),rand(1:size(W,1))
scatter(W[r1,1:end],W[r2,1:end])
scatter!(W[r1,1:10], W[r2,1:10],mc=:red) # mark beginning with red
scatter!(W[r1,end-9:end],W[r2,end-9:end],mc=:yellow) # mark end with yellow

In [None]:
# Minibatch training seems to converge to a slightly worse spot
wmin = Array(vec(fmin.w.value))
μ = mean(W[:,2500:end],dims=2)
wlast = W[:,end]
@show norm(wmin), norm(μ)
@show norm(wmin - μ)
@show extrema(wmin - μ)
@show mean(abs.(wmin - μ) .> 0.01)
@show Linear(wmin)(xtrn,ytrn)
@show Linear(μ)(xtrn,ytrn)
@show Linear(wlast)(xtrn,ytrn)

# Covariance of SGD trajectory around minimum

In [None]:
Wstable = W[:,2500:end];  @show summary(Wstable)
μ = mean(Wstable,dims=2); @show summary(μ)
Wzero = Wstable .- μ;     @show summary(Wzero)
Σ = (Wzero * Wzero') / size(Wzero,2); @show summary(Σ)
@show norm(Σ),extrema(Σ);

In [None]:
# Σ is not positive definite, MLE fails because Hessian=inverse(Σ)
# Note that this is not the same as the loss Hessian defined below, it is the distribution Hessian!
# using Distributions
# b = fit_mle(MvNormal, Wstable)

In [None]:
ENV["COLUMNS"]=94
Σ[3001:3005,3001:3005]

In [None]:
@show Σ == Σ'
#@time eigenΣ = eigen(Symmetric(Σ))  # ~53s
#eigenΣ.values'

In [None]:
#plot(eigenΣ.values .+ 1e-16, yscale=:log10)

# Hessian of loss around minimum

In [None]:
function hessian(loss,w,x,y)
    ∇loss = grad(loss)
    ∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
    ∇∇lossi = grad(∇lossi)
    n = length(w)
    h = similar(w,n,n)
    for i in progress(1:n)
        h[:,i] .= vec(∇∇lossi(w,x,y,i))
    end
    return h # Symmetric(Array(0.5*(h+h')))
end

In [None]:
# Compute hessian: ~6 mins
if !isfile("hess03.jld2")
    Knet.gc()
    wmin = fmin.w.value
    loss(w,x,y) = Linear(w)(x,y)
    hmin = hessian(loss,wmin,gtrn,ytrn)
    Knet.save("hess03.jld2","h",hmin)
else
    hmin = Knet.load("hess03.jld2","h")
end
summary(hmin),norm(hmin),extrema(Array(hmin))

In [None]:
display(hmin[3001:3005,3001:3005])
@show isapprox(hmin,hmin',rtol=0.2)
@show isapprox(hmin,hmin',rtol=0.3);

In [None]:
H = Array(0.5*(hmin + hmin'));

In [None]:
#@time eigenH = eigen(Symmetric(H))  # ~53s
#eigenH.values'

In [None]:
#eigenH.values[4000:4010]'

In [None]:
#plot(eigenH.values)

# Hessian (numeric check)

In [None]:
# f(w) ≈ f(wmin) + (w-wmin)' g + 1/2 (w-wmin)' H (w-wmin)
# Gradient at wmin is ≈0, so the middle term can be assumed 0
df = @diff fmin(gtrn,ytrn)
J = vec(grad(df, fmin.w)); @show summary(J)
norm(J)

In [None]:
# We can sample points around wmin with distance ~ 1
@show mean(sqrt.(sum(abs2, Wstable .- μ,dims=1)))
@show mean(sqrt.(sum(abs2, Wstable .- vec(Array(wmin)), dims=1)));

In [None]:
# Apparently we do not need the 1/2 factor?
# hmin and hgpu give identical results?
# adding first order term does not make much difference as expected
hgpu = KnetArray(H)
@show norm(hmin), norm(hgpu), norm(hmin-hgpu)
wmin = vec(fmin.w.value)
wrnd = randn!(similar(wmin)) / sqrt(length(wmin))
loss(w) = Linear(w)(gtrn,ytrn)
@show loss(wmin)
@show loss(wmin + wrnd)
@show loss(wmin) + wrnd' * hgpu * wrnd
@show loss(wmin) + wrnd' * hmin * wrnd
@show loss(wmin) + J' * wrnd + wrnd' * hmin * wrnd

# Diffusion Tensor

In [None]:
function diffusiontensor(loss,w,data,lr=LR) # lr=0.1 is default for sgd
    ∇loss = grad(loss)
    grads = [ ∇loss(w,x,y) for (x,y) in data ]
    n,m = length(grads), mean(grads)
    prefac = lr^2/(2n)
    v = ARRAY(zeros(length(w),length(w)))
    for g in progress(grads)
        e=vec(m-g)
        axpy!(prefac,e*e',v)
    end
    return Array(v)
end

In [None]:
wmin = reshape(wmin,10,:)
summary.((wmin,first(dtrn)...))

In [None]:
# compute diffusion tensor ~24 secs/600 iter. should we go longer?
Knet.gc()
loss(w,x,y) = Linear(w)(x,y)
D = diffusiontensor(loss,wmin,take(dtrn,600));

In [None]:
summary.((Σ,H,D))

In [None]:
@show Σ == Σ'
@show H == H'
@show D == D'

In [None]:
a = H*Σ + Σ*H
b = (2/LR)*D
a ≈ b

In [None]:
norm(a),norm(b),norm(a-b)