In [1]:
import torch

In [4]:
# 为了能够计算损失函数相对于变量的梯度，需要将他们的requires_grad属性设置为True
x = torch.ones(5)
y = torch.zeros(3)
# 或者也可以使用x.requires_grad(True)来设置
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
forward = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(forward, y)
print(loss)

tensor(3.7261, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [5]:
print(f"Gradient function for z = {forward.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

Gradient function for z = <AddBackward0 object at 0x000001362D4337F0>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x000001362D432920>


In [6]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.2497, 0.1637, 0.3333],
        [0.2497, 0.1637, 0.3333],
        [0.2497, 0.1637, 0.3333],
        [0.2497, 0.1637, 0.3333],
        [0.2497, 0.1637, 0.3333]])
tensor([0.2497, 0.1637, 0.3333])


In [7]:
# 禁用梯度跟踪可以将神经网络中的某些参数标记为冻结参数
# 在只进行前向传播时加速计算
z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

True
False


In [8]:
z = torch.matmul(x, w)+b
# detach会创建z的一个分离副本z_det，这个操作会切断z_det与计算图的联系
z_det = z.detach()
print(z_det.requires_grad)

False


In [9]:
# 梯度累加现象：若想计算正确的梯度，需要先将grad属性归零。
inp = torch.eye(4, 5, requires_grad=True)
out = (inp + 1).pow(2).T
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
