In [None]:
using Knet, Plots, Statistics, LinearAlgebra
# using IterTools, Base.Iterators # take, cycle, takenth
# using StatsBase, Statistics # mean, describe
ENV["COLUMNS"]=40
ARRAY = KnetArray{Float64}

In [None]:
# Load mnist data
include(Knet.dir("data/mnist.jl"))
xtrn,ytrn,xtst,ytst = mnist()
println.(summary.((xtrn,ytrn,xtst,ytst)));

In [None]:
# Minibatched reshaped version
dtrn = minibatch(xtrn,ytrn,100;xtype=ARRAY,xsize=(784,:),shuffle=true)
x1,y1 = first(dtrn)
println.(summary.((dtrn,x1,y1)));

In [None]:
# Train linear model until convergence
w = param(10,784,atype=ARRAY)
pred(w,x) = w * x
loss(w,x,y) = nll(pred(w,x),y)
losses = collect(progress(sgd((x,y)->loss(w,x,y), repeat(dtrn,20))));

In [None]:
plot(losses)

In [None]:
∇loss(w,x,y) = grad((@diff loss(w,x,y)), w)

function diffusiontensor(w,data,lr=0.1)
    grads = [ ∇loss(w,x,y) for (x,y) in data ]
    m = mean(grads)
    v = ARRAY(zeros(length(w),length(w)))
    # progress!((e=vec(m-g); v.+=e*e') for g in grads)  # axpy! faster
    progress!((e=vec(m-g); axpy!(1,e*e',v)) for g in grads)
    (lr^2/2) * v
end

In [None]:
# (-0.00012176828791668883, 0.0002895118920005772)
# (-0.00010208279622571245, 0.000302902658497447)
dt = diffusiontensor(w,dtrn)
extrema(Array(dt)) 

In [None]:
dt1,w1 = Knet.load("dt01.jld2","dt","w") # Results from the per-instance calc

In [None]:
dt2,w2 = Array(dt),Array(w.value);
@show isapprox(w1,w2,rtol=0.3)
@show isapprox(dt1/norm(dt1),dt2/norm(dt2),rtol=0.3)

In [None]:
norm(dt1), norm(dt2)

In [None]:
Knet.save("dt02.jld2","dt",dt2,"w",w2)