In [None]:
using Pkg,Statistics,Random,Printf,GZip,Knet,Plots,LinearAlgebra,Distributions #,Interact,StatsBase

## Generate samples from a noisy Gaussian

In [None]:
# ENV["GRDIR"]=""
# Pkg.build("GR")

In [None]:
Random.seed!(4);
Range=3.0; # range of the x values for the target Gaussian function
Incr = 0.03; # determines the number of samples from which we'll learn
Noise_std=0.1; # add noise on the Gaussian

In [None]:
# generate the data from which we'll learn the Gaussian function
# obligatory arguments listed before ";" while optional arguments come after ";".
function gen_noisy_gaussian(;range=1.0,noise=0.1)
    x = collect(-Range:Incr:Range)
    y = exp.(-x.^2) + randn(length(x))*noise; # additive gaussian noise
    return (x,y)
end
# output is two vectors x,y

In [None]:
(x_train,y_train) = gen_noisy_gaussian(range=Range,noise=Noise_std);
pop!(x_train);pop!(y_train);

In [None]:
Ntrain =length(x_train) # number of training data points

In [None]:
plot(x_train,[y_train,exp.(-x_train.^2)])

## Construct the network and the loss function

In [None]:
HiddenSize = 2; # number of neurons in the hidden layer
Batchsize = 10;
RegWeight=0.001; # lambda for L2 regularization
InitNorm = 0.5; # initial weight norm

In [None]:
# The deep learning package requires a certain array structure for the weights
# but it is easier for the later analysis to dump them all into a single column vector
function flat(w) # make a single vector out of all weights
    return vcat(w[1],w[2],w[3],w[4])
end

In [None]:
# reconstruct the weight array from the flat weight vector
function unflat(wf)
    return [wf[1:HiddenSize],wf[HiddenSize+1:2*HiddenSize],wf[2*HiddenSize+1:3*HiddenSize],wf[end]]
end

In [None]:
# Change this seed to try different initial weigths w/o changing the training data
Random.seed!(2);

In [None]:
# initialize weights: w = [w1,w2,w3,w4] -> output = w3*tanh.(w1*x .+ w2) .+ w4
w = [rand(HiddenSize),rand(HiddenSize),rand(HiddenSize),rand()];
w = InitNorm*w/norm(flat(w)); # rescale w so that the norm is InitNorm
Nweights = length(flat(w));

In [None]:
function predict(w,x) # returns a row of predicted values for each sample in x
    return w[3]'*tanh.(w[1]*x' .+ w[2]) .+ w[4]
end

In [None]:
# both x and y are ordered in columns per training data point
function sqloss(w,x,y)
    return mean(abs2,y'-predict(w,x))
end

function reg(w)
    return RegWeight*sum(norm(w[i])^2  for i=1:4)
end

function loss(w,x,y)
    return sqloss(w,x,y) + reg(w)
end

## Gradient-calculating functions

In [None]:
# grad() is a "functional" whose input and output
# is a function. Note that grad() requires loss to be a scalar function
lossgradient = grad(loss)
sqlossgradient = grad(sqloss)
reggradient = grad(reg)

In [None]:
# calculate gradient at the initial w
# dw has dimensions of w: each weight w_i is replaced with the gradient wrt w_i
dw = lossgradient(w,x_train,y_train)

## Training function (with replacement)

In [None]:
function batchtrain!(w,lr)
perm = randperm(Ntrain)[1:Batchsize]; # a random permutation of [1:Ntrain] - pick batches as chunks from this array        # construct batch
x = [x_train[n] for n in perm]
y = [y_train[n] for n in perm]
# calculate gradient over the batch
dw = lossgradient(w,x,y);
# update weights
for i=1:length(w)
    w[i] -= lr*dw[i]
end
return w
end

In [None]:
# THE TRAINING FUNCTION THAT PRESENTS THE TRAINING SET IN RANDOM BATCHES (WITH NO REPLACEMENT)
# For random batches WITH replacement, move the line "perm = .." inside the for loop

function mytrain!(w;lr=0.1)
    Nbatch = floor(Int,Ntrain/Batchsize); # few training samples will be left out if Ntrain/Batchsize != integer
    for nb=1:Nbatch
        batchtrain!(w,lr)
    end
    return w
end

## Initial training run

In [None]:
Nepoch = 2000; # For a quick training run
η = 0.001; # Learning rate

In [None]:
# Increase learning rate by a factor of 10 for the initial run
@time w_training = [ deepcopy(mytrain!(w,lr=10*η)) for i=1:Nepoch ];  # copy only copies the top layer, does not descend.
wf_training = zeros(Nepoch,Nweights); for i=1:Nepoch wf_training[i,:] = flat(w_training[i]) end

In [None]:
# Checking if the training worked. Compare the learned function with the actual gaussian
xplot=collect(-Range:0.01:Range)
# plot the converged function, the initial gaussian and the noisy training samples
plot(xplot,[predict(w,xplot)',exp.(-xplot.^2)]); scatter!(x_train,y_train,leg=false)

In [None]:
# Plot loss vs epoch
SamplingRate=10;
x = collect(1:SamplingRate:Nepoch);
y = [loss(w_training[i],x_train,y_train) for i in x];
plot(x,y)
#plot(x,y,xaxis=:log10,yaxis=:log10) # can also plot in log-log scale

## Diffusion tensor, Hessian, Covariance Matrix

In [None]:
function diffusiontensor(w,xt,yt,Nb,lr)
    Nweights = length(flat(w)) # number of weights, that is, dimensions of the diffusion tensor
    Nt = length(xt) # number of training examples to be summed over
    prefac = (Nt-Nb)/(2*Nb*(Nt-1))
    V = zeros(Nweights,Nt) # initialize the diffusion matrix
    for i=1:Nt
        x=xt[i]
        y=yt[i]
        V[:,i] = flat(lossgradient(w,[x],[y]))
    end
    V /= Nt;
    dL = flat(lossgradient(w,xt,yt));
    
    return lr^2 * prefac * (Nt*V*V' - dL*dL')
end

In [None]:
# Calculate the diffusion tensor by sampling the noise
function diffusiontensor_num(w,n,lr) # n: number of samples used for estimation
    wstart = deepcopy(w)
    wlist = zeros(Nweights,n)
    
    for i=1:n
        ww = deepcopy(wstart)
        batchtrain!(ww,lr)
        wlist[:,i] = flat(ww)
    end

    # subtract mean
    wlist .-= sum(wlist[:,i] for i=1:n)/n

    D = zeros(Nweights,Nweights)
    for α=1:Nweights
        for β=1:Nweights
            for i=1:n
                D[α,β] += wlist[α,i]*wlist[β,i]
            end
            D[α,β] /= n;
        end
    end
    return 0.5*D
end

In [None]:
# To calculate the Hessian,
# define a function returning elements of the loss-gradient vector dL/dw_j

function lossgradj(w,x,y,j)
    return flat(lossgradient(w,x,y))[j]
end

In [None]:
# Define a function calculating a column of Hessian matrix:
# Returns d^2(L)/dw_idw_j for all i and given j
lossgradgrad = grad(lossgradj)

In [None]:
function hessianmatrix(w,x,y)
    Nw = length(flat(w));
    Hess = zeros(Nw,Nw);
    for j=1:Nw
        Hess[:,j] = flat(lossgradgrad(w,x,y,j))
    end
    return Hess
end

In [None]:
# This calculation is from Michael's overleaf notes:
# https://www.overleaf.com/2523873322bvvnxpwnskfk
function covariancematrix(D,H,lr) # is a function of the learning rate
    F = eigen(H);
    h = F.values
    O = F.vectors
    Nw = length(h)
    ODO = O'*D*O;
    Delta = zeros(Nw,Nw);
    for i=1:Nw
        for j=1:Nw
            Delta[i,j] = ODO[i,j]/(h[i]+h[j])
        end
    end
    return (2/lr)*O*Delta*O'
end

### Use Newton's Method to find the minimum of loss

In [None]:
# Using Newton's method get the the true minimum of the full loss function
wf = flat(w_training[end]);
# implement Newton's method to find the true minimum. 4 steps are enough!
for n=1:10
    Hess = hessianmatrix(unflat(wf),x_train,y_train)
    gradwf = flat(lossgradient(unflat(wf),x_train,y_train))
    wf = wf - inv(Hess)*gradwf
end

wminf = wf[:,1]
wmin = unflat(wminf)

In [None]:
lossgradient(wmin,x_train,y_train)

### Hessian at the loss minimum

In [None]:
Hessmin = hessianmatrix(wmin,x_train,y_train)

In [None]:
eigvals(Hessmin)

### Diffusion tensor at the loss minimum

In [None]:
# diffusion tensor at the loss minimum (using VV')
Dmin = diffusiontensor(wmin,x_train,y_train,Batchsize,η)

In [None]:
# compare with D calculated numerically (using 10000 trajectory points)
diffusiontensor_num(wmin,10000,η) ./ Dmin

### Covariance matrix as a function of Hessian and Diffusion matrix

In [None]:
Covmin = covariancematrix(Dmin,Hessmin,η)

In [None]:
# verify that the math is right: HC+CH = (2/η)D
Hessmin*Covmin + Covmin*Hessmin - (2/η)*Dmin

## Steady state

In [None]:
Random.seed!(1) # Verified that the results don't change for different seeds.
Nmarkov = 3000000; # number of memoryless "Markovian" steps

trans = 500000;

In [None]:
w = deepcopy(wmin); # start from the minimum of the potential

In [None]:
@time w_ss = [ deepcopy(batchtrain!(w,η)) for step=1:Nmarkov ];

In [None]:
# Construct the flat trajectory
@time wf_ss = zeros(Nmarkov,Nweights); for i=1:Nmarkov wf_ss[i,:] = flat(w_ss[i]) end

### Visualize the steady-state distribution

In [None]:
# weight indices to visualize
xid = 5
yid = 6;

In [None]:
# visualize

using StatsBase

ss_range=collect(trans:Nmarkov)
wx = wf_ss[ss_range,xid]
wy = wf_ss[ss_range,yid]
resxy=(200,200) # histogram bins

fith = fit(Histogram,(wx,wy),nbins=resxy)

fith.weights # bin counts
fith.edges # bin boundaries
maxhist=maximum(fith.weights) # will use later for better looking plots

histogram2d(wx,wy,bins=resxy)
scatter!([wminf[xid,1]],[wminf[yid,1]],leg=false,markercolor="cyan",markersize=4) # loss minimum

### Fit a Mv-Gaussian to the equilibrium data

In [None]:
Fit_ss = fit_mle(MvNormal,wf_ss[ss_range,:]')

### Steady-state mean


In [None]:
meanwf = Distributions.mean(Fit_ss)
meanw = unflat(meanwf)

### Covariance matrix (from ss-trajectory)

In [None]:
Cov_ss = Distributions.cov(Fit_ss)

In [None]:
# Compare with the solution of ΣH + HΣ = (2/η)D
Cov_ss./covariancematrix(Dmin,Hessmin,η)

In [None]:
# 2x2 submatrix that goes into the exponent of the projected fit
Cov_xy_inv = inv(Cov_ss[[xid,yid],[xid,yid]])

### Visualize the steady-state distribution on top of the loss landscape

In [None]:
# Construct a grid enclosing the steady-state trajectory
minmaxdiff(t) = maximum(t)-minimum(t)

function makegrid(xvec,yvec,mean,xindex,yindex;Nx=10,Ny=10,zoom=0.75)
    Lx,Ly = minmaxdiff(xvec),minmaxdiff(yvec)
    xrange = zoom*Lx
    yrange = zoom*Ly
    dx = xrange/Nx
    dy = yrange/Ny
    x = collect(-xrange:dx:xrange) .+ mean[xindex]
    y = collect(-yrange:dy:yrange) .+ mean[yindex]

    # some mumbo-jumbo for calculating weights corresponding to grid points
    Identity = Diagonal(ones(Nweights,Nweights)); # unit matrix
    xmask = Identity[:,xindex];
    ymask = Identity[:,yindex];
    Imask = Identity - xmask*xmask' - ymask*ymask' # set two diagonal elements to zero
    return (x,y,Imask,xmask,ymask)
end

(x,y,Imask,xmask,ymask) = makegrid(wx,wy,meanwf,xid,yid)

histogram2d(wx,wy,bins=200,aspect_ratio=1.0)

meanxy = meanwf[[xid yid]]
## mv-Gaussian fit contours
fexp(s,t) = -(([s t]-meanxy)*Cov_xy_inv*([s t]-meanxy)')[1]
ffit(s,t) =  maxhist * fexp(s,t)/fexp(x[end],y[end])
contour!(x,y,ffit,linestyle=:dash)

## Loss contours
midx = Int((length(x)-1)/2)
midy = Int((length(y)-1)/2)
fexp(s,t) = loss(unflat(Imask*meanwf + s*xmask + t*ymask),x_train,y_train) - loss(wmin,x_train,y_train)
flossxy(s,t) = maxhist * log(fexp(s,t))/log(fexp(x[midx],y[midy]))
contour!(x,y,flossxy)

### Move to the eigen-coordinates

In [None]:
# pick two eigen directions
Xid = Nweights
Yid = Nweights-1

O = eigvecs(Cov_ss);

# revert to original weights
#O *= O' # identity
#Xidx = 5
#Yidx = 6


W_ss = wf_ss*O; # sample weights are row vectors
Wx = W_ss[trans:end,Xid]
Wy = W_ss[trans:end,Yid]

COV_ss = O'*Cov_ss*O
COV_xy_inv = inv(COV_ss[[Xid,Yid],[Xid,Yid]])

meanW = O'*meanwf
Wminf = O'*wminf;


In [None]:
(x,y,Imask,xmask,ymask) = makegrid(Wx,Wy,meanW,Xid,Yid)

histogram2d(Wx,Wy,bins=200,aspect_ratio=1)

meanXY = meanW[[Xid Yid]]
# Contours of the fit mv-Gaussian
fexp(s,t) = -(([s t]-meanXY)*COV_xy_inv*([s t]-meanXY)')[1]
Ffit(s,t) = maxhist* fexp(s,t)/fexp(x[end],y[end])
contour!(x,y,Ffit,linestyle=:dash)

# contours of loss
midx = Int((length(x)-1)/2)
midy = Int((length(y)-1)/2)
fexp(s,t) = loss(unflat(O*(Imask*meanW + s*xmask + t*ymask)),x_train,y_train) - loss(wmin,x_train,y_train)
Flossxy(s,t) = (maxhist/5) * (log(fexp(s,t)) - log(fexp(x[midx],y[midy])))
contour!(x,y,Flossxy)

### Area-sweep estimation of rotation

In [None]:
function arealvelocity(traj,x,y,center)
    N = length(traj[:,1])
    Cosθs=zeros(N-1);
    Sinθs=zeros(N-1);
    Area=zeros(N-1);
    Areasum=zeros(N);
    for n=1:N-1
        # construct vectors connecting the trajectory points to the center (mean)
        v1=[traj[n,x]-center[x],traj[n,y]-center[y],0]
        v2=[traj[n+1,x]-center[x],traj[n+1,y]-center[y],0]
        v1norm = norm(v1)
        v2norm = norm(v2)
        # get the angle between them. Sign of Sinθ gives the direction
        Cosθs[n] = dot(v1,v2)/(v1norm*v2norm)
        Sinθs[n] = cross(v1,v2)[3]/(v1norm*v2norm)
        Area[n] = v1norm*v2norm*Sinθs[n]/2
        Areasum[n+1] = Areasum[n]+Area[n]
    end
    
    for n=1:N
        Areasum[n] /= n
    end
    
    return Areasum   
end

# test..
# N=10000;
# circ = zeros(N,2)
# for n=1:N
#     circ[n,:] = [cos(2π*n/N)+0.1*rand(),sin(2π*n/N)+0.1*rand()]
# end
# plot(arealvelocity(circ,1,2,[0,0])/(π/N),ylim=[0,2])

parr = arealvelocity(W_ss[trans:end,:],Xid,Yid,meanW);

In [None]:
mynorm = sqrt(π^2*COV_ss[Xid,Xid]*COV_ss[Yid,Yid]/4)
plot(parr[1:100:end]/mynorm,leg=false,ylim=[-0.0001,0.0001]) 

In [None]:
# Function that calculates the probability current at a given w
# again, from Michael's notes.

# Note that, current is calculated in the original weight basis

function currentvec(w,wcenter,x,y,lr)
    deltaw = flat(w-wcenter)
    H = hessianmatrix(w,x,y);
    return (Dmin*inv(covariancematrix(Dmin,H,lr)) - lr*H)*deltaw 
end

In [None]:
# Visualizing the vector fields

# scan the vicinity of ss-distribution's mean (center point of the grid) the 
# get a sense of the current vector field. Two components are scanned while
# the rest are fixed to their value at the mean.

function gridarray(wcenter,x_index,y_index,Nx,Ny,Lx,Ly)
    warray = zeros(Nweights,(Nx+1)*(Ny+1))
    for nx=0:Nx
        for ny=0:Ny
            dw = wcenter-wcenter # zero
            dw[x_index] = -Lx + 2*Lx*nx/Nx
            dw[y_index] = -Ly + 2*Ly*ny/Ny
            warray[:,1+nx*(Ny+1)+ny] = wcenter + dw
        end
    end
    return warray
end


In [None]:
Lx = minmaxdiff(Wx)/2
Ly = minmaxdiff(Wy)/2
Nx = Ny = 10
Warray = gridarray(Wminf,Xid,Yid,Nx,Ny,Lx,Ly)

# Scatter plot of weight values. wstar is shown in red
#scatter(warray[x_index,:],warray[y_index,:]); scatter!([wmin_ef[x_index]],[wmin_ef[y_index]],leg=false)
scatter(Warray[Xid,:],Warray[Yid,:]); scatter!([Wminf[Xid]],[Wminf[Yid]],leg=false)

In [None]:
# calculate current vectors
# Need to back transform the warray in eigen-basis to the original weight basis
# using (O*warray), since currentvec() is defined for the original weights

npts = length(Warray[1,:])
currents = zeros(Nweights,npts)
for i=1:npts
    currents[:,i] = O'*currentvec(unflat(O*Warray[:,i]),meanw,x_train,y_train,η)
end
currents

In [None]:
using Interact

m = @manipulate for xind in slider(1:7,value=5), yind in slider(1:7,value=7)
    Wxx = W_ss[trans:end,xind]
    Wyy = W_ss[trans:end,yind]
    Lx = minmaxdiff(Wxx)/2
    Ly = minmaxdiff(Wyy)/2
    Nx = Ny = 10
    Warr = gridarray(Wminf,xind,yind,Nx,Ny,Lx,Ly)
    npts = length(Warr[1,:])
    currents = zeros(Nweights,npts)
    for i=1:npts
        currents[:,i] = O'*currentvec(unflat(O*Warr[:,i]),meanw,x_train,y_train,η)
    end
    x = Warr[xind,:];
    y = Warr[yind,:];
    mynorm = 3*sqrt(4*Lx*Ly/(Nx*Ny))/maximum(abs.(currents[[xind yind],:]))
    u_cur = mynorm*currents[xind,:];
    v_cur = mynorm*currents[yind,:];
    histogram2d(Wxx,Wyy,bins=200)#,aspect_ratio=1.0)
    quiver!(x, y, quiver=(u_cur, v_cur));
    scatter!([meanW[xind]],[meanW[yind]],leg=false)
end

In [None]:
# plot the current vector field
xind = 6
yind = 4
Warray = gridarray(Wminf,xind,yind,Nx,Ny,Lx,Ly)
npts = length(Warray[1,:])
currents = zeros(Nweights,npts)
for i=1:npts
    currents[:,i] = O'*currentvec(unflat(O*Warray[:,i]),meanw,x_train,y_train,η)
end
x = Warray[xind,:];
y = Warray[yind,:];
mynorm = 3*sqrt(4*Lx*Ly/(Nx*Ny))/maximum(abs.(currents[[xind yind],:]))
u_cur = mynorm*currents[xind,:];
v_cur = mynorm*currents[yind,:];

histogram2d(W_ss[trans:end,xind],W_ss[trans:end,yind],bins=200)
quiver!(x, y, quiver=(u_cur, v_cur));
scatter!([meanW[xind]],[meanW[yind]],leg=false)

(xx,yy,Imask,xmask,ymask) = makegrid(Wx,Wy,meanW,xind,yind)
meanXY = meanW[[xind yind]]
midx = Int((length(xx)-1)/2)
midy = Int((length(yy)-1)/2)

# contours of loss
contour!(xx,yy,Flossxy)