In [1]:
import torch

In [10]:
x = torch.tensor(-2,dtype=torch.float, requires_grad=True)
y = torch.tensor(5,dtype=torch.float, requires_grad=True)
z = torch.tensor(-4,dtype=torch.float, requires_grad=True)

In [11]:
print (x,y,z)

tensor(-2., requires_grad=True) tensor(5., requires_grad=True) tensor(-4., requires_grad=True)


In [12]:
# perform the forward pass
q = x + y # q becomes 3
f = q * z # f becomes -12

In [13]:
print (q,f)

tensor(3., grad_fn=<AddBackward0>) tensor(-12., grad_fn=<MulBackward0>)


In [17]:
f.backward()

In [18]:
print (x.grad)

tensor(-4.)


In [19]:
print (y.grad)

tensor(-4.)


In [20]:
# perform the backward pass (backpropagation) in reverse order:
# first backprop through f = q * z
dfdz = q # df/dz = q, so gradient on z becomes 3
dfdq = z # df/dq = z, so gradient on q becomes -4
# now backprop through q = x + y
dfdx = 1.0 * dfdq # dq/dx = 1. And the multiplication here is the chain rule!
dfdy = 1.0 * dfdq # dq/dy = 1

print (dfdx, dfdy)

tensor(-4., grad_fn=<MulBackward0>) tensor(-4., grad_fn=<MulBackward0>)


In [21]:
print (z.grad)

tensor(3.)


In [30]:
from torch.autograd import Variable

In [80]:
W = torch.randn(5, 10,dtype=torch.float, requires_grad=True)
X = torch.randn(10, 3,dtype=torch.float, requires_grad=True)
D = W @ X
P = D.mean()

In [81]:
print (P)

tensor(-1.3501, grad_fn=<MeanBackward0>)


In [82]:
P.backward()

In [83]:
print (W.grad.shape)
print (X.grad.shape)

torch.Size([5, 10])
torch.Size([10, 3])


In [103]:
W = torch.randn(5, 10,dtype=torch.float, requires_grad=True)
X = torch.randn(10, 3,dtype=torch.float, requires_grad=True)
D = W @ X

In [106]:
explicit_grad_to_D = torch.randn(5,3)

In [107]:
D.backward(explicit_grad_to_D)

In [108]:
print (X.grad)

tensor([[-2.0097,  1.8341,  0.7918],
        [-0.8307,  2.2695, -1.6885],
        [-0.8939, -2.0668, -5.4994],
        [ 3.7884, -0.9896,  2.6973],
        [-3.4458,  2.0048, -4.2438],
        [ 1.7317, -1.9592, -0.7908],
        [ 0.1546,  2.1604,  2.1281],
        [-1.5641,  1.6967,  3.8119],
        [-0.4025, -0.6474, -3.5855],
        [-0.5514,  2.3924, -0.7945]])
