# Introduction

This notebook is based on the [fast.ai v3 part 2 course "02_fully_connected.ipynb" notebook](https://github.com/fastai/course-v3/blob/master/nbs/dl2/02_fully_connected.ipynb) and the [Knet.jl documentation](https://denizyuret.github.io/Knet.jl/stable/).

The goal is to build up all the building blocks from scratch to learn how everything is working.

# Imports

In [1]:
using Knet
using Statistics
using Test
using AutoGrad

# Data

## Data structures

First, we setup two simple structures for data handling:
* **dataset**: container for x and y data
* **dataloader**: hold training and validation data

In [2]:
struct dataset
    x
    y
end

In [3]:
### THIS NEEDS TO BE UPDATED TO SUPPORT ITERATION !!!
struct dataloader
    train
    valid
end

##  MNIST data

Next, we get the MNIST data with a Knet library helper function and create our own from scratch.

In [4]:
include(Knet.dir("data","mnist.jl"))

In [273]:
mnist_dl_orig = dataloader(mnistdata()...);

The Knet dataloader is already setup as can be seen below, but we will disassemble it and build it up again.

In [6]:
summary.((mnist_dl_orig.train,mnist_dl_orig.valid))

("600-element Knet.Data{Tuple{KnetArray{Float32,4},Array{UInt8,1}}}", "100-element Knet.Data{Tuple{KnetArray{Float32,4},Array{UInt8,1}}}")

By going through our original Knet dataloader we concatenate our x and y data and put that into the train and valid dataset structure:

In [7]:
mnist_ds_train = dataset(cat((x for (x,y) in mnist_dl_orig.train)..., dims=4),
                         cat((y for (x,y) in mnist_dl_orig.train)..., dims=1));

mnist_ds_valid = dataset(cat((x for (x,y) in mnist_dl_orig.valid)..., dims=4),
                         cat((y for (x,y) in mnist_dl_orig.valid)..., dims=1));

In [8]:
summary.((mnist_ds_train.x, mnist_ds_train.y))

("28×28×1×60000 KnetArray{Float32,4}", "60000-element Array{UInt8,1}")

In [9]:
summary.((mnist_ds_valid.x, mnist_ds_valid.y))

("28×28×1×10000 KnetArray{Float32,4}", "10000-element Array{UInt8,1}")

Now, we can put the data into our custom dataloader:

In [10]:
mnist_dl = dataloader(mnist_ds_train, mnist_ds_valid);

Finally, we test our dataloader for completeness:

In [11]:
@testset "Check MNIST data" begin
    @testset "Training data" begin
        @test size(mnist_dl.train.x) == (28,28,1,60000)
        @test size(mnist_dl.train.y) == (60000,)
    end
    @testset "Validation data" begin
        @test size(mnist_dl.valid.x) == (28,28,1,10000)
        @test size(mnist_dl.valid.y) == (10000,)
    end
end;

[37m[1mTest Summary:    | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Check MNIST data | [32m   4  [39m[36m    4[39m


##  Data functions

### Mean & variance

Next, we create some simple function to get summary statistics on our data.

In [12]:
function get_stats(data)
    d_mean = mean(data)
    d_std = sqrt(var(data))
    return d_mean, d_std
end;

We save our train x statistics for later.

In [13]:
train_ds_mean, train_ds_std = get_stats(mnist_dl.train.x)

(0.13066033f0, 0.3081088f0)

### Normalize

Then, we define a normalization function for our data and test it.

In [14]:
function normalize(x, m, s)
    (x.-m)./s
end;

In [15]:
mnist_train_ds_norm = normalize(mnist_dl.train.x, train_ds_mean, train_ds_std);

In [16]:
get_stats(mnist_train_ds_norm)

(3.3241324f-7, 1.0000026f0)

In [17]:
@testset "Test dataset normalization" begin
    m, s = get_stats(mnist_train_ds_norm)
    @test m ≈ 0 atol=1e-3
    @test s ≈ 1 atol=1e-3
end;

[37m[1mTest Summary:              | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test dataset normalization | [32m   2  [39m[36m    2[39m


And we pack our tests into a function to check for zero mean and unit variance ("zmuv") to easily reuse it later for testing.

In [18]:
function test_zmuv(data, atol=1e-3)
    m, s = get_stats(data)
    @test m ≈ 0 atol=atol
    @test s ≈ 1 atol=atol
end;

In [19]:
test_zmuv(mnist_train_ds_norm)

[32m[1mTest Passed[22m[39m

# Model building blocks

## Knet helper functions

We borrow the helper functions for datatype handling and parameter creation from the [Knet library](https://github.com/denizyuret/Knet.jl/blob/master/src/train.jl#L100):

In [20]:
param(x::Union{Array,KnetArray}; atype=identity) = Param(atype(x))
param(d...; init=xavier_uniform, atype=atype()) = Param(atype(init(d...)))
param0(d...; atype=atype()) = param(d...; init=zeros, atype=atype)
atype() = (gpu() >= 0 ? KnetArray{Float32} : Array{Float32});

In [21]:
@testset "Test Knet helper functions" begin
    dims = rand(1:100,2)
    @test typeof(atype()(undef, dims...)) == (gpu() >= 0 ? KnetArray{Float32,2} : Array{Float32,2})
    @test typeof(param(dims...)) == Param{gpu() >= 0 ? KnetArray{Float32,2} : Array{Float32,2}}
    @test typeof(param0(dims[1])) == Param{gpu() >= 0 ? KnetArray{Float32,1} : Array{Float32,1}}
    @test typeof(param0(dims[2])) == Param{gpu() >= 0 ? KnetArray{Float32,1} : Array{Float32,1}}
end;

[37m[1mTest Summary:              | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test Knet helper functions | [32m   4  [39m[36m    4[39m


## Flatten function

Because we look into a fully connected network first, we need a function that changes our input data from (28, 28, 1, x) to (784, x).

In [22]:
flatten(x) = (reshape(x, (784,:)));

In [23]:
@testset "Test flatten function" begin
    for x in rand(1:10,3) # test 3x
        @test size(flatten(randn(28,28,1,x))) == (784,x) 
    end
end;

[37m[1mTest Summary:         | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test flatten function | [32m   3  [39m[36m    3[39m


## Linear

We create our custom linear layer structure as show in the [linear model Knet tutorial](https://github.com/denizyuret/Knet.jl/blob/master/tutorial/30.lin.ipynb):

In [24]:
struct Linear_c; w; b; end

and its creation function with the possibility to  hand over custom initialization functions:

In [25]:
#Linear_c(i::Int,o::Int) = Linear_c(param((o,i)), param0(o));

In [26]:
Linear_c(i::Int,o::Int;init=xavier_uniform) = Linear_c(param(o,i;init=init), param0(o));

and then we define the forward function of our linear layer:

In [27]:
(m::Linear_c)(x) = m.w * x .+ m.b

Then, we can check the methodes of Linear_c and which gets called with two integers.

In [28]:
methods(Linear_c)

In [29]:
@which Linear_c(16,32)

In [30]:
@which Linear_c(16,32;init=xavier_normal) # why does the output look like this?

In [31]:
@testset "Test linear layer" begin
    lin = Linear_c(4,2) # create simple linear layer
    lin.w.value = hcat(atype()(1.:1.:4.),atype()(5.:1.:8.))' # can this be done easier?
    lin.b.value = atype()(fill(1.,2))
    x = atype()(1.:1.:4.)
    @test lin(x) == [31,71]
end;

[37m[1mTest Summary:     | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test linear layer | [32m   1  [39m[36m    1[39m


## ReLU

The ReLU function can be implemented simply by:

In [32]:
relu_c(x) = max.(x,0);

In [33]:
@testset "Test ReLU function" begin
    @test relu_c(atype()([-1.,2.,-3.])) == [0,2,0]
end;

[37m[1mTest Summary:      | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test ReLU function | [32m   1  [39m[36m    1[39m


## Model setup

In order to setup our model we will define a chain of layers structure like shown in the [Knet CNN tutorial](https://github.com/denizyuret/Knet.jl/blob/master/tutorial/50.cnn.ipynb):

In [34]:
struct Chain
    layers
    Chain(layers...) = new(layers)
end

And define that it goes through each layer when it gets an input:

In [35]:
(c::Chain)(x) = (for l in c.layers; x = l(x); end; x)

We also implement a function to easily get a fresh model with a specific initialization:

In [156]:
function get_model(;init=xavier_normal)
    return model = Chain(flatten, Linear_c(784,50;init=init), relu_c, Linear_c(50,10;init=init), relu_c)
end;

In [37]:
@testset "Test correct model output shape" begin
    model_c = get_model()
    @test size(model_c(mnist_dl.train.x)) == (10,60000)
end;

[37m[1mTest Summary:                   | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Test correct model output shape | [32m   1  [39m[36m    1[39m


## Model initialization

This part is based on the [fastai v3 part 2 course notebook on initialization](https://github.com/fastai/course-v3/blob/master/nbs/dl2/02b_initializing.ipynb).

### Zero mean unit variance

If we use a normal distribuiton with zero mean and unit variance ("zmuv") we get infinity after a few matrix multiplications:

In [38]:
function check_mm_stats(w,x;n=100)
    for i in 1:n
        x = w * x # this resembles the matrix multiplication in our linear layer
        println(i, get_stats(x))
        if isnan(sum(x)) # a better way to check for NaN in Julia?
            break
        end
    end
end

check_mm_stats (generic function with 1 method)

In [39]:
check_mm_stats(randn(Float32, (512,512)), randn(Float32, (512)))

1(-0.8880974f0, 21.919664f0)
2(10.800011f0, 502.33023f0)
3(164.73192f0, 11663.487f0)
4(8193.326f0, 268020.53f0)
5(164598.72f0, 5.9489035f6)
6(-1.1952992f6, 1.2947917f8)
7(-1.6434246f7, 2.8907553f9)
8(6.949123f9, 6.7945808f10)
9(6.792001f10, 1.5690845f12)
10(5.606563f11, 3.3972894f13)
11(5.913527f13, 7.1514474f14)
12(-3.8653397f14, 1.5622399f16)
13(2.9918182f16, 3.65233f17)
14(1.931122f17, Inf32)
15(7.707326f18, Inf32)
16(-2.7277535f19, Inf32)
17(2.2705316f21, Inf32)
18(2.9920891f23, Inf32)
19(2.1653806f24, Inf32)
20(4.9999434f25, Inf32)
21(5.023884f26, Inf32)
22(-1.3973304f27, Inf32)
23(6.1725947f29, Inf32)
24(-4.100791f30, Inf32)
25(7.4782933f31, Inf32)
26(3.3923973f33, Inf32)
27(-6.8517203f34, Inf32)
28(NaN32, NaN32)


Note: We can also check for not a number ("NaN") with inequal:

In [None]:
sum(a) != sum(a)

### Xavier Glorot initialization

The Xavier Glorot initialization scales the weights by `1/sqrt(input_dim)`. Why we will see first with an example that is not using the rescaled weights:

In [41]:
check_mm_stats(randn(Float32, (512,512))/sqrt(512), randn(Float32, (512)))

1(-0.03128440527211984, 0.9785400099795301)
2(0.062482564351289384, 0.9549475212082945)
3(0.0481612188786319, 0.897197185466266)
4(-0.007637851584996487, 0.8614452382368822)
5(0.01940059081517396, 0.8781310017838251)
6(0.02907944766253427, 0.8601837807772011)
7(-0.040974408407286486, 0.907097594229379)
8(0.010406160384609326, 0.9092702193102259)
9(0.04064044307602728, 0.9253791546453949)
10(-0.012968853651001162, 0.9695259501085394)
11(0.008465413304944758, 1.0170732557773408)
12(0.01677841115613776, 1.0164794541532973)
13(0.0005815771015007525, 1.1024772594993877)
14(-0.029632750084465632, 1.133589763946688)
15(-0.003811471577445865, 1.1089892330208575)
16(-0.017838824738245485, 1.110239180125311)
17(0.011163569469248829, 1.1481655253111993)
18(0.02303792451940852, 1.184842686489866)
19(0.03407754855491532, 1.1810687393628223)
20(0.01451591977301915, 1.2007394076397606)
21(-0.008970411758448052, 1.2263296435290871)
22(-0.07677806514846948, 1.2510283767254162)
23(0.007374703220732181, 

This also works with the Xavier Glorot initialization functions that come with Knet:

In [42]:
check_mm_stats(xavier_normal(512,512), randn(Float32, (512)))

1(-0.025764756214730098, 1.010365641103227)
2(-0.016510672247334893, 1.042731301502906)
3(0.07053541933821068, 1.0187607752633279)
4(0.04026334702784537, 1.025618062803471)
5(0.035956734763293016, 1.063607671170404)
6(-0.10194041706101799, 1.00667690681816)
7(-0.04064300263478686, 0.9570560552958817)
8(-0.03268762703511828, 0.9851841596286874)
9(0.004810227598166334, 0.9530903967575174)
10(0.02014791169552281, 0.9365707797325832)
11(0.10223101924053843, 0.9336582276873219)
12(-0.0025316407980058298, 0.9038199316402679)
13(0.02479400765447818, 0.9087256432560729)
14(0.013034428801039251, 0.9122329308100718)
15(-0.014545509081038873, 0.9046908392713527)
16(0.031570572995641685, 0.8961197561106242)
17(-0.03964315245440956, 0.8995107916422198)
18(-0.015570348654177564, 0.8833648727429358)
19(0.00894669550701821, 0.8485627275007369)
20(-0.04180286678714634, 0.8093296806658057)
21(-0.013149888477411107, 0.7985905700114407)
22(0.02007755778078706, 0.8161308405103112)
23(0.0441146114730565, 0.

In [43]:
#check_mm_stats(xavier_uniform(512,512), randn(Float32, (512)))

### Kaiming He initialization

If the ReLU activation is used in a network it is better to use the Kaiming He initialization which accounts for that.

The implemented simple version below is based on the [PyTorch implementation](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_normal_) and the implementation of the Xavier Glorot initialization from a [Knet tutorial]():
```
xavier(o,i) = (s = sqrt(2/(i+o)); 2s .* rand(o,i) .- s)
```
(This version does not account for receptive field sizes, see [PyTorch implementation code](https://pytorch.org/docs/stable/_modules/torch/nn/init.html#xavier_normal_).)

In [44]:
he_normal(o::Int,i::Int) = (b = sqrt(2/(i+o)); randn(o,i)*b)

he_normal (generic function with 1 method)

In [45]:
check_mm_stats(he_normal(512,512), randn(Float32, (512)))

1(-0.0006492275970580106, 0.9983577168520512)
2(-0.06287906861338996, 1.0304645527581902)
3(0.02704631342697731, 1.034926733052963)
4(-0.07850524147808256, 1.0316022692658526)
5(0.03238308405560818, 1.0312478549789863)
6(-0.02594280233487835, 0.9828517398988978)
7(-0.05657203544481125, 0.9155525786378651)
8(0.06508963032276632, 0.940677072682043)
9(0.05575416547666757, 0.9691694048448378)
10(-0.03085558250163454, 0.9930964770119972)
11(-0.057211790192234206, 0.950263062597301)
12(-0.0023962580708100723, 0.8962698644531173)
13(0.061895259869711845, 0.8639441363972277)
14(0.03314479323749496, 0.8137855376392461)
15(-0.04404186704225108, 0.817127236608736)
16(-0.09783040161725179, 0.8128496295763988)
17(-0.018847339186553062, 0.7699102999619977)
18(0.017600189639281182, 0.7998207975395049)
19(0.0805317180472713, 0.7867505364884475)
20(0.008261247006206956, 0.7898148361088396)
21(0.012312285760713174, 0.7726014702792978)
22(-0.029635533840860626, 0.7915063637661032)
23(-0.02150964555015184

We can also test that with a ReLU after each matrix multiplication:

In [46]:
function check_mm_stats_relu(w,x;n=100)
    for i in 1:n
        x = w * x # this resembles the matrix multiplication in our linear layer
        x = relu_c(x)
        println(i, get_stats(x))
        if isnan(sum(x)) # a better way to check for NaN in Julia?
            break
        end
    end
end

check_mm_stats_relu (generic function with 1 method)

In [47]:
check_mm_stats_relu(he_normal(512,512), randn(Float32, (512)))

1(0.3416252017081806, 0.540076598310794)
2(0.2804155352066413, 0.3854349324191724)
3(0.18465434935092823, 0.2582271690528979)
4(0.13957001597581256, 0.20014044915148343)
5(0.09886264971727451, 0.1442685959099948)
6(0.0731251549671309, 0.10612856924695935)
7(0.05056022076114854, 0.07410406296732527)
8(0.03537099215925616, 0.05243103515924084)
9(0.02703086724581158, 0.03772932167412195)
10(0.019042557243630027, 0.02739192197016558)
11(0.013784664278644955, 0.018946384160686877)
12(0.009871147524037971, 0.013357376626618456)
13(0.006791658064556748, 0.009754233082952103)
14(0.004903933999820292, 0.006982795756707133)
15(0.0034400848890012493, 0.004695231844630468)
16(0.0022877896976704324, 0.003240014825219356)
17(0.0016013121013383578, 0.0023100646098956597)
18(0.0011342193281117332, 0.0015548544143519434)
19(0.0007855623868600321, 0.0010864821387561155)
20(0.0005291421908566735, 0.0007492127440402382)
21(0.00037799831265795876, 0.0005191928191632235)
22(0.0002501942358213617, 0.00035795

# Loss functions

## Loss helper functions

### Log-sum-exp (LSE) trick

Implementation of the log-sum-exp (LSE) trick for using it in the implementation of a numericale stable NLL loss (see next section).

In [48]:
function lse(x; dims=1)
    x_max = maximum(x, dims=dims)
    x_max .+ log.(sum(exp.(x .- x_max), dims=dims))
end;

In [49]:
@testset "Check LSE trick" begin
    y = atype()(rand(10,60000))
    @test lse(y) ≈ log.(sum(exp.(y), dims=1))
    end;

[37m[1mTest Summary:   | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Check LSE trick | [32m   1  [39m[36m    1[39m


### Log-softmax

In the next step, we can incorporate the LSE trick in our log-softmax implementation, that we will then use in the NLL loss.

In [50]:
function logsoftmax_c(x; dims=1)
    x .- lse(x; dims=dims)
end;

For completeness, find below the naive implementation without the LSE trick:

In [51]:
function logsoftmax_c_nolse(x; dims=1)
    x .- log.(sum(exp.(x), dims=dims))
end;

In [52]:
@testset "Check log-softmax" begin
    y = atype()(rand(10,60000))
    @test logsoftmax_c(y) ≈ (y .- log.(sum(exp.(y), dims=1)))
    @test logsoftmax_c(y) ≈ logsoftmax_c_nolse(y)
end;

[37m[1mTest Summary:     | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Check log-softmax | [32m   2  [39m[36m    2[39m


## Negative log-likelihood (NLL)

With the building blocks from above we can put together the NLL loss:

In [53]:
function nll_c(y, preds)
    y_int = Int.(y) # because data comes in UInt8
    scores = logsoftmax_c(preds)
    mean(-scores[y_int[i],i] for i in 1:length(y_int))
end;

In order to test the NLL implementation, we compare the NLL loss of a random guess for our data to the NLL loss that we get if we predict the same probability fo every classe, i.e., 1/(number of classes), which should be similar because we use a mean aggregation in our NLL loss implementation:

In [54]:
-log(1/10.)

2.3025850929940455

In [55]:
@testset "Check NLL loss" begin
    @test nll_c(mnist_dl.train.y, atype()(rand(10,60000))) ≈ -log(1/10.) atol=0.05
end;

[37m[1mTest Summary:  | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Check NLL loss | [32m   1  [39m[36m    1[39m


# Metrics

## Accuracy

Next, we implement a acuracy metric to further monitor our training progress:

In [56]:
function acc_c(y, preds)
    y_int = Int.(y) # because data comes in UInt8
    y_int = reshape(y_int, (1,length(y_int))) # this seems to be necessary?
    mean(y_int .== map(i->i[1], argmax(preds, dims=1)))
end;

We can test our accuracy metric by comparing again to a random guess and by specific test data.

In [59]:
@testset "Check accuracy metric" begin
    @test acc_c(mnist_dl.train.y, atype()(randn(10,60000))) ≈ 1/10. atol=0.01
    @test acc_c([1, 2, 1],[1 0 1;0 1 0]) == 1
    @test acc_c([1 2 1],[1 2 1;0 1 2]) == 1/3
    @test acc_c([1 2 1],[1 0 1;0 1 2]) == 2/3
    @test acc_c([1 2 1],[1 0 1;0 1 0]) == 1
end;

[37m[1mTest Summary:         | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Check accuracy metric | [32m   5  [39m[36m    5[39m


# Training loop

## Minibatch dataloader

Here we use the [minibatch function from Knet](https://denizyuret.github.io/Knet.jl/latest/softmax/#MNIST-example-1) to setup our dataloader with a specific batch size (bs):

In [60]:
x_train,y_train,x_valid,y_valid = mnist()
bs = 64
#mnist_dl_orig_train = minibatch(x_train, y_train, bs, shuffle=true, xtype=atype());
#mnist_dl_orig_valid = minibatch(x_valid, y_valid, bs, shuffle=true, xtype=atype());
mnist_dl_orig = dataloader(minibatch(x_train, y_train, bs, shuffle=true,  xtype=atype()),
                           minibatch(x_valid, y_valid, bs, shuffle=false, xtype=atype()));

In [61]:
summary.((mnist_dl_orig.train,mnist_dl_orig.valid))

("937-element Knet.Data{Tuple{KnetArray{Float32,4},Array{UInt8,1}}}", "156-element Knet.Data{Tuple{KnetArray{Float32,4},Array{UInt8,1}}}")

***Code for minibatch dataloader from scratch comes here!***

## Optimization loop

### Simple stochastic gradient descent (SGD)

First, we implement a basic SGD loop for the training data. In order to update our weights we define a learning rate:

In [126]:
lr = 0.5;

In [136]:
function basic_sgd_loop(epochs=1;model=model_c, data=mnist_dl_orig, lr=lr)
    for e in 1:epochs
        for (n, (x, y)) in enumerate(data.train)
            J = @diff nll_c(y, model(x))
            n % 100 == 0 && println("epoch: ",e,", batch: ",n,", loss: ",value(J))
            for p in params(J)
                p .-= lr .* grad(J, p)
            end
        end
    end
end

basic_sgd_loop (generic function with 2 methods)

With our `get_model` function we get a new model which we can use for subsequent training:

In [159]:
model_c = get_model(;init=he_normal)

Chain((flatten, Linear_c(P(KnetArray{Float32,2}(50,784)), P(KnetArray{Float32,1}(50))), relu_c, Linear_c(P(KnetArray{Float32,2}(10,50)), P(KnetArray{Float32,1}(10))), relu_c))

In [160]:
basic_sgd_loop()

epoch: 1, batch: 100, loss: 0.7908371
epoch: 1, batch: 200, loss: 0.5546519
epoch: 1, batch: 300, loss: 0.08012933
epoch: 1, batch: 400, loss: 0.27340904
epoch: 1, batch: 500, loss: 0.22885036
epoch: 1, batch: 600, loss: 0.11693973
epoch: 1, batch: 700, loss: 0.08952023
epoch: 1, batch: 800, loss: 0.22877154
epoch: 1, batch: 900, loss: 0.26514804


Our results look reasonable as the loss is decreasing during training!

### SGD with validation

Next, we will add our validation data and calculate its loss and accuaracy too, to better monitor our training progress:

In [234]:
function sgd_loop(epochs=1;model=model_c, data=mnist_dl_orig, lr=lr)
    for e in 1:epochs
        
        # train loop
        n_samples_train = 0
        n_loss_train = 0
        for (n, (x, y)) in enumerate(data.train)
            J = @diff nll_c(y, model(x))
            n_samples_train += length(y)
            n_loss_train += value(J) * length(y)
            for p in params(J)
                p .-= lr .* grad(J, p)
            end
        end

        # valid loop
        n_samples_valid = 0
        n_loss_valid = 0
        n_acc_valid = 0
        for (n, (x, y)) in enumerate(data.valid)
            n_samples_valid += length(y)
            n_loss_valid += nll_c(y, model(x)) * length(y)
            n_acc_valid += acc_c(y, model(x)) * length(y)
        end
        println("epoch: ",e,
                ", train loss: ",round(n_loss_train/n_samples_train, digits=3),
                ", valid loss: ",round(n_loss_valid/n_samples_valid, digits=3),
                ", valid acc: ",round(n_acc_valid/n_samples_valid, digits=3))
    end
end;

In [161]:
model_c = get_model(;init=he_normal)

Chain((flatten, Linear_c(P(KnetArray{Float32,2}(50,784)), P(KnetArray{Float32,1}(50))), relu_c, Linear_c(P(KnetArray{Float32,2}(10,50)), P(KnetArray{Float32,1}(10))), relu_c))

In [162]:
sgd_loop(5)

epoch: 1, train loss: 0.311, valid loss: 0.15, valid acc: 0.952
epoch: 2, train loss: 0.136, valid loss: 0.142, valid acc: 0.957
epoch: 3, train loss: 0.107, valid loss: 0.119, valid acc: 0.964
epoch: 4, train loss: 0.088, valid loss: 0.106, valid acc: 0.971
epoch: 5, train loss: 0.075, valid loss: 0.107, valid acc: 0.969


Our training and validation loss is goind down and accuracy is going up, therefore, our training looks good!

### Other optimizers

In [229]:
function add_opt(model, opt)
    for p in params(model)
        p.opt = opt
    end
    return model
end;

In [249]:
function get_model_opt(;init=xavier_normal, opt=SGD(lr=0.5))
    model = Chain(flatten, Linear_c(784,50;init=init), relu_c, Linear_c(50,10;init=init), relu_c)
    add_opt(model, opt) # add optimizer to model layers
    return model
end;

In [303]:
function fit(epochs=1;model=model_c_opt, data=mnist_dl_orig, lr=lr)
    for e in 1:epochs
        
        # train loop
        n_samples_train = 0
        n_loss_train = 0
        for (n, (x, y)) in enumerate(data.train)
            J = @diff nll_c(y, model(x))
            n_samples_train += length(y)
            n_loss_train += value(J) * length(y)
            #update!(model, J)
            for p in params(model)
                p .-= lr .* grad(J, p)
            end
        end

        # valid loop
        n_samples_valid = 0
        n_loss_valid = 0
        n_acc_valid = 0
        for (n, (x, y)) in enumerate(data.valid)
            n_samples_valid += length(y)
            n_loss_valid += nll_c(y, model(x)) * length(y)
            n_acc_valid += acc_c(y, model(x)) * length(y)
        end
        println("epoch: ",e,
                ", train loss: ",round(n_loss_train/n_samples_train, digits=3),
                ", valid loss: ",round(n_loss_valid/n_samples_valid, digits=3),
                ", valid acc: ",round(n_acc_valid/n_samples_valid, digits=3))
    end
end;

First, we check our new model and optimizer setup with our new fit training loop:

In [304]:
model_c_opt = get_model_opt()

Chain((flatten, Linear_c(P(KnetArray{Float32,2}(50,784)), P(KnetArray{Float32,1}(50))), relu_c, Linear_c(P(KnetArray{Float32,2}(10,50)), P(KnetArray{Float32,1}(10))), relu_c))

In [305]:
@test typeof(model_c_opt.layers[2].w.opt) == typeof(SGD())

[32m[1mTest Passed[22m[39m

In [306]:
fit(5)

epoch: 1, train loss: 0.495, valid loss: 0.199, valid acc: 0.937
epoch: 2, train loss: 0.157, valid loss: 0.146, valid acc: 0.955
epoch: 3, train loss: 0.117, valid loss: 0.124, valid acc: 0.96
epoch: 4, train loss: 0.095, valid loss: 0.122, valid acc: 0.962
epoch: 5, train loss: 0.082, valid loss: 0.124, valid acc: 0.962


Now we can try the same but with the Adam optimizer:

In [307]:
model_c_opt = get_model_opt(;opt=Adam())

Chain((flatten, Linear_c(P(KnetArray{Float32,2}(50,784)), P(KnetArray{Float32,1}(50))), relu_c, Linear_c(P(KnetArray{Float32,2}(10,50)), P(KnetArray{Float32,1}(10))), relu_c))

In [308]:
@test typeof(model_c_opt.layers[2].w.opt) == typeof(Adam())

[32m[1mTest Passed[22m[39m

In [309]:
fit(5)

epoch: 1, train loss: 0.41, valid loss: 0.199, valid acc: 0.937
epoch: 2, train loss: 0.159, valid loss: 0.147, valid acc: 0.955
epoch: 3, train loss: 0.122, valid loss: 0.131, valid acc: 0.961
epoch: 4, train loss: 0.101, valid loss: 0.122, valid acc: 0.964
epoch: 5, train loss: 0.087, valid loss: 0.116, valid acc: 0.967


This looks also fine!

#### With data normalization

However, data normalization seems to be not beneficial in this case:

In [310]:
function fit(epochs=1;model=model_c_opt, data=mnist_dl_orig, lr=lr)
    for e in 1:epochs
        
        # train loop
        n_samples_train = 0
        n_loss_train = 0
        for (n, (x, y)) in enumerate(data.train)
            J = @diff nll_c(y, model(normalize(x, train_ds_mean, train_ds_std)))
            n_samples_train += length(y)
            n_loss_train += value(J) * length(y)
            #update!(model, J)
            for p in params(model)
                p .-= lr .* grad(J, p)
            end
        end

        # valid loop
        n_samples_valid = 0
        n_loss_valid = 0
        n_acc_valid = 0
        for (n, (x, y)) in enumerate(data.valid)
            n_samples_valid += length(y)
            n_loss_valid += nll_c(y, model(x)) * length(y)
            n_acc_valid += acc_c(y, model(x)) * length(y)
        end
        println("epoch: ",e,
                ", train loss: ",round(n_loss_train/n_samples_train, digits=3),
                ", valid loss: ",round(n_loss_valid/n_samples_valid, digits=3),
                ", valid acc: ",round(n_acc_valid/n_samples_valid, digits=3))
    end
end;

In [311]:
model_c_opt = get_model_opt(;opt=Adam())

Chain((flatten, Linear_c(P(KnetArray{Float32,2}(50,784)), P(KnetArray{Float32,1}(50))), relu_c, Linear_c(P(KnetArray{Float32,2}(10,50)), P(KnetArray{Float32,1}(10))), relu_c))

In [312]:
fit(5)

epoch: 1, train loss: 0.841, valid loss: 1.0, valid acc: 0.768
epoch: 2, train loss: 0.369, valid loss: 0.733, valid acc: 0.869
epoch: 3, train loss: 0.245, valid loss: 1.306, valid acc: 0.48
epoch: 4, train loss: 0.202, valid loss: 1.365, valid acc: 0.489
epoch: 5, train loss: 0.175, valid loss: 1.243, valid acc: 0.532
