In [None]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------
PyTorch：控制流 + 参数共享
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [8]:
import random
import torch
from torch.autograd import Variable


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):  
            #动态网络表现在这，每次前向计算可能都不一样，可能以下函数会计算多次，但这个东西怎么反向计算呢？
            #在这里参数都是共享的，前向计算了几次，反向计算就有多少次；
            #在原来的基础上循环前向计算相同次数的反向计算，但是更新参数只会更新一次。
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 644.113525390625
1 601.1029052734375
2 529.4418334960938
3 623.205810546875
4 624.9813232421875
5 622.7435913085938
6 314.7531433105469
7 621.8671264648438
8 245.9272918701172
9 620.7202758789062
10 619.9015502929688
11 599.138671875
12 617.7550048828125
13 132.6993865966797
14 620.1903686523438
15 617.6774291992188
16 549.537353515625
17 607.6634521484375
18 599.76220703125
19 81.80657958984375
20 75.15414428710938
21 468.3707580566406
22 563.2669067382812
23 54.04541015625
24 539.77685546875
25 44.50946044921875
26 38.32130813598633
27 498.91510009765625
28 481.95831298828125
29 566.0974731445312
30 435.88043212890625
31 406.72216796875
32 281.3657531738281
33 336.4341125488281
34 54.803340911865234
35 412.0787353515625
36 61.69672775268555
37 346.5909729003906
38 206.01133728027344
39 53.128501892089844
40 175.39865112304688
41 247.8813018798828
42 133.5842742919922
43 114.14833068847656
44 96.24835205078125
45 152.98471069335938
46 136.13735961914062
47 115.55523681640625
48 122.

420 0.6512894034385681
421 3.430307626724243
422 0.8898411989212036
423 1.7777690887451172
424 1.9297442436218262
425 1.0623478889465332
426 0.9839991331100464
427 0.6497685313224792
428 0.6333263516426086
429 1.833083987236023
430 1.4843039512634277
431 0.40622130036354065
432 0.5584956407546997
433 1.1217098236083984
434 0.3669961392879486
435 0.9342831969261169
436 0.24214166402816772
437 0.18409274518489838
438 1.1766005754470825
439 0.8138988614082336
440 0.46037307381629944
441 0.1364617645740509
442 0.6049593091011047
443 1.1329712867736816
444 0.4817863404750824
445 0.5966701507568359
446 0.8293098211288452
447 0.5952858924865723
448 0.9541667103767395
449 0.7673025131225586
450 0.4877229928970337
451 0.42036017775535583
452 1.8591797351837158
453 1.3634849786758423
454 0.790248453617096
455 1.0548418760299683
456 0.819815456867218
457 1.0562429428100586
458 0.39113539457321167
459 0.42810821533203125
460 1.3604860305786133
461 1.2254717350006104
462 0.5003488659858704
463 0.28