In [1]:
using GZip

function gzget(fname)
    isfile(fname) || download("http://yann.lecun.com/exdb/mnist/$fname", fname)
    gzopen(fname) do f
        return read(f)
    end
end

function mnist()
    xtrain = gzget("train-images-idx3-ubyte.gz")[17:end]
    xtest  = gzget("t10k-images-idx3-ubyte.gz")[17:end]
    ytrain = gzget("train-labels-idx1-ubyte.gz")[9:end]
    ytest  = gzget("t10k-labels-idx1-ubyte.gz")[9:end]
    return xtrain, xtest, ytrain, ytest
end

function minibatch(x, y, batchsz)
    xrows     = 784
    yrows     = 10
    xscale    = 255
    xbatch(a) = reshape(a./xscale, xrows, length(a).÷xrows)
    ybatch(a) = (a[a.==0]=10; sparse(convert(Vector{Int},a),1:length(a),one(eltype(a)),yrows,length(a)))
    xcols     = div(length(x),xrows)
    xcols == length(y) || throw(DimensionMismatch())
    data = Any[]
    for i in 1:batchsz:xcols - batchsz + 1
        j=i + batchsz - 1
        push!(data, (xbatch(x[1 + (i - 1)*xrows:j*xrows]), ybatch(y[i:j])))
    end
    return data
end

minibatch (generic function with 1 method)

In [2]:
function weights(h...)
    w = Any[]
    x = 28*28
    for y in [h..., 10]
        push!(w, 0.1*randn(y, x))
        push!(w, zeros(y, 1))
        x = y
    end
    w
end

function predict(w, x)
    for i=1:2:length(w)
        x = w[i]*x .+ w[i + 1]
        if i < length(w) - 1
            x = max(0, x)
        end
    end
    x
end

function loss(w, x, y)
    ypred = predict(w, x)
    ynorm = logp(ypred, 1)
    -sum(y.*ynorm)/size(y, 2)
end

function accuracy(w, samples)
    correct   = 0
    instances = 0
    for (x, y) in samples
        ypred      = predict(w, x)
        correct   += sum(y.*(ypred .== maximum(ypred, 1)))
        instances += size(y, 2)
    end
    correct/instances
end

accuracy (generic function with 1 method)

In [3]:
using Knet

∇loss = grad(loss)

function train(w, samples; μ=.1)
    for (x, y) in samples
        ∇w = ∇loss(w, x, y)
        for i in 1:length(w)
            w[i] -= μ*∇w[i]
        end
    end
    w
end

train (generic function with 1 method)

In [4]:
batchsz = 32

w = weights(128, 64, 32)

xtrain, xtest, ytrain, ytest = mnist()

trainset = minibatch(xtrain, ytrain, batchsz)
testset  = minibatch(xtest, ytest, batchsz)

@time for epoch in 1:7
    train(w, trainset)
    @printf("%d\t%.5f\t%.5f\n", epoch, accuracy(w, trainset), accuracy(w, testset))
end

1	0.94527	0.94181
2	0.96315	0.96004
3	0.97247	0.96394
4	0.97480	0.96484
5	0.98297	0.97065
6	0.98602	0.97015
7	0.98445	0.97005
 26.498682 seconds (15.34 M allocations: 40.165 GB, 8.20% gc time)


In [5]:
function minibatch4(x, y, batchsz)
    data = minibatch(x, y, batchsz)
    for i=1:length(data)
        (x, y)  = data[i]
        data[i] = (reshape(x, (28, 28, 1, batchsz)), y)
    end
    data
end

minibatch4 (generic function with 1 method)

In [6]:
function weights()
    ϵ = 0.1
    [
        ϵ*randn(5, 5, 1, 20),
        zeros(1, 1, 20, 1),
        ϵ*randn(5, 5, 20, 50),
        zeros(1, 1, 50, 1),
        ϵ*randn(500, 800),
        zeros(500, 1),
        ϵ*randn(10, 500),
        zeros(10, 1),
    ]
end

function predict(w, x)
    n = length(w) - 4
    for i in 1:2:n
        x = pool(max(conv4(w[i], x;padding=0) .+ w[i + 1], 0))
    end
    x = mat(x)
    for i in n + 1:2:length(w) - 2
        x = max(w[i]*x .+ w[i+1], 0)
    end
    return w[end - 1]*x .+ w[end]
end

function loss(w, x, y)
    ypred = predict(w, x)
    ynorm = logp(ypred, 1)
    -sum(y.*ynorm)/size(y, 2)
end

function accuracy(w, samples)
    correct   = 0
    instances = 0
    for (x, y) in samples
        ypred      = predict(w, x)
        correct   += sum(y.*(ypred .== maximum(ypred, 1)))
        instances += size(y, 2)
    end
    correct/instances
end

using Knet

∇loss = grad(loss)

function train(w, data; μ=.1)
    for (x, y) in data
        ∇w = ∇loss(w, x, y)
        for i in 1:length(w)
            w[i] -= μ*∇w[i]
        end
    end
    w
end



train (generic function with 1 method)

In [11]:
batchsz = 32

w = weights()

xtrain, xtest, ytrain, ytest = mnist()

trainset = minibatch4(xtrain, ytrain, batchsz)
testset  = minibatch4(xtest, ytest, batchsz)

@printf("%d\t%.5f\t%.5f\n", 0, accuracy(w, trainset), accuracy(w, testset))

# Experimental on CPU, run at your own risk.
#@time for epoch in 1:8
#    train(w, trainset)
#    @printf("%d\t%.5f\t%.5f\n", epoch, accuracy(w, trainset), accuracy(w, testset))
#end

0	0.07558	0.07943
