In [1]:
import torch

In [2]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        '''
        In the constructor we instantiate two nn.Linear modules
        and assign them as member values.
        '''
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        
    def forward(self, x):
        '''
        In the forward function we accept a Tensor of input data
        and we must return a Tensor of output data. We can use 
        Modules defined in the constuctor as well as arbitrary
        operators on Tensors.
        '''
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
# N: batch size
# D_in: input dimension
# H: hidden dimension
# D_out: output dimension
N, D_in, H, D_out = 64, 1000, 100, 10

In [24]:
# creating random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

In [25]:
# instantiating the class
model = TwoLayerNet(D_in, H, D_out)

In [26]:
# loss and Optimizer function
# model.parameters() in the SGD constructor contains
# the learnable parameters of the two nn.Linear modules
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [27]:
# no. of epochs
epochs = 500

In [28]:
for t in range(epochs):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)
    
    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())
    
    # zero gradients, perform a backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0 1.1038658618927002
1 1.103738784790039
2 1.1036115884780884
3 1.1034845113754272
4 1.1033574342727661
5 1.103230357170105
6 1.1031032800674438
7 1.1029763221740723
8 1.1028493642807007
9 1.102722406387329
10 1.1025954484939575
11 1.102468490600586
12 1.102341651916504
13 1.1022148132324219
14 1.1020878553390503
15 1.1019611358642578
16 1.1018344163894653
17 1.1017078161239624
18 1.1015812158584595
19 1.1014546155929565
20 1.1013281345367432
21 1.1012015342712402
22 1.1010750532150269
23 1.1009485721588135
24 1.1008220911026
25 1.1006956100463867
26 1.100569248199463
27 1.100442886352539
28 1.1003164052963257
29 1.1001900434494019
30 1.1000638008117676
31 1.0999374389648438
32 1.09981107711792
33 1.0996848344802856
34 1.0995585918426514
35 1.099432349205017
36 1.0993062257766724
37 1.0991801023483276
38 1.099053978919983
39 1.0989279747009277
40 1.098801851272583
41 1.0986758470535278
42 1.0985498428344727
43 1.0984238386154175
44 1.0982978343963623
45 1.0981719493865967
46 1.09804594