In [1]:
# ref
# https://pytorch.apachecn.org/2.0/tutorials/intermediate/jacobians_hessians/#google_vignette
# https://blog.csdn.net/apache/article/details/113925886
# https://github.com/Haskely/pytorch-jacobian
# https://bobondemon.github.io/2024/02/07/高效率計算-Jacobian-Hessian-VJP-JVP-HVP/

In [2]:
from functools import partial

import torch
import torch.nn.functional as F
from torch.func import jacrev, vjp, vmap

torch.manual_seed(0)

<torch._C.Generator at 0x7b2ff199cb30>

In [3]:
def predict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

# 单批量

In [4]:
D = 16

weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D)  # feature vector

ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
predict(weight, bias, x).shape, ft_jacobian.shape

(torch.Size([16]), torch.Size([16, 16]))

# 多批量

In [5]:
batch_size = 10
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)  # feature vector

compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian = compute_batch_jacobian(weight, bias, x)
predict(weight, bias, x).shape, batch_jacobian.shape

(torch.Size([10, 33]), torch.Size([10, 33, 31]))

# 多批量，多输入

In [6]:
def predict(x, x0):
    print(x.shape)
    print(x0.shape)
    out = x**3 - x0
    return x**3


batch_size = 10
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)  # feature vector
x0 = torch.randn(batch_size, Din) * 0 + 1

compute_batch_jacobian = vmap(jacrev(predict, argnums=0), in_dims=(0, 0))
batch_jacobian = compute_batch_jacobian(x, x0)

print(batch_jacobian.shape)

torch.Size([31])
torch.Size([31])
torch.Size([10, 31, 31])


In [13]:
def predict(x, x0):
    print(x.shape)
    print(x0.shape)
    eq1 = torch.zeros(31)
    out[:14] = x[:14]
    return x**3


batch_size = 10
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)  # feature vector
x0 = torch.randn(batch_size, Din) * 0 + 1

# predict_2 = partial(predict, x0=x0)
compute_batch_jacobian = vmap(jacrev(predict, argnums=0), in_dims=(0, 0))
batch_jacobian = compute_batch_jacobian(x, x0)

print(batch_jacobian.shape)

torch.Size([31])
torch.Size([31])


RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensor `other` in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but `self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.