In [1]:
%matplotlib inline


PyTorch: Control Flow + Weight Sharing
--------------------------------------

To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.



In [6]:
import random
import torch

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate three nn.Linear modules that we will use in the forward pass.
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2 or 3 and reuse the middle_linear Module
        that many times to compute hidden layer representations.
        
        Since each forward pass builds a dynamic computation graph, we can use normal Python control-flow operators 
        like loops or conditional statements when defining the forward pass of the model.
        
        Here we also see that it is perfectly safe to reuse the same Module many times when defining a computational graph.
        This is a big improvement from Lua Torch, where each Module could be used only once.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0,3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
        
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
# nn.Sequential is a Module which contains other Modules, and applies them in sequence to produce its output.
# Each Linear Module computes output from input using a linear function, and holds internal Tensors for its weight and bias.
model = DynamicNet(D_in, H, D_out)

# Construct out loss functions and Optimizer. 
# Training this strange model with vanilla SGD is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # Zero the gradients before running the backward pass
    model.zero_grad()
    
    # Backward pass
    loss.backward()
    
    # Calling the step function on an Optimizer makes and update to its parameters
    optimizer.step()

0 686.158935546875
1 645.3951416015625
2 636.0466918945312
3 638.6068725585938
4 637.2940673828125
5 508.0
6 626.281494140625
7 623.15478515625
8 606.9229736328125
9 630.5114135742188
10 612.5642700195312
11 342.9231262207031
12 314.528076171875
13 601.3003540039062
14 244.7935791015625
15 207.2559051513672
16 562.058349609375
17 618.7696533203125
18 116.0136489868164
19 522.2223510742188
20 82.94682312011719
21 559.734619140625
22 66.45045471191406
23 435.85894775390625
24 404.7574157714844
25 67.23566436767578
26 492.2685546875
27 300.6871643066406
28 442.68426513671875
29 524.11328125
30 117.36742401123047
31 188.24317932128906
32 110.45944213867188
33 90.54247283935547
34 63.07945251464844
35 139.38507080078125
36 127.74549865722656
37 248.20631408691406
38 339.3763122558594
39 307.3236999511719
40 270.86614990234375
41 52.97814178466797
42 88.54129791259766
43 51.17705154418945
44 224.05467224121094
45 96.63970184326172
46 110.48670196533203
47 64.12554168701172
48 318.92840576171

382 1.7385120391845703
383 0.6553234457969666
384 0.42818766832351685
385 1.0637290477752686
386 2.326870918273926
387 1.3079681396484375
388 1.8601759672164917
389 1.1364649534225464
390 0.9749656915664673
391 0.7625821828842163
392 0.8831744194030762
393 0.8885843753814697
394 0.7702021598815918
395 0.885308563709259
396 0.7721874713897705
397 0.39400985836982727
398 0.6588807106018066
399 0.761936604976654
400 1.8723974227905273
401 0.41867193579673767
402 1.9484823942184448
403 1.6181418895721436
404 0.34289440512657166
405 1.4806108474731445
406 0.39685243368148804
407 0.40670090913772583
408 0.35830971598625183
409 0.2784680724143982
410 0.6272037625312805
411 0.16187551617622375
412 0.5833662748336792
413 1.1715284585952759
414 0.23817725479602814
415 0.2820730209350586
416 0.26332560181617737
417 0.736952543258667
418 0.1797962635755539
419 0.4604541063308716
420 0.8670156002044678
421 0.7974526286125183
422 0.2068466693162918
423 0.6982595324516296
424 0.45856165885925293
425 