Inspired by https://joelgrus.com/2016/05/23/fizz-buzz-in-tensorflow/

In [34]:
import torch
import numpy as np
from torch import nn

In [35]:
BITS = 12

In [36]:
label_text = [
    'fizz', 
    'buzz', 
    'fizzbuzz',
    ''
]

In [67]:
def to_binary(i, num_bits):
    digits = np.array([i >> d & 1 for d in range(num_bits)])
    # Reverses the array to have the bits in the usual order.
    return digits[::-1].copy()

In [68]:
to_binary(4, 12)[::-1]

array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [69]:
def get_label(i):
    if   i % 15 == 0: return 2
    elif i % 5  == 0: return 1
    elif i % 3  == 0: return 0
    else:             return 3

In [70]:
train_x = np.array([to_binary(i, BITS) for i in range(101, 2**BITS)])
train_y = np.array([get_label(i) for i in range(101, 2**BITS)])

In [71]:
def create_net():
    return nn.Sequential(
        nn.Linear(BITS, 100),
        nn.ReLU(),
        nn.Linear(100, len(label_text)),
        nn.LogSoftmax(dim=-1)
    )

In [82]:
loss = nn.NLLLoss()

In [83]:
def accuracy(model,  loss):
    test_x = []
    test_y = []
    for i in range(1, 100):
        expect = get_label(i)
        test_x.append(to_binary(i, BITS))
        test_y.append(expect)
    x = torch.from_numpy(np.array(test_x)).float()
    y = torch.from_numpy(np.array(test_y))
    pred_y = model(x)
    pred = torch.argmax(pred_y, dim=1)
    acc = (y == pred).sum().float() / x.shape[0]
    return acc.item()

In [91]:
model = create_net()

In [92]:
opt = torch.optim.SGD(
    model.parameters(), 
    lr=5e-3, 
    momentum=0.9, 
    nesterov=True)

Training Loop:

In [95]:
train_x_t = torch.from_numpy(train_x).float()
train_y_t = torch.from_numpy(train_y)

print_every=100
batch_size = 32
test_pred = None

for i in range(500):
    batches = int(train_x.shape[0] / batch_size + 1)
    printed = False
    for b in range(batches):
        start = b * batch_size
        end = (b + 1) * batch_size
        bt = train_x_t[start:end]
        by = train_y_t[start:end]
        y_pred = model(bt)
        loss_val = loss(y_pred, by)
        model.zero_grad()
        loss_val.backward()
        opt.step()
        if i % print_every == 1 and not printed:
            print(i, b, "Loss:", loss_val.item(), "test accuracy:", accuracy(model, loss))
            printed = True
print(i, b, "Loss:", loss_val.item(), "test accuracy:", accuracy(model, loss))

1 0 Loss: 0.01220577210187912 test accuracy: 0.9797979593276978
101 0 Loss: 0.0042299311608076096 test accuracy: 1.0
201 0 Loss: 0.0035613637883216143 test accuracy: 1.0
301 0 Loss: 0.003199198981747031 test accuracy: 1.0
401 0 Loss: 0.002829307224601507 test accuracy: 1.0
499 124 Loss: 0.0019508247496560216 test accuracy: 1.0


In [109]:
def predict(num):
    enc = to_binary(num, BITS)
    enc = torch.from_numpy(enc).float()
    pred = model(enc)
    pred = torch.exp(pred)
    index = torch.argmax(pred).item()
    return (num, label_text[index])

In [110]:
predict(75), predict(10), predict(3), predict(21), predict(7)

((75, 'fizzbuzz'), (10, 'buzz'), (3, 'fizz'), (21, 'fizz'), (7, ''))