In [55]:
import torch

`requires_grad=True` on tensors allows to track all operations on it, such that in the end `backward()` computes all the gradients on it automatically.

In [56]:
x = torch.ones(2,2,requires_grad=True)
print(x)

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)


In [57]:
y = x + 2
print(y)

tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)


An operation was applied to x, a new tensor y was born with a `grad_fn` attribute, telling us it comes from an operation.

In [58]:
z = y * y * 2
output = z.mean()
print(output, z)

tensor(18., grad_fn=<MeanBackward0>) tensor([[18., 18.],
        [18., 18.]], grad_fn=<MulBackward0>)


Backpropagation with `backward()`:

Works only for *scalar* outputs.

In [59]:
output.backward()

In [60]:
#print gradients:
print(x.grad)


tensor([[3., 3.],
        [3., 3.]])


Can also apply `requires_grad` later on:

In [61]:
a = torch.rand(3,3)
print(a)
a = a*3/(a-1)
print(a)

print(a.requires_grad)

a.requires_grad_(True)
print(a.requires_grad)

tensor([[0.8072, 0.0287, 0.5787],
        [0.3036, 0.8935, 0.7434],
        [0.8447, 0.7086, 0.5354]])
tensor([[-12.5575,  -0.0886,  -4.1207],
        [ -1.3079, -25.1821,  -8.6913],
        [-16.3122,  -7.2959,  -3.4566]])
False
True


In [62]:
x = torch.randn(3, requires_grad=True)
print(x)

y = x*2
print(y)

#vector norm of y:
while y.data.norm() < 1000:
    print(y.data.norm())
    y = y*2

print(y)


tensor([1.0309, 0.7556, 0.7971], requires_grad=True)
tensor([2.0618, 1.5112, 1.5941], grad_fn=<MulBackward0>)
tensor(3.0126)
tensor(6.0253)
tensor(12.0506)
tensor(24.1012)
tensor(48.2023)
tensor(96.4046)
tensor(192.8092)
tensor(385.6184)
tensor(771.2368)
tensor([1055.6305,  773.7512,  816.1973], grad_fn=<MulBackward0>)


In [63]:
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)


In [64]:
x.grad

tensor([1.0240e+02, 1.0240e+03, 1.0240e-01])

Stopping tracking of `requires_grad` by either `.detach()` or wrapping code blocks into `with torch.no_grad()`:

In [65]:
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)


True
True
False


In [66]:
print(x.requires_grad)
y = x.detach()
print(y.requires_grad)
print(x.eq(y).all())


True
False
tensor(True)
