In [2]:
using Knet
using JLD
using Statistics

In [68]:
struct Conv;  
    # The S is to denote stride. However, the next FC layer will have to be modified if stride != 5
    w; s; f;
end
(c::Conv)(x) = c.f.(conv4(c.w, x; stride=c.s, mode = 0))
Conv(w1::Int, w2::Int, cx::Int, cy::Int, s::Int, f=relu) = Conv(param(w1, w2, cx, cy),s,f)

Conv

In [4]:
# Define dense layer:
struct Dense; w; b; f; end
(d::Dense)(x) = d.f.(d.w * mat(x) .+ d.b)
Dense(i::Int,o::Int,f=relu) = Dense(param(o,i), param0(o), f);

In [5]:
# Define a chain of layers:
struct Chain; layers; Chain(args...)=new(args); end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain)(x,y) = nll(c(x),y)

In [18]:
# Function for evaluating model performance
function initopt!(model, optimizer="Adam(lr=0.0001, gclip = 0.0)")
    for par in params(model)
        par.opt = eval(Meta.parse(optimizer))
    end
end

function result(model, trn_data, tst_data, epochs)
    loss(x, yref) = model(x, yref)
    gradients = []
    loss_batch = []
    record = params(model)
    initopt!(model)

    for epoch = 1:epochs
        losses = []
        for (x, y) in trn_data
            lss = @diff loss(x, y)
            push!(losses, value(lss))
            for par in params(model)
                g = grad(lss, par)
                update!(value(par), g, par.opt)
            end
        end
        push!(loss_batch, mean(losses))
#         println("Training accuracy for epoch:$epoch = ", accuracy(model, trn_data))
#         println("Epoch Number : $epoch")
    end
    println("Final Accuracy for test set = ", accuracy(model, tst_data))
    println("Final Loss :", loss_batch[end])
end

result (generic function with 1 method)

In [6]:
# Load MNIST data
include(Knet.dir("data","mnist.jl"))
xtrain, ytrain, xtest, ytest = mnist();
dtrn, dtst = mnistdata();

┌ Info: Loading MNIST...
└ @ Main /home/ahnaf/.julia/packages/Knet/05UDD/data/mnist.jl:33


## Baseline Results

In [7]:
# Train and test LeNet (about 30 secs on a gpu to reach 99% accuracy)
model1 = Chain(Dense(784, 256, identity), Dense(256, 10, identity))
progress!(adam(model1, repeat(dtrn,10)))
println(accuracy(model1, dtst))
println(zeroone(model1, dtst))
println(nll(model1, dtst))

2.56e-01  100.00%┣██████████████████████████▉┫ 6000/6000 [00:09/00:09, 696.91i/s]
0.9138
0.08620000000000005
0.30675378


In [72]:
# Working Model parameters (Conv(10,10,1,8) FC(128,10))
model2 = Chain(Conv(10,10,1,8,5,identity), Dense(128,10, identity))
progress!(adam(model2, repeat(dtrn,10)))
println(accuracy(model2, dtst))
println(nll(model2, dtst))

2.89e-01  100.00%┣█████████████████████████▉┫ 6000/6000 [00:05/00:05, 1296.71i/s]
0.9177
0.2896525
