# Training an image classifier

Let's put *FluxTraining.jl* to train a model on the MNIST dataset.

MNIST is simple enough that we can focus on the part where *FluxTraining.jl* comes in, the training. If you want to see examples of using FluxTraining.jl on larger datasets, see the documentation of [FastAI.jl](https://github.com/FluxML/FastAI.jl).

## Setup

*If you want to run this tutorial yourself, you can find the notebook file [here](https://github.com/lorenzoh/FluxTraining.jl/blob/master/docs/tutorials/mnist.ipynb)*.

To make data loading and batching a bit easier, we'll install an additional dependency:

```julia
using Pkg; Pkg.add(["MLUtils"])
```

Now we can import everything we'll need.

In [1]:
using MLUtils: splitobs, unsqueeze
using MLDatasets: MNIST
using Flux
using Flux: onehotbatch
using Flux.Data: DataLoader
using FluxTraining

## Overview

There are 4 pieces that you always need to construct and train a [`Learner`](#):

- a model
- data
- an optimizer; and
- a loss function

## Building a `Learner`

Let's look at the **data** first.

*FluxTraining.jl* is agnostic of the data source. The only requirements are:

- it is iterable and each iteration returns a tuple `(xs, ys)`
- the model can take in `xs`, i.e. `model(xs)` works; and
- the loss function can take model outputs and `ys`, i.e. `lossfn(model(xs), ys)` returns a scalar


Glossing over the details as it's not the focus of this tutorial, here's the code for getting a data iterator of the MNIST dataset. We use `DataLoaders.DataLoader` to create an iterator of batches from our dataset.

In [2]:
data = MNIST(:train)[:]

const LABELS = 0:9

# unsqueeze to reshape from (28, 28, numobs) to (28, 28, 1, numobs)
function preprocess((data, targets))
    return unsqueeze(data, 3), onehotbatch(targets, LABELS)
end


# traindata and testdata contain both inputs (pixel values) and targets (correct labels)
traindata = MNIST(Float32, :train)[:] |> preprocess
testdata = MNIST(Float32, :test)[:] |> preprocess

# create iterators
trainiter, testiter = DataLoader(traindata, batchsize=128), DataLoader(testdata, batchsize=256);

Next, let's create a simple *Flux.jl* **model** that we'll train to classify the MNIST digits.

In [3]:
model = Chain(
    Conv((3, 3), 1 => 16, relu, pad = 1, stride = 2),
    Conv((3, 3), 16 => 32, relu, pad = 1),
    GlobalMeanPool(),
    Flux.flatten,
    Dense(32, 10),
)

Chain(
  Conv((3, 3), 1 => 16, relu, pad=1, stride=2),  [90m# 160 parameters[39m
  Conv((3, 3), 16 => 32, relu, pad=1),  [90m# 4_640 parameters[39m
  GlobalMeanPool(),
  Flux.flatten,
  Dense(32 => 10),                      [90m# 330 parameters[39m
) [90m                  # Total: 6 arrays, [39m5_130 parameters, 20.867 KiB.

We'll use *categorical cross entropy* as a **loss function** and *ADAM* as an **optimizer**.

In [4]:
lossfn = Flux.Losses.logitcrossentropy
optimizer = Flux.ADAM();

Now we're ready to create a [`Learner`](#). At this point you can also add any callbacks, like [`ToGPU`](#) to run the training on your GPU if you have one available. Some callbacks are also [included by default](../callbacks/reference.md).

Since we're classifying digits, we also use the [`Metrics`](#) callback to track the accuracy of the model's predictions:

In [5]:
learner = Learner(model, lossfn; callbacks=[ToGPU(), Metrics(accuracy)], optimizer)

Learner()

## Training

With a `Learner` in place, training is as simple as calling [`fit!`](#)`(learner, nepochs, dataiters)`.

In [None]:
FluxTraining.fit!(learner, 10, (trainiter, testiter))

Epoch 1 TrainingPhase() ...
