在训练神经网络时，最常用的算法是反向传播。在该算法中，参数(模型权值)根据损失函数相对于给定参数的梯度进行调整。

In [1]:
import torch

x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

In [3]:
# 后向传播函数引用
print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

Gradient function for z = <AddBackward0 object at 0x00000158561B3E80>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x00000158561B34E0>


In [4]:
# 调用loss.backward()计算梯度
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.1560, 0.2988, 0.2274],
        [0.1560, 0.2988, 0.2274],
        [0.1560, 0.2988, 0.2274],
        [0.1560, 0.2988, 0.2274],
        [0.1560, 0.2988, 0.2274]])
tensor([0.1560, 0.2988, 0.2274])


由于性能原因，我们只能在给定的图上使用一次后向传播梯度计算。如果需要对同一个图进行几个`backward`调用，则需要将retain_graph=True传递给`backward`。

In [5]:
# 关闭梯度追踪
z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b

print(z.requires_grad)

True
False


In [6]:
# 另一种方法：使用detach()方法
z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

False


In [8]:
inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"First call\n{inp.grad}")
# 当进行反向传播时，PyTorch会累加梯度
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])


**不带参数的backward()相当于backward(torch.tensor(1,0))**