In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [6]:
import random
import torch


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 685.166748046875
1 716.2672119140625
2 669.1483154296875
3 666.2140502929688
4 649.7866821289062
5 663.4417724609375
6 662.125732421875
7 515.1488037109375
8 474.4584045410156
9 418.6979675292969
10 651.2013549804688
11 604.246826171875
12 273.0561828613281
13 230.65745544433594
14 642.33447265625
15 654.3265991210938
16 631.9901733398438
17 623.1936645507812
18 610.7982788085938
19 95.9649658203125
20 517.8375854492188
21 564.3501586914062
22 87.16880798339844
23 446.4664611816406
24 504.5140686035156
25 379.14117431640625
26 448.0452880859375
27 412.2851257324219
28 372.36846923828125
29 248.2703399658203
30 220.92465209960938
31 281.49908447265625
32 377.87957763671875
33 336.4659423828125
34 179.73191833496094
35 243.58837890625
36 117.1220703125
37 194.98748779296875
38 143.90576171875
39 149.97337341308594
40 217.8767852783203
41 113.43510437011719
42 109.23954010009766
43 85.36133575439453
44 288.4894714355469
45 348.6808776855469
46 207.8501739501953
47 225.65953063964844
48 

446 0.045957356691360474
447 0.9381728172302246
448 0.3522927165031433
449 0.64239901304245
450 0.5884502530097961
451 0.3817349374294281
452 0.08661539107561111
453 0.3979063630104065
454 0.6674526333808899
455 0.13075731694698334
456 0.34571757912635803
457 0.1751527488231659
458 0.41949495673179626
459 0.6893526315689087
460 0.3510071039199829
461 0.49145132303237915
462 0.4775811433792114
463 0.5453117489814758
464 0.47334325313568115
465 0.3833746314048767
466 0.5152800679206848
467 0.2873888909816742
468 0.493534117937088
469 0.08340216428041458
470 0.5045223832130432
471 0.08799059689044952
472 0.0833733007311821
473 0.0680755078792572
474 0.4551537334918976
475 0.507404088973999
476 0.38691529631614685
477 0.2923770546913147
478 0.323481023311615
479 0.5142210125923157
480 0.2810901999473572
481 0.08804640173912048
482 0.2459821254014969
483 0.23974791169166565
484 0.44389182329177856
485 0.17990386486053467
486 0.1565612405538559
487 0.545596182346344
488 0.1389005333185196
48