In [0]:
import random
import torch

We implement a (weird) fully-connected ReLU network that on each forward pass chooses a random number between 1 and 4, and uses that many hidden layers; reusing the same weights multiple times to compute the innermost hidden layers.

In [0]:
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):

        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
        and reuse the middle_linear Module that many times to compute hidden layer
        representations.
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
      

In [4]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 627.1648559570312
1 624.8715209960938
2 650.8927001953125
3 627.8784790039062
4 621.5174560546875
5 608.927978515625
6 497.1159362792969
7 446.33062744140625
8 614.066650390625
9 587.6821899414062
10 302.4366760253906
11 594.8131103515625
12 592.0335693359375
13 587.6215209960938
14 581.5404052734375
15 573.7177124023438
16 564.1738891601562
17 143.7026824951172
18 534.2293701171875
19 513.8088989257812
20 591.9434204101562
21 502.76080322265625
22 426.2236328125
23 564.0747680664062
24 548.1019287109375
25 526.4710693359375
26 498.5574951171875
27 283.1812438964844
28 254.2452392578125
29 222.02203369140625
30 174.9481658935547
31 362.5265808105469
32 163.79759216308594
33 297.35772705078125
34 268.6069641113281
35 235.6332244873047
36 145.01133728027344
37 123.94227600097656
38 195.7516632080078
39 113.5147476196289
40 64.6424331665039
41 185.19960021972656
42 84.90140533447266
43 51.91461944580078
44 51.59046936035156
45 517.9658203125
46 178.0287322998047
47 289.9883117675781
48 