In [None]:
%matplotlib inline


PyTorch: nn
-----------

A fully-connected ReLU network with one hidden layer, trained to predict y from x
by minimizing squared Euclidean distance.

This implementation uses the nn package from PyTorch to build the network.
PyTorch autograd makes it easy to define computational graphs and take gradients,
but raw autograd can be a bit too low-level for defining complex neural networks;
this is where the nn package can help. The nn package defines a set of Modules,
which you can think of as a neural network layer that has produces output from
input and may have some trainable weights.



In [4]:
import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. 
# nn.Sequential is a Module which contains other Modules, and applies them in sequence to produce its output.
# Each Linear Module computes output from input using a linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this case we will use MSE as loss function
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass
    y_pred = model(x)
    
    # Compute and print loss
    loss = loss_fn(y_pred, y)
    print(t, loss.item())
    
    # Zero the gradients before running the backward pass
    model.zero_grad()
    
    # Backward pass
    loss.backward()
    
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 689.7198486328125
1 637.0250244140625
2 591.122802734375
3 550.8493041992188
4 514.6917724609375
5 482.4339904785156
6 452.9874572753906
7 426.0262451171875
8 401.0783996582031
9 378.09686279296875
10 356.7677001953125
11 336.6618347167969
12 317.73712158203125
13 299.9366760253906
14 283.01104736328125
15 266.9770202636719
16 251.75967407226562
17 237.33480834960938
18 223.65911865234375
19 210.67059326171875
20 198.38137817382812
21 186.75399780273438
22 175.73867797851562
23 165.32000732421875
24 155.4608612060547
25 146.13311767578125
26 137.29576110839844
27 128.9393768310547
28 121.04786682128906
29 113.60173797607422
30 106.58842468261719
31 99.97036743164062
32 93.74945068359375
33 87.89073181152344
34 82.38420867919922
35 77.21085357666016
36 72.35086059570312
37 67.7947769165039
38 63.52042007446289
39 59.52016067504883
40 55.776756286621094
41 52.2702751159668
42 48.9859619140625
43 45.907432556152344
44 43.03394317626953
45 40.3491096496582
46 37.832305908203125
47 35.481

393 3.556175215635449e-05
394 3.451555676292628e-05
395 3.3502623409731314e-05
396 3.251800808357075e-05
397 3.1567327823722735e-05
398 3.0641636840300635e-05
399 2.974254311993718e-05
400 2.887102345994208e-05
401 2.8028844099026173e-05
402 2.72083307208959e-05
403 2.6416772016091272e-05
404 2.5643870685598813e-05
405 2.489598591637332e-05
406 2.4171344193746336e-05
407 2.3468721337849274e-05
408 2.278597094118595e-05
409 2.211954051745124e-05
410 2.1478783310158178e-05
411 2.08561268664198e-05
412 2.0249286535545252e-05
413 1.966334275493864e-05
414 1.9092722141067497e-05
415 1.8538295989856124e-05
416 1.8003765944740735e-05
417 1.748215072439052e-05
418 1.6977950508589856e-05
419 1.648815486987587e-05
420 1.6011759726097807e-05
421 1.5549920135526918e-05
422 1.5102520592336077e-05
423 1.466716275899671e-05
424 1.4244887097447645e-05
425 1.3834629498887807e-05
426 1.3435971595754381e-05
427 1.3050308552919887e-05
428 1.2676277037826367e-05
429 1.2311090358707588e-05
430 1.19580799946