In [None]:
%matplotlib inline


PyTorch: Defining New autograd Functions
----------------------------------------

A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.

This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform
the ReLU function.



In [4]:
import torch

class MyReLU(torch.autograd.Function):
    """
    We can implement out own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward
    and backward passes which operate on Tensors.
    """
    
    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. 
        ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        return input.clamp(min=0)
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, 
        and we need to compute the gradient of the loss with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    

    
dtype = torch.float
device = torch.device("cuda:0")

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.

N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights and set gradients compute
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # To apply our Function we use Function.apply method. We alias this as 'relu'.
    relu = MyReLU.apply
    
    # Forward pass: computed predicted y using operations. We compute ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)
    
    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())
    
    # Backprop to compute gradients of w1 and w2 w.r.t loss
    loss.backward()
    
    # Update weights
    with torch.no_grad():
        w1 -=learning_rate * w1.grad
        w2 -=learning_rate * w2.grad
        
        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

0 31413648.0
1 25686550.0
2 24327268.0
3 23221008.0
4 20491104.0
5 15970276.0
6 11062328.0
7 7004372.0
8 4294336.5
9 2670113.5
10 1750279.0
11 1225514.25
12 915114.25
13 719471.25
14 587329.1875
15 491794.625
16 418853.125
17 360896.21875
18 313515.6875
19 274034.25
20 240663.65625
21 212193.875
22 187689.0625
23 166496.75
24 148108.40625
25 132061.546875
26 118025.4765625
27 105703.984375
28 94864.203125
29 85294.8671875
30 76827.46875
31 69323.7265625
32 62650.16015625
33 56706.90625
34 51401.4921875
35 46652.5390625
36 42396.54296875
37 38579.24609375
38 35146.515625
39 32056.978515625
40 29275.861328125
41 26765.068359375
42 24495.830078125
43 22441.58203125
44 20578.86328125
45 18889.59375
46 17354.28125
47 15958.3818359375
48 14688.8037109375
49 13531.3935546875
50 12475.6728515625
51 11511.451171875
52 10629.66796875
53 9823.748046875
54 9085.466796875
55 8408.5634765625
56 7787.35107421875
57 7217.30810546875
58 6693.40576171875
59 6211.77392578125
60 5768.68798828125
61 5360.4

424 0.00021186517551541328
425 0.00020707747898995876
426 0.00020306264923419803
427 0.0001988906878978014
428 0.00019451265688985586
429 0.00019030935072805732
430 0.00018619338516145945
431 0.00018223203369416296
432 0.00017831931472755969
433 0.00017484009731560946
434 0.00017141674470622092
435 0.00016742442676331848
436 0.00016401856555603445
437 0.00016054656589403749
438 0.00015697741764597595
439 0.00015374629583675414
440 0.00015092945250216872
441 0.00014823541278019547
442 0.00014468689914792776
443 0.0001423382491338998
444 0.00013918636250309646
445 0.00013679885887540877
446 0.00013423929340206087
447 0.00013149289588909596
448 0.00012952665565535426
449 0.00012710060400422662
450 0.00012436023098416626
451 0.00012233969755470753
452 0.0001198705576825887
453 0.00011792506120400503
454 0.00011636711133178324
455 0.00011442979302955791
456 0.00011222963803447783
457 0.0001101930538425222
458 0.00010782614117488265
459 0.00010644999565556645
460 0.0001047117548296228
461 0.