# 第一课

褚则伟 zeweichu@gmail.com

[参考资料 reference](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html)


什么是PyTorch?
================

PyTorch是一个基于Python的科学计算库，它有以下特点:

- 类似于NumPy，但是它可以使用GPU
- 可以用它定义深度学习模型，可以灵活地进行深度学习模型的训练和使用

Tensors
---------------


Tensor类似与NumPy的ndarray，唯一的区别是Tensor可以在GPU上加速运算。


In [1]:
from __future__ import print_function
import torch

构造一个未初始化的5x3矩阵:

In [2]:
x = torch.empty(5, 3)
print(x)

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 4.7339e+30, 1.4347e-19],
        [2.7909e+23, 1.8037e+28, 1.7237e+25],
        [9.1041e-12, 6.2609e+22, 4.7428e+30],
        [3.8001e-39, 0.0000e+00, 0.0000e+00]])


构建一个随机初始化的矩阵:

In [3]:
x = torch.rand(5, 3)
print(x)

tensor([[0.4821, 0.3854, 0.8517],
        [0.7962, 0.0632, 0.5409],
        [0.8891, 0.6112, 0.7829],
        [0.0715, 0.8069, 0.2608],
        [0.3292, 0.0119, 0.2759]])


构建一个全部为0，类型为long的矩阵:

In [4]:
x = torch.zeros(5, 3, dtype=torch.long)
print(x)

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


从数据直接直接构建tensor:

In [5]:
x = torch.tensor([5.5, 3])
print(x)

tensor([5.5000, 3.0000])


也可以从一个已有的tensor构建一个tensor。这些方法会重用原来tensor的特征，例如，数据类型，除非提供新的数据。

In [6]:
x = x.new_ones(5, 3, dtype=torch.double)      # new_* methods take in sizes
print(x)

x = torch.randn_like(x, dtype=torch.float)    # override dtype!
print(x)                                      # result has the same size

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
tensor([[ 1.4793, -2.4772,  0.9738],
        [ 2.0328,  1.3981,  1.7509],
        [-0.7931, -0.0291, -0.6803],
        [-1.2944, -0.7352, -0.9346],
        [ 0.5917, -0.5149, -1.8149]])


得到tensor的形状:

In [7]:
print(x.size())

torch.Size([5, 3])


<div class="alert alert-info"><h4>注意</h4><p>``torch.Size`` 返回的是一个tuple</p></div>

Operations


有很多种tensor运算。我们先介绍加法运算。



In [8]:
y = torch.rand(5, 3)
print(x + y)

tensor([[ 1.7113, -1.5490,  1.4009],
        [ 2.4590,  1.6504,  2.6889],
        [-0.3609,  0.4950, -0.3357],
        [-0.5029, -0.3086, -0.1498],
        [ 1.2850, -0.3189, -0.8868]])


另一种着加法的写法


In [9]:
print(torch.add(x, y))

tensor([[ 1.7113, -1.5490,  1.4009],
        [ 2.4590,  1.6504,  2.6889],
        [-0.3609,  0.4950, -0.3357],
        [-0.5029, -0.3086, -0.1498],
        [ 1.2850, -0.3189, -0.8868]])


加法：把输出作为一个变量

In [10]:
result = torch.empty(5, 3)
torch.add(x, y, out=result)
print(result)

tensor([[ 1.7113, -1.5490,  1.4009],
        [ 2.4590,  1.6504,  2.6889],
        [-0.3609,  0.4950, -0.3357],
        [-0.5029, -0.3086, -0.1498],
        [ 1.2850, -0.3189, -0.8868]])


in-place加法

In [11]:
# adds x to y
y.add_(x)
print(y)

tensor([[ 1.7113, -1.5490,  1.4009],
        [ 2.4590,  1.6504,  2.6889],
        [-0.3609,  0.4950, -0.3357],
        [-0.5029, -0.3086, -0.1498],
        [ 1.2850, -0.3189, -0.8868]])


<div class="alert alert-info"><h4>注意</h4><p>任何in-place的运算都会以``_``结尾。
    举例来说：``x.copy_(y)``, ``x.t_()``, 会改变 ``x``。</p></div>

各种类似NumPy的indexing都可以在PyTorch tensor上面使用。


In [12]:
print(x[:, 1])

tensor([-2.4772,  1.3981, -0.0291, -0.7352, -0.5149])


Resizing: 如果你希望resize/reshape一个tensor，可以使用``torch.view``：

In [13]:
x = torch.randn(4, 4)
y = x.view(16)
z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
print(x.size(), y.size(), z.size())

torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])


如果你有一个只有一个元素的tensor，使用``.item()``方法可以把里面的value变成Python数值。

In [14]:
x = torch.randn(1)
print(x)
print(x.item())

tensor([0.4726])
0.4726296067237854


**更多阅读**


  各种Tensor operations, 包括transposing, indexing, slicing,
  mathematical operations, linear algebra, random numbers在
  `<https://pytorch.org/docs/torch>`.

Numpy和Tensor之间的转化
------------

在Torch Tensor和NumPy array之间相互转化非常容易。

Torch Tensor和NumPy array会共享内存，所以改变其中一项也会改变另一项。

把Torch Tensor转变成NumPy Array


In [15]:
a = torch.ones(5)
print(a)

tensor([1., 1., 1., 1., 1.])


In [16]:
b = a.numpy()
print(b)

[1. 1. 1. 1. 1.]


改变numpy array里面的值。

In [17]:
a.add_(1)
print(a)
print(b)

tensor([2., 2., 2., 2., 2.])
[2. 2. 2. 2. 2.]


把NumPy ndarray转成Torch Tensor

In [18]:
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)


所有CPU上的Tensor都支持转成numpy或者从numpy转成Tensor。

CUDA Tensors
------------

使用``.to``方法，Tensor可以被移动到别的device上。



In [19]:
# let us run this cell only if CUDA is available
# We will use ``torch.device`` objects to move tensors in and out of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU
    x = x.to(device)                       # or just use strings ``.to("cuda")``
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))       # ``.to`` can also change dtype together!


热身: 用numpy实现两层神经网络
--------------

一个全连接ReLU神经网络，一个隐藏层，没有bias。用来从x预测y，使用L2 Loss。

这一实现完全使用numpy来计算前向神经网络，loss，和反向传播。

numpy ndarray是一个普通的n维array。它不知道任何关于深度学习或者梯度(gradient)的知识，也不知道计算图(computation graph)，只是一种用来计算数学运算的数据结构。



In [20]:
import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    
    # loss = (y_pred - y) ** 2
    grad_y_pred = 2.0 * (y_pred - y)
    # 
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 34399246.46047344
1 29023199.257758312
2 25155679.85447208
3 20344203.603057466
4 14771404.625789404
5 9796072.99431371
6 6194144.749997159
7 3948427.3657580013
8 2637928.1726997104
9 1879876.2597949505
10 1424349.925182723
11 1131684.579785501
12 930879.9521737935
13 783503.167740541
14 669981.8287784329
15 579151.6288421676
16 504610.5781504087
17 442295.18952143926
18 389647.44224490353
19 344718.3535892912
20 306120.2245707266
21 272728.24885829526
22 243778.8617292929
23 218485.92082002352
24 196304.70602822883
25 176774.2980280186
26 159509.34934842546
27 144200.52956072442
28 130597.06878493169
29 118484.47548850597
30 107661.24303895692
31 97973.75762285746
32 89291.0096051952
33 81500.46898789635
34 74477.4654945682
35 68139.90452489533
36 62418.87519034026
37 57241.53801123622
38 52545.34658231941
39 48280.5552386464
40 44399.73653914068
41 40864.495617471934
42 37640.08489317873
43 34695.77852549495
44 32004.894008637555
45 29545.09481447049
46 27292.93700341219
47 25232.8

367 0.0004326546423136559
368 0.0004116382458083261
369 0.0003916440959886334
370 0.0003726296356534275
371 0.0003545443586216977
372 0.000337347352488608
373 0.00032099061370803334
374 0.0003054229784132819
375 0.00029061647064382485
376 0.0002765299098361774
377 0.0002631327221101076
378 0.0002503865963973947
379 0.0002382599294869431
380 0.00022672670184804494
381 0.00021575299560298047
382 0.00020531375263207438
383 0.000195381616896771
384 0.00018593500698085453
385 0.00017694494225329907
386 0.00016839225855899982
387 0.00016025517275686525
388 0.00015251350815142156
389 0.0001451491411549753
390 0.0001381428245892601
391 0.00013147417414693054
392 0.00012512977608770297
393 0.00011909308605343111
394 0.00011334857979979945
395 0.00010788480695473414
396 0.00010268704883570024
397 9.773868892276339e-05
398 9.303020197524704e-05
399 8.85491663624475e-05
400 8.428485316645869e-05
401 8.022778747190388e-05
402 7.636668153099922e-05
403 7.269236014951034e-05
404 6.919607836124983e-05


PyTorch: Tensors
----------------

这次我们使用PyTorch tensors来创建前向神经网络，计算损失，以及反向传播。

一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是，PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算，就需要把Tensor换成cuda类型。


In [21]:
import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

0 31704728.0
1 25331164.0
2 22378086.0
3 19262238.0
4 15348289.0
5 11017595.0
6 7356282.0
7 4705923.5
8 3027346.5
9 2012536.375
10 1409662.25
11 1041771.75
12 807321.0625
13 649262.0
14 536533.1875
15 451980.875
16 385983.53125
17 332925.53125
18 289368.1875
19 253030.78125
20 222354.703125
21 196214.3125
22 173766.515625
23 154378.140625
24 137539.375
25 122867.1015625
26 110037.3515625
27 98769.4921875
28 88842.109375
29 80063.15625
30 72279.015625
31 65361.66796875
32 59195.42578125
33 53687.4453125
34 48757.57421875
35 44338.4453125
36 40370.34765625
37 36803.1484375
38 33587.4453125
39 30684.1640625
40 28059.435546875
41 25683.255859375
42 23528.814453125
43 21570.8515625
44 19792.4296875
45 18175.244140625
46 16704.6640625
47 15364.2578125
48 14141.7509765625
49 13026.609375
50 12007.3115234375
51 11075.3896484375
52 10221.8857421875
53 9439.876953125
54 8722.13671875
55 8063.46826171875
56 7458.20703125
57 6901.8876953125
58 6390.34375
59 5919.4794921875
60 5485.79345703125
61 5

375 0.0002844816190190613
376 0.00027625024085864425
377 0.0002687727683223784
378 0.0002608516369946301
379 0.00025311342324130237
380 0.0002469048195052892
381 0.00024049097555689514
382 0.0002342124644201249
383 0.00022811403323430568
384 0.00022231723414734006
385 0.0002166029589716345
386 0.00021077181736472994
387 0.00020510501053649932
388 0.00020020001102238894
389 0.0001948442222783342
390 0.00018990584067068994
391 0.00018529882072471082
392 0.00018070911755785346
393 0.00017650797963142395
394 0.00017214834224432707
395 0.0001683011942077428
396 0.00016451899136882275
397 0.00016050187696237117
398 0.00015686434926465154
399 0.00015321985119953752
400 0.0001501761726103723
401 0.00014639270375482738
402 0.00014274154091253877
403 0.0001396275474689901
404 0.0001364489580737427
405 0.00013346801279112697
406 0.00013024920190218836
407 0.00012755846546497196
408 0.00012532222899608314
409 0.0001224723382620141
410 0.00011974618973908946
411 0.00011740042100427672
412 0.0001144

简单的autograd

In [22]:
# Create tensors.
x = torch.tensor(1., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

# Build a computational graph.
y = w * x + b    # y = 2 * x + 3

# Compute gradients.
y.backward()

# Print out the gradients.
print(x.grad)    # x.grad = 2 
print(w.grad)    # w.grad = 1 
print(b.grad)    # b.grad = 1 

tensor(2.)
tensor(1.)
tensor(1.)



PyTorch: Tensor和autograd
-------------------------------

PyTorch的一个重要功能就是autograd，也就是说只要定义了forward pass(前向神经网络)，计算了loss之后，PyTorch可以自动求导计算模型所有参数的梯度。

一个PyTorch的Tensor表示计算图中的一个节点。如果``x``是一个Tensor并且``x.requires_grad=True``那么``x.grad``是另一个储存着``x``当前梯度(相对于一个scalar，常常是loss)的向量。


In [23]:
import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N 是 batch size; D_in 是 input dimension;
# H 是 hidden dimension; D_out 是 output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# 创建随机的Tensor来保存输入和输出
# 设定requires_grad=False表示在反向传播的时候我们不需要计算gradient
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# 创建随机的Tensor和权重。
# 设置requires_grad=True表示我们希望反向传播的时候计算Tensor的gradient
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # 前向传播:通过Tensor预测y；这个和普通的神经网络的前向传播没有任何不同，
    # 但是我们不需要保存网络的中间运算结果，因为我们不需要手动计算反向传播。
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # 通过前向传播计算loss
    # loss是一个形状为(1，)的Tensor
    # loss.item()可以给我们返回一个loss的scalar
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # PyTorch给我们提供了autograd的方法做反向传播。如果一个Tensor的requires_grad=True，
    # backward会自动计算loss相对于每个Tensor的gradient。在backward之后，
    # w1.grad和w2.grad会包含两个loss相对于两个Tensor的gradient信息。
    loss.backward()

    # 我们可以手动做gradient descent(后面我们会介绍自动的方法)。
    # 用torch.no_grad()包含以下statements，因为w1和w2都是requires_grad=True，
    # 但是在更新weights之后我们并不需要再做autograd。
    # 另一种方法是在weight.data和weight.grad.data上做操作，这样就不会对grad产生影响。
    # tensor.data会我们一个tensor，这个tensor和原来的tensor指向相同的内存空间，
    # 但是不会记录计算图的历史。
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 31590738.0
1 34389704.0
2 44504280.0
3 52598508.0
4 46752264.0
5 27227634.0
6 10779343.0
7 3889138.75
8 1856397.875
9 1232127.25
10 967278.5
11 806383.9375
12 687169.25
13 591936.25
14 513579.40625
15 448339.5
16 393390.71875
17 346772.71875
18 306952.625
19 272743.90625
20 243250.578125
21 217760.4375
22 195513.75
23 176012.4375
24 158848.59375
25 143694.4375
26 130272.53125
27 118357.1328125
28 107732.5625
29 98245.9296875
30 89754.4375
31 82145.9765625
32 75299.703125
33 69130.7265625
34 63549.09375
35 58498.18359375
36 53914.7421875
37 49751.984375
38 45963.8515625
39 42512.19140625
40 39364.1484375
41 36486.7421875
42 33852.94921875
43 31441.951171875
44 29230.11328125
45 27200.080078125
46 25335.595703125
47 23618.97265625
48 22036.193359375
49 20575.412109375
50 19227.5078125
51 17980.865234375
52 16826.919921875
53 15756.392578125
54 14762.513671875
55 13839.58203125
56 12981.9228515625
57 12184.3896484375
58 11442.140625
59 10750.8681640625
60 10106.751953125
61 9505.8720703

394 0.005430158693343401
395 0.005243257619440556
396 0.005058295093476772
397 0.0048800683580338955
398 0.004707938991487026
399 0.004541801754385233
400 0.004385354463011026
401 0.0042332010343670845
402 0.0040851193480193615
403 0.003942274488508701
404 0.003809330752119422
405 0.0036788880825042725
406 0.0035530496388673782
407 0.0034328829497098923
408 0.003316469956189394
409 0.0032058244105428457
410 0.003095718566328287
411 0.002996482653543353
412 0.002896404592320323
413 0.002801347989588976
414 0.0027062646113336086
415 0.0026161009445786476
416 0.002530781552195549
417 0.002449025632813573
418 0.002370838774368167
419 0.002294242149218917
420 0.002220114693045616
421 0.002151642693206668
422 0.0020829373970627785
423 0.0020190104842185974
424 0.0019563380628824234
425 0.0018947365460917354
426 0.0018343634437769651
427 0.0017779992194846272
428 0.0017241643508896232
429 0.001670036930590868
430 0.0016198739176616073
431 0.0015696510672569275
432 0.0015243508387356997
433 0.


PyTorch: nn
-----------


这次我们使用PyTorch中nn这个库来构建网络。
用PyTorch autograd来构建计算图和计算gradients，
然后PyTorch会帮我们自动计算gradient。




In [2]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
torch.nn.init.normal_(model[0].weight)
torch.nn.init.normal_(model[2].weight)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 35284576.0
1 32432502.0
2 33032056.0
3 31040754.0
4 24188534.0
5 15314465.0
6 8266213.0
7 4301408.0
8 2409370.5
9 1536877.125
10 1103504.25
11 857924.6875
12 697947.75
13 582105.25
14 492513.09375
15 420523.875
16 361562.375
17 312568.5
18 271470.28125
19 236731.046875
20 207311.296875
21 182228.4375
22 160721.953125
23 142198.15625
24 126180.734375
25 112269.6328125
26 100122.375
27 89506.40625
28 80192.046875
29 71993.2578125
30 64751.7578125
31 58336.12890625
32 52642.5234375
33 47582.78125
34 43079.55078125
35 39059.6484375
36 35458.6875
37 32231.09765625
38 29333.8515625
39 26726.8984375
40 24377.580078125
41 22257.607421875
42 20341.6953125
43 18609.6015625
44 17039.90234375
45 15616.369140625
46 14323.9052734375
47 13149.0888671875
48 12079.79296875
49 11105.9375
50 10217.779296875
51 9407.205078125
52 8666.0849609375
53 7988.56201171875
54 7368.7734375
55 6800.861328125
56 6280.4072265625
57 5803.27587890625
58 5365.1240234375
59 4962.60009765625
60 4592.69189453125
61 4252.4


PyTorch: optim
--------------

这一次我们不再手动更新模型的weights,而是使用optim这个包来帮助我们更新参数。
optim这个package提供了各种不同的模型优化方法，包括SGD+momentum, RMSProp, Adam等等。


In [8]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
# learning_rate = 1e-4
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

0 665.336669921875
1 648.9841918945312
2 633.0789794921875
3 617.599365234375
4 602.4755859375
5 587.7586669921875
6 573.465087890625
7 559.5175170898438
8 545.9730834960938
9 532.826904296875
10 520.11474609375
11 507.81353759765625
12 495.8321533203125
13 484.17486572265625
14 472.873291015625
15 461.8371276855469
16 451.0406494140625
17 440.4892272949219
18 430.27142333984375
19 420.3191223144531
20 410.5938720703125
21 401.08624267578125
22 391.8051452636719
23 382.7082214355469
24 373.7925109863281
25 365.119140625
26 356.6669921875
27 348.3946838378906
28 340.3127746582031
29 332.4108581542969
30 324.68994140625
31 317.16448974609375
32 309.8135070800781
33 302.61700439453125
34 295.558837890625
35 288.6358947753906
36 281.82830810546875
37 275.2060852050781
38 268.714111328125
39 262.3332214355469
40 256.0782470703125
41 249.93118286132812
42 243.91510009765625
43 238.0201416015625
44 232.23870849609375
45 226.56715393066406
46 221.00692749023438
47 215.56346130371094
48 210.232


PyTorch: 自定义 nn Modules
--------------------------

我们可以定义一个模型，这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型，就需要定义nn.Module模型。



In [26]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 656.6958618164062
1 608.1090087890625
2 566.172607421875
3 529.2335815429688
4 496.7382507324219
5 467.453125
6 440.5755310058594
7 416.12872314453125
8 393.6068420410156
9 372.708251953125
10 353.00006103515625
11 334.477783203125
12 316.97283935546875
13 300.36737060546875
14 284.6544189453125
15 269.65936279296875
16 255.33456420898438
17 241.66688537597656
18 228.60800170898438
19 216.09536743164062
20 204.13780212402344
21 192.75645446777344
22 181.89234924316406
23 171.58370971679688
24 161.7939453125
25 152.4780731201172
26 143.59371948242188
27 135.14727783203125
28 127.13992309570312
29 119.55585479736328
30 112.37797546386719
31 105.62073516845703
32 99.24383544921875
33 93.24134826660156
34 87.58341979980469
35 82.25212860107422
36 77.24210357666016
37 72.55087280273438
38 68.1427230834961
39 64.00277709960938
40 60.1308479309082
41 56.49887466430664
42 53.0952033996582
43 49.906524658203125
44 46.91959762573242
45 44.11970520019531
46 41.50297164916992
47 39.0628700256347

375 0.0006505012279376388
376 0.0006343786371871829
377 0.0006186614627949893
378 0.0006033276440575719
379 0.0005883832345716655
380 0.0005738206673413515
381 0.0005596213741227984
382 0.0005457888473756611
383 0.0005322962533682585
384 0.0005191444652155042
385 0.000506328884512186
386 0.0004938290221616626
387 0.00048163760220631957
388 0.0004697689728345722
389 0.0004582055553328246
390 0.00044691533548757434
391 0.00043590739369392395
392 0.0004251690406817943
393 0.0004147063591517508
394 0.00040450665983371437
395 0.0003945553908124566
396 0.0003848606429528445
397 0.00037539892946369946
398 0.0003661849768832326
399 0.00035720854066312313
400 0.000348439411027357
401 0.0003398970584385097
402 0.00033156739664264023
403 0.0003234421892557293
404 0.0003155224258080125
405 0.00030779733788222075
406 0.0003002593875862658
407 0.00029291390092112124
408 0.00028574312455020845
409 0.0002787590492516756
410 0.000271946337306872
411 0.00026530082686804235
412 0.00025882109184749424
413

# FizzBuzz

FizzBuzz是一个简单的小游戏。游戏规则如下：从1开始往上数数，当遇到3的倍数的时候，说fizz，当遇到5的倍数，说buzz，当遇到15的倍数，就说fizzbuzz，其他情况下则正常数数。

我们可以写一个简单的小程序来决定要返回正常数值还是fizz, buzz 或者 fizzbuzz。

In [9]:
# One-hot encode the desired outputs: [number, "fizz", "buzz", "fizzbuzz"]
def fizz_buzz_encode(i):
    if   i % 15 == 0: return 3
    elif i % 5  == 0: return 2
    elif i % 3  == 0: return 1
    else:             return 0
    
def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

print(fizz_buzz_decode(1, fizz_buzz_encode(1)))
print(fizz_buzz_decode(2, fizz_buzz_encode(2)))
print(fizz_buzz_decode(5, fizz_buzz_encode(5)))
print(fizz_buzz_decode(12, fizz_buzz_encode(12)))
print(fizz_buzz_decode(15, fizz_buzz_encode(15)))

1
2
buzz
fizz
fizzbuzz


我们首先定义模型的输入与输出(训练数据)

In [10]:
import numpy as np
import torch

NUM_DIGITS = 10

# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

trX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])

然后我们用PyTorch定义模型

In [11]:
# Define the model
NUM_HIDDEN = 100
model = torch.nn.Sequential(
    torch.nn.Linear(NUM_DIGITS, NUM_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(NUM_HIDDEN, 4)
)

- 为了让我们的模型学会FizzBuzz这个游戏，我们需要定义一个损失函数，和一个优化算法。
- 这个优化算法会不断优化（降低）损失函数，使得模型的在该任务上取得尽可能低的损失值。
- 损失值低往往表示我们的模型表现好，损失值高表示我们的模型表现差。
- 由于FizzBuzz游戏本质上是一个分类问题，我们选用Cross Entropyy Loss函数。
- 优化函数我们选用Stochastic Gradient Descent。

In [12]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.05)

以下是模型的训练代码

In [13]:
# Start training it
BATCH_SIZE = 128
for epoch in range(10000):
    for start in range(0, len(trX), BATCH_SIZE):
        end = start + BATCH_SIZE
        batchX = trX[start:end]
        batchY = trY[start:end]

        y_pred = model(batchX)
        loss = loss_fn(y_pred, batchY)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Find loss on training data
    loss = loss_fn(model(trX), trY).item()
    print('Epoch:', epoch, 'Loss:', loss)

Epoch: 0 Loss: 1.1830308437347412
Epoch: 1 Loss: 1.154726266860962
Epoch: 2 Loss: 1.1481016874313354
Epoch: 3 Loss: 1.1456503868103027
Epoch: 4 Loss: 1.1444178819656372
Epoch: 5 Loss: 1.1436442136764526
Epoch: 6 Loss: 1.1430857181549072
Epoch: 7 Loss: 1.142643928527832
Epoch: 8 Loss: 1.1422686576843262
Epoch: 9 Loss: 1.141944408416748
Epoch: 10 Loss: 1.1416491270065308
Epoch: 11 Loss: 1.1413792371749878
Epoch: 12 Loss: 1.1411278247833252
Epoch: 13 Loss: 1.1408933401107788
Epoch: 14 Loss: 1.140674352645874
Epoch: 15 Loss: 1.1404683589935303
Epoch: 16 Loss: 1.1402713060379028
Epoch: 17 Loss: 1.1400843858718872
Epoch: 18 Loss: 1.1399037837982178
Epoch: 19 Loss: 1.1397323608398438
Epoch: 20 Loss: 1.1395655870437622
Epoch: 21 Loss: 1.139404058456421
Epoch: 22 Loss: 1.1392452716827393
Epoch: 23 Loss: 1.139093279838562
Epoch: 24 Loss: 1.1389470100402832
Epoch: 25 Loss: 1.1388049125671387
Epoch: 26 Loss: 1.1386692523956299
Epoch: 27 Loss: 1.1385384798049927
Epoch: 28 Loss: 1.1384092569351196
E

Epoch: 460 Loss: 0.955564558506012
Epoch: 461 Loss: 0.955300509929657
Epoch: 462 Loss: 0.9528169631958008
Epoch: 463 Loss: 0.9527114033699036
Epoch: 464 Loss: 0.9501111507415771
Epoch: 465 Loss: 0.9524601101875305
Epoch: 466 Loss: 0.9508019089698792
Epoch: 467 Loss: 0.94879549741745
Epoch: 468 Loss: 0.9485660195350647
Epoch: 469 Loss: 0.947771430015564
Epoch: 470 Loss: 0.9463621973991394
Epoch: 471 Loss: 0.9447852969169617
Epoch: 472 Loss: 0.9463176727294922
Epoch: 473 Loss: 0.942451536655426
Epoch: 474 Loss: 0.943315327167511
Epoch: 475 Loss: 0.9403923749923706
Epoch: 476 Loss: 0.9428916573524475
Epoch: 477 Loss: 0.9392880797386169
Epoch: 478 Loss: 0.9389854669570923
Epoch: 479 Loss: 0.9389311671257019
Epoch: 480 Loss: 0.9371123909950256
Epoch: 481 Loss: 0.9358103275299072
Epoch: 482 Loss: 0.9360257387161255
Epoch: 483 Loss: 0.933333158493042
Epoch: 484 Loss: 0.9336093068122864
Epoch: 485 Loss: 0.9313328266143799
Epoch: 486 Loss: 0.9320828318595886
Epoch: 487 Loss: 0.9307708740234375


Epoch: 916 Loss: 0.49431347846984863
Epoch: 917 Loss: 0.49400439858436584
Epoch: 918 Loss: 0.4926239252090454
Epoch: 919 Loss: 0.49269041419029236
Epoch: 920 Loss: 0.491123229265213
Epoch: 921 Loss: 0.49085748195648193
Epoch: 922 Loss: 0.4894615709781647
Epoch: 923 Loss: 0.49001845717430115
Epoch: 924 Loss: 0.48794025182724
Epoch: 925 Loss: 0.48777738213539124
Epoch: 926 Loss: 0.48735925555229187
Epoch: 927 Loss: 0.4870048463344574
Epoch: 928 Loss: 0.4847734868526459
Epoch: 929 Loss: 0.4858655035495758
Epoch: 930 Loss: 0.4837634265422821
Epoch: 931 Loss: 0.48287758231163025
Epoch: 932 Loss: 0.4826996624469757
Epoch: 933 Loss: 0.48170971870422363
Epoch: 934 Loss: 0.4818239212036133
Epoch: 935 Loss: 0.4798835515975952
Epoch: 936 Loss: 0.480034202337265
Epoch: 937 Loss: 0.47886136174201965
Epoch: 938 Loss: 0.478089839220047
Epoch: 939 Loss: 0.47834452986717224
Epoch: 940 Loss: 0.4766152799129486
Epoch: 941 Loss: 0.4762994945049286
Epoch: 942 Loss: 0.47530168294906616
Epoch: 943 Loss: 0.47

Epoch: 1355 Loss: 0.2591025233268738
Epoch: 1356 Loss: 0.2588486075401306
Epoch: 1357 Loss: 0.2583293616771698
Epoch: 1358 Loss: 0.25809767842292786
Epoch: 1359 Loss: 0.25786876678466797
Epoch: 1360 Loss: 0.2573625147342682
Epoch: 1361 Loss: 0.2572280168533325
Epoch: 1362 Loss: 0.25675249099731445
Epoch: 1363 Loss: 0.25644758343696594
Epoch: 1364 Loss: 0.25623515248298645
Epoch: 1365 Loss: 0.2557029724121094
Epoch: 1366 Loss: 0.25556832551956177
Epoch: 1367 Loss: 0.2551448941230774
Epoch: 1368 Loss: 0.2548053562641144
Epoch: 1369 Loss: 0.25453221797943115
Epoch: 1370 Loss: 0.25420501828193665
Epoch: 1371 Loss: 0.25374817848205566
Epoch: 1372 Loss: 0.2535306513309479
Epoch: 1373 Loss: 0.253292053937912
Epoch: 1374 Loss: 0.2527993619441986
Epoch: 1375 Loss: 0.25253695249557495
Epoch: 1376 Loss: 0.2522173225879669
Epoch: 1377 Loss: 0.251817911863327
Epoch: 1378 Loss: 0.25163233280181885
Epoch: 1379 Loss: 0.2513378858566284
Epoch: 1380 Loss: 0.2510133683681488
Epoch: 1381 Loss: 0.250629484

Epoch: 1789 Loss: 0.16320858895778656
Epoch: 1790 Loss: 0.16308173537254333
Epoch: 1791 Loss: 0.16297517716884613
Epoch: 1792 Loss: 0.16275718808174133
Epoch: 1793 Loss: 0.16264288127422333
Epoch: 1794 Loss: 0.16241124272346497
Epoch: 1795 Loss: 0.1623183786869049
Epoch: 1796 Loss: 0.1621813178062439
Epoch: 1797 Loss: 0.1620868444442749
Epoch: 1798 Loss: 0.16195814311504364
Epoch: 1799 Loss: 0.16185784339904785
Epoch: 1800 Loss: 0.16160328686237335
Epoch: 1801 Loss: 0.16144298017024994
Epoch: 1802 Loss: 0.16140632331371307
Epoch: 1803 Loss: 0.16122616827487946
Epoch: 1804 Loss: 0.16098584234714508
Epoch: 1805 Loss: 0.16086074709892273
Epoch: 1806 Loss: 0.16073186695575714
Epoch: 1807 Loss: 0.16067461669445038
Epoch: 1808 Loss: 0.1605333536863327
Epoch: 1809 Loss: 0.1603788435459137
Epoch: 1810 Loss: 0.16026204824447632
Epoch: 1811 Loss: 0.16007807850837708
Epoch: 1812 Loss: 0.159893199801445
Epoch: 1813 Loss: 0.15978185832500458
Epoch: 1814 Loss: 0.15963368117809296
Epoch: 1815 Loss: 0

Epoch: 2223 Loss: 0.11568006873130798
Epoch: 2224 Loss: 0.11562129855155945
Epoch: 2225 Loss: 0.11561109125614166
Epoch: 2226 Loss: 0.11546032130718231
Epoch: 2227 Loss: 0.11545832455158234
Epoch: 2228 Loss: 0.11533285677433014
Epoch: 2229 Loss: 0.11529581993818283
Epoch: 2230 Loss: 0.11515042185783386
Epoch: 2231 Loss: 0.11510957032442093
Epoch: 2232 Loss: 0.11496042460203171
Epoch: 2233 Loss: 0.11492668837308884
Epoch: 2234 Loss: 0.11477892845869064
Epoch: 2235 Loss: 0.11477494239807129
Epoch: 2236 Loss: 0.11458668857812881
Epoch: 2237 Loss: 0.11453967541456223
Epoch: 2238 Loss: 0.1145249754190445
Epoch: 2239 Loss: 0.11436062306165695
Epoch: 2240 Loss: 0.11433528363704681
Epoch: 2241 Loss: 0.11422601342201233
Epoch: 2242 Loss: 0.11417543143033981
Epoch: 2243 Loss: 0.11405922472476959
Epoch: 2244 Loss: 0.11399277299642563
Epoch: 2245 Loss: 0.11403156071901321
Epoch: 2246 Loss: 0.11378809064626694
Epoch: 2247 Loss: 0.11376475542783737
Epoch: 2248 Loss: 0.11368866264820099
Epoch: 2249 L

Epoch: 2655 Loss: 0.08510478585958481
Epoch: 2656 Loss: 0.08512859046459198
Epoch: 2657 Loss: 0.08499963581562042
Epoch: 2658 Loss: 0.08498919010162354
Epoch: 2659 Loss: 0.08497168868780136
Epoch: 2660 Loss: 0.084864541888237
Epoch: 2661 Loss: 0.08479481190443039
Epoch: 2662 Loss: 0.08469557762145996
Epoch: 2663 Loss: 0.08469956368207932
Epoch: 2664 Loss: 0.0845843032002449
Epoch: 2665 Loss: 0.0845298245549202
Epoch: 2666 Loss: 0.08445698767900467
Epoch: 2667 Loss: 0.08444090932607651
Epoch: 2668 Loss: 0.08441707491874695
Epoch: 2669 Loss: 0.0842650905251503
Epoch: 2670 Loss: 0.0842733159661293
Epoch: 2671 Loss: 0.08419335633516312
Epoch: 2672 Loss: 0.08410793542861938
Epoch: 2673 Loss: 0.08411204814910889
Epoch: 2674 Loss: 0.08403468877077103
Epoch: 2675 Loss: 0.08394497632980347
Epoch: 2676 Loss: 0.0839448869228363
Epoch: 2677 Loss: 0.08379952609539032
Epoch: 2678 Loss: 0.0837685689330101
Epoch: 2679 Loss: 0.08371113985776901
Epoch: 2680 Loss: 0.08368319272994995
Epoch: 2681 Loss: 0.

Epoch: 3087 Loss: 0.06350138038396835
Epoch: 3088 Loss: 0.06348108500242233
Epoch: 3089 Loss: 0.06342444568872452
Epoch: 3090 Loss: 0.06341073662042618
Epoch: 3091 Loss: 0.06334488838911057
Epoch: 3092 Loss: 0.06330767273902893
Epoch: 3093 Loss: 0.06331552565097809
Epoch: 3094 Loss: 0.06324208527803421
Epoch: 3095 Loss: 0.06316233426332474
Epoch: 3096 Loss: 0.0631156861782074
Epoch: 3097 Loss: 0.06312043964862823
Epoch: 3098 Loss: 0.06306398659944534
Epoch: 3099 Loss: 0.0630311444401741
Epoch: 3100 Loss: 0.06300246715545654
Epoch: 3101 Loss: 0.06294472515583038
Epoch: 3102 Loss: 0.062893807888031
Epoch: 3103 Loss: 0.06285963952541351
Epoch: 3104 Loss: 0.06282580643892288
Epoch: 3105 Loss: 0.0627516359090805
Epoch: 3106 Loss: 0.06275436282157898
Epoch: 3107 Loss: 0.06271260976791382
Epoch: 3108 Loss: 0.06265950947999954
Epoch: 3109 Loss: 0.06260956823825836
Epoch: 3110 Loss: 0.06257007271051407
Epoch: 3111 Loss: 0.06249317526817322
Epoch: 3112 Loss: 0.06249535083770752
Epoch: 3113 Loss:

Epoch: 3515 Loss: 0.04864414036273956
Epoch: 3516 Loss: 0.04859721288084984
Epoch: 3517 Loss: 0.04855599254369736
Epoch: 3518 Loss: 0.048539627343416214
Epoch: 3519 Loss: 0.04851280525326729
Epoch: 3520 Loss: 0.048481766134500504
Epoch: 3521 Loss: 0.048417117446660995
Epoch: 3522 Loss: 0.048427458852529526
Epoch: 3523 Loss: 0.048387560993433
Epoch: 3524 Loss: 0.04836104437708855
Epoch: 3525 Loss: 0.048297904431819916
Epoch: 3526 Loss: 0.04830915853381157
Epoch: 3527 Loss: 0.04826648160815239
Epoch: 3528 Loss: 0.04823547601699829
Epoch: 3529 Loss: 0.048212893307209015
Epoch: 3530 Loss: 0.048146337270736694
Epoch: 3531 Loss: 0.04815470427274704
Epoch: 3532 Loss: 0.04812120646238327
Epoch: 3533 Loss: 0.04808728024363518
Epoch: 3534 Loss: 0.04804634675383568
Epoch: 3535 Loss: 0.04801720753312111
Epoch: 3536 Loss: 0.048010509461164474
Epoch: 3537 Loss: 0.047970205545425415
Epoch: 3538 Loss: 0.04790303111076355
Epoch: 3539 Loss: 0.04791119322180748
Epoch: 3540 Loss: 0.04788000136613846
Epoch

Epoch: 3942 Loss: 0.03761538118124008
Epoch: 3943 Loss: 0.03759725019335747
Epoch: 3944 Loss: 0.03755766525864601
Epoch: 3945 Loss: 0.03756934031844139
Epoch: 3946 Loss: 0.0375404879450798
Epoch: 3947 Loss: 0.037527699023485184
Epoch: 3948 Loss: 0.03749304264783859
Epoch: 3949 Loss: 0.03747009485960007
Epoch: 3950 Loss: 0.03746798634529114
Epoch: 3951 Loss: 0.037446681410074234
Epoch: 3952 Loss: 0.03741373494267464
Epoch: 3953 Loss: 0.03737418353557587
Epoch: 3954 Loss: 0.03736887127161026
Epoch: 3955 Loss: 0.03734169155359268
Epoch: 3956 Loss: 0.037325967103242874
Epoch: 3957 Loss: 0.03731459006667137
Epoch: 3958 Loss: 0.03729785233736038
Epoch: 3959 Loss: 0.03726079314947128
Epoch: 3960 Loss: 0.037257641553878784
Epoch: 3961 Loss: 0.037222180515527725
Epoch: 3962 Loss: 0.03718649595975876
Epoch: 3963 Loss: 0.037170007824897766
Epoch: 3964 Loss: 0.03715326264500618
Epoch: 3965 Loss: 0.03715755417943001
Epoch: 3966 Loss: 0.037116728723049164
Epoch: 3967 Loss: 0.037104282528162
Epoch: 3

Epoch: 4369 Loss: 0.030071662738919258
Epoch: 4370 Loss: 0.030035950243473053
Epoch: 4371 Loss: 0.03003639355301857
Epoch: 4372 Loss: 0.030015461146831512
Epoch: 4373 Loss: 0.029988454654812813
Epoch: 4374 Loss: 0.02999514900147915
Epoch: 4375 Loss: 0.02996733784675598
Epoch: 4376 Loss: 0.02996576949954033
Epoch: 4377 Loss: 0.02995605766773224
Epoch: 4378 Loss: 0.029922395944595337
Epoch: 4379 Loss: 0.029909387230873108
Epoch: 4380 Loss: 0.02990051731467247
Epoch: 4381 Loss: 0.029881827533245087
Epoch: 4382 Loss: 0.029878320172429085
Epoch: 4383 Loss: 0.02985219471156597
Epoch: 4384 Loss: 0.02984016202390194
Epoch: 4385 Loss: 0.02982771396636963
Epoch: 4386 Loss: 0.029805846512317657
Epoch: 4387 Loss: 0.029788121581077576
Epoch: 4388 Loss: 0.029782680794596672
Epoch: 4389 Loss: 0.029763534665107727
Epoch: 4390 Loss: 0.029761139303445816
Epoch: 4391 Loss: 0.02975706197321415
Epoch: 4392 Loss: 0.029727093875408173
Epoch: 4393 Loss: 0.02970656380057335
Epoch: 4394 Loss: 0.0297015924006700

Epoch: 4793 Loss: 0.0247193593531847
Epoch: 4794 Loss: 0.02469531260430813
Epoch: 4795 Loss: 0.02469918318092823
Epoch: 4796 Loss: 0.024692051112651825
Epoch: 4797 Loss: 0.024659203365445137
Epoch: 4798 Loss: 0.02464747242629528
Epoch: 4799 Loss: 0.02465033158659935
Epoch: 4800 Loss: 0.024646103382110596
Epoch: 4801 Loss: 0.024624941870570183
Epoch: 4802 Loss: 0.02462282031774521
Epoch: 4803 Loss: 0.02459602802991867
Epoch: 4804 Loss: 0.0245954729616642
Epoch: 4805 Loss: 0.02458273246884346
Epoch: 4806 Loss: 0.02456013672053814
Epoch: 4807 Loss: 0.024563943967223167
Epoch: 4808 Loss: 0.024554887786507607
Epoch: 4809 Loss: 0.02454095520079136
Epoch: 4810 Loss: 0.024528145790100098
Epoch: 4811 Loss: 0.02450895868241787
Epoch: 4812 Loss: 0.0245161484926939
Epoch: 4813 Loss: 0.024493735283613205
Epoch: 4814 Loss: 0.024495214223861694
Epoch: 4815 Loss: 0.02447408065199852
Epoch: 4816 Loss: 0.02447664365172386
Epoch: 4817 Loss: 0.02444697916507721
Epoch: 4818 Loss: 0.0244361013174057
Epoch: 

Epoch: 5217 Loss: 0.020775513723492622
Epoch: 5218 Loss: 0.02076730877161026
Epoch: 5219 Loss: 0.020759424194693565
Epoch: 5220 Loss: 0.020743245258927345
Epoch: 5221 Loss: 0.020742079243063927
Epoch: 5222 Loss: 0.02073465660214424
Epoch: 5223 Loss: 0.020725175738334656
Epoch: 5224 Loss: 0.020723676308989525
Epoch: 5225 Loss: 0.020710280165076256
Epoch: 5226 Loss: 0.020703468471765518
Epoch: 5227 Loss: 0.02069724351167679
Epoch: 5228 Loss: 0.020683584734797478
Epoch: 5229 Loss: 0.020676622167229652
Epoch: 5230 Loss: 0.020680086687207222
Epoch: 5231 Loss: 0.020657232031226158
Epoch: 5232 Loss: 0.02065407857298851
Epoch: 5233 Loss: 0.020646845921874046
Epoch: 5234 Loss: 0.020637020468711853
Epoch: 5235 Loss: 0.020637935027480125
Epoch: 5236 Loss: 0.020623592659831047
Epoch: 5237 Loss: 0.02061941847205162
Epoch: 5238 Loss: 0.02060708776116371
Epoch: 5239 Loss: 0.020603463053703308
Epoch: 5240 Loss: 0.020595818758010864
Epoch: 5241 Loss: 0.020578857511281967
Epoch: 5242 Loss: 0.02057278156

Epoch: 5641 Loss: 0.017798885703086853
Epoch: 5642 Loss: 0.01778426207602024
Epoch: 5643 Loss: 0.017785092815756798
Epoch: 5644 Loss: 0.017780130729079247
Epoch: 5645 Loss: 0.017770903185009956
Epoch: 5646 Loss: 0.01776626892387867
Epoch: 5647 Loss: 0.017755992710590363
Epoch: 5648 Loss: 0.017751313745975494
Epoch: 5649 Loss: 0.017749985679984093
Epoch: 5650 Loss: 0.017741549760103226
Epoch: 5651 Loss: 0.017736375331878662
Epoch: 5652 Loss: 0.017720717936754227
Epoch: 5653 Loss: 0.01772133819758892
Epoch: 5654 Loss: 0.017714451998472214
Epoch: 5655 Loss: 0.017708970233798027
Epoch: 5656 Loss: 0.01769893802702427
Epoch: 5657 Loss: 0.017693065106868744
Epoch: 5658 Loss: 0.01769004575908184
Epoch: 5659 Loss: 0.017684752121567726
Epoch: 5660 Loss: 0.017679614946246147
Epoch: 5661 Loss: 0.017664430662989616
Epoch: 5662 Loss: 0.017660904675722122
Epoch: 5663 Loss: 0.017655791714787483
Epoch: 5664 Loss: 0.017652206122875214
Epoch: 5665 Loss: 0.01764996163547039
Epoch: 5666 Loss: 0.01763382740

Epoch: 6064 Loss: 0.015307975932955742
Epoch: 6065 Loss: 0.015304038301110268
Epoch: 6066 Loss: 0.01529699843376875
Epoch: 6067 Loss: 0.015291950665414333
Epoch: 6068 Loss: 0.015283928252756596
Epoch: 6069 Loss: 0.015285785309970379
Epoch: 6070 Loss: 0.015279894694685936
Epoch: 6071 Loss: 0.01527296844869852
Epoch: 6072 Loss: 0.015266054309904575
Epoch: 6073 Loss: 0.015260014683008194
Epoch: 6074 Loss: 0.01525754202157259
Epoch: 6075 Loss: 0.015251112170517445
Epoch: 6076 Loss: 0.015239736996591091
Epoch: 6077 Loss: 0.015236946754157543
Epoch: 6078 Loss: 0.015233361162245274
Epoch: 6079 Loss: 0.015229363925755024
Epoch: 6080 Loss: 0.015222443267703056
Epoch: 6081 Loss: 0.015217356383800507
Epoch: 6082 Loss: 0.015214623883366585
Epoch: 6083 Loss: 0.015208860859274864
Epoch: 6084 Loss: 0.015201973728835583
Epoch: 6085 Loss: 0.015199034474790096
Epoch: 6086 Loss: 0.01519101019948721
Epoch: 6087 Loss: 0.015184912830591202
Epoch: 6088 Loss: 0.015181416645646095
Epoch: 6089 Loss: 0.015178482

Epoch: 6486 Loss: 0.013384479098021984
Epoch: 6487 Loss: 0.013383720070123672
Epoch: 6488 Loss: 0.013376779854297638
Epoch: 6489 Loss: 0.013372219167649746
Epoch: 6490 Loss: 0.013376050628721714
Epoch: 6491 Loss: 0.013368066400289536
Epoch: 6492 Loss: 0.013361657969653606
Epoch: 6493 Loss: 0.013358446769416332
Epoch: 6494 Loss: 0.013357756659388542
Epoch: 6495 Loss: 0.013355244882404804
Epoch: 6496 Loss: 0.013343114405870438
Epoch: 6497 Loss: 0.013346418738365173
Epoch: 6498 Loss: 0.013342425227165222
Epoch: 6499 Loss: 0.01333602424710989
Epoch: 6500 Loss: 0.01333218440413475
Epoch: 6501 Loss: 0.013330920599400997
Epoch: 6502 Loss: 0.013323621824383736
Epoch: 6503 Loss: 0.013319705612957478
Epoch: 6504 Loss: 0.01331861037760973
Epoch: 6505 Loss: 0.013309535570442677
Epoch: 6506 Loss: 0.013310359790921211
Epoch: 6507 Loss: 0.013304263353347778
Epoch: 6508 Loss: 0.013298939913511276
Epoch: 6509 Loss: 0.013295703567564487
Epoch: 6510 Loss: 0.013292481191456318
Epoch: 6511 Loss: 0.01329105

Epoch: 6908 Loss: 0.011891866102814674
Epoch: 6909 Loss: 0.011891811154782772
Epoch: 6910 Loss: 0.01188511960208416
Epoch: 6911 Loss: 0.011880991980433464
Epoch: 6912 Loss: 0.011877680197358131
Epoch: 6913 Loss: 0.011878480203449726
Epoch: 6914 Loss: 0.011876565404236317
Epoch: 6915 Loss: 0.011868329718708992
Epoch: 6916 Loss: 0.011866657063364983
Epoch: 6917 Loss: 0.011865249834954739
Epoch: 6918 Loss: 0.011860457248985767
Epoch: 6919 Loss: 0.011859288439154625
Epoch: 6920 Loss: 0.01185622625052929
Epoch: 6921 Loss: 0.011849803850054741
Epoch: 6922 Loss: 0.011849602684378624
Epoch: 6923 Loss: 0.011842848733067513
Epoch: 6924 Loss: 0.011841187253594398
Epoch: 6925 Loss: 0.011838323436677456
Epoch: 6926 Loss: 0.01183371152728796
Epoch: 6927 Loss: 0.011831785552203655
Epoch: 6928 Loss: 0.011828416027128696
Epoch: 6929 Loss: 0.011827508918941021
Epoch: 6930 Loss: 0.011821595020592213
Epoch: 6931 Loss: 0.011818910017609596
Epoch: 6932 Loss: 0.011815263889729977
Epoch: 6933 Loss: 0.01181390

Epoch: 7329 Loss: 0.010678734630346298
Epoch: 7330 Loss: 0.010677246376872063
Epoch: 7331 Loss: 0.010673513635993004
Epoch: 7332 Loss: 0.010670274496078491
Epoch: 7333 Loss: 0.010668925009667873
Epoch: 7334 Loss: 0.010666847229003906
Epoch: 7335 Loss: 0.010663609020411968
Epoch: 7336 Loss: 0.010661721229553223
Epoch: 7337 Loss: 0.010660112835466862
Epoch: 7338 Loss: 0.01065454538911581
Epoch: 7339 Loss: 0.010651963762938976
Epoch: 7340 Loss: 0.010650607757270336
Epoch: 7341 Loss: 0.01064769085496664
Epoch: 7342 Loss: 0.010644188150763512
Epoch: 7343 Loss: 0.010641015134751797
Epoch: 7344 Loss: 0.010641386732459068
Epoch: 7345 Loss: 0.010637298226356506
Epoch: 7346 Loss: 0.01063461322337389
Epoch: 7347 Loss: 0.010634565725922585
Epoch: 7348 Loss: 0.010628748685121536
Epoch: 7349 Loss: 0.010628625750541687
Epoch: 7350 Loss: 0.010622911155223846
Epoch: 7351 Loss: 0.010620822198688984
Epoch: 7352 Loss: 0.010621165856719017
Epoch: 7353 Loss: 0.010619184002280235
Epoch: 7354 Loss: 0.01061578

Epoch: 7750 Loss: 0.009671836160123348
Epoch: 7751 Loss: 0.009666908532381058
Epoch: 7752 Loss: 0.00966612808406353
Epoch: 7753 Loss: 0.009663299657404423
Epoch: 7754 Loss: 0.009663427248597145
Epoch: 7755 Loss: 0.00965804886072874
Epoch: 7756 Loss: 0.009657131507992744
Epoch: 7757 Loss: 0.009655500762164593
Epoch: 7758 Loss: 0.009652075357735157
Epoch: 7759 Loss: 0.009651806205511093
Epoch: 7760 Loss: 0.0096494872123003
Epoch: 7761 Loss: 0.00964576005935669
Epoch: 7762 Loss: 0.009643412195146084
Epoch: 7763 Loss: 0.009642749093472958
Epoch: 7764 Loss: 0.00964102428406477
Epoch: 7765 Loss: 0.009639257565140724
Epoch: 7766 Loss: 0.009635954163968563
Epoch: 7767 Loss: 0.009633798152208328
Epoch: 7768 Loss: 0.009631331078708172
Epoch: 7769 Loss: 0.009629553183913231
Epoch: 7770 Loss: 0.0096265384927392
Epoch: 7771 Loss: 0.00962323322892189
Epoch: 7772 Loss: 0.009622563607990742
Epoch: 7773 Loss: 0.009619858115911484
Epoch: 7774 Loss: 0.009617182426154613
Epoch: 7775 Loss: 0.00961555261164

Epoch: 8171 Loss: 0.008812904357910156
Epoch: 8172 Loss: 0.008810250088572502
Epoch: 8173 Loss: 0.008809497579932213
Epoch: 8174 Loss: 0.008806302212178707
Epoch: 8175 Loss: 0.008805260062217712
Epoch: 8176 Loss: 0.008804518729448318
Epoch: 8177 Loss: 0.008801174350082874
Epoch: 8178 Loss: 0.008800052106380463
Epoch: 8179 Loss: 0.008797317743301392
Epoch: 8180 Loss: 0.00879385694861412
Epoch: 8181 Loss: 0.008794771507382393
Epoch: 8182 Loss: 0.008791784755885601
Epoch: 8183 Loss: 0.008790872059762478
Epoch: 8184 Loss: 0.008787807077169418
Epoch: 8185 Loss: 0.00878863874822855
Epoch: 8186 Loss: 0.008784188888967037
Epoch: 8187 Loss: 0.008782368153333664
Epoch: 8188 Loss: 0.008780639618635178
Epoch: 8189 Loss: 0.008778505958616734
Epoch: 8190 Loss: 0.008779403753578663
Epoch: 8191 Loss: 0.008776476606726646
Epoch: 8192 Loss: 0.00877394899725914
Epoch: 8193 Loss: 0.008771105669438839
Epoch: 8194 Loss: 0.008770299144089222
Epoch: 8195 Loss: 0.008768292143940926
Epoch: 8196 Loss: 0.00876770

Epoch: 8593 Loss: 0.008078635670244694
Epoch: 8594 Loss: 0.008076764643192291
Epoch: 8595 Loss: 0.00807438138872385
Epoch: 8596 Loss: 0.008074554614722729
Epoch: 8597 Loss: 0.008073106408119202
Epoch: 8598 Loss: 0.008070179261267185
Epoch: 8599 Loss: 0.008067533373832703
Epoch: 8600 Loss: 0.008067985065281391
Epoch: 8601 Loss: 0.008065298199653625
Epoch: 8602 Loss: 0.008064255118370056
Epoch: 8603 Loss: 0.008063171058893204
Epoch: 8604 Loss: 0.008060683496296406
Epoch: 8605 Loss: 0.008058788254857063
Epoch: 8606 Loss: 0.008058460429310799
Epoch: 8607 Loss: 0.008056743070483208
Epoch: 8608 Loss: 0.008053973317146301
Epoch: 8609 Loss: 0.00805263128131628
Epoch: 8610 Loss: 0.008050705306231976
Epoch: 8611 Loss: 0.00805003009736538
Epoch: 8612 Loss: 0.008046949282288551
Epoch: 8613 Loss: 0.008047002367675304
Epoch: 8614 Loss: 0.008044973015785217
Epoch: 8615 Loss: 0.00804260652512312
Epoch: 8616 Loss: 0.008041645400226116
Epoch: 8617 Loss: 0.008039516396820545
Epoch: 8618 Loss: 0.008038007

Epoch: 8804 Loss: 0.007750099990516901
Epoch: 8805 Loss: 0.007748206611722708
Epoch: 8806 Loss: 0.007746836636215448
Epoch: 8807 Loss: 0.007745313923805952
Epoch: 8808 Loss: 0.007743860129266977
Epoch: 8809 Loss: 0.007742151152342558
Epoch: 8810 Loss: 0.0077410065568983555
Epoch: 8811 Loss: 0.007738983258605003
Epoch: 8812 Loss: 0.007737601641565561
Epoch: 8813 Loss: 0.007736830040812492
Epoch: 8814 Loss: 0.007735193707048893
Epoch: 8815 Loss: 0.007734062150120735
Epoch: 8816 Loss: 0.007732938975095749
Epoch: 8817 Loss: 0.007730201352387667
Epoch: 8818 Loss: 0.007728886790573597
Epoch: 8819 Loss: 0.007727320306003094
Epoch: 8820 Loss: 0.0077259354293346405
Epoch: 8821 Loss: 0.0077249412424862385
Epoch: 8822 Loss: 0.007723172660917044
Epoch: 8823 Loss: 0.007721311878412962
Epoch: 8824 Loss: 0.007719706278294325
Epoch: 8825 Loss: 0.007718480192124844
Epoch: 8826 Loss: 0.007717662025243044
Epoch: 8827 Loss: 0.0077149770222604275
Epoch: 8828 Loss: 0.007714030332863331
Epoch: 8829 Loss: 0.0

Epoch: 9223 Loss: 0.0071679516695439816
Epoch: 9224 Loss: 0.007166317664086819
Epoch: 9225 Loss: 0.00716443033888936
Epoch: 9226 Loss: 0.0071643041446805
Epoch: 9227 Loss: 0.007162617985159159
Epoch: 9228 Loss: 0.007161527872085571
Epoch: 9229 Loss: 0.007159453816711903
Epoch: 9230 Loss: 0.007158221211284399
Epoch: 9231 Loss: 0.007157512474805117
Epoch: 9232 Loss: 0.0071578421629965305
Epoch: 9233 Loss: 0.007153663318604231
Epoch: 9234 Loss: 0.007153362967073917
Epoch: 9235 Loss: 0.007152981124818325
Epoch: 9236 Loss: 0.007150046061724424
Epoch: 9237 Loss: 0.007149914279580116
Epoch: 9238 Loss: 0.007147929165512323
Epoch: 9239 Loss: 0.007147182710468769
Epoch: 9240 Loss: 0.007145698182284832
Epoch: 9241 Loss: 0.007143624592572451
Epoch: 9242 Loss: 0.007143177557736635
Epoch: 9243 Loss: 0.007142225746065378
Epoch: 9244 Loss: 0.007139974273741245
Epoch: 9245 Loss: 0.007139210123568773
Epoch: 9246 Loss: 0.00713769719004631
Epoch: 9247 Loss: 0.007136509288102388
Epoch: 9248 Loss: 0.0071355

Epoch: 9642 Loss: 0.006660243030637503
Epoch: 9643 Loss: 0.006658557336777449
Epoch: 9644 Loss: 0.00665704719722271
Epoch: 9645 Loss: 0.006656096316874027
Epoch: 9646 Loss: 0.006654903292655945
Epoch: 9647 Loss: 0.006653842516243458
Epoch: 9648 Loss: 0.006652991287410259
Epoch: 9649 Loss: 0.006651673465967178
Epoch: 9650 Loss: 0.006650335621088743
Epoch: 9651 Loss: 0.006649484857916832
Epoch: 9652 Loss: 0.006648705806583166
Epoch: 9653 Loss: 0.006646750960499048
Epoch: 9654 Loss: 0.00664652930572629
Epoch: 9655 Loss: 0.006644642446190119
Epoch: 9656 Loss: 0.00664388807490468
Epoch: 9657 Loss: 0.006642211228609085
Epoch: 9658 Loss: 0.006641430780291557
Epoch: 9659 Loss: 0.0066404580138623714
Epoch: 9660 Loss: 0.006639582570642233
Epoch: 9661 Loss: 0.0066382321529090405
Epoch: 9662 Loss: 0.006636765319854021
Epoch: 9663 Loss: 0.006635616067796946
Epoch: 9664 Loss: 0.006634220015257597
Epoch: 9665 Loss: 0.0066345687955617905
Epoch: 9666 Loss: 0.006632519885897636
Epoch: 9667 Loss: 0.00663

Epoch: 9852 Loss: 0.00642844894900918
Epoch: 9853 Loss: 0.006426795851439238
Epoch: 9854 Loss: 0.006426469888538122
Epoch: 9855 Loss: 0.006426021456718445
Epoch: 9856 Loss: 0.00642486521974206
Epoch: 9857 Loss: 0.006422880571335554
Epoch: 9858 Loss: 0.006422063335776329
Epoch: 9859 Loss: 0.00642079021781683
Epoch: 9860 Loss: 0.00642005680128932
Epoch: 9861 Loss: 0.006418760400265455
Epoch: 9862 Loss: 0.006417789962142706
Epoch: 9863 Loss: 0.006416906137019396
Epoch: 9864 Loss: 0.006415777374058962
Epoch: 9865 Loss: 0.006414301227778196
Epoch: 9866 Loss: 0.006413406226783991
Epoch: 9867 Loss: 0.006412367802113295
Epoch: 9868 Loss: 0.006411366630345583
Epoch: 9869 Loss: 0.006410631351172924
Epoch: 9870 Loss: 0.006408738438040018
Epoch: 9871 Loss: 0.006408374290913343
Epoch: 9872 Loss: 0.006406840868294239
Epoch: 9873 Loss: 0.0064062392339110374
Epoch: 9874 Loss: 0.006405849475413561
Epoch: 9875 Loss: 0.006403896491974592
Epoch: 9876 Loss: 0.006403105333447456
Epoch: 9877 Loss: 0.00640226

最后我们用训练好的模型尝试在1到100这些数字上玩FizzBuzz游戏

In [17]:
# Output now
testX = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 1001)])
with torch.no_grad():
    testY = model(testX)
predictions = zip(range(1, 1001), list(testY.max(1)[1].data.tolist()))

print([fizz_buzz_decode(i, x) for (i, x) in predictions])

['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', 'buzz', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', 'fizz', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', '84', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz', '101', 'fizz', '103', '104', 'fizzbuzz', '106', '107', 'fizz', '109', 'buzz', 'fizz', '112', '113', 'fizz', 'buzz', '116', 'fizz', '118', '119', 'fizzbuzz', '121', '122', 'fizz', '124', 'buzz', 'fizz', '127', '128', 'fizz', 'buzz', '131', 'fizz', '133', '134', 'fizzbuzz', '136', '137

In [18]:
print(np.sum(testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,1001)])))
testY.max(1)[1].numpy() == np.array([fizz_buzz_encode(i) for i in range(1,1001)])

997


array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
       False,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True, False,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,