In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# FizzBuzz

In [2]:
# 一般实现

In [3]:
# 编码
def fizz_buzz_encode(i):
    if i % 15 == 0:
        return 3
    elif i % 5 == 0:
        return 2
    elif i % 3 == 0:
        return 1
    else:
        return 0

# 解码
def fizz_buzz_decode(i, prediction):
    return [str(i), 'fizz', 'buzz', 'fizzbuzz'][prediction]

# 帮助
def helper(i):
    print(fizz_buzz_decode(i=i, prediction=fizz_buzz_encode(i)))

for i in range(1, 16):
    helper(i)

1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz


## 神经网络学会玩FizzBuzz

In [4]:
# 我们首先定义模型的输入与输出（训练数据），这里把每一个数字转为二进制数据，模型容易学习

In [5]:
# 二进制数字，即将训练数据二进制化

In [6]:
NUM_DIGITS = 10  # 10个位数，所以最多只有1024个数字可以被表示出

def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)][::-1])

trX = torch.Tensor([binary_encode(i=i, num_digits=NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i=i) for i in range(101, 2 ** NUM_DIGITS)])

In [7]:
# 然后我们用PyTorch定义两层神经网络

In [8]:
NUM_HIDDEN = 100
model = nn.Sequential(
    nn.Linear(in_features=NUM_DIGITS, out_features=NUM_HIDDEN),
    nn.ReLU(),
    nn.Linear(in_features=NUM_HIDDEN, out_features=4)  # 4 logits, after softmax, we will get a probability distribution
)
if torch.cuda.is_available():
    model = model.cuda()

In [9]:
# 为了让我们的模型学会FizzBuzz这个游戏，我们需要定义一个损失函数和一个优化算法。这个优化算法会不断优化（降低）损失函数，使得模型在该任务上取得尽可能低的损失值。由于FizzBuzz这个游戏本质上是一个分类问题，我们选用Cross Entropy Loss函数。优化函数我们选择SGD。

In [11]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=0.01)

In [12]:
# 模型训练

In [13]:
BATCH_SIZE = 128
for epoch in range(1000):
    for start in range(0, len(trX), BATCH_SIZE):  # 每BATCH_SIZE个数据截取一次数据
        end = start + BATCH_SIZE
        batchX = trX[start: end]
        batchY = trY[start: end]

        if torch.cuda.is_available():
            batchX = batchX.cuda()
            batchY = batchY.cuda()
        
        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)

        print(f'Epoch：{epoch}; Loss：{loss.item()}')

        # 优化三部曲：1.清空之前的梯度；2.loss.backward()；3.一步更新optimizer.step()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

0.0019578970968723297
Epoch：933; Loss：0.000933770788833499
Epoch：934; Loss：0.0015001632273197174
Epoch：934; Loss：0.0014131814241409302
Epoch：934; Loss：0.002041023224592209
Epoch：934; Loss：0.0010188296437263489
Epoch：934; Loss：0.0025875698775053024
Epoch：934; Loss：0.0017970912158489227
Epoch：934; Loss：0.0019313544034957886
Epoch：934; Loss：0.0009228918352164328
Epoch：935; Loss：0.0014886707067489624
Epoch：935; Loss：0.0014076456427574158
Epoch：935; Loss：0.002028524875640869
Epoch：935; Loss：0.0010108593851327896
Epoch：935; Loss：0.0025868192315101624
Epoch：935; Loss：0.0017909817397594452
Epoch：935; Loss：0.0019331052899360657
Epoch：935; Loss：0.000924339983612299
Epoch：936; Loss：0.0014781318604946136
Epoch：936; Loss：0.0014014467597007751
Epoch：936; Loss：0.0020238086581230164
Epoch：936; Loss：0.0010090656578540802
Epoch：936; Loss：0.002567969262599945
Epoch：936; Loss：0.0017879493534564972
Epoch：936; Loss：0.0019254572689533234
Epoch：936; Loss：0.0009161101188510656
Epoch：937; Loss：0.001475464552640

In [58]:
# 最后用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

In [14]:
testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
if torch.cuda.is_available():
    testX = testX.cuda()
# 测试阶段参数无需梯度
with torch.no_grad():
    testY = model(testX)  # testY.shape=torch.Size([100, 4])

# predicts = zip(range(1, 101), list(testY.max))

In [16]:
testY.shape

torch.Size([100, 4])

In [43]:
predictions = zip(range(1, 101), testY.max(1)[1].cpu().data.tolist())
print([fizz_buzz_decode(i, x) for i, x in predictions])

['1', 'buzz', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', '10', '11', '12', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', 'buzz', 'fizz', 'buzz', 'buzz', 'fizz', '37', 'buzz', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', 'fizz', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', '65', '66', '67', 'fizz', 'fizz', 'buzz', '71', '72', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'buzz', 'fizz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', 'buzz', 'fizz', 'fizz']
