# MLP model on MNIST

In [None]:
using Knet, Plots, Statistics, LinearAlgebra, Base.Iterators, Random, StatsBase, JLD
ENV["COLUMNS"] = 80
ARRAY = KnetArray{Float64}
XSIZE=784   # input dimension
HIDDENSIZE=32 # hidden layer dimension
YSIZE=10    # output dimension
BATCHSIZE=10 # minibatch size
LAMBDA=1e-2 # regularization parameter
#LAMBDA=1e-10
#LR=1e-1     # learning rate
LR=1e-2     # learning rate
#MITER=10^4  # iterations for finding minimum
MITER=10^5  # iterations for finding minimum
DITER=10^5  # iterations for diffusion tensor
CITER=10^7  # iterations for covariance trajectory 
CFREQ=10^2  # keep every CFREQ points on trajectory

# Define regularized linear model with softmax loss

In [None]:
pred(w,x) = reshape(w[HIDDENSIZE*(XSIZE+1)+1:end],YSIZE,HIDDENSIZE)*max.(0,reshape(w[1:HIDDENSIZE*XSIZE],HIDDENSIZE,XSIZE)*reshape(x,XSIZE,:) .+ w[HIDDENSIZE*XSIZE+1:HIDDENSIZE*(XSIZE+1)])
loss(w,x,y;λ=LAMBDA) = nll(pred(w,x), y) + (λ/2) * sum(abs2,w)

# Load MNIST data

In [None]:
include(Knet.dir("data/mnist.jl"))
xtrn,ytrn,xtst,ytst = mnist()
atrn,atst = ARRAY(xtrn), ARRAY(xtst) # GPU copies for batch training
println.(summary.((xtrn,ytrn,xtst,ytst,atrn,atst)));

# Minibatch data

In [None]:
# Minibatching for SGD-I, i.e. with replacement. Knet.minibatch can't do this, we define new struct
struct MB; x; y; n; end
Base.Iterators.IteratorSize(::Type{MB}) = Base.IsInfinite() # need this for collect to work
Base.iterate(d::MB, s...)=(r = rand(1:length(d.y),d.n); ((ARRAY(mat(d.x)[:,r]), d.y[r]), true))
dtrn = MB(xtrn, ytrn, BATCHSIZE)
println.(summary.(first(dtrn)));

# Find minimum

In [None]:
LAMBDA,MITER,BATCHSIZE

In [None]:
# Find minimum without minibatching
# ~50 iters/sec, converges in 3 mins to 
# 0.267218 for LAMBDA=1e-4 (err=0.25, reg=0.02)
# 0.344490 for LAMBDA=1e-3 (err=0.29, reg=0.05)
# 0.558482 for LAMBDA=1e-2 (err=0.41, reg=0.15)
wminfile = "MLP-wmin-$LAMBDA-$MITER.jld2"
if !isfile(wminfile)
    wmin = Param(ARRAY(LAMBDA*rand((XSIZE+1)*HIDDENSIZE + HIDDENSIZE*YSIZE)))
    args = repeat([(wmin,atrn,ytrn)],MITER)
    Knet.gc()
    losses = collect(progress(adam(loss,args)))
    Knet.save(wminfile, "wmin", wmin, "losses", losses)
else
    wmin, losses = Knet.load(wminfile, "wmin", "losses");
end
@show summary(wmin)
losses[end-4:end]'

In [None]:
println.((
(loss(wmin,atrn,ytrn),nll(pred(wmin,atrn),ytrn),(LAMBDA/2)*sum(abs2,wmin)),
(loss(wmin,atst,ytst),nll(pred(wmin,atst),ytst),(LAMBDA/2)*sum(abs2,wmin)),
(accuracy(pred(wmin,atrn),ytrn),accuracy(pred(wmin,atst),ytst))));

In [None]:
∇loss = grad(loss)
∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
∇∇lossi = grad(∇lossi)
w = value(wmin)
n = length(w)
∇lossi(w,atrn,ytrn,1)

# Hessian of loss around minimum

In [None]:
function hessian(loss,w,x,y)
    ∇loss = grad(loss)
    ∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
    ∇∇lossi = grad(∇lossi)
    w = value(w)
    n = length(w)
#    h = similar(w,n,n)
    h = zeros(length(w),length(w))
    for i in progress(1:n)
        h[:,i] .= ARRAY(vec(∇∇lossi(w,x,y,i)))
    end
    return h
end

In [None]:
# Compute hessian: ~5 mins
hessfile = "MLP-hess-$LAMBDA.jld"
if !isfile(hessfile)
    hmin = hessian(loss,wmin,atrn,ytrn)
    save(hessfile,"h",hmin)
else
    hmin = load(hessfile,"h")
end
println.((summary(hmin),extrema(Array(hmin)),norm(hmin),norm(hmin-hmin')));

# Eigenvalues of the Hessian

In [None]:
heigfile = "MLP-nosym_heig-$LAMBDA.jld2"
H = Array(hmin)
if !isfile(heigfile)
#    @time eigenH = eigen(Symmetric(H)) # ~53s
    @time eigenH = eigen(H) # ~53s
    Knet.save(heigfile,"eigenH",eigenH)
else
    eigenH = Knet.load(heigfile,"eigenH")
end
eigenH.values'

In [None]:
summarystats(real.(eigenH.values)) |> dump
scatter(real.(eigenH.values),xaxis=:log10,yaxis=:log10)

# Diffusion Tensor

In [None]:
function diffusiontensor(loss,w,x,y;iters=DITER,lr=LR,batchsize=BATCHSIZE)
    ∇loss = grad(loss)
    grad0 = Array(∇loss(w, ARRAY(x), y))
    data = MB(x,y,batchsize)
    grads = zeros(length(w), iters)
    for (i,d) in progress(enumerate(take(data,iters)))
        grads[:,i] .= Array(∇loss(w,d...))
    end
    prefac = (lr^2)/(2iters)
    grads = grad0 .- grads
    @time v = prefac * (grads * grads')
    return v
end

In [None]:
LAMBDA,LR,BATCHSIZE,DITER

In [None]:
dtfile = "MLP-dt-$LAMBDA-$LR-$BATCHSIZE-$DITER.jld2"
if !isfile(dtfile)
    Knet.gc()
    D = diffusiontensor(loss,wmin,xtrn,ytrn) # ~700 iters/sec
    Knet.save(dtfile,"D",D)
else
    D = Knet.load(dtfile,"D")
end
summarystats(vec(D)) |> dump

# Record trajectory with SGD starting at minimum

In [None]:
LAMBDA,LR,BATCHSIZE,CITER,CFREQ

In [None]:
# Trajectory of w starting from wmin recorded after each update: 
# ~1000 updates/sec, ~16 secs total
trajfile = "MLP-traj-$LAMBDA-$LR-$BATCHSIZE-$CITER-$CFREQ.jld2"
if !isfile(trajfile)
    w = Param(ARRAY(value(wmin)))
    data = MB(xtrn,ytrn,BATCHSIZE)
    d = take(data,CITER)
    W = zeros(eltype(w),length(w),div(CITER,CFREQ))
    f(x,y) = loss(w,x,y)
    Knet.gc()
    i = 0
    for t in progress(sgd(f,d; lr=LR))
        i += 1; (div,rem)=divrem(i,CFREQ)
        if rem == 0
            W[:,div] = Array(vec(w))
        end
    end
    Knet.save(trajfile,"W",W)
else
    W = Knet.load(trajfile,"W")
end
summary(W)

In [None]:
# Plot losses on whole dataset, first steps seem transient, ~10 secs
r = 1:100:size(W,2)
@time plot(r, [loss(ARRAY(W[:,i]),atrn,ytrn) for i in r])

In [None]:
rr1,rr2 = rand(1:size(W,1)),rand(1:size(W,1))

In [None]:
# Plot trajectory of two random dimensions
@show rr1,rr2 = rand(1:size(W,1)),rand(1:size(W,1))
if norm(W[rr1,:]) > 0 && norm(W[rr2,:]) > 0
    histogram2d(W[rr1,:],W[rr2,:],background_color="black")
end

In [None]:
# Minibatch training seems to converge to a slightly worse spot
w0 = Array(value(wmin))
μ = mean(W[:,2500:end],dims=2)
w1 = W[:,end]
@show norm(w0), norm(μ), norm(w0 - μ)
@show extrema(w0), extrema(μ), extrema(w0 - μ)
@show mean(abs.(w0 - μ) .> 0.01)
@show loss(w0,xtrn,ytrn)
@show loss(μ,xtrn,ytrn)
@show loss(w1,xtrn,ytrn)

# Covariance of SGD trajectory around minimum

In [None]:
#Wstable = W[:,2500:end];  @show summary(Wstable)
Wstable = W
μ = mean(Wstable,dims=2); @show summary(μ)
Wzero = Wstable .- μ;     @show summary(Wzero)
Σ = (Wzero * Wzero') / size(Wzero,2)
@show summary(Σ)
@show norm(Σ)
@show extrema(Σ)
@show norm(diag(Σ));

In [None]:
# check for convergence
n2 = div(size(W,2),2)
w1 = W[:,1:n2]
w2 = W[:,1+n2:end]
w1 = w1 .- mean(w1,dims=2)
w2 = w2 .- mean(w2,dims=2)
Σ1 = (w1 * w1') / size(w1,2)
Σ2 = (w2 * w2') / size(w2,2);

In [None]:
# The variances (diagonal elements) converge
norm(diag(Σ1)),norm(diag(Σ2)),norm(diag(Σ1)-diag(Σ2))

In [None]:
# The off diagonal elements are still not there
norm(Σ1),norm(Σ2),norm(Σ1-Σ2)

# Check Einstein relation

In [None]:
summary.((H,D,Σ))

In [None]:
a = H*Σ + Σ*H
b = (2/LR)*D
norm(a),norm(b),norm(a-b)

## Covariance eigenspace

In [None]:
# 46 sec
@time eigenΣ = eigen(Σ);

In [None]:
 ΛΣ = eigenΣ.values; O = eigenΣ.vectors;

In [None]:
# check that the eigenvectors/values are OK
# norm(ΛΣ[end]*O[:,end]),norm(Σ*O[:,end]-ΛΣ[end]*O[:,end])

In [None]:
# transform the trajectory to the eigenbasis
Weig = O'*W;

In [None]:
## check Einstein relation in top Neig eigen-directions of Σ
Neig=100
Or = O[:,end-Neig+1:end];
aa = Or'*a*Or
bb = Or'*b*Or
norm(aa),norm(bb)/norm(aa),norm(aa-bb)/norm(aa)

In [None]:
aa[1:10,1:10]

In [None]:
bb[1:10,1:10]

In [None]:
# pick two eigen directions
Nweights = size(W,1)
Xid = Nweights
Yid = Nweights-1

#O = eigvecs(Σ);

#W_ss = W*O; # sample weights are row vectors
Wx = Weig[Xid,:]
Wy = Weig[Yid,:]

#COV_ss = O'*Σ*O
#COV_xy_inv = inv(COV_ss[[Xid,Yid],[Xid,Yid]])
COV_xy_inv = Diagonal([1/ΛΣ[Xid],1/ΛΣ[Yid]]) + zeros(2,2)
μeig = O'*μ
W0eig = O'*w0;

In [None]:
myhplot = histogram2d(Wx,Wy
    ,bins=100
    ,aspect_ratio=1
    ,background_color="black"
);

In [None]:
display(myhplot)

In [None]:
# Construct a grid enclosing the steady-state trajectory
minmaxdiff(t) = maximum(t)-minimum(t)

function makegrid(xvec,yvec,mean,xindex,yindex;Nx=50,Ny=50,zoom=0.35)
    Lx,Ly = minmaxdiff(xvec),minmaxdiff(yvec)
    xrange,yrange = zoom*Lx,zoom*Ly
    dx,dy = xrange/Nx,yrange/Ny
    x = collect(-xrange:dx:xrange) .+ mean[xindex]
    y = collect(-yrange:dy:yrange) .+ mean[yindex]

    # some mumbo-jumbo for calculating weights corresponding to grid points
    Identity = Diagonal(ones(Nweights,Nweights)); # unit matrix
    xmask = Identity[:,xindex];
    ymask = Identity[:,yindex];
    Imask = Identity - xmask*xmask' - ymask*ymask' # set two diagonal elements to zero
    return (x,y,Imask,xmask,ymask)
end

In [None]:
(x,y,Imask,xmask,ymask) = makegrid(Wx,Wy,W0eig,Xid,Yid;Nx=12,Ny=12,zoom=0.5)
#(x,y,Imask,xmask,ymask) = makegrid(Wx,Wy,μeig,Xid,Yid;Nx=4,Ny=4);

In [None]:
meanXY = W0eig[[Xid Yid]]
# Contours of the fit mv-Gaussian
ffit(s,t) = -(([s t]-meanXY)*COV_xy_inv*([s t]-meanXY)')[1]
Ffit(s,t) = 250*ffit(s,t)/ffit(x[end],y[end])
contour!(x,y,Ffit
    ,linestyle=:dash
    ,levels=10
    ,linewidth=2
)

In [None]:
# contours of loss
midx = Int((length(x)-1)/2)
midy = Int((length(y)-1)/2)
#fexp(s,t) = 1e5*(loss(O*(Imask*W0eig + s*xmask + t*ymask),xtrn,ytrn) - loss(w0,xtrn,ytrn))
fexp(s,t) = 1e5*(loss(O*(Imask*μeig + s*xmask + t*ymask),xtrn,ytrn) - loss(w0,xtrn,ytrn))
logfexpmidp1 = log(1+fexp(x[midx],y[midy]))
Flossxy(s,t) = 1e2*(log((1+fexp(s,t))) - logfexpmidp1)

In [None]:
Flossxy(0,0)

## N,N-1

In [None]:
@time contour!(x,y,Flossxy
#contour!(x,y,fexp
    ,levels=10
    ,linewidth=2
)