In [32]:
import torch

#backward:与tensor.backward效果一样
#计算计算图中叶子节点的梯度
x = torch.tensor([1.,2.,3.],requires_grad=True)
out = x.pow(2).mul(3)
out = out.sum()
torch.autograd.backward(out)
print(x.grad)

#grad: 计算输出对输入的各自梯度的综合
#grad_outputs参数理解为在导数外的乘积,必须的参数

x =  torch.tensor([1.,2.,3.],requires_grad=True)
y = x.pow(2)
grad_output = torch.ones_like(y)  # 初始化梯度

z = torch.autograd.grad((y,), (x,), grad_outputs=grad_output)
print(z)

x =  torch.tensor([1.,2.,3.],requires_grad=True)
line = torch.nn.Linear(3,2)
y = line(x)
grad_output = torch.ones_like(y)  # 初始化梯度

z = torch.autograd.grad((y,), (x,), grad_outputs=grad_output)
print(z)

tensor([ 6., 12., 18.])
(tensor([2., 4., 6.]),)
tensor([-1.0743, -0.7277], grad_fn=<AddBackward0>)
(tensor([ 0.6039,  0.2139, -0.9551]),)


<font face = '黑体' font size = 5>Context method mixins


In [None]:
# ctx对象可以调用的方法

#ctx.mark_non_differentiable()
#标记前向过程中output不需要计算梯度的部分
class Func(Function):
    @staticmethod
    def forward(ctx, x):
        sorted, idx = x.sort()
        ctx.mark_non_differentiable(idx)
        ctx.save_for_backward(x, idx)
        return sorted, idx
    @staticmethod
    @once_differentiable
    def backward(ctx, g1, g2):  # still need to accept g2
        x, idx = ctx.saved_tensors
        grad_input = torch.zeros_like(x)
        grad_input.index_add_(0, idx, g1)
        return grad_input
    
#ctx.save_for_backward()
#保存在反向过程中需要用到的tensor
#非tensor可以直接保存为变量
#在backward中使用ctx.saved_tensors调用
class Func(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
        w = x * z
        out = x * y + y * z + w * y
        ctx.save_for_backward(x, y, w, out)
        ctx.z = z  # z is not a tensor
        return out
    @staticmethod
    @once_differentiable
    def backward(ctx, grad_out):
        x, y, w, out = ctx.saved_tensors
        z = ctx.z
        gx = grad_out * (y + y * z)
        gy = grad_out * (x + z + w)
        gz = None
        return gx, gy, gz
a = torch.tensor(1., requires_grad=True, dtype=torch.double)
b = torch.tensor(2., requires_grad=True, dtype=torch.double)
c = 4
d = Func.apply(a, b, c)