# Simple LeNet for MNIST classification

In [1]:
using Knet
using NNHelferlein
using MLDatasets: MNIST

┌ Info: Precompiling Knet [1902f260-5fb4-5aff-8c31-6271790ab950]
└ @ Base loading.jl:1278
┌ Info: Precompiling NNHelferlein [b9e938e5-d80d-48a2-bb0e-6649b4a98aeb]
└ @ Base loading.jl:1278
┌ Info: Precompiling MLDatasets [eb30cadb-4394-5ae3-aed4-317e484a6458]
└ @ Base loading.jl:1278


### Get MNIST data:

In [2]:
xtrn,ytrn = MNIST.traindata(Float32)
ytrn[ytrn.==0] .= 10
dtrn = minibatch(xtrn, ytrn, 100; xsize = (28,28,1,:))

xvld,yvld = MNIST.testdata(Float32)
yvld[yvld.==0] .= 10
dvld = minibatch(xvld, yvld, 100; xsize = (28,28,1,:));

### Define LeNet with NNHelferlein types:
(Knet style)

In [3]:
lenet = Classifier(Conv(5,5,1,20), 
                Pool(),
                Conv(5,5,20,50),
                Pool(),
                Flat(),
                Dense(800,512), 
                Predictions(512,10)
        )

Classifier((Conv(P(Array{Float32,4}(5,5,1,20)), P(Array{Float32,4}(1,1,20,1)), (0, 0), Knet.Ops20.relu), Pool(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}()), Conv(P(Array{Float32,4}(5,5,20,50)), P(Array{Float32,4}(1,1,50,1)), (0, 0), Knet.Ops20.relu), Pool(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}()), Flat(), Dense(P(Array{Float32,2}(512,800)), P(Array{Float32,1}(512)), Knet.Ops20.sigm), Dense(P(Array{Float32,2}(10,512)), P(Array{Float32,1}(10)), identity)))

### Train with TensorBoard log:

In [5]:
tb_train!(lenet, Adam, dtrn, dvld, epochs=1,
        acc_fun=accuracy,
        eval_size=0.25, eval_freq=2, mb_loss_freq=100, 
        tb_name="example_run", tb_text="NNHelferlein example")

Training 1 epochs with 600 minibatches/epoch
    (and 100 validation mbs).
Evaluation is performed every 300 minibatches (with 25 mbs).
Watch the progress with TensorBoard at: logs/example_run/2021-01-20T18-08-11


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:02:11[39m


Training finished with:
Training loss:       0.0684726345182086
Training accuracy:   0.9784833333333395
Validation loss:     0.0632524535007542
Validation accuracy: 0.9794999999999997


Classifier((Conv(P(Array{Float32,4}(5,5,1,20)), P(Array{Float32,4}(1,1,20,1)), (0, 0), Knet.Ops20.relu), Pool(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}()), Conv(P(Array{Float32,4}(5,5,20,50)), P(Array{Float32,4}(1,1,50,1)), (0, 0), Knet.Ops20.relu), Pool(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}()), Flat(), Dense(P(Array{Float32,2}(512,800)), P(Array{Float32,1}(512)), Knet.Ops20.sigm), Dense(P(Array{Float32,2}(10,512)), P(Array{Float32,1}(10)), identity)))