# Flux JL

## Flux Model Zoo Examples

# 3. LeNet CNN

**FluxML contributors**

**Source:** https://github.com/FluxML/model-zoo/blob/master/vision/conv_mnist/conv_mnist.jl

In this notebook we will do a Julia version of the LeNet5 convolutional neural network.

In [20]:
using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
using Statistics, Random
using MLDatasets: MNIST
using ProgressMeter: @showprogress
using Logging: with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using CUDA
import BSON


┌ Info: Precompiling ProgressMeter [92933f4c-e287-5a05-a399-4b506db050ca]
└ @ Base loading.jl:1278


#### Utility Functions

In [17]:
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)
;

### LeNet5

Let's start with the model:

In [10]:
function LeNet5(; imgdims=(28,28,1), nclasses=10)
    outer_conv_kernel_size = (
        imgdims[1]÷4 - 3,  # formula based on stride 1 and padding 0
        imgdims[2]÷4 - 3,
        16
    )

    return Chain(
            Conv((5,5), imgdims[end]=>6, relu),   # Flux.Conv 2nd argument is the pair: n_in_channels => n_out_channels
            MaxPool((2, 2)),
            Conv((5,5), 6=>16, relu),
            MaxPool((2, 2)),
            flatten,
            Dense(prod(outer_conv_kernel_size), 120, relu),
            Dense(120, 84, relu),
            Dense(84, nclasses)
        )
end

LeNet5 (generic function with 1 method)

### DataLoader

We will use MNIST

In [28]:
function get_data(batchsize)
    xtrain, ytrain = MNIST.traindata(Float32)
    xtest, ytest = MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain = onehotbatch(ytrain, 0:9)
    ytest = onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=batchsize)
    
    return train_loader, test_loader
end

get_data (generic function with 1 method)

#### Loss Function

In [13]:
loss(ŷ, y) = logitcrossentropy(ŷ, y)

loss (generic function with 1 method)

In [14]:
function evaluate(dataloader, model, device)
    l_counter = 0f0
    accuracy_counter = 0
    n_counter = 0

    for (x,y) in dataloader
        x, y = x |> device, y |> device

        ŷ = model(x)

        l_counter += loss(ŷ, y) * size(x)[end]     
        accuracy_counter += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        n_counter += size(x)[end]
    end
    
    eval_loss = l_counter / n_counter |> round4
    eval_accuracy = 100 * (accuracy_counter / n_counter) |> round4

    return (loss=eval_loss, accuracy=eval_accuracy)
end

evaluate (generic function with 1 method)

#### Training Loop

In [29]:
function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()

    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## Data
    train_loader, test_loader = get_data(args.batchsize)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## Model
    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params" 
    
    ## Optimiser
    θ = Flux.params(model)

    optimiser = ADAM(args.η)
    if args.λ > 0 # add weight decay, ie. L2 reguralirization
        optimiser = Optimiser(optimiser, WeightDecay(args.λ))
    end
    
    ## TensorBoard Logger
    if args.tensorboard 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end

    ## Epoch logging
    function report(epoch)
        train = evaluate(train_loader, model, device)
        test = evaluate(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tensorboard
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss=train.loss  acc=train.accuracy
                @info "test"  loss=test.loss   acc=test.accuracy
            end
        end
    end

    ## Training Loop
    @info "Training started ..."
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x,y) in train_loader
            x, y = x |> device, y |> device
            ∂loss = Flux.gradient(θ) do
                        ŷ = model(x)
                        loss(ŷ, y)
                    end
            
            Flux.Optimise.update!(optimiser, θ, ∂loss)
        end

        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) #return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end
            

train (generic function with 1 method)

## Train LeNet5

#### Programme Parameters

Arguments for the train function

In [32]:
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 128      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = false      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tensorboard = false      # log training with tensorboard
    savepath = "runs/lenet5"    # results path
end

Args

In [33]:
train()

┌ Info: Training on CPU
└ @ Main In[29]:12
┌ Info: Dataset MNIST: 60000 train and 10000 test examples
└ @ Main In[29]:17
┌ Info: LeNet5 model: 44426 trainable params
└ @ Main In[29]:21
┌ Info: TensorBoard logging at "runs/lenet5"
└ @ Main In[29]:35
┌ Info: Training started ...
└ @ Main In[29]:53


Epoch: 0   Train: (loss = 2.3028f0, accuracy = 11.4017)   Test: (loss = 2.3019f0, accuracy = 11.51)


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:32[39m


Epoch: 1   Train: (loss = 0.1819f0, accuracy = 94.5933)   Test: (loss = 0.1659f0, accuracy = 95.13)


[32mProgress:  56%|███████████████████████▏                 |  ETA: 0:00:14[39m