In [None]:
using Knet, Plots, Statistics, LinearAlgebra
ENV["COLUMNS"]=40
ARRAY = KnetArray{Float64}

In [None]:
# Load mnist data
include(Knet.dir("data/mnist.jl"))
xtrn,ytrn,xtst,ytst = mnist()
println.(summary.((xtrn,ytrn,xtst,ytst)));

In [None]:
# Minibatched reshaped version, shuffle=true corresponds to SGD-II
dtrn = minibatch(xtrn,ytrn,100;xtype=ARRAY,xsize=(784,:),shuffle=true)
x1,y1 = first(dtrn)
println.(summary.((dtrn,x1,y1)));

In [None]:
loss(w,x,y) = nll(w*x,y)

In [None]:
# Train linear model 20 epochs, takes ~14secs
w = param(10,784,atype=ARRAY)
Knet.gc()
losses = collect(progress(adam((x,y)->loss(w,x,y), repeat(dtrn,20))));

In [None]:
# Train linear model without minibatching ~50 iters/sec
wmin = param(10,784,atype=ARRAY)
x,y = ARRAY(reshape(xtrn,784,:)), ytrn
Knet.gc()
losses = collect(progress(adam((x,y)->loss(wmin,x,y), repeat([(x,y)],10000))));

In [None]:
plot(losses)

In [None]:
function diffusiontensor(loss,w,data,lr=0.1) # lr=0.1 is default for sgd
    ∇loss = grad(loss)
    grads = [ ∇loss(w,x,y) for (x,y) in data ]
    n,m = length(grads), mean(grads)
    prefac = lr^2/(2n)
    v = ARRAY(zeros(length(w),length(w)))
    for g in progress(grads)
        e=vec(m-g)
        axpy!(prefac,e*e',v)
    end
    return v
end

In [None]:
# compute diffusion tensor ~10 secs
Knet.gc()
dtmin = diffusiontensor(loss,wmin,dtrn);

In [None]:
# Compare with results from the old per-instance calc
dt1,w1 = Knet.load("dt01.jld2","dt","w")
dt3,w3 = Array(dtmin),Array(wmin.value);
@show summary.((dt1,w1,dt3,w3))
@show extrema(w1), norm(w1), extrema(w3), norm(w3)
@show extrema(dt1), norm(dt1), extrema(dt3), norm(dt3)
@show isapprox(w1,w3,rtol=0.3), isapprox(dt1,dt3,rtol=0.3);

In [None]:
function hessian(loss,w,x,y)
    ∇loss = grad(loss)
    ∇lossi(w,x,y,i) = ∇loss(w,x,y)[i]
    ∇∇lossi = grad(∇lossi)
    n = length(w)
    h = similar(w,n,n)
    for i in progress(1:n)
        h[:,i] .= vec(∇∇lossi(w,x,y,i))
    end
    return h
end

In [None]:
# Compute hessian: ~6 mins
if true # !isfile("hess03.jld2")
    x,y = ARRAY(reshape(xtrn,784,:)), ytrn
    Knet.gc()
    hmin = hessian(loss,wmin,x,y)
    Knet.save("hess03.jld2","h",Array(hmin))
else
    hmin = ARRAY(Knet.load("hess03.jld2","h"))
end
summary.((dtmin,hmin))

In [None]:
H = Array(hmin)
@show isapprox(H,H',rtol=0.2)
H[4001:4005,4001:4005]

In [None]:
@time F = eigen(Symmetric((H+H')/2))
plot(F.values)

In [None]:
@show mean(F.values .< 0)
@show mean(F.values .== 0)
@show mean(F.values .> 0)

In [None]:
F.values[end-10:end]

In [None]:
# This calculation is from Michael's overleaf notes:
# https://www.overleaf.com/2523873322bvvnxpwnskfk
function covariancematrix(D,H;lr=0.1)
    LinearAlgebra.BLAS.set_num_threads(20)
    @time F = eigen(Symmetric((H+H')/2)) # H not symmetric, eigen gives complex values
    h = copy(F.values)
    h[h.<1e-8] .= 1e-8
    O = F.vectors
    Nw = length(h)
    @time ODO = O'*D*O;
    Delta = zeros(Nw,Nw);
    for i=progress(1:Nw)
        for j=1:Nw
            Delta[i,j] = ODO[i,j]/(h[i]+h[j])
        end
    end
    return (2/lr)*O*Delta*O'
end

In [None]:
@time D = Array(dtmin); # ~ 0.3s
@time H = Array(hmin);  # ~ 0.3s

In [None]:
C = covariancematrix(D,H);

In [None]:
HH = (H+H')/2
z1 = HH*C + C*HH
z2 = (2/lr)*D
isapprox(z1,z2,rtol=0.2)

In [None]:
z1[3001:3005,3001:3005]

In [None]:
z2[3001:3005,3001:3005]

In [None]:
D[3001:3005,3001:3005]