## PyTorch: Control Flow + Weight Sharing
An an example of dynamic graphs and weight sharing, we implement a very strange model: a fully-connected ReLU network that on each forward pass chooses a random number between 1 and 4 and uses that many hidden layers, reusing the same weights multiple times to compute the innermost hidden layers.

In [1]:
import random
import torch

In [2]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        '''
        In the constructor we construct three nn.Linear instances
        that we will use in the forward pass.
        '''
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        '''
        For the forward pass, we randomly choose from 0-3 and reuse
        the middle_linear Module that many times to compute hidden layer
        represtations.
        
        Since each forward pass builds a dynamic computation graph,
        we can use normal Python control-flow operators like loops
        or conditional statements when defining the forward pass of the model.
        '''
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0,3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        
        return y_pred

In [3]:
# N: batch size
# D_in: input dimension
# H: hidden dimension
# D_out: output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

In [4]:
# creating random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [5]:
# constructing our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

In [6]:
# constructing the loss function and Optimizer
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

In [9]:
# no of epochs
epochs = 500

In [10]:
for t in range(epochs):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # Zero gradients, perform a backward pass, and update the weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 165.4602813720703
1 261.2413024902344
2 96.89674377441406
3 159.33856201171875
4 251.6715087890625
5 147.5995635986328
6 184.430908203125
7 108.24604034423828
8 75.5179672241211
9 76.15865325927734
10 96.05320739746094
11 195.11940002441406
12 71.62260437011719
13 68.87356567382812
14 52.84550476074219
15 88.43640899658203
16 59.90333938598633
17 47.007904052734375
18 148.77040100097656
19 114.01355743408203
20 28.117427825927734
21 107.98941040039062
22 52.590171813964844
23 43.41320037841797
24 38.60298538208008
25 62.08235168457031
26 100.0115966796875
27 22.716875076293945
28 43.713951110839844
29 41.14851760864258
30 33.15886688232422
31 41.06728744506836
32 85.38646697998047
33 46.186893463134766
34 50.577178955078125
35 32.59864807128906
36 51.406978607177734
37 28.631372451782227
38 18.074790954589844
39 47.49089813232422
40 19.819948196411133
41 14.962172508239746
42 29.113447189331055
43 29.284706115722656
44 16.281211853027344
45 23.290889739990234
46 81.88089752197266
47 

479 0.038315415382385254
480 1.1275575160980225
481 0.08300705999135971
482 0.8971609473228455
483 0.4640955924987793
484 0.08656749874353409
485 0.4761935770511627
486 1.9717895984649658
487 0.2491835206747055
488 0.03292009234428406
489 1.4700579643249512
490 0.3536103069782257
491 0.5469667911529541
492 0.02119143307209015
493 0.7032222747802734
494 0.7734482288360596
495 0.0293220654129982
496 0.5886643528938293
497 0.40399280190467834
498 0.1285419911146164
499 0.8187296986579895
