In [None]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [1]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 724.3505859375
1 670.3387451171875
2 623.8209228515625
3 583.058349609375
4 546.6366577148438
5 513.4913330078125
6 483.0377502441406
7 454.89349365234375
8 428.8160705566406
9 404.4842529296875
10 381.7021179199219
11 360.38873291015625
12 340.2685546875
13 321.18829345703125
14 303.0686950683594
15 285.8721618652344
16 269.4504699707031
17 253.77236938476562
18 238.8438720703125
19 224.63856506347656
20 211.13441467285156
21 198.28201293945312
22 186.09686279296875
23 174.52684020996094
24 163.55003356933594
25 153.11355590820312
26 143.24098205566406
27 133.9525909423828
28 125.22652435302734
29 117.0047378540039
30 109.24691772460938
31 101.9602279663086
32 95.11880493164062
33 88.72640228271484
34 82.75401306152344
35 77.1789321899414
36 71.977783203125
37 67.13279724121094
38 62.61711502075195
39 58.41838073730469
40 54.51648712158203
41 50.88768768310547
42 47.515201568603516
43 44.391929626464844
44 41.488494873046875
45 38.7930908203125
46 36.28755187988281
47 33.96456146240

396 0.00010048761032521725
397 9.749881428433582e-05
398 9.459884313400835e-05
399 9.178585605695844e-05
400 8.905752474674955e-05
401 8.641182648716494e-05
402 8.385132969124243e-05
403 8.13594160717912e-05
404 7.89469777373597e-05
405 7.660653500352055e-05
406 7.433307473547757e-05
407 7.21293909009546e-05
408 6.999797187745571e-05
409 6.792126077925786e-05
410 6.590816337848082e-05
411 6.39643840258941e-05
412 6.206794205354527e-05
413 6.023239984642714e-05
414 5.845347186550498e-05
415 5.672735642292537e-05
416 5.50492295587901e-05
417 5.3425530495587736e-05
418 5.1846796850441024e-05
419 5.0319096772000194e-05
420 4.883572910330258e-05
421 4.739240466733463e-05
422 4.59965922345873e-05
423 4.463949881028384e-05
424 4.332458775024861e-05
425 4.205185905448161e-05
426 4.081570659764111e-05
427 3.961402762797661e-05
428 3.844917955575511e-05
429 3.7316993257263675e-05
430 3.62192586180754e-05
431 3.515621574479155e-05
432 3.412362275412306e-05
433 3.31203373207245e-05
434 3.214956814