# Linear model on MNIST

In [None]:
using Knet, Plots, Statistics, LinearAlgebra, Base.Iterators, Random, StatsBase
ENV["COLUMNS"] = 80
ARRAY = KnetArray{Float64}
XSIZE=100   # input dimension
YSIZE=10    # output dimension
BATCHSIZE=100 # minibatch size
LAMBDA=1e-2 # regularization parameter
LR=1e-1     # learning rate
MITER=5000  # iterations for finding minimum
DITER=10^5  # iterations for diffusion tensor
CITER=10^6  # iterations for covariance trajectory 
CFREQ=10^1  # keep every CFREQ points on trajectory
@show gpu();

# Define regularized linear model with softmax loss

In [None]:
pred(w,x) = reshape(w,YSIZE,XSIZE) * reshape(x,XSIZE,:)
loss(w,x,y;λ=LAMBDA) = nll(pred(w,x), y) + (λ/2) * sum(abs2,w)

# Load MNIST data

In [None]:
xtrn,ytrn,xtst,ytst = Knet.load("mnist10x10.jld2","xtrn","ytrn","xtst","ytst")
atrn,atst = ARRAY(xtrn), ARRAY(xtst) # GPU copies for batch training
println.(summary.((xtrn,ytrn,xtst,ytst,atrn,atst)));

# Minibatch data

In [None]:
# Minibatching for SGD-I, i.e. with replacement. Knet.minibatch can't do this, we define new struct
struct MB; x; y; n; end
Base.Iterators.IteratorSize(::Type{MB}) = Base.IsInfinite() # need this for collect to work
Base.iterate(d::MB, s...)=(r = rand(1:length(d.y),d.n); ((ARRAY(mat(d.x)[:,r]), d.y[r]), true))
dtrn = MB(xtrn, ytrn, BATCHSIZE)
println.(summary.(first(dtrn)));

# Find minimum

In [None]:
LAMBDA,MITER

In [None]:
# Find minimum without minibatching
# ~50 iters/sec, LAMBDA=1e-2, MITER=5000, converges in 1 min to 
# (trnloss = 0.88162015413913, nll = 0.6166511380346865, reg = 0.2649690161044434)
wminfile = "mnist10x10linear-wmin-$LAMBDA-$MITER.jld2"
if !isfile(wminfile)
    wmin = Param(ARRAY(zeros(XSIZE*YSIZE)))
    args = repeat([(wmin,atrn,ytrn)],MITER)
    Knet.gc()
    losses = collect(progress(adam(loss,args)))
    Knet.save(wminfile, "wmin", wmin, "losses", losses)
else
    wmin, losses = Knet.load(wminfile, "wmin", "losses");
end
@show summary(wmin)
losses[1000:1000:end]'

In [None]:
println.((
(trnloss=loss(wmin,atrn,ytrn),nll=nll(pred(wmin,atrn),ytrn),reg=(LAMBDA/2)*sum(abs2,wmin)),
(tstloss=loss(wmin,atst,ytst),nll=nll(pred(wmin,atst),ytst),reg=(LAMBDA/2)*sum(abs2,wmin)),
(trnacc=accuracy(pred(wmin,atrn),ytrn),tstacc=accuracy(pred(wmin,atst),ytst))));

# Hessian of loss around minimum

In [None]:
function hessian(loss,w,x,y)
    ∇loss = grad(loss)
    ∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
    ∇∇lossi = grad(∇lossi)
    w = value(w)
    n = length(w)
    h = similar(Array(w),n,n)
    for i in progress(1:n)
        h[:,i] .= Array(vec(∇∇lossi(w,x,y,i)))
    end
    return h
end

In [None]:
# Compute hessian: ~15 secs
hessfile = "mnist10x10linear-hess-$LAMBDA.jld2"
if !isfile(hessfile)
    Knet.gc()
    H = hessian(loss,wmin,atrn,ytrn)
    Knet.save(hessfile,"h",H)
else
    H = Knet.load(hessfile,"h")
end
summarystats(vec(H)) |> dump
@show norm(H), norm(H-H');

# Eigenvalues of the Hessian

In [None]:
heigfile = "mnist10x10linear-heig-$LAMBDA.jld2"
if !isfile(heigfile)
    @time eigenH = eigen(Symmetric(H)) # ~1s
    Knet.save(heigfile,"eigenH",eigenH)
else
    eigenH = Knet.load(heigfile,"eigenH")
end
eigenH.values'

In [None]:
summarystats(eigenH.values) |> dump
plot(reverse(eigenH.values), yscale=:log10)

# Diffusion Tensor

In [None]:
function diffusiontensor(loss,w,x,y;iters=DITER,lr=LR,batchsize=BATCHSIZE)
    ∇loss = grad(loss)
    grad0 = Array(∇loss(w, ARRAY(x), y))
    data = MB(x,y,batchsize)
    grads = zeros(length(w), iters)
    for (i,d) in progress(enumerate(take(data,iters)))
        grads[:,i] .= Array(∇loss(w,d...))
    end
    prefac = (lr^2)/(2iters)
    grads = grad0 .- grads
    @time v = prefac * (grads * grads')
    return v
end

In [None]:
LAMBDA,LR,BATCHSIZE,DITER

In [None]:
dtfile = "mnist10x10linear-diff-$LAMBDA-$LR-$BATCHSIZE-$DITER.jld2"
if !isfile(dtfile)
    Knet.gc()
    D = diffusiontensor(loss,wmin,xtrn,ytrn) # ~1600 iters/sec, 60secs total
    Knet.save(dtfile,"D",D)
else
    D = Knet.load(dtfile,"D")
end
summarystats(vec(D)) |> dump
@show norm(D),norm(D-D');

# Eigenvalues of the diffusion tensor

In [None]:
deigfile = "mnist10x10linear-deig-$LAMBDA-$LR-$BATCHSIZE-$DITER.jld2"
if !isfile(deigfile)
    @time eigenD = eigen(Symmetric(D)) # ~1s
    Knet.save(deigfile,"eigenD",eigenD)
else
    eigenD = Knet.load(deigfile,"eigenD")
end
eigenD.values'

In [None]:
summarystats(eigenD.values) |> dump
plot(reverse(eigenD.values) .+ 1e-20, yscale=:log10)

# Record trajectory with SGD starting at minimum

In [None]:
LAMBDA,LR,BATCHSIZE,CITER,CFREQ

In [None]:
# Trajectory of w starting from wmin recorded after each update: 
# ~1800 updates/sec, ~9 mins total for CITER=1M
trajfile = "mnist10x10linear-traj-$LAMBDA-$LR-$BATCHSIZE-$CITER-$CFREQ.jld2"
if !isfile(trajfile)
    w = Param(ARRAY(value(wmin)))
    data = MB(xtrn,ytrn,BATCHSIZE)
    d = take(data,CITER)
    W = zeros(eltype(w),length(w),div(CITER,CFREQ))
    f(x,y) = loss(w,x,y)
    Knet.gc()
    i = 0
    for t in progress(sgd(f,d; lr=LR))
        i += 1; (div,rem)=divrem(i,CFREQ)
        if rem == 0
            W[:,div] = Array(vec(w))
        end
    end
    Knet.save(trajfile,"W",W)
else
    W = Knet.load(trajfile,"W")
end
summary(W)

In [None]:
# Plot losses on whole dataset, first steps seem transient, ~6 secs
r = 1:100:size(W,2)
@time plot(r, [loss(ARRAY(W[:,i]),atrn,ytrn) for i in r])

In [None]:
# Plot trajectory of two random dimensions
@show i1,i2 = rand(1:size(W,1)),rand(1:size(W,1))
if norm(W[i1,:]) > 0 && norm(W[i2,:]) > 0
    histogram2d(W[i1,:],W[i2,:])
end

In [None]:
# Minibatch training seems to converge to a slightly worse spot, but mean is close to optimal
w0 = Array(value(wmin))
μ = mean(W[:,2500:end],dims=2)
w1 = W[:,end]
@show norm(w0), norm(μ), norm(w0 - μ)
@show extrema(w0), extrema(μ), extrema(w0 - μ)
@show mean(abs.(w0 - μ) .> 0.01)
@show loss(w0,xtrn,ytrn)
@show loss(μ,xtrn,ytrn)
@show loss(w1,xtrn,ytrn)

# Covariance of SGD trajectory around minimum

In [None]:
#Wstable = W[:,2500:end];  @show summary(Wstable)
Wstable = W
μ = mean(Wstable,dims=2); @show summary(μ)
Wzero = Wstable .- μ;     @show summary(Wzero)
@time Σ = (Wzero * Wzero') / size(Wzero,2) # ~1s
summarystats(vec(Σ)) |> dump
@show norm(Σ),norm(Σ-Σ');
@show summary(Σ)
@show norm(Σ)
@show extrema(Σ)
@show norm(diag(Σ));

In [None]:
# check for convergence
n2 = div(size(W,2),2)
w1 = W[:,1:n2]
w2 = W[:,1+n2:end]
w1 = w1 .- mean(w1,dims=2)
w2 = w2 .- mean(w2,dims=2)
Σ1 = (w1 * w1') / size(w1,2)
Σ2 = (w2 * w2') / size(w2,2);

In [None]:
# The variances (diagonal elements) converge
norm(diag(Σ1)),norm(diag(Σ2)),norm(diag(Σ1)-diag(Σ2))

In [None]:
# The off diagonal elements are still not there
norm(Σ1),norm(Σ2),norm(Σ1-Σ2)

# Eigenvalues of the covariance

In [None]:
ceigfile = "mnist10x10linear-ceig-$LAMBDA-$LR-$BATCHSIZE-$CITER-$CFREQ.jld2"
if !isfile(ceigfile)
    @time eigenC = eigen(Symmetric(Σ)) # ~1s
    Knet.save(ceigfile,"eigenC",eigenC)
else
    eigenC = Knet.load(ceigfile,"eigenC")
end
eigenC.values'

In [None]:
# Note: no need for log-scale in plot this time
summarystats(eigenC.values) |> dump
plot(reverse(eigenC.values)) # .+ 1e-19, yscale=:log10)

# Check equation

In [None]:
summary.((H,D,Σ))

In [None]:
a = H*Σ + Σ*H
b = (2/LR)*D
norm(a),norm(b),norm(a-b)

In [None]:
# Solve for Sigma as in Mike's notes
function fsigma(eigenH,D)
    O = eigenH.vectors
    ODO = O'*D*O;
    n = size(D,1)
    Delta=zero(D)
    for i=1:n
        for j=1:n
            Delta[i,j]=ODO[i,j]/(eigenH.values[i]+eigenH.values[j])
        end
    end
    Sigma = O*Delta*O'
    return Sigma
end

In [None]:
S = fsigma(eigenH,D)
norm(Σ),norm(S),norm(Σ-S)

In [None]:
dΣ,dS = diag(Σ),diag(S)
norm(dΣ),norm(dS),norm(dΣ-dS)

In [None]:
summarystats(dΣ ./ dS) |> dump

In [None]:
S20 = 20*fsigma(eigenH,D)
norm(Σ),norm(S20),norm(Σ-S20)

In [None]:
dΣ,dS20 = diag(Σ),diag(S20)
norm(dΣ),norm(dS20),norm(dΣ-dS20)

In [None]:
S == S', Σ == Σ'

In [None]:
norm(S),norm(S-S')

In [None]:
mean(diag(S20)), mean(S20-Diagonal(S20))

In [None]:
mean(diag(Σ)), mean(Σ-Diagonal(Σ))

# JUNK below this line

In [None]:
# norm(a),norm(b),norm(a-b)
# (0.001831031218512692, 0.0015956482650563955, 0.000668750876334433): CITER=1M CFREQ=100
# (0.0017456054525856169, 0.0015956482650563955, 0.00044092592553571534): CITER=1M CFREQ=10

In [None]:
a[3000:3004,3000:3004]

In [None]:
b[3000:3004,3000:3004]

In [None]:
(a./b)[3000:3004,3000:3004]

In [None]:
summarystats(vec(abs.(a))) |> dump

In [None]:
a0 = a; a0[abs.(a0) .< 1e-7] .= 0
b0 = b; b0[abs.(b0) .< 1e-7] .= 0

In [None]:
norm(a0), norm(b0), norm(a0-b0)

# Check convergence of covariance

In [None]:
function sigma(W)
    μ = mean(W,dims=2)
    W0 = W .- μ
    Σ = (W0 * W0') / size(W0,2)
end

In [None]:
sigmas = [ sigma(W[:,1:(i*1000)]) for i in (1,2,5,10,20,50,100) ];

In [None]:
for i in 2:length(sigmas)
   println((norm(sigmas[i-1]),norm(sigmas[i]),norm(sigmas[i]-sigmas[i-1])))
end

In [None]:
# 1K,2K,...,10K sampling every 100 up to 1M convergence
(0.17517119848767196, 0.18735112177809712, 0.1736369823475756)
(0.18735112177809712, 0.18194022797202322, 0.13163000956991822)
(0.18194022797202322, 0.17754343048529778, 0.10970021736453625)
(0.17754343048529778, 0.16902882189691706, 0.08681261945879391)
(0.16902882189691706, 0.16118998631852746, 0.07373054738895173)
(0.16118998631852746, 0.15395757281329958, 0.06383283177104303)
(0.15395757281329958, 0.14673395708896458, 0.05427855979560519)
(0.14673395708896458, 0.14047011468580411, 0.04757881795972907)
(0.14047011468580411, 0.13584321137382668, 0.04594090264182607)

In [None]:
# 1K,2K,5K,10K,20K,50K,100K sampling every 10 up to 1M 
(0.09419547449556358, 0.12122309088637423, 0.11268632188794234)
(0.12122309088637423, 0.15926002834426348, 0.15786197264298812)
(0.15926002834426348, 0.17998457741465182, 0.16622733785950075)
(0.17998457741465182, 0.18679671394160588, 0.1732145346986742)
(0.18679671394160588, 0.16400614726671942, 0.17491052732529933)
(0.16400614726671942, 0.1335983636197007, 0.1283291677336581)

In [None]:
for i in 2:length(sigmas)
   println((norm(diag(sigmas[i-1])),norm(diag(sigmas[i])),norm(diag(sigmas[i]-sigmas[i-1]))))
end

In [None]:
plot(norm.(diag.(sigmas)))

In [None]:
plot(sort(diag(H)))

# Hessian (numeric check)

In [None]:
# f(w) ≈ f(wmin) + (w-wmin)' g + 1/2 (w-wmin)' H (w-wmin)
# Gradient at wmin is ≈0, so the middle term can be assumed 0
df = @diff loss(wmin,atrn,ytrn)
J = vec(grad(df, wmin)); @show summary(J)
@show norm(J);

In [None]:
# Test approx at ~1 distance around wmin
# adding first order term does not make much difference as expected
wrnd = randn!(similar(wmin)) / sqrt(length(wmin))
lossw(w) = loss(w,atrn,ytrn)
@show lossw(wmin)
@show lossw(wmin + wrnd)
@show lossw(wmin) + 0.5 * wrnd' * hmin * wrnd
@show lossw(wmin) + J' * wrnd + 0.5 * wrnd' * hmin * wrnd

# Check convergence of variance in SGD trajectory

In [None]:
BATCHSIZE,LR,LAMBDA

In [None]:
w  = ARRAY(value(wmin))
w1 = w .+ 0
w2 = w .* w
nw = 1
pw = Param(w)
data = MB(xtrn,ytrn,BATCHSIZE)
f(x,y) = loss(pw,x,y,\lambda=LAMBDA)

function wvar(w)
    global nw, w1, w2
    nw += 1
    w1 .+= w
    w2 .+= w .* w
    var = w2 / nw - (w1 .* w1) / (nw * nw) # E[x^2] - E[x]^2
    norm(var)
end

c = collect(progress(wvar(w) for wloss in sgd(f, take(data,500000), lr=LR)));

In [None]:
plot(c, xscale=:log10)

In [None]:
norm(w),norm(w.*w),norm(w1),norm(w2)

# DiffusionTensor Convergence 
(with LAMBDA=0.0001,LR=0.1)

In [None]:
#norm(d100),norm(d1000),norm(d100-d1000)         #(9.006901905269366e-5, 7.966566524654371e-5, 4.746245825247884e-5)
#norm(d1000),norm(d2000),norm(d1000-d2000)       #(7.966566524654371e-5, 7.976183314696048e-5, 1.8908596356951573e-5)
#norm(d4000),norm(d2000),norm(d4000-d2000)       #(8.024754500933944e-5, 7.976183314696048e-5, 1.3446098748867312e-5)
#norm(d10000),norm(d4000),norm(d10000-d4000)     #(7.942869188760434e-5, 8.024754500933944e-5, 8.963487531775732e-6)
#norm(d10000),norm(d20000),norm(d20000-d10000)   #(7.955107659141219e-5, 7.976461525637195e-5, 5.902290704618786e-6)
#norm(d20000),norm(d50000),norm(d50000-d20000)   #(7.976461525637195e-5, 7.971793474706985e-5, 3.962945418293533e-6)
#norm(d50000),norm(d100000),norm(d100000-d50000) #(7.971793474706985e-5, 7.978241325281978e-5, 2.739184868054573e-6)

# Scatter plots for trajectory

In [None]:
# Seems to converge to a slightly different point?
# Interesting patterns: staircase, globe, H shaped
@show r1,r2 = rand(1:size(W,1)),rand(1:size(W,1))
scatter(W[r1,1:end],W[r2,1:end])
scatter!(W[r1,1:10], W[r2,1:10],mc=:red) # mark beginning with red
scatter!(W[r1,end-9:end],W[r2,end-9:end],mc=:yellow) # mark end with yellow

# Computing covariance with fit_mle fails

In [None]:
# b = fit_mle(MvNormal, Wstable)
# Σ is not positive definite, MLE fails because Hessian=inverse(Σ)
# Note that this is not the same as the loss Hessian defined below, it is the distribution Hessian!
# using Distributions
@show sum(diag(Σ) .== 0) # 670 dims do not move at all!
@show sum(wmin .== 0) # these stay at 0 in the w matrix.
findall(diag(Σ) .== 0) == findall(Array(wmin) .== 0) # to make sure they are the same

In [None]:
Knet.gc()