In [1]:
%matplotlib inline


PyTorch: Custom nn Modules
--------------------------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.



In [2]:
import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 686.9542846679688
1 634.451171875
2 589.1823120117188
3 549.3012084960938
4 513.879150390625
5 481.9059143066406
6 453.0523376464844
7 426.7938232421875
8 402.4256286621094
9 379.7935791015625
10 358.55865478515625
11 338.73822021484375
12 320.1229248046875
13 302.5858459472656
14 286.0292053222656
15 270.29901123046875
16 255.3907470703125
17 241.27320861816406
18 227.87142944335938
19 215.20631408691406
20 203.12429809570312
21 191.617919921875
22 180.68690490722656
23 170.33123779296875
24 160.54180908203125
25 151.21832275390625
26 142.38148498535156
27 134.03012084960938
28 126.15363311767578
29 118.71261596679688
30 111.66876983642578
31 105.02369689941406
32 98.76378631591797
33 92.87251281738281
34 87.2999267578125
35 82.0508804321289
36 77.10986328125
37 72.46900939941406
38 68.11077117919922
39 64.02490997314453
40 60.1901969909668
41 56.58099365234375
42 53.19809341430664
43 50.016998291015625
44 47.03483200073242
45 44.23588180541992
46 41.61253356933594
47 39.13805007934

403 1.782151412044186e-05
404 1.7253185433219187e-05
405 1.6704034351278096e-05
406 1.6173546100617386e-05
407 1.5659226846764795e-05
408 1.5160311704676133e-05
409 1.4679354535473976e-05
410 1.4213873328117188e-05
411 1.3761885384155903e-05
412 1.3324894098332152e-05
413 1.2902564776595682e-05
414 1.2494057955336757e-05
415 1.2098275874450337e-05
416 1.171531494037481e-05
417 1.134414560510777e-05
418 1.0985183507727925e-05
419 1.0639475476637017e-05
420 1.0302263945050072e-05
421 9.976946785172913e-06
422 9.662579032010399e-06
423 9.357478120364249e-06
424 9.062619028554764e-06
425 8.776893992035184e-06
426 8.500183866999578e-06
427 8.232649634010158e-06
428 7.973510946612805e-06
429 7.722505870333407e-06
430 7.479656233044807e-06
431 7.243721029226435e-06
432 7.016451036179205e-06
433 6.796541129006073e-06
434 6.583052254427457e-06
435 6.376568308041897e-06
436 6.176441274874378e-06
437 5.9827498262166046e-06
438 5.795832294097636e-06
439 5.6143567235267255e-06
440 5.438695097836899