In [None]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [1]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 684.8455810546875
1 715.6790771484375
2 680.0076904296875
3 683.5552978515625
4 677.4606323242188
5 668.28466796875
6 652.6633911132812
7 668.4595947265625
8 491.0012512207031
9 650.9712524414062
10 646.75048828125
11 657.6286010742188
12 380.71978759765625
13 346.2368469238281
14 590.6543579101562
15 626.5357666015625
16 234.5246124267578
17 614.5545043945312
18 544.4700317382812
19 521.798583984375
20 630.9016723632812
21 623.3340454101562
22 432.3409729003906
23 600.044189453125
24 582.3042602539062
25 337.1929931640625
26 534.210693359375
27 164.92201232910156
28 259.1739807128906
29 391.61859130859375
30 360.3932189941406
31 200.52752685546875
32 184.99545288085938
33 268.5693054199219
34 318.0976867675781
35 141.74195861816406
36 270.9896545410156
37 230.8187713623047
38 152.89410400390625
39 167.7162322998047
40 148.23609924316406
41 170.6865997314453
42 138.9336700439453
43 118.53469848632812
44 135.08352661132812
45 206.50604248046875
46 96.02307891845703
47 74.8746643066406

411 0.47666582465171814
412 0.09399555623531342
413 0.5593478679656982
414 0.44827011227607727
415 0.12351332604885101
416 1.18004310131073
417 0.4098416268825531
418 0.5684894323348999
419 0.8889914751052856
420 1.624813437461853
421 0.9577195048332214
422 1.0328303575515747
423 0.34034597873687744
424 2.020263433456421
425 0.4594763219356537
426 0.3217617869377136
427 0.29588648676872253
428 3.304525375366211
429 0.9785993695259094
430 3.021498918533325
431 1.5116639137268066
432 1.6282178163528442
433 1.1358128786087036
434 2.3062069416046143
435 0.9410892128944397
436 0.34163859486579895
437 4.140970706939697
438 1.2594940662384033
439 0.6974071860313416
440 1.104834794998169
441 0.9831544756889343
442 0.6095170378684998
443 0.6906258463859558
444 0.5749287605285645
445 0.9768989086151123
446 0.5293078422546387
447 0.7490023374557495
448 0.7101702094078064
449 0.5341802835464478
450 0.5009061694145203
451 0.5524845719337463
452 0.42883118987083435
453 0.29929524660110474
454 0.7004