In [28]:
using Pkg  # src
Pkg.activate("../FluxTraining/docs")  # src

[32m[1m Activating[22m[39m environment at `~/.julia/dev/FluxTraining/docs/Project.toml`


Let's put *FluxTraining.jl* to train a model on the MNIST dataset.

MNIST is simple enough that we can focus on the part where *FluxTraining.jl* comes in, the training.

## Setup

*if you want to run this tutorial yourself, you can find the notebook file [here](https://github.com/lorenzoh/FluxTraining.jl/blob/master/docs/tutorials/mnist.ipynb)*.

To make data loading and batching a bit easier, we'll install some additional dependencies:

In [11]:
using Pkg
Pkg.add(url="https://github.com/lorenzoh/DataLoaders.jl")
Pkg.add("MLDataPattern")

[?25l[2K

[32m[1m   Updating[22m[39m git-repo `https://github.com/lorenzoh/DataLoaders.jl`


[?25h

[32m[1m  Resolving[22m[39m package versions...
[32m[1mNo Changes[22m[39m to `~/.julia/dev/FluxTraining/docs/Project.toml`
[32m[1mNo Changes[22m[39m to `~/.julia/dev/FluxTraining/docs/Manifest.toml`
[32m[1m  Resolving[22m[39m package versions...
[32m[1mNo Changes[22m[39m to `~/.julia/dev/FluxTraining/docs/Project.toml`
[32m[1mNo Changes[22m[39m to `~/.julia/dev/FluxTraining/docs/Manifest.toml`


Now we can import everything we'll need.

In [6]:
using DataLoaders: DataLoader
using MLDataPattern: splitobs
using Flux
using FluxTraining

## Overview

There are 4 pieces that you always need to construct and train a [`Learner`](#):

- a model
- data
- an optimizer; and
- a loss function

## Building a `Learner`

Let's look at the **data** first.

*FluxTraining.jl* is agnostic of the data source. The only requirements are:

- it is iterable and each iteration returns a tuple `(xs, ys)`
- the model can take in `xs`, i.e. `model(xs)` works; and
- the loss function can take model outputs and `ys`, i.e. `lossfn(model(xs), ys)` returns a scalar


Glossing over the details as it's not the focus of this tutorial, here's the code for getting a data iterator of the MNIST dataset. We use `DataLoaders.DataLoader` to create an iterator of batches from our dataset.

In [15]:
xs, ys = (
    # convert each image into h*w*1 array of floats 
    [Float32.(reshape(img, 28, 28, 1)) for img in Flux.Data.MNIST.images()],
    # one-hot encode the labels
    [Flux.onehot(y, 0:9) for y in Flux.Data.MNIST.labels()],
)

# split into training and validation sets
traindata, valdata = splitobs((xs, ys))

# create iterators
trainiter, valiter = DataLoader(traindata, 128), DataLoader(valdata, 256);

Next, let's create a simple *Flux.jl* **model** that we'll train to classify the MNIST digits.

In [16]:
model = Chain(
    Conv((3, 3), 1 => 16, relu, pad = 1, stride = 2),
    Conv((3, 3), 16 => 32, relu, pad = 1),
    GlobalMeanPool(),
    Flux.flatten,
    Dense(32, 10),
)

Chain(Conv((3, 3), 1=>16, relu), Conv((3, 3), 16=>32, relu), GlobalMeanPool(), flatten, Dense(32, 10))

We'll use *categorical cross entropy* as a **loss function** and *ADAM* as an **optimizer**.

In [20]:
lossfn = Flux.Losses.logitcrossentropy
optim = Flux.ADAM();

Now we're ready to create a [`Learner`](#). At this point you can also add any callbacks, like [`ToGPU`](#) to run the training on your GPU if you have one available. Some callbacks are also [included by default](../callbacks/reference.md).

Since we're classifying digits, we also use the [`Metrics`](#) callback to track the accuracy of the model's predictions:

In [25]:
learner = Learner(model, (trainiter, valiter), optim, lossfn, ToGPU(), Metrics(accuracy))

Learner()

## Training

With a `Learner` inplace, training is as simple as calling [`fit!`](#)`(learner, nepochs)`.

In [26]:
FluxTraining.fit!(learner, 10)

[32mEpoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:01[39m


Loss: 2.0401972503285277
Accuracy: 0.2504511778115501


[32mEpoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 1.7366820892817538
Accuracy: 0.37671654929577464


[32mEpoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 1.6143117624937944
Accuracy: 0.4416555851063829


[32mEpoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 1.4844125492472044
Accuracy: 0.4939480633802816


[32mEpoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 1.4098337729288815
Accuracy: 0.5474924012158052


[32mEpoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 1.3011583778220164
Accuracy: 0.5921544894366199


[32mEpoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 1.2435689039505724
Accuracy: 0.6170687689969607


[32mEpoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 1.1359686717181139
Accuracy: 0.6670664612676056


[32mEpoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 1.109516095970177
Accuracy: 0.6673632218844984


[32mEpoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 1.0406720579510005
Accuracy: 0.6853983274647888


[32mEpoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 1.0109151308297386
Accuracy: 0.6997292933130699


[32mEpoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 0.9342230894196202
Accuracy: 0.7232394366197183


[32mEpoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 0.9322713883452141
Accuracy: 0.719153685410334


[32mEpoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 0.869062495063728
Accuracy: 0.7425286091549296


[32mEpoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 0.8683096063535627
Accuracy: 0.739575417933131


[32mEpoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


Loss: 0.812866168122896
Accuracy: 0.7532790492957749


[32mEpoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:00[39m


Loss: 0.8111663188977807
Accuracy: 0.7559365501519754


[32mEpoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:00[39m


Loss: 0.7621234353159515
Accuracy: 0.7681007922535211


[32mEpoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:00[39m


Loss: 0.7691689948905203
Accuracy: 0.7685457826747714


[32mEpoch 11 ValidationPhase(): 100%|███████████████████████| Time: 0:00:00[39m


Loss: 0.724167847717312
Accuracy: 0.7846170774647885


Learner()