In [1]:
using Pkg; Pkg.activate("/home/dhairyagandhi96/temp/model-zoo/script/.."); Pkg.status();

    Status `~/temp/model-zoo/Project.toml`
  [1520ce14]   AbstractTrees v0.2.1
  [fbb218c0] ↑ BSON v0.2.3 ⇒ v0.2.4
  [54eefc05]   Cascadia v0.4.0
  [8f4d0f93]   Conda v1.3.0
  [864edb3b] ↑ DataStructures v0.17.0 ⇒ v0.17.5
  [31c24e10] ↑ Distributions v0.21.3 ⇒ v0.21.5
  [587475ba]   Flux v0.9.0
  [708ec375]   Gumbo v0.5.1
  [b0807396]   Gym v1.1.3
  [cd3eb016] ↑ HTTP v0.8.6 ⇒ v0.8.7
  [6218d12a]   ImageMagick v0.7.5
  [916415d5]   Images v0.18.0
  [e5e0dc1b]   Juno v0.7.2
  [ca7b5df7]   MFCC v0.3.1
  [dbeba491] + Metalhead v0.4.0 #c4d1eba (https://github.com/FluxML/Metalhead.jl.git)
  [91a5bcdd] ↑ Plots v0.26.3 ⇒ v0.27.0
  [2913bbd2]   StatsBase v0.32.0
  [98b73d46]   Trebuchet v0.1.0
  [8149f6b0] ↑ WAV v1.0.2 ⇒ v1.0.3
  [e88e6eb3] + Zygote v0.4.1
  [10745b16]   Statistics 
  [4ec0a83e]   Unicode 


Inspired by "Fizz Buzz in Tensorflow" blog by Joel Grus
http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/

In [2]:
using Flux: Chain, Dense, params, crossentropy, onehotbatch,
            ADAM, train!, softmax
using Test

Data preparation

In [3]:
function fizzbuzz(x::Int)
    is_divisible_by_three = x % 3 == 0
    is_divisible_by_five = x % 5 == 0

    if is_divisible_by_three & is_divisible_by_five
        return "fizzbuzz"
    elseif is_divisible_by_three
        return "fizz"
    elseif is_divisible_by_five
        return "buzz"
    else
        return "else"
    end
end

const LABELS = ["fizz", "buzz", "fizzbuzz", "else"];

@test fizzbuzz.([3, 5, 15, 98]) == LABELS

raw_x = 1:100;
raw_y = fizzbuzz.(raw_x);

Feature engineering

In [4]:
features(x) = float.([x % 3, x % 5, x % 15])
features(x::AbstractArray) = hcat(features.(x)...)

X = features(raw_x);
y = onehotbatch(raw_y, LABELS);

Model

In [5]:
m = Chain(Dense(3, 10), Dense(10, 4), softmax)
loss(x, y) = crossentropy(m(x), y)
opt = ADAM()

Flux.Optimise.ADAM(0.001, (0.9, 0.999), IdDict{Any,Any}())

Helpers

In [6]:
deepbuzz(x) = (a = argmax(m(features(x))); a == 4 ? x : LABELS[a])

function monitor(e)
    print("epoch $(lpad(e, 4)): loss = $(round(loss(X,y).data; digits=4)) | ")
    @show deepbuzz.([3, 5, 15, 98])
end

monitor (generic function with 1 method)

Training

In [7]:
for e in 0:1000
    train!(loss, params(m), [(X, y)], opt)
    if e % 50 == 0; monitor(e) end
end

epoch    0: loss = 7.4925 | deepbuzz.([3, 5, 15, 98]) = ["fizz", "buzz", "fizzbuzz", "buzz"]
epoch   50: loss = 4.6865 | deepbuzz.([3, 5, 15, 98]) = ["fizz", "buzz", "fizzbuzz", "buzz"]
epoch  100: loss = 2.3992 | deepbuzz.([3, 5, 15, 98]) = ["fizz", "buzz", "fizzbuzz", "buzz"]
epoch  150: loss = 0.8405 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  200: loss = 0.6807 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  250: loss = 0.6487 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  300: loss = 0.618 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  350: loss = 0.5882 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  400: loss = 0.559 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  450: loss = 0.5301 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  500: loss = 0.5016 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  550: loss = 0.4733 |