In [1]:
using Pkg; Pkg.activate("."); Pkg.instantiate();

  Updating registry at `~/.julia/registries/General`
  Updating git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h

Inspired by "Fizz Buzz in Tensorflow" blog by Joel Grus
http://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/

In [2]:
using Flux: Chain, Dense, params, crossentropy, onehotbatch,
            ADAM, train!, softmax
using Test



Data preparation

In [3]:
function fizzbuzz(x::Int)
    is_divisible_by_three = x % 3 == 0
    is_divisible_by_five = x % 5 == 0

    if is_divisible_by_three & is_divisible_by_five
        return "fizzbuzz"
    elseif is_divisible_by_three
        return "fizz"
    elseif is_divisible_by_five
        return "buzz"
    else
        return "else"
    end
end

const LABELS = ["fizz", "buzz", "fizzbuzz", "else"];

@test fizzbuzz.([3, 5, 15, 98]) == LABELS

raw_x = 1:100;
raw_y = fizzbuzz.(raw_x);

Feature engineering

In [4]:
features(x) = float.([x % 3, x % 5, x % 15])
features(x::AbstractArray) = hcat(features.(x)...)

X = features(raw_x);
y = onehotbatch(raw_y, LABELS);

Model

In [5]:
m = Chain(Dense(3, 10), Dense(10, 4), softmax)
loss(x, y) = crossentropy(m(X), y)
opt = ADAM(params(m))

#43 (generic function with 1 method)

Helpers

In [6]:
deepbuzz(x) = (a = argmax(m(features(x))); a == 4 ? x : LABELS[a])

function monitor(e)
    print("epoch $(lpad(e, 4)): loss = $(round(loss(X,y).data; digits=4)) | ")
    @show deepbuzz.([3, 5, 15, 98])
end

monitor (generic function with 1 method)

Training

In [7]:
for e in 0:1000
    train!(loss, [(X, y)], opt)
    if e % 50 == 0; monitor(e) end
end

epoch    0: loss = 1.5041 | deepbuzz.([3, 5, 15, 98]) = Any["fizzbuzz", 5, "fizz", 98]
epoch   50: loss = 0.9277 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizz", 98]
epoch  100: loss = 0.7661 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "buzz", 98]
epoch  150: loss = 0.6563 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  200: loss = 0.5688 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  250: loss = 0.4946 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  300: loss = 0.4305 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", 5, "fizzbuzz", 98]
epoch  350: loss = 0.3748 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  400: loss = 0.3263 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  450: loss = 0.2842 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  500: loss = 0.2475 | deepbuzz.([3, 5, 15, 98]) = Any["fizz", "buzz", "fizzbuzz", 98]
epoch  550: loss = 0.2157 | dee