In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [3]:
import torch

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return a Tensor of output data. We can use
        modules defined in the constructor as well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred
    
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
# nn.Sequential is a Module which contains other Modules, and applies them in sequence to produce its output.
# Each Linear Module computes output from input using a linear function, and holds internal Tensors for its weight and bias.
model = TwoLayerNet(D_in, H, D_out)

# Construct out loss functions and Optimizer. The call to model.parameters() in the SGD constructor will contain the
# learnable parameters of the two nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # Zero the gradients before running the backward pass
    model.zero_grad()
    
    # Backward pass
    loss.backward()
    
    # Calling the step function on an Optimizer makes and update to its parameters
    optimizer.step()

0 684.6431884765625
1 636.395263671875
2 594.4322509765625
3 557.5004272460938
4 524.3043823242188
5 494.3757019042969
6 467.4385681152344
7 442.8833312988281
8 420.12445068359375
9 399.0678405761719
10 379.4973449707031
11 361.17291259765625
12 343.9154357910156
13 327.60601806640625
14 312.07391357421875
15 297.2497253417969
16 283.15191650390625
17 269.76458740234375
18 256.93353271484375
19 244.6273956298828
20 232.84896850585938
21 221.56820678710938
22 210.75926208496094
23 200.388427734375
24 190.4715118408203
25 180.98196411132812
26 171.9142303466797
27 163.24472045898438
28 154.96888732910156
29 147.08094787597656
30 139.5520782470703
31 132.36215209960938
32 125.51335906982422
33 118.984130859375
34 112.77422332763672
35 106.87351989746094
36 101.2674331665039
37 95.93646240234375
38 90.88064575195312
39 86.08831787109375
40 81.53853607177734
41 77.2236328125
42 73.12403869628906
43 69.23712921142578
44 65.55020141601562
45 62.05568313598633
46 58.74513626098633
47 55.612339

377 6.810043851146474e-05
378 6.581954221474007e-05
379 6.36159093119204e-05
380 6.148531247163191e-05
381 5.94327284488827e-05
382 5.744309237343259e-05
383 5.551880894927308e-05
384 5.366704863263294e-05
385 5.1876693760277703e-05
386 5.014684938942082e-05
387 4.847067248192616e-05
388 4.6854263928253204e-05
389 4.529393481789157e-05
390 4.3785639718407765e-05
391 4.232787250657566e-05
392 4.091809751116671e-05
393 3.955721331294626e-05
394 3.8243149901973084e-05
395 3.697204010677524e-05
396 3.5745437344303355e-05
397 3.45579901477322e-05
398 3.3412183256587014e-05
399 3.23055028275121e-05
400 3.123383430647664e-05
401 3.020002259290777e-05
402 2.9198297852417454e-05
403 2.8231630494701676e-05
404 2.73004116024822e-05
405 2.6396675821160898e-05
406 2.5525454475427978e-05
407 2.4680673959665e-05
408 2.3867523850640282e-05
409 2.3079315724316984e-05
410 2.2317219190881588e-05
411 2.1580985048785806e-05
412 2.0871508240816183e-05
413 2.018285886151716e-05
414 1.9519457055139355e-05
415