# F-MNIST Multilayer Perceptron
This notebook uses [Flux](https://github.com/FluxML/Flux.jl) to train a multilayer perceptron on the [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset.  It exports the trained model with [FluxJS](https://github.com/FluxML/FluxJS.jl) to [deeplearn.js](https://deeplearnjs.org/) so that it can be used in the browser.  

(Credit for this notebook goes to [WooKyoung Noh](https://github.com/wookay)).

In [1]:
using Flux
using Flux: onehotbatch, argmax, crossentropy, throttle, @epochs
using BSON
using Base.Iterators: repeated
using FluxJS
using MLDatasets # FashionMNIST
using ColorTypes: N0f8, Gray

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/alex/.julia/lib/v0.6/Graphics.ji for module Graphics.
[39m[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/alex/.julia/lib/v0.6/ImageCore.ji for module ImageCore.
Expr(:call, Expr(:., :Base, :include_from_node1)::Any, "/home/alex/.julia/v0.6/FFTW/src/FFTW.jl")::Any
  ** incremental compilation may be broken for this module **

[1m[36mINFO: [39m[22m[36mRecompiling stale cache file /home/alex/.julia/lib/v0.6/MLDatasets.ji for module MLDatasets.
[39m

In [2]:
const Img = Matrix{Gray{N0f8}}

function prepare_train()
    # load full training set
    train_x, train_y = FashionMNIST.traindata() # 60_000

    trainrange = 1:60_000 # 1:60_000
    imgs = Img.([train_x[:,:,i] for i in trainrange])
    # Stack images into one large batch
    X = hcat(float.(reshape.(imgs, :))...) |> gpu
    # One-hot-encode the labels
    Y = onehotbatch(train_y[trainrange], 0:9) |> gpu
    X, Y
end

function prepare_test()
    # Load full test set
    test_x,  test_y  = FashionMNIST.testdata() # 10_000

    testrange = 1:1_000 # 1:10_000
    test_imgs = Img.([test_x[:,:,i] for i in testrange])
    tX = hcat(float.(reshape.(test_imgs, :))...)  |> gpu
    tY = onehotbatch(test_y[testrange], 0:9) |> gpu
    
    # Save the first 100 images in a bson for use in the web demo
    bson("test_images.bson", Dict(
        :images => reshape(Float32.(tX[:,1:100]), 784*100),
        :labels =>Int32.(test_y[1:100])
    ))
    
    tX, tY
end

X, Y = prepare_train()
tX, tY = prepare_test()

m = Chain(
  Dense(28^2, 32, relu),
  Dense(32, 10),
  softmax) |> gpu

Chain(Dense(784, 32, NNlib.relu), Dense(32, 10), NNlib.softmax)

In [3]:
loss(x, y) = crossentropy(m(x), y)

accuracy(x, y) = mean(argmax(m(x)) .== argmax(y))

dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM(params(m))

@epochs 5 Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 2))

println("Training set accuracy: ", accuracy(X, Y))
# 0.983

println("Test set accuracy: ", accuracy(tX, tY))
# 0.83

[1m[36mINFO: [39m[22m[36mEpoch 1
[39m

loss(X, Y) = 2.319393114238521 (tracked)
loss(X, Y) = 1.8492554278209932 (tracked)
loss(X, Y) = 1.465513296237327 (tracked)
loss(X, Y) = 1.207173994425875 (tracked)
loss(X, Y) = 1.0362687231537662 (tracked)
loss(X, Y) = 0.9271482992555827 (tracked)
loss(X, Y) = 0.8630851948043434 (tracked)
loss(X, Y) = 0.8037644415911736 (tracked)
loss(X, Y) = 0.7649856856332238 (tracked)
loss(X, Y) = 0.726926655542281 (tracked)
loss(X, Y) = 0.6997025804171768 (tracked)
loss(X, Y) = 0.6767615869350542 (tracked)
loss(X, Y) = 0.6524676297710844 (tracked)
loss(X, Y) = 0.6311056607728174 (tracked)
loss(X, Y) = 0.6122592764697472 (tracked)
loss(X, Y) = 0.5944632160638561 (tracked)
loss(X, Y) = 0.5785307247531124 (tracked)
loss(X, Y) = 0.5640768693919396 (tracked)
loss(X, Y) = 0.5511865210894342 (tracked)
loss(X, Y) = 0.5396836403987842 (tracked)
loss(X, Y) = 0.5293738020058126 (tracked)
loss(X, Y) = 0.5201390581958197 (tracked)
loss(X, Y) = 0.5118085087319725 (tracked)
loss(X, Y) = 0.5042554802866919 (track

[1m[36mINFO: [39m[22m[36mEpoch 2
[39m

loss(X, Y) = 0.4502621911132155 (tracked)
loss(X, Y) = 0.44683130489333106 (tracked)
loss(X, Y) = 0.44407039422716904 (tracked)
loss(X, Y) = 0.4414153647409257 (tracked)
loss(X, Y) = 0.4393455991399356 (tracked)
loss(X, Y) = 0.4368265315915678 (tracked)
loss(X, Y) = 0.43437839415680224 (tracked)
loss(X, Y) = 0.43198811640709706 (tracked)
loss(X, Y) = 0.42963606603755744 (tracked)
loss(X, Y) = 0.4282075174652327 (tracked)
loss(X, Y) = 0.42635100317340247 (tracked)
loss(X, Y) = 0.4245033800905603 (tracked)
loss(X, Y) = 0.4226436438372134 (tracked)
loss(X, Y) = 0.42125164020957645 (tracked)
loss(X, Y) = 0.41892879029299424 (tracked)
loss(X, Y) = 0.41622673812845984 (tracked)
loss(X, Y) = 0.41351798157830805 (tracked)
loss(X, Y) = 0.41089071340946653 (tracked)
loss(X, Y) = 0.40834345573836023 (tracked)
loss(X, Y) = 0.40628398274759797 (tracked)
loss(X, Y) = 0.40427737120650986 (tracked)
loss(X, Y) = 0.40194703785209135 (tracked)
loss(X, Y) = 0.39969735045919647 (tracked)
loss(X, Y) = 0.397

[1m[36mINFO: [39m[22m[36mEpoch 3
[39m

loss(X, Y) = 0.3727750497773655 (tracked)
loss(X, Y) = 0.3713583263323641 (tracked)
loss(X, Y) = 0.36996805406537087 (tracked)
loss(X, Y) = 0.36857954553462574 (tracked)
loss(X, Y) = 0.3672518476646927 (tracked)
loss(X, Y) = 0.3659750890601754 (tracked)
loss(X, Y) = 0.36470835004733504 (tracked)
loss(X, Y) = 0.363491600024328 (tracked)
loss(X, Y) = 0.3622990541610327 (tracked)
loss(X, Y) = 0.36110250063831456 (tracked)
loss(X, Y) = 0.3599278223293612 (tracked)
loss(X, Y) = 0.358796996409914 (tracked)
loss(X, Y) = 0.3579405109512548 (tracked)
loss(X, Y) = 0.3570248084599389 (tracked)
loss(X, Y) = 0.35587035622918534 (tracked)
loss(X, Y) = 0.35477356633041135 (tracked)
loss(X, Y) = 0.35372548937565723 (tracked)
loss(X, Y) = 0.3527097460778781 (tracked)
loss(X, Y) = 0.3518722217844755 (tracked)
loss(X, Y) = 0.35090592313506913 (tracked)
loss(X, Y) = 0.349958644949679 (tracked)
loss(X, Y) = 0.34901102464518674 (tracked)
loss(X, Y) = 0.3480935227007142 (tracked)
loss(X, Y) = 0.3471947916060

[1m[36mINFO: [39m[22m[36mEpoch 4
[39m

loss(X, Y) = 0.3384036607450193 (tracked)
loss(X, Y) = 0.33773055184016276 (tracked)
loss(X, Y) = 0.3369674572532239 (tracked)
loss(X, Y) = 0.33616962710187576 (tracked)
loss(X, Y) = 0.3354822625457311 (tracked)
loss(X, Y) = 0.33491497915125196 (tracked)
loss(X, Y) = 0.3340783622805057 (tracked)
loss(X, Y) = 0.33349458449635616 (tracked)
loss(X, Y) = 0.332888920605551 (tracked)
loss(X, Y) = 0.33210106083479196 (tracked)
loss(X, Y) = 0.3314637970276914 (tracked)
loss(X, Y) = 0.3308454456606266 (tracked)
loss(X, Y) = 0.33017801754011433 (tracked)
loss(X, Y) = 0.32947939948126137 (tracked)
loss(X, Y) = 0.32893236599506404 (tracked)
loss(X, Y) = 0.3284084230618013 (tracked)
loss(X, Y) = 0.3278847312174922 (tracked)
loss(X, Y) = 0.3276210727111398 (tracked)
loss(X, Y) = 0.3266664618181417 (tracked)
loss(X, Y) = 0.3263384364060749 (tracked)
loss(X, Y) = 0.3256763706849859 (tracked)
loss(X, Y) = 0.3252434742695658 (tracked)
loss(X, Y) = 0.32469468294154574 (tracked)
loss(X, Y) = 0.32433859832

[1m[36mINFO: [39m[22m[36mEpoch 5
[39m

loss(X, Y) = 0.31772328532882127 (tracked)
loss(X, Y) = 0.31730095811094405 (tracked)
loss(X, Y) = 0.3168698304375715 (tracked)
loss(X, Y) = 0.3163401783947382 (tracked)
loss(X, Y) = 0.31585198177809104 (tracked)
loss(X, Y) = 0.31548748862140613 (tracked)
loss(X, Y) = 0.31509024904324384 (tracked)
loss(X, Y) = 0.31429506922194583 (tracked)
loss(X, Y) = 0.313906818257718 (tracked)
loss(X, Y) = 0.3134208054567013 (tracked)
loss(X, Y) = 0.3128650150930208 (tracked)
loss(X, Y) = 0.3124280320710254 (tracked)
loss(X, Y) = 0.3120869980329065 (tracked)
loss(X, Y) = 0.31169809354876954 (tracked)
loss(X, Y) = 0.3109621643727837 (tracked)
loss(X, Y) = 0.3105407329774212 (tracked)
loss(X, Y) = 0.3100911456022601 (tracked)
loss(X, Y) = 0.30956707077353035 (tracked)
loss(X, Y) = 0.30902275323928224 (tracked)
loss(X, Y) = 0.30854440095529606 (tracked)
loss(X, Y) = 0.30813938077501624 (tracked)
loss(X, Y) = 0.30770502875097894 (tracked)
loss(X, Y) = 0.30793818806563245 (tracked)
loss(X, Y) = 0.3068514

In [3]:
# See the deeplearn.js representation of the model
@code_js m(X[:,1])

let model = (function () {
  let math = tf;
  function alligator(coati) {
    return math.add(math.matrixTimesVector(model.weights[0], coati), model.weights[1]);
  };
  function cobra(eland) {
    return math.relu(math.add(math.matrixTimesVector(model.weights[2], eland), model.weights[3]));
  };
  function model(jellyfish) {
    return math.softmax(alligator(cobra(jellyfish)));
  };
  model.weights = [];
  return model;
})();
flux.fetchWeights("model.bson").then((function (ws) {
  return model.weights = ws;
}));


In [5]:
# Write the model javascript and the model weights to files
FluxJS.compile("mlp", m, X[:,1])