To showcase the power of PyTorch dynamic graphs, we will implement a very strange model: a fully-connected ReLU network that on each forward pass randomly chooses a number between 1 and 4 and has that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers.

In [2]:
import torch
import random

In [3]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we construct three nn.Linear instances that we will use
        in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.

        Since each forward pass builds a dynamic computation graph, we can use normal
        Python control-flow operators like loops or conditional statements when
        defining the forward pass of the model.

        Here we also see that it is perfectly safe to reuse the same Module many
        times when defining a computational graph. This is a big improvement from Lua
        Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [4]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [5]:
# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

In [6]:
# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 637.8958129882812
1 636.5549926757812
2 632.8723754882812
3 614.2298583984375
4 627.6593017578125
5 641.2145385742188
6 621.7235717773438
7 618.8411254882812
8 610.5510864257812
9 613.4346313476562
10 489.9781188964844
11 448.3780822753906
12 608.1876831054688
13 349.1754150390625
14 541.2648315429688
15 596.7841186523438
16 603.661376953125
17 601.8534545898438
18 584.8062744140625
19 577.8792724609375
20 493.40399169921875
21 557.2153930664062
22 581.8602294921875
23 573.761962890625
24 125.92034912109375
25 553.4913330078125
26 540.7816772460938
27 104.73466491699219
28 364.9103088378906
29 85.4630126953125
30 480.9323425292969
31 66.40955352783203
32 390.08428955078125
33 274.93994140625
34 408.7766418457031
35 231.0667266845703
36 303.07470703125
37 336.3463439941406
38 90.8254165649414
39 84.40387725830078
40 64.11050415039062
41 156.04241943359375
42 282.8420104980469
43 116.72854614257812
44 32.92977523803711
45 271.66632080078125
46 94.31676483154297
47 41.22771453857422
48 