A fully-connected ReLU network with one hidden layer, trained to predict y from x by minimizing squared Euclidean distance.

This implementation defines the model as a custom Module subclass. Whenever you want a model more complex than a simple sequence of existing Modules you will need to define your model this way.

In [1]:
import torch

In [2]:

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [4]:
# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

In [6]:
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 680.48876953125
1 627.3369750976562
2 581.5759887695312
3 541.5928344726562
4 506.42340087890625
5 475.12841796875
6 446.62420654296875
7 420.659912109375
8 396.8742980957031
9 374.8042907714844
10 354.3362731933594
11 335.12823486328125
12 316.979736328125
13 299.9565124511719
14 283.8301086425781
15 268.49676513671875
16 253.98379516601562
17 240.285400390625
18 227.17185974121094
19 214.6721954345703
20 202.81863403320312
21 191.56460571289062
22 180.8815460205078
23 170.7073211669922
24 161.054931640625
25 151.89430236816406
26 143.2033233642578
27 134.95925903320312
28 127.14183044433594
29 119.71399688720703
30 112.69258880615234
31 106.04601287841797
32 99.7802734375
33 93.86283874511719
34 88.29151153564453
35 83.02970886230469
36 78.07930755615234
37 73.41121673583984
38 69.01280212402344
39 64.86299133300781
40 60.959232330322266
41 57.299598693847656
42 53.86094665527344
43 50.63246154785156
44 47.604713439941406
45 44.76101303100586
46 42.09214782714844
47 39.571434020996