In [1]:
using Knet, Plots, Statistics, LinearAlgebra, Random

┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1260


In [2]:
Base.argmax(a::KnetArray) = argmax(Array(a))
Base.argmax(a::AutoGrad.Value) = argmax(value(a))

In [3]:
ENV["COLUMNS"] = 72

72

In [6]:
include(Knet.dir("data/mnist.jl"))

xtrn, ytrn, xtst, ytst = mnist()

ARRAY = Array{Float32}

xtrn, xtst = ARRAY( mat(xtrn) ), ARRAY( mat(xtst) )

function onehot(y)
    m = ARRAY( zeros( maximum(y), length(y) ) )
    for i in 1:length(y)
        m[ y[i], i] = 1
    end
    m
end
    
ytrn, ytst = onehot(ytrn), onehot(ytst)

└ @ Pkg D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.4\Pkg\src\Pkg.jl:531


(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 1.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0])

In [7]:
println.( summary.((xtrn, ytrn, xtst, ytst)))

784×60000 Array{Float32,2}
10×60000 Array{Float32,2}
784×10000 Array{Float32,2}
10×10000 Array{Float32,2}


(nothing, nothing, nothing, nothing)

In [11]:
NTRN, NTST, XDIN, YDIN = size(xtrn,2), size(xtst,2), size(xtrn,1), size(ytrn,1)

(60000, 10000, 784, 10)

In [12]:
# Model weights
w = ARRAY(randn(YDIN, XDIN))

10×784 Array{Float32,2}:
  1.46921    -0.82374     1.02799    …   0.767117     1.18014
 -0.734889    1.63881    -0.233767      -0.176612     0.742851
 -0.588088   -0.635935   -1.07059       -1.20605      1.12807
  0.45404    -0.843134   -0.165599      -0.465834    -0.00227915
 -0.941397    0.277438   -2.19875       -0.00797649  -1.00516
  0.199459   -0.801237    0.825526   …  -0.294744     0.289672
 -1.3696      0.0872109  -0.0883313      0.230567     1.64465
  0.0182876   0.509391    0.0119192      0.0563827    0.303291
  0.667549   -0.668325   -1.19391       -0.0636985    0.36177
  0.553008   -1.84913    -0.291411       1.65455      1.6533

In [13]:
# class scores
w * xtrn

10×60000 Array{Float32,2}:
  15.6275    12.3336      8.75735   …   18.4437     5.84873   7.28817
 -12.7581    -4.3409      2.39845        5.68609   -8.59408   7.63697
   2.85026   12.607     -10.0656        -1.70818   21.7291   -5.82019
  -2.91722    6.66727     0.775055      -5.3123    14.0391    9.04198
   3.53098    3.92803     6.70215       -3.03835    2.82004  10.1184
  18.5901    -2.57427    13.1555    …   12.5579    -3.06313  -6.2331
  -0.424006   4.61277    11.8919        10.2205     8.25196   8.79809
  -9.52184   -3.37142    10.3548        -0.253827  -2.168    -4.66507
  -0.666528   4.77355    -3.17451      -11.791     -4.8382   -5.25604
   3.76447   -0.925771    1.73937       14.8356    -4.5529    9.76019

In [14]:
# correct answers
[ argmax( ytrn[:, i]) for i in 1:NTRN]'

1×60000 Adjoint{Int64,Array{Int64,1}}:
 5  10  4  1  9  2  1  3  1  4  3  …  8  9  2  9  5  1  8  3  5  6  8

In [19]:
# Accuracy

acc(w, x, y) = mean( argmax(w * x, dims=1) .== argmax(y, dims=1))
acc(w, xtrn, ytrn), acc(w, xtst, ytst)

(0.14406666666666668, 0.1479)

In [20]:
# Training LOOPS
function train(algo, x, y, T=2^20)
    w = ARRAY( zeros( size(y,1), size(x,1) ))
    n_examples = size(x,2)
    next_print = 1
    for t = 1:T
        i = rand(1:n_examples)
        algo(w, x[:,i], y[:,i])
        if t == next_print
            println((iter = t, accuracy = acc(w, x, y), wnorm=norm(w)))
            next_print = min(2t,T)
        end
    end
    w
end

train (generic function with 2 methods)

In [21]:
# Perceptron

function perceptron(w, x, y)
    guess = argmax(w * x)
    class = argmax(y)
    if guess != class
        w[class,:] .+= x
        w[guess,:] .-= x
    end
end

perceptron (generic function with 1 method)

In [22]:
@time wperceptron = train(perceptron, xtrn, ytrn)

(iter = 1, accuracy = 0.09871666666666666, wnorm = 17.490957f0)
(iter = 2, accuracy = 0.15111666666666668, wnorm = 19.585426f0)
(iter = 4, accuracy = 0.13533333333333333, wnorm = 25.171194f0)
(iter = 8, accuracy = 0.16751666666666667, wnorm = 32.31005f0)
(iter = 16, accuracy = 0.11953333333333334, wnorm = 41.04537f0)
(iter = 32, accuracy = 0.29855, wnorm = 52.008053f0)
(iter = 64, accuracy = 0.4830333333333333, wnorm = 67.332634f0)
(iter = 128, accuracy = 0.5064666666666666, wnorm = 85.274055f0)
(iter = 256, accuracy = 0.6461333333333333, wnorm = 115.53338f0)
(iter = 512, accuracy = 0.6070166666666666, wnorm = 140.69481f0)
(iter = 1024, accuracy = 0.7332, wnorm = 174.76021f0)
(iter = 2048, accuracy = 0.7402333333333333, wnorm = 218.41365f0)
(iter = 4096, accuracy = 0.82755, wnorm = 272.1519f0)
(iter = 8192, accuracy = 0.8335, wnorm = 337.5124f0)
(iter = 16384, accuracy = 0.8628833333333333, wnorm = 413.61862f0)
(iter = 32768, accuracy = 0.8724833333333334, wnorm = 493.90195f0)
(iter = 

10×784 Array{Float32,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …   0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0      0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0      0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     -0.698039  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0      0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  …   0.0       0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0      5.11373   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     -0.227451  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     -2.14118   0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0     -2.04706   0.0  0.0  0.0  0.0

In [23]:
function adaline(w, x, y; lr=0.0001)
    error = w * x - y
    w .-=lr * error * x'
end

adaline (generic function with 1 method)

In [24]:
@time wadaline = train(adaline, xtrn, ytrn)

(iter = 1, accuracy = 0.0993, wnorm = 0.0011197067f0)
(iter = 2, accuracy = 0.11498333333333334, wnorm = 0.0015732059f0)
(iter = 4, accuracy = 0.10363333333333333, wnorm = 0.002574287f0)
(iter = 8, accuracy = 0.11326666666666667, wnorm = 0.0034773455f0)
(iter = 16, accuracy = 0.19723333333333334, wnorm = 0.0052489904f0)
(iter = 32, accuracy = 0.15298333333333333, wnorm = 0.009137772f0)
(iter = 64, accuracy = 0.2892, wnorm = 0.014683782f0)
(iter = 128, accuracy = 0.4221666666666667, wnorm = 0.02620069f0)
(iter = 256, accuracy = 0.49605, wnorm = 0.043245707f0)
(iter = 512, accuracy = 0.5553333333333333, wnorm = 0.06823031f0)
(iter = 1024, accuracy = 0.73555, wnorm = 0.10469287f0)
(iter = 2048, accuracy = 0.7642333333333333, wnorm = 0.16665664f0)
(iter = 4096, accuracy = 0.7827166666666666, wnorm = 0.254468f0)
(iter = 8192, accuracy = 0.7974, wnorm = 0.35880274f0)
(iter = 16384, accuracy = 0.8251166666666667, wnorm = 0.46244785f0)
(iter = 32768, accuracy = 0.8390333333333333, wnorm = 0.56

10×784 Array{Float32,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0

In [25]:
function softmax( w, x, y; lr=0.01)
    probs = exp.(w * x)
    probs ./= sum(probs)
    error = probs - y
    w .-= lr * error * x'
end

softmax (generic function with 1 method)

In [26]:
@time wsoftmax = train(softmax, xtrn, ytrn)

(iter = 1, accuracy = 0.09915, wnorm = 0.05874139f0)
(iter = 2, accuracy = 0.09641666666666666, wnorm = 0.09345736f0)
(iter = 4, accuracy = 0.17438333333333333, wnorm = 0.123013206f0)
(iter = 8, accuracy = 0.18971666666666667, wnorm = 0.20881957f0)
(iter = 16, accuracy = 0.26565, wnorm = 0.31606832f0)
(iter = 32, accuracy = 0.36556666666666665, wnorm = 0.48388472f0)
(iter = 64, accuracy = 0.4581, wnorm = 0.79003394f0)
(iter = 128, accuracy = 0.6109166666666667, wnorm = 1.2114425f0)
(iter = 256, accuracy = 0.7113833333333334, wnorm = 1.8747087f0)
(iter = 512, accuracy = 0.78195, wnorm = 2.7205887f0)
(iter = 1024, accuracy = 0.8278666666666666, wnorm = 3.6480122f0)
(iter = 2048, accuracy = 0.85375, wnorm = 4.725615f0)
(iter = 4096, accuracy = 0.8724833333333334, wnorm = 5.860556f0)
(iter = 8192, accuracy = 0.88945, wnorm = 7.097817f0)
(iter = 16384, accuracy = 0.8944166666666666, wnorm = 8.659387f0)
(iter = 32768, accuracy = 0.90695, wnorm = 10.418147f0)
(iter = 65536, accuracy = 0.91591

10×784 Array{Float32,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0

In [30]:
# training via optimization
function optimize(loss, x, y; lr=0.1, iters=2^20)
    w = Param( ARRAY( zeros( size(y, 1), size(x, 1) ) ) )
    n_examples = size(x,2)
    next_print = 1
    for t = 1:iters
        i = rand(1:n_examples)
        L = @diff loss(w, x[:,i], y[:,i])
        ∇w = grad(L,w)
        w .-= lr * ∇w
        if t == next_print
            println((iter=t, accuracy=acc(w,x,y), wnorm=norm(w)))
            next_print = min(2t,iters)
        end
    end
    w
end

optimize (generic function with 1 method)

In [31]:
function perceptronloss(w,x,y)
    score = w * x
    guess = argmax(score)
    class = argmax(y)
    score[guess] - score[class]
end

perceptronloss (generic function with 1 method)

In [32]:
@time wperceptron2 = optimize(perceptronloss,xtrn,ytrn,lr=1);

(iter = 1, accuracy = 0.09736666666666667, wnorm = 11.994327f0)
(iter = 2, accuracy = 0.09751666666666667, wnorm = 16.98135f0)
(iter = 4, accuracy = 0.11601666666666667, wnorm = 21.507568f0)
(iter = 8, accuracy = 0.1274, wnorm = 27.092562f0)
(iter = 16, accuracy = 0.2914833333333333, wnorm = 35.89878f0)
(iter = 32, accuracy = 0.47555, wnorm = 44.747078f0)
(iter = 64, accuracy = 0.30051666666666665, wnorm = 63.20732f0)
(iter = 128, accuracy = 0.58285, wnorm = 82.46719f0)
(iter = 256, accuracy = 0.6792333333333334, wnorm = 109.288574f0)
(iter = 512, accuracy = 0.7593, wnorm = 144.27118f0)
(iter = 1024, accuracy = 0.7638666666666667, wnorm = 178.67558f0)
(iter = 2048, accuracy = 0.7983, wnorm = 227.84065f0)
(iter = 4096, accuracy = 0.7963333333333333, wnorm = 278.55872f0)
(iter = 8192, accuracy = 0.8452, wnorm = 335.31934f0)
(iter = 16384, accuracy = 0.8552, wnorm = 407.56577f0)
(iter = 32768, accuracy = 0.85715, wnorm = 479.41226f0)
(iter = 65536, accuracy = 0.8553, wnorm = 583.99976f0)


In [33]:
function quadraticloss(w,x,y)
    0.5 * sum(abs2, w * x - y)
end

quadraticloss (generic function with 1 method)

In [34]:
@time wadaline2 = optimize(quadraticloss,xtrn,ytrn,lr=0.0001);

(iter = 1, accuracy = 0.09915, wnorm = 0.0009123806f0)
(iter = 2, accuracy = 0.18126666666666666, wnorm = 0.0011192395f0)
(iter = 4, accuracy = 0.12525, wnorm = 0.0019428008f0)
(iter = 8, accuracy = 0.24768333333333334, wnorm = 0.0032802117f0)
(iter = 16, accuracy = 0.2459, wnorm = 0.005103771f0)
(iter = 32, accuracy = 0.22216666666666668, wnorm = 0.008813391f0)
(iter = 64, accuracy = 0.2804833333333333, wnorm = 0.013814504f0)
(iter = 128, accuracy = 0.30006666666666665, wnorm = 0.02399184f0)
(iter = 256, accuracy = 0.35428333333333334, wnorm = 0.040716104f0)
(iter = 512, accuracy = 0.6429166666666667, wnorm = 0.06564615f0)
(iter = 1024, accuracy = 0.6990833333333333, wnorm = 0.10388712f0)
(iter = 2048, accuracy = 0.7501, wnorm = 0.16638325f0)
(iter = 4096, accuracy = 0.7838, wnorm = 0.25608873f0)
(iter = 8192, accuracy = 0.8011, wnorm = 0.35674772f0)
(iter = 16384, accuracy = 0.8284333333333334, wnorm = 0.46174675f0)
(iter = 32768, accuracy = 0.8363666666666667, wnorm = 0.5626146f0)
(

In [35]:
function negloglik(w,x,y)
    probs = exp.(w * x)
    probs = probs / sum(probs)
    class = argmax(y)
    -log(probs[class])
end

negloglik (generic function with 1 method)

In [36]:
@time wsoftmax2 = optimize(negloglik,xtrn,ytrn,lr=0.01);

(iter = 1, accuracy = 0.09736666666666667, wnorm = 0.07533733f0)
(iter = 2, accuracy = 0.1178, wnorm = 0.11761533f0)
(iter = 4, accuracy = 0.13635, wnorm = 0.16791211f0)
(iter = 8, accuracy = 0.2720666666666667, wnorm = 0.21142988f0)
(iter = 16, accuracy = 0.1513, wnorm = 0.34892815f0)
(iter = 32, accuracy = 0.4332, wnorm = 0.48950693f0)
(iter = 64, accuracy = 0.5525, wnorm = 0.76276207f0)
(iter = 128, accuracy = 0.6042, wnorm = 1.1828429f0)
(iter = 256, accuracy = 0.6960833333333334, wnorm = 1.8737861f0)
(iter = 512, accuracy = 0.7996666666666666, wnorm = 2.6965687f0)
(iter = 1024, accuracy = 0.8137666666666666, wnorm = 3.6402164f0)
(iter = 2048, accuracy = 0.8639, wnorm = 4.7378488f0)
(iter = 4096, accuracy = 0.8807666666666667, wnorm = 5.9154253f0)
(iter = 8192, accuracy = 0.8876166666666667, wnorm = 7.185288f0)
(iter = 16384, accuracy = 0.8885166666666666, wnorm = 8.734807f0)
(iter = 32768, accuracy = 0.9077833333333334, wnorm = 10.440622f0)
(iter = 65536, accuracy = 0.910783333333