In [0]:
import torch

## Defining custom nn module

In [0]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

## Use our custom nn module

In [3]:
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 736.8270874023438
1 684.160888671875
2 639.2748413085938
3 600.6077880859375
4 566.5045776367188
5 535.7425537109375
6 507.5771789550781
7 481.83465576171875
8 458.1763916015625
9 436.0893859863281
10 415.3513488769531
11 395.8398132324219
12 377.2782897949219
13 359.56842041015625
14 342.55572509765625
15 326.248779296875
16 310.65484619140625
17 295.6502685546875
18 281.24114990234375
19 267.3606872558594
20 253.957275390625
21 241.05027770996094
22 228.63514709472656
23 216.69972229003906
24 205.22238159179688
25 194.1946258544922
26 183.5950469970703
27 173.43321228027344
28 163.73406982421875
29 154.46853637695312
30 145.62994384765625
31 137.24935913085938
32 129.29934692382812
33 121.74323272705078
34 114.56651306152344
35 107.76819610595703
36 101.34439849853516
37 95.27494812011719
38 89.55268859863281
39 84.16302490234375
40 79.1001205444336
41 74.33307647705078
42 69.84589385986328
43 65.63115692138672
44 61.672630310058594
45 57.953224182128906
46 54.468257904052734
47 51