In [None]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [1]:
import torch


class MyReLU(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs.
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Create random Tensors for weights.
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function, we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply

    # Forward pass: compute predicted y using operations; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 42057624.0
1 41157980.0
2 38808380.0
3 30229334.0
4 18554472.0
5 9621538.0
6 4931002.5
7 2856007.0
8 1926771.625
9 1450763.0
10 1161231.125
11 959224.5
12 805862.5625
13 684087.625
14 585033.0625
15 503369.3125
16 435371.59375
17 378312.40625
18 330138.53125
19 289265.15625
20 254411.703125
21 224492.3125
22 198684.609375
23 176357.078125
24 156968.765625
25 140073.453125
26 125316.734375
27 112384.96875
28 101011.2421875
29 90980.5390625
30 82110.4921875
31 74237.546875
32 67236.9140625
33 60995.9296875
34 55423.4140625
35 50434.875
36 45960.8984375
37 41940.25
38 38321.91796875
39 35058.03515625
40 32109.81640625
41 29445.36328125
42 27035.888671875
43 24855.162109375
44 22871.451171875
45 21067.01953125
46 19423.40625
47 17925.0625
48 16555.56640625
49 15303.4560546875
50 14157.205078125
51 13107.7880859375
52 12144.8076171875
53 11260.240234375
54 10447.50390625
55 9700.0859375
56 9013.1083984375
57 8381.0751953125
58 7798.0283203125
59 7259.80908203125
60 6762.80859375
61 6303.3

416 0.0004680791462305933
417 0.0004555700288619846
418 0.00044283716124482453
419 0.0004308807256165892
420 0.0004197129455860704
421 0.0004094038449693471
422 0.0003987450909335166
423 0.0003876851696986705
424 0.00037719920510426164
425 0.0003684620023705065
426 0.00035879481583833694
427 0.00034950030385516584
428 0.0003407062613405287
429 0.00033312157029286027
430 0.00032627591281197965
431 0.0003183223307132721
432 0.0003100494504906237
433 0.00030293010058812797
434 0.00029629163327626884
435 0.00028917062445543706
436 0.000282198452623561
437 0.0002765730023384094
438 0.0002694855793379247
439 0.00026313646230846643
440 0.000256833533057943
441 0.00025161702069453895
442 0.00024545230553485453
443 0.0002403398248134181
444 0.00023494652123190463
445 0.00022974605963099748
446 0.00022465288930106908
447 0.0002197691355831921
448 0.0002146622573491186
449 0.00021044549066573381
450 0.0002058181562460959
451 0.00020202450104989111
452 0.00019731292559299618
453 0.0001933701714733