In [1]:
import taichi as ti
import numpy as np
import torch

ti.init(arch=ti.cpu)

# ---------------- 设置维度 ----------------
B = 2   # batch size
n = 8   # 输入维度
m = 4   # 输出维度

# ---------------- 定义 Taichi 变量 ----------------
W = ti.Matrix.field(m, n, dtype=ti.f32, shape=())
bias = ti.Vector.field(m, dtype=ti.f32, shape=())

# x[b] : b号样本的输入向量 (size=n)
x = ti.Vector.field(n, dtype=ti.f32, shape=(B,), needs_grad=True)
# y[b] : b号样本的输出向量 (size=m)
y = ti.Vector.field(m, dtype=ti.f32, shape=(B,), needs_grad=True)

# v[b] : 对应每个样本的向量 (size=m), 用于 J^T v 的乘法
v = ti.Vector.field(m, dtype=ti.f32, shape=(B,))

# 对应每个样本的结果 (J^T v)[b] (size=n)
JVP_result = ti.Vector.field(n, dtype=ti.f32, shape=(B,))

# ---------------- Kernel：前向传播 (处理 batch 中所有样本) ----------------
@ti.kernel
def compute_y_batch():
    for b_ in range(B):
        y[b_] = ti.tanh(W[None] @ x[b_] + bias[None])

# ---------------- Kernel：累加第 i 个分量的贡献到 JVP_result ----------------
@ti.kernel
def accumulate_jvp(i: ti.i32):
    for b_ in range(B):
        for j in ti.static(range(n)):
            JVP_result[b_][j] += x.grad[b_][j] * v[b_][i]

# ---------------- (1) 随机初始化数据 ----------------
torch.manual_seed(42)  # for reproducibility

W_np = torch.randn(m, n, dtype=torch.float32).numpy()
bias_np = torch.randn(m, dtype=torch.float32).numpy()
x_np = torch.randn(B, n, dtype=torch.float32).numpy()
v_np = torch.randn(B, m, dtype=torch.float32).numpy()

# 将 NumPy 数组赋值给 Taichi
W[None] = W_np
bias[None] = bias_np
for b_ in range(B):
    x[b_] = x_np[b_]
    v[b_] = v_np[b_]
    JVP_result[b_] = ti.Vector(np.zeros(n, dtype=np.float32))

# ---------------- (2) 在 Taichi 中计算 J^T v ----------------
for i in range(m):
    # 重置梯度
    for b_ in range(B):
        x.grad[b_] = ti.Vector(np.zeros(n, dtype=np.float32))
        y.grad[b_] = ti.Vector(np.zeros(m, dtype=np.float32))

    # 对输出的第 i 个分量打上 grad=1 (每个样本都打 1)
    for b_ in range(B):
        y.grad[b_][i] = 1.0

    # 前向 + 反向
    compute_y_batch()
    compute_y_batch.grad()

    # 累加到 JVP_result
    accumulate_jvp(i)

# 拉取 Taichi 结果到 NumPy
taichi_jtv = np.vstack([JVP_result[b_].to_numpy() for b_ in range(B)])  # shape (B, n)

# -------------------------------------------------------------
#       用 PyTorch 的 torch.autograd.functional.jacobian
#       构造 full Jacobian，然后与 v 做乘法
# -------------------------------------------------------------
W_torch = torch.tensor(W_np, dtype=torch.float32)   # shape (m, n)
b_torch = torch.tensor(bias_np, dtype=torch.float32)# shape (m,)
x_torch = torch.tensor(x_np, dtype=torch.float32)   # shape (B, n), no grad needed for jacobian call
v_torch = torch.tensor(v_np, dtype=torch.float32)   # shape (B, m)

def forward_fn(x_):
    """x_.shape = (B, n). Returns (B, m)."""
    return torch.tanh(x_ @ W_torch.T + b_torch)

# (3) 计算 full Jacobian: shape = (B, m, B, n)
jac = torch.autograd.functional.jacobian(forward_fn, x_torch, create_graph=False)
# 正确提取每个样本的 Jacobian 矩阵
# 原始 jac 形状为 (B, m, B, n)，其中 jac[b, :, b, :] 对应样本 b 的 Jacobian
jac_transposed = jac.transpose(1, 2)  # 形状变为 (B, B, m, n)
jac_transposed = jac_transposed.diagonal(dim1=0, dim2=1).movedim(-1, 0)  # 形状 (B, m, n)
#print(jac_transposed.shape,v_torch.shape)
# 正确计算 J^T v
torch_jtv_jacobian = torch.einsum('bmn,bm->bn', jac_transposed, v_torch).numpy()


# (4) 做对比
print("\n============================")
print("Taichi (J^T v) vs. PyTorch (full Jacobian) (J^T v)")
print("============================\n")
for b_ in range(B):
    print(f"Sample {b_}:")
    print("  Taichi :  ", taichi_jtv[b_])
    print("  PyTorch:  ", torch_jtv_jacobian[b_])
    err = np.max(np.abs(taichi_jtv[b_] - torch_jtv_jacobian[b_]))
    print(f"  Max error = {err:e}")


[Taichi] version 1.7.2, llvm 15.0.5, commit 0131dce9, osx, python 3.9.21


[I 03/17/25 16:07:16.732 205289] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


[Taichi] Starting on arch=x64

Taichi (J^T v) vs. PyTorch (full Jacobian) (J^T v)

Sample 0:
  Taichi :   [ 0.40772447  0.85730678  0.18844506  0.31153518 -0.23295417 -0.08521497
 -0.24905279  0.5076037 ]
  PyTorch:   [ 0.40772447  0.8573068   0.18844506  0.31153518 -0.23295417 -0.08521497
 -0.2490528   0.5076037 ]
  Max error = 0.000000e+00
Sample 1:
  Taichi :   [-1.31211734  0.4779858  -0.65214491 -1.96736133 -0.38303438 -0.48829293
 -0.42431751 -0.01537222]
  PyTorch:   [-1.3121173   0.4779858  -0.6521449  -1.9673613  -0.38303438 -0.48829293
 -0.4243175  -0.01537222]
  Max error = 0.000000e+00


In [2]:
import time

# ---------------- (5) 加入速度比较 ----------------
# Taichi 计算时间
start_time = time.time()
for i in range(m):
    # 重置梯度
    for b_ in range(B):
        x.grad[b_] = ti.Vector(np.zeros(n, dtype=np.float32))
        y.grad[b_] = ti.Vector(np.zeros(m, dtype=np.float32))

    # 对输出的第 i 个分量打上 grad=1 (每个样本都打 1)
    for b_ in range(B):
        y.grad[b_][i] = 1.0

    # 前向 + 反向
    compute_y_batch()
    compute_y_batch.grad()

    # 累加到 JVP_result
    accumulate_jvp(i)
taichi_time = time.time() - start_time

# PyTorch 计算时间
start_time = time.time()
jac = torch.autograd.functional.jacobian(forward_fn, x_torch, create_graph=False)
jac_transposed = jac.transpose(1, 2)  # 形状变为 (B, B, m, n)
jac_transposed = jac_transposed.diagonal(dim1=0, dim2=1).movedim(-1, 0)  # 形状 (B, m, n)
torch_jtv_jacobian = torch.einsum('bmn,bm->bn', jac_transposed, v_torch).numpy()
torch_time = time.time() - start_time

# ---------------- (6) 输出速度比较结果 ----------------
print("\n============================")
print("速度比较")
print("============================")
print(f"Taichi 计算时间: {taichi_time * 1000:.2f} ms")
print(f"PyTorch 计算时间: {torch_time * 1000:.2f} ms")
print(f"Taichi 比 PyTorch 快 {torch_time / taichi_time:.2f} 倍")


速度比较
Taichi 计算时间: 1.88 ms
PyTorch 计算时间: 2.86 ms
Taichi 比 PyTorch 快 1.53 倍
