# Fashion MNIST in Flux - CPU version
This notebook trains a simple MLP on F-MNIST without using the GPU.

In [1]:
using Flux, MLDatasets
using Flux: onehotbatch, argmax, crossentropy
using Flux.Optimise: runall
using Flux.Tracker: back!, value, data

### Load the test and training data

In [2]:
x = FashionMNIST.convert2features(FashionMNIST.traintensor());
println("Test x:", size(x))
y = onehotbatch(FashionMNIST.trainlabels(), 0:9)
println( "Test y:",size(y))

xt = FashionMNIST.convert2features(FashionMNIST.testtensor());
println("Validation x:",size(xt))
yt = onehotbatch(FashionMNIST.testlabels(), 0:9);
println("Validation y:", size(yt))

dataset = Iterators.repeated((x, y), 250);

Test x:(784, 60000)
Test y:(10, 60000)
Validation x:(784, 10000)
Validation y:(10, 10000)


### Define the MLP model, loss, accuracy and solver

In [3]:
model = Chain(
    Dense(784, 128, relu), 
    Dense(128, 32, relu),
    Dense(32,10),
    softmax
)

loss(x, y) = crossentropy(model(x), y)

accuracy(x, y) = mean(argmax(model(x)) .== argmax(y))

opt = ADAM(params(model));

### Training function 
Flux provides a `train!` function but its built-in support for callbacks is very limited so we define a custom one instead.  

In [4]:
pct(x) = string(floor(Int64, 100*x), "%")
function callback(i, n, l, atrain, atest)
    println(
        pct(i/n),
        ": loss=", @sprintf("%.3f", l),
        ", accuracy on training set=", pct(atrain),
        ", accuracy on test set=", pct(atest)
    )
end;

function my_train!(loss, data, opt; cb=true)
    n = length(data)
    for (i, d) in enumerate(data)
        l = loss(d...)
        isinf(value(l)) && error("Loss is Inf")
        isnan(value(l)) && error("Loss is NaN")
        back!(l)
        opt()
        if (cb)&&(i % 10 == 0)
            callback(i, n, value(l),accuracy(x, y), accuracy(xt, yt))
        end
    end
end;

### Running the training

In [5]:
@time my_train!(loss, dataset, opt)

4%: loss=1.484, accuracy on training set=48%, accuracy on test set=48%
8%: loss=0.975, accuracy on training set=73%, accuracy on test set=72%
12%: loss=0.706, accuracy on training set=77%, accuracy on test set=76%
16%: loss=0.597, accuracy on training set=80%, accuracy on test set=79%
20%: loss=0.537, accuracy on training set=82%, accuracy on test set=80%
24%: loss=0.499, accuracy on training set=83%, accuracy on test set=81%
28%: loss=0.470, accuracy on training set=84%, accuracy on test set=82%
32%: loss=0.448, accuracy on training set=84%, accuracy on test set=83%
36%: loss=0.433, accuracy on training set=85%, accuracy on test set=83%
40%: loss=0.416, accuracy on training set=85%, accuracy on test set=84%
44%: loss=0.402, accuracy on training set=86%, accuracy on test set=84%
48%: loss=0.391, accuracy on training set=86%, accuracy on test set=84%
52%: loss=0.380, accuracy on training set=86%, accuracy on test set=85%
56%: loss=0.373, accuracy on training set=87%, accuracy on test se