In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [2]:
import random
import torch
from torch.autograd import Variable


class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred


In [3]:


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.data[0])

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 620.4891357421875
1 618.0677490234375
2 648.7521362304688
3 614.8499145507812
4 610.2169799804688
5 619.343994140625
6 603.6087646484375
7 607.4952392578125
8 606.0618896484375
9 604.703369140625
10 593.4639282226562
11 585.6552734375
12 588.67578125
13 394.923828125
14 599.9515991210938
15 580.244140625
16 597.57861328125
17 546.2041625976562
18 593.6824951171875
19 563.4633178710938
20 556.8660888671875
21 266.631103515625
22 579.8040161132812
23 530.8768310546875
24 463.6202392578125
25 506.1432800292969
26 489.68475341796875
27 169.88270568847656
28 152.59906005859375
29 130.2489013671875
30 417.3656005859375
31 335.2411193847656
32 468.54718017578125
33 71.90532684326172
34 65.40592956542969
35 260.2340393066406
36 239.87025451660156
37 291.664794921875
38 62.372337341308594
39 172.6224822998047
40 153.015625
41 72.78180694580078
42 67.8285903930664
43 107.8554458618164
44 95.7746810913086
45 278.8534240722656
46 256.7251892089844
47 230.15737915039062
48 205.0227508544922
49 62

389 0.3545897901058197
390 1.0907611846923828
391 0.9372134208679199
392 0.6535137295722961
393 0.982441246509552
394 0.48749077320098877
395 0.7933921217918396
396 0.8023062944412231
397 1.0111744403839111
398 0.5391867160797119
399 0.30734434723854065
400 0.1504114866256714
401 0.14359991252422333
402 0.23422028124332428
403 6.473554611206055
404 0.7285903692245483
405 1.0780205726623535
406 0.5025356411933899
407 0.6904423236846924
408 9.163646697998047
409 5.190720558166504
410 3.1262121200561523
411 2.918837547302246
412 8.598191261291504
413 5.3283257484436035
414 1.7247415781021118
415 1.7734147310256958
416 3.070840358734131
417 11.406994819641113
418 33.56575393676758
419 5.490185737609863
420 10.577467918395996
421 32.08915328979492
422 13.71599006652832
423 8.813142776489258
424 0.8618404269218445
425 8.968876838684082
426 9.3276948928833
427 4.003738880157471
428 16.00235366821289
429 4.429225921630859
430 7.274837017059326
431 2.82851505279541
432 5.830391883850098
433 19.