In [1]:
#损失函数是计算标签值和预测值之间差异的函数，损失Loss是标量
#autograd根据损失函数的梯度自动优化模型参数，记录过程梯度
import torch

x = torch.ones(5)  # 输入张量
y = torch.zeros(3)  # 期待输出，即实际标签值
# requires_grad=True在计算中保留对应的梯度信息
w = torch.randn(5, 3, requires_grad=True) #参数模型权重weight
b = torch.randn(3, requires_grad=True) #参数模型训练期间优化的偏差bias
z = torch.matmul(x, w)+b #x*w+b，z是实际输出，即预测值
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) #损失函数

In [2]:
# grad_fn用于记录变量是怎么来的
print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}\n")

Gradient function for z = <AddBackward0 object at 0x0000020FD61735C8>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x0000020FD6173448>



In [3]:
#计算损失函数相对于两个参数的导数，如loss=w*b，则求出来的w.grad为b
loss.backward() 
print(w.grad)
print(b.grad)

tensor([[0.0931, 0.1116, 0.1795],
        [0.0931, 0.1116, 0.1795],
        [0.0931, 0.1116, 0.1795],
        [0.0931, 0.1116, 0.1795],
        [0.0931, 0.1116, 0.1795]])
tensor([0.0931, 0.1116, 0.1795])


In [5]:
#禁用梯度跟踪
# 原因1：将神经网络中的某些参数标记为冻结参数。微调预训练网络的常见场景
# 在只做正向传递时加快计算速度
#方法1
z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)
#方法2
z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

True
False
False


In [10]:
inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}") # 此时是2*(inp+1)
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}") # 梯度累积2*(inp+1)*2
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
