## Learn how to build a simple multi-layer-perceptron on the MNIST dataset

MNIST from: https://github.com/FluxML/model-zoo/blob/master/mnist/mlp.jl

Let's start by loading `Flux`, importing a few things from `Flux` explicitly, and bringing the `repeated` function into our scope.

In [None]:
using Flux, Flux.Data.MNIST
using Flux: onehotbatch, argmax, crossentropy, throttle
using Base.Iterators: repeated

We can now store all the MNIST images in `imgs` and take a peak into this vector to see what the data looks like

In [None]:
imgs = MNIST.images()
imgs[3]

Let's look at the type of an individual image.

In [None]:
typeof(imgs[3])

#### Reorganizing our array of images

We see this is a 2D array that stores `ColorTypes`. To work more easily with this data, let's convert all `ColorTypes` to floating point numbers.

In [None]:
fpt_imgs = float.(imgs)

Now we can see what `imgs[3]` looks like as an array of floats, rather than as an array of colors!

In [None]:
fpt_imgs[3]

**Let's stack the images to create one large 2D array, `X`, that stores the data for each image as a column.**

To do this, we can **first** use `reshape` to unravel each image, creating a 1D array (`Vector`) of floats from a 2D array (`Matrix`) of floats.

In [None]:
unraveled_fpt_imgs = reshape.(fpt_imgs, :);
typeof(unraveled_fpt_imgs)

(Note that `Vector` is an alias for a 1D `Array`.)

In [None]:
Vector

This makes `unraveled_fpt_imgs` a `Vector` of `Vector`s where `imgs[3]` is now

In [None]:
unraveled_fpt_imgs[3]

After using `reshape` to get a `Vector` of `Vector`s, we can use `hcat` to build a `Matrix`, `X`, from `unraveled_fpt_imgs` where the `Vector`s stored in `unraveled_fpt_imgs` will become the columns of `X`.

Note that we're using the "splat" command below, `...`, which allows you to pass all the elements of an object to a function, rather than just passing the object itself.

In [None]:
X = hcat(unraveled_fpt_imgs...)

#### How to go back to images from this 2D `Array`

So now each column in X is an image reshaped to a vector of floating points. Let's pick one column and see what the digit is.

Let's try to view the second image in the original array, `imgs`, by taking the second column of `X`

In [None]:
onefigure = X[:,2]

We'll `reshape` this array to a 2D, 28x28 array,

In [None]:
t1 = reshape(onefigure,28,28)

and finally use `colorview` from the `Images` package to view the handwritten digit.

In [None]:
using Images

In [None]:
colorview(Gray, t1)

*Our data is in working order!*

For our machine to learn the digit with which each image is associated, we'll need to train it using correct answers. Therefore we'll make use of the `labels` associated with these images from MNIST.

In [None]:
labels = MNIST.labels() # the true labels

One-hot-encode the labels with `onehotbatch`

In [None]:
Y = onehotbatch(labels, 0:9)

which gives a binary indicator vector for each figure

Build the network

In [None]:
m = Chain(
  Dense(28^2, 32, relu),
  Dense(32, 10),
  softmax)

Define the loss functions and accuracy

In [None]:
loss(x, y) = Flux.crossentropy(m(x), y)
accuracy(x, y) = mean(argmax(m(x)) .== argmax(y))

Use `X` to create our training data and then declare our evaluation function:

In [None]:
dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM(Flux.params(m))

So far, we have defined our training data and our evaluation functions.

Let's take a look at the function signature of Flux.train!

In [None]:
?Flux.train!

**Now we can train our model and look at the accuracy thereafter.**

In [None]:
Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))

accuracy(X, Y)

Now that we've trained our model, let's create test data, `tX`, 

In [None]:
tX = hcat(float.(reshape.(MNIST.images(:test), :))...)

and run our model on one of the images from `tX`

In [None]:
test_image = m(tX[:,1])

In [None]:
indmax(test_image) - 1

The largest element of `test_image` is the 8th element, so our model says that test_image is a "7".

Now we can look at the original image.

In [None]:
using Images
t1 = reshape(tX[:,1],28,28)
colorview(Gray, t1)

and there we have it!