# Imports

In [34]:
using MLDatasets: MNIST
using Knet, IterTools, MLDatasets
using Base.Iterators: take, drop, cycle, Stateful
using Printf
using Knet:minibatch
using Knet:minimize
using Knet
using Knet: Param
using Knet: Knet, dir, accuracy, progress, sgd, gc, Data, nll, relu
using Flatten
using Flux.Data;
using Flux, Statistics
import Flatten: flattenable

# Processing Data

In [35]:
# This loads the MNIST handwritten digit recognition dataset. This code is based off the Knet Tutorial Notebook. 
xtrn,ytrn = MNIST.traindata(Float32)
xtst,ytst = MNIST.testdata(Float32)
println.(summary.((xtrn,ytrn,xtst,ytst)));

28×28×60000 Array{Float32, 3}
60000-element Vector{Int64}
28×28×10000 Array{Float32, 3}
10000-element Vector{Int64}


In [36]:
xtrn = reshape(xtrn, 784, 60000 ) 
xtst = reshape(xtst, 784, 10000 )
println(summary.((xtrn, xtst))) # can see the data that is flattened 

("784×60000 Matrix{Float32}", "784×10000 Matrix{Float32}")


In [37]:
#Preprocessing targets: one hot vectors
# ytrn = onehotbatch(ytrn, 0:9)
# ytst = onehotbatch(ytst, 0:9)

# Batch Processing

In [38]:
train_loader = DataLoader((xtrn, ytrn), batchsize=128);

In [39]:
train_loader

DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())

In [40]:
(x,y) = first(train_loader) #gives the first minibatch from training dataset
println.(summary.((x,y)));

784×128 Matrix{Float32}
128-element Vector{Int64}


In [41]:
train_loader

DataLoader{Tuple{Matrix{Float32}, Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}((Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], [5, 0, 4, 1, 9, 2, 1, 3, 1, 4  …  9, 2, 9, 5, 1, 8, 3, 5, 6, 8]), 128, false, true, false, false, Val{nothing}(), Random._GLOBAL_RNG())

# Define Dense Layer

In [42]:
struct Dense1; w; b; f; end
Dense1(i,o; f=relu) = Dense1(param(o,i), param0(o), f)
(d::Dense1)(x) = d.f.(d.w * mat(x) .+ d.b)

# Define Chain Layer


In [43]:
# Define a chain of layers and a loss function:
struct Chain; layers; end
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)
(c::Chain)(x,y) = nll(c(x),y)

# Define the Model

In [44]:
model = Chain((Dense1(784, 100), Dense1(100, 10)))

Chain((Dense1(P(Matrix{Float32}(100,784)), P(Vector{Float32}(100)), Knet.Ops20.relu), Dense1(P(Matrix{Float32}(10,100)), P(Vector{Float32}(10)), Knet.Ops20.relu)))

# Training

In [45]:
model(x) #checking if training is working

10×128 Matrix{Float32}:
 0.214272  0.269956   0.0         …  0.0        0.0        0.0
 0.0       0.0        0.0            0.0        0.0        0.0688322
 0.447498  0.223454   0.391645       0.0        0.0        0.0
 0.0       0.0        0.209459       0.0        0.0        0.0
 0.11944   0.0        0.0            0.124931   0.0        0.0
 0.480624  0.0618137  0.0567281   …  0.337361   0.321017   0.279554
 0.567259  0.579033   0.0            0.0177205  0.0287222  0.0821771
 0.519518  0.900291   0.177696       0.304197   0.39036    0.436304
 0.0       0.0        0.00714952     0.0        0.0        0.0
 0.0       0.0        0.0            0.0        0.0        0.0

In [47]:
loss(xtst, ytst) = nll(model(xtst), ytst)
evalcb = () -> (loss(xtst, ytst)) #function that will be called to get the loss 

    for epoch in 1:10
        @time begin
        progress!(adam(model, train_loader; lr = 1e-3))
        @printf("epoch is %d, loss is %f, accuracy is %f", epoch, (evalcb()), accuracy(model, train_loader))
        end 
    end 


println("Overall Loss: ", evalcb()) 
println("Overall Accuracy: ", accuracy(model, train_loader))


┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 122.72i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 162.35i/s] 

epoch is 1, loss is 0.223415, accuracy is 0.935518  4.726154 seconds (327.16 k allocations: 1.256 GiB, 15.36% gc time, 2.45% compilation time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 185.82i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:02, 205.09i/s] 

epoch is 2, loss is 0.161276, accuracy is 0.956081  3.293366 seconds (311.97 k allocations: 1.254 GiB, 4.57% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 161.56i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 169.43i/s] 

epoch is 3, loss is 0.129508, accuracy is 0.967232  3.659111 seconds (312.32 k allocations: 1.254 GiB, 4.26% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 151.47i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 162.74i/s] 

epoch is 4, loss is 0.111773, accuracy is 0.973371  3.976988 seconds (312.50 k allocations: 1.254 GiB, 4.70% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 163.64i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:02, 247.97i/s] 

epoch is 5, loss is 0.100974, accuracy is 0.978142  3.635032 seconds (312.30 k allocations: 1.254 GiB, 5.48% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 156.73i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 172.15i/s] 

epoch is 6, loss is 0.093855, accuracy is 0.981822  3.744225 seconds (312.30 k allocations: 1.254 GiB, 4.23% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 162.35i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 164.20i/s] 

epoch is 7, loss is 0.089960, accuracy is 0.984115  3.644892 seconds (312.31 k allocations: 1.254 GiB, 3.83% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 170.55i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:02, 196.92i/s] 

epoch is 8, loss is 0.087794, accuracy is 0.986242  3.537420 seconds (312.31 k allocations: 1.254 GiB, 5.22% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:04/00:04, 109.85i/s] 
┣                    ┫ [0.21%, 1/469, 00:00/00:03, 179.85i/s] 

epoch is 9, loss is 0.085906, accuracy is 0.987814  5.220644 seconds (312.67 k allocations: 1.254 GiB, 4.46% gc time)


┣████████████████████┫ [100.00%, 469/469, 00:03/00:03, 157.98i/s] 


epoch is 10, loss is 0.086418, accuracy is 0.989034  3.682666 seconds (312.32 k allocations: 1.254 GiB, 3.70% gc time)
Overall Loss: 0.086418
Overall Accuracy: 0.989034155001202
