In [None]:
using Knet, Plots
# using IterTools, Base.Iterators # take, cycle, takenth
# using StatsBase, Statistics # mean, describe
ENV["COLUMNS"]=40
ARRAY = KnetArray{Float64}

In [None]:
# Load mnist data
include(Knet.dir("data/mnist.jl"))
xtrn,ytrn,xtst,ytst = mnist()
println.(summary.((xtrn,ytrn,xtst,ytst)));

In [None]:
# Minibatched reshaped version
dtrn = minibatch(xtrn,ytrn,100;xtype=ARRAY,xsize=(784,:),shuffle=true)
x1,y1 = first(dtrn)
println.(summary.((dtrn,x1,y1)));

In [None]:
# Train linear model until convergence
struct Lin; w; end
(f::Lin)(x) = f.w * x
(f::Lin)(x,y) = nll(f(x), y)
f = Lin(param(10,784,atype=ARRAY))
losses = collect(progress(sgd(f,repeat(dtrn,20))));

In [None]:
plot(losses)

In [None]:
function diffusiontensor(f,xt,yt,Nb,lr)
    w = f.w
    Nw = length(w)  # number of weights, that is, dimensions of the diffusion tensor
    Nt = length(yt) # number of training examples to be summed over
    prefac = (lr^2 / 2) * ((Nt-Nb) / (Nb*(Nt-1)))
    # V = zeros(Nw,Nt) # initialize the diffusion matrix
    vvt = ARRAY(undef,Nw,Nw) # this is smaller, just keep VV'
    vvt .= 0
    function helper(i)
        j = @diff f(xt[:,i:i],yt[i:i])
        g = vec(grad(j,f.w))
        vvt .+= g * g'
    end
    progress!(helper(i) for i in 1:Nt)
    j = @diff f(xt,yt)
    dL = vec(grad(j,f.w))
    return prefac * (vvt/Nt - dL*dL')
end

In [None]:
xt = ARRAY(reshape(xtrn,784,:))
yt = ytrn
Nb = 100
lr = Knet.SGD().lr  # the default learning rate = 0.1
Knet.gc()

In [None]:
dt = diffusiontensor(f,xt,yt,Nb,lr)
dt == dt'

In [None]:
dtcpu = Array(dt)
extrema(dtcpu) # (-2.1491637463846156e-7, 5.306538667800477e-7)

In [None]:
Knet.save("dt01.jld2","dt",dtcpu,"w",f.w.value)