In [None]:
%matplotlib inline


PyTorch: Defining new autograd functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [1]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")
# dtype = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 31368516.0
1 27116342.0
2 26994252.0
3 26711260.0
4 23696660.0
5 18024618.0
6 11763796.0
7 6974952.0
8 4040564.0
9 2451887.75
10 1620544.875
11 1171665.75
12 909745.4375
13 741265.6875
14 622273.3125
15 531981.625
16 459972.59375
17 400775.65625
18 351157.3125
19 309089.28125
20 273072.5
21 242047.03125
22 215222.0
23 191908.328125
24 171558.109375
25 153713.8125
26 138014.609375
27 124164.15625
28 111907.875
29 101041.40625
30 91367.4765625
31 82750.0
32 75056.9609375
33 68176.8515625
34 62007.78515625
35 56462.8359375
36 51458.109375
37 46949.25390625
38 42878.80078125
39 39199.65625
40 35875.78125
41 32872.98046875
42 30150.13671875
43 27674.638671875
44 25423.931640625
45 23373.302734375
46 21502.181640625
47 19794.759765625
48 18235.1953125
49 16808.55859375
50 15503.115234375
51 14307.4345703125
52 13211.0234375
53 12205.919921875
54 11282.8203125
55 10434.869140625
56 9655.3740234375
57 8938.619140625
58 8278.58984375
59 7670.88232421875
60 7110.953125
61 6594.8525390625
62 61

436 6.25054890406318e-05
437 6.147770181996748e-05
438 6.0576530813705176e-05
439 5.935509761911817e-05
440 5.855140261701308e-05
441 5.7672052207635716e-05
442 5.6561890232842416e-05
443 5.56758968741633e-05
444 5.463084380608052e-05
445 5.3761563322041184e-05
446 5.270666952128522e-05
447 5.198900180403143e-05
448 5.108256664243527e-05
449 5.025399150326848e-05
450 4.944618194713257e-05
451 4.869233089266345e-05
452 4.777517824550159e-05
453 4.683215593104251e-05
454 4.6289067540783435e-05
455 4.562174581224099e-05
456 4.499458009377122e-05
457 4.425035876920447e-05
458 4.371292016003281e-05
459 4.2987914639525115e-05
460 4.2368043068563566e-05
461 4.1737275751074776e-05
462 4.121288293390535e-05
463 4.063494270667434e-05
464 3.9895061490824446e-05
465 3.942671901313588e-05
466 3.869970532832667e-05
467 3.807366010732949e-05
468 3.7536578020080924e-05
469 3.7137953768251464e-05
470 3.672583989100531e-05
471 3.6165791243547574e-05
472 3.5642475268105045e-05
473 3.512912371661514e-05
4