In [19]:
import torch

def f(x):
    return torch.stack([
        torch.abs(x[0] ** 2 + torch.sin(x[1]) - 5 * x[2] + 3).pow(0.3),
        torch.abs(x[0] ** 1.6 + torch.cos(x[1]) - 8 * x[2] + 9).pow(0.7),
        torch.abs(x[0] ** 1.8 + torch.tan(x[1]) - 4 * x[2] + 6).pow(0.4)
    ])
    # return 2 * x + 1

z0 = torch.rand((3, ), requires_grad=True)
z1 = torch.rand((3, ), requires_grad=True)

z2 = z0 + f(z1)
print(z2)
z3 = z1 + f(z2)
print(z3)
z4 = z2 + f(z3)
print(z4)

loss = torch.sum(z4)
# loss = torch.sum((z4 - 5) ** 3) + 1
print(loss)

# 直接计算dz1
dz1 = torch.autograd.grad(loss, z1, retain_graph=True)[0]
print("Direct dz1:", dz1)

dz3 = torch.autograd.grad(loss, z3, retain_graph=True)[0]
print("Direct dz3:", dz3)

loss_grad_z4 = torch.autograd.grad(loss, z4, retain_graph=True, create_graph=True)[0]
print("loss_grad_z4: ", loss_grad_z4)
loss_grad_z3 = torch.autograd.grad(loss, z3, retain_graph=True, create_graph=True)[0]
print("loss_grad_z3: ", loss_grad_z3)
loss_grad_z2 = torch.autograd.grad(loss, z2, retain_graph=True, create_graph=True)[0]
print("loss_grad_z2: ", loss_grad_z2)

z4_grad_z3 = torch.autograd.functional.jacobian(lambda x: z2 + f(x), z3)
print("z3_grad_z1: ", z4_grad_z3)
z3_grad_z1 = torch.autograd.functional.jacobian(lambda x: x + f(z2), z1)
print("z3_grad_z1: ", z3_grad_z1)
z2_grad_z1 = torch.autograd.functional.jacobian(lambda x: z0 + f(x), z1)
print("z2_grad_z1: ", z2_grad_z1)
z3_grad_z2 = torch.autograd.functional.jacobian(lambda x: z1 + f(x), z2)
print("z3_grad_z2: ", z3_grad_z2)

# dz1_rec = loss_grad_z3 @ z3_grad_z1 + loss_grad_z2 @ z2_grad_z1  
dz1_rec = torch.autograd.functional.vjp(lambda x: x, inputs=z1, v=loss_grad_z3)[1] +\
                 torch.autograd.functional.vjp(lambda x: f(x), inputs=z1, v=loss_grad_z2)[1]

print("Reconstructed dz1:", dz1_rec)

dz3_rec = torch.autograd.functional.vjp(func=lambda x: f(x), inputs=z3, v=loss_grad_z4)[1]
print("Reconstructed dz3:", dz3_rec)


tensor([0.7175, 3.6058, 2.5113], grad_fn=<AddBackward0>)
tensor([2.4393, 5.5836, 2.1795], grad_fn=<AddBackward0>)
tensor([2.0482, 6.0121, 3.6615], grad_fn=<AddBackward0>)
tensor(11.7217, grad_fn=<SumBackward0>)
Direct dz1: tensor([-0.0930,  0.8862, -0.0346])
Direct dz3: tensor([-0.8729,  0.1267,  3.3168])
loss_grad_z4:  tensor([1., 1., 1.])
loss_grad_z3:  tensor([-0.8729,  0.1267,  3.3168], grad_fn=<AddBackward0>)
loss_grad_z2:  tensor([0.0734, 0.0731, 3.8192], grad_fn=<AddBackward0>)
z3_grad_z1:  tensor([[-0.7515, -0.1179,  0.7702],
        [-1.3126, -0.3094,  3.8437],
        [ 1.1913,  0.5539, -1.2972]])
z3_grad_z1:  tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
z2_grad_z1:  tensor([[ 0.9736,  1.0203, -5.1229],
        [ 0.4353, -0.0388, -3.4014],
        [ 0.1771,  0.1800, -0.7139]])
z3_grad_z2:  tensor([[-0.0891,  0.0555,  0.3105],
        [-0.4423, -0.1510,  2.6987],
        [-0.2859, -0.2591,  0.8286]])
Reconstructed dz1: tensor([-0.0930,  0.8862, -0.0346])
