In [2]:
import numpy as np
import os, sys
import torch
sys.path.append('/Users/zhanglige/Desktop/JP-Lab/Code/Velocity_Flow_Matching/')
import dnnlib
from training.networks import ToyMLP
#set up device first
device_name = 'cpu' #can swap this to cuda:0, etc pending on resources
device = torch.device(device_name)
print(device)
mlp = ToyMLP(dim=16, time_varying=True, n_hidden=1, w=8)
mlp.train().to(device)
print(mlp)

cpu
ToyMLP(
  (net): Sequential(
    (0): Linear(in_features=17, out_features=8, bias=True)
    (1): Tanh()
    (2): Linear(in_features=8, out_features=8, bias=True)
    (3): Tanh()
    (4): Linear(in_features=8, out_features=16, bias=True)
  )
)


In [None]:
#first, set up inputs for net 
#am choosing small batch size to make computation faster 
batch_size = 3
flat_data_dim = 16
imgs = torch.randn(batch_size, flat_data_dim).type(torch.float32).to(device)
ts = torch.rand(batch_size, device=device) 
#ok now calc Jacobian of net_out w.r.t imgs input 
#set requires_grad to True for net inputs... 
ts.requires_grad=True
imgs.requires_grad=True


mlp_jac = torch.autograd.functional.jacobian(mlp, (imgs, ts))
#print(mlp_jac[0].shape)
torch_jac = torch.sum(mlp_jac[0], dim=2) #collapse over extra batch dim, see comment above # then it would be img:(B,D',D)
nabla_imgs = torch_jac.transpose(2,1) #transpose, to get gradient
print()
print("JVP vector shape: ",mlp(imgs, ts).unsqueeze(1).shape)
print("Jacobian Transpose Shape: ", nabla_imgs.shape)#[B,D,D']
imgs_jvp = torch.einsum('bij, bjk -> bik', mlp(imgs, ts).unsqueeze(1), nabla_imgs).squeeze(1) # (B,D)
print(imgs_jvp.shape)


JVP vector shape:  torch.Size([3, 1, 16])
Jacobian Transpose Shape:  torch.Size([3, 16, 16])
torch.Size([3, 16])


In [1]:
#######################################################
#                  1) PyTorch 部分
#######################################################
import numpy as np
import os, sys
import torch

# 这里是你原本的路径设置
sys.path.append('/Users/zhanglige/Desktop/JP-Lab/Code/Velocity_Flow_Matching/')
import dnnlib
from training.networks import ToyMLP

device_name = 'cpu'  # 可以根据需要切换 'cuda:0' 等
device = torch.device(device_name)
print("Current PyTorch device:", device)

# 构造网络 (注意：dim=16, time_varying=True => 输入其实是 16 + 1 = 17 维)
mlp = ToyMLP(dim=8, time_varying=True, n_hidden=1, w=4)
mlp.train().to(device)
print("ToyMLP Structure:\n", mlp)

# 准备输入
batch_size = 2
flat_data_dim = 8
imgs = torch.randn(batch_size, flat_data_dim, device=device, dtype=torch.float32)
ts = torch.rand(batch_size, device=device, dtype=torch.float32)
# 需要求梯度
imgs.requires_grad=True
ts.requires_grad=True

# ---------------- 计算 Jacobian + 处理 + JVP ----------------
mlp_jac = torch.autograd.functional.jacobian(mlp, (imgs, ts))
# mlp_jac[0] => wrt imgs, shape通常是 [batch, out_dim, batch, in_dim]
# 根据你的注释，你是要在 dim=2 处进行 sum
torch_jac = torch.sum(mlp_jac[0], dim=2)  # => shape [B, out_dim, in_dim]
nabla_imgs = torch_jac.transpose(2, 1)   # => shape [B, in_dim, out_dim]

# 这里构造一个要与 nabla_imgs 相乘的向量
# 你使用的是 net output: mlp(imgs, ts).unsqueeze(1) => shape [B,1,D']
out_torch = mlp(imgs, ts)  # => (B,16)
print("\nJVP vector shape: ", out_torch.shape)

# 类似做法: (B,1,D') x (B,D',D) -> (B,D)
# einsum: 'bij, bjk -> bik'
imgs_jvp = torch.einsum('bij,bjk->bik', out_torch.unsqueeze(1), nabla_imgs).squeeze(1)
print("PyTorch final JVP shape: ", imgs_jvp.shape)  # => [B,D]

#######################################################
#                  2) Taichi 部分
#######################################################
import taichi as ti

use_gpu = False  # 可以改成 True 在支持的环境下使用 GPU
if use_gpu:
    ti.init(arch=ti.gpu)
else:
    ti.init(arch=ti.cpu)

B = batch_size
dim_x = 8   # imgs 维度
dim_t = 1    # ts 维度
dim_in = dim_x + dim_t  # 17
dim_h1 = 4
dim_h2 = 4
dim_out = 8

# ------------------ Taichi 字段: MLP 参数 + 输入输出 ------------------
W1 = ti.Matrix.field(dim_h1, dim_in, dtype=ti.f32, shape=())
b1 = ti.Vector.field(dim_h1, dtype=ti.f32, shape=())

W2 = ti.Matrix.field(dim_h2, dim_h1, dtype=ti.f32, shape=())
b2 = ti.Vector.field(dim_h2, dtype=ti.f32, shape=())

W3 = ti.Matrix.field(dim_out, dim_h2, dtype=ti.f32, shape=())
b3 = ti.Vector.field(dim_out, dtype=ti.f32, shape=())

# 输入: imgs + ts 合并成 (B,17)
x_in = ti.Vector.field(dim_in, dtype=ti.f32, shape=(B,), needs_grad=True)
y_out = ti.Vector.field(dim_out, dtype=ti.f32, shape=(B,), needs_grad=True)

# 用于 JVP 的向量
v = ti.Vector.field(dim_out, dtype=ti.f32, shape=(B,))
# 用于存放 x_in.grad 的结果
JVP_result = ti.Vector.field(dim_in, dtype=ti.f32, shape=(B,))

# ------------------ Taichi Forward Kernel (与 ToyMLP 相同结构) ------------------
@ti.kernel
def forward_mlp():
    for b in range(B):
        # 读入 x_in[b], shape=17
        # layer1: 17->8
        z1 = W1[None] @ x_in[b] + b1[None]
        for i in range(dim_h1):
            z1[i] = ti.tanh(z1[i])
        # layer2: 8->8
        z2 = W2[None] @ z1 + b2[None]
        for i in range(dim_h2):
            z2[i] = ti.tanh(z2[i])
        # layer3: 8->16
        z3 = W3[None] @ z2 + b3[None]
        y_out[b] = z3

# ------------------ 辅助 Kernel ------------------
@ti.kernel
def clear_gradients():
    for b in range(B):
        x_in.grad[b].fill(0.0)
        y_out.grad[b].fill(0.0)

@ti.kernel
def set_v_grad():
    for b in range(B):
        y_out.grad[b] = v[b]

@ti.kernel
def copy_grad_to_result():
    for b in range(B):
        JVP_result[b] = x_in.grad[b]

# ------------------ 最终 JVP 计算 ------------------
def taichi_compute_jvp():
    clear_gradients()
    set_v_grad()   # 将 y_out 的梯度设为 v
    forward_mlp()  # 前向
    forward_mlp.grad()  # 反向
    copy_grad_to_result()
    return JVP_result.to_numpy()

#######################################################
#       3) 参数对齐 + 结果对比
#######################################################

# (1) 将 PyTorch 中的参数拷贝到 Taichi (保证同样的 weight/bias)
with torch.no_grad():
    # layer0
    W1_np = mlp.net[0].weight.detach().cpu().numpy()  # shape=(8,17)
    b1_np = mlp.net[0].bias.detach().cpu().numpy()    # shape=(8,)
    # layer1
    W2_np = mlp.net[2].weight.detach().cpu().numpy()  # shape=(8,8)
    b2_np = mlp.net[2].bias.detach().cpu().numpy()    # shape=(8,)
    # layer2
    W3_np = mlp.net[4].weight.detach().cpu().numpy()  # shape=(16,8)
    b3_np = mlp.net[4].bias.detach().cpu().numpy()    # shape=(16,)

# 赋值到 Taichi
W1[None] = W1_np
b1[None] = b1_np
W2[None] = W2_np
b2[None] = b2_np
W3[None] = W3_np
b3[None] = b3_np

# (2) 将 PyTorch imgs + ts 合并到 Taichi x_in
# imgs.shape=[B,16], ts.shape=[B]
imgs_np = imgs.detach().cpu().numpy()
ts_np   = ts.detach().cpu().numpy().reshape(-1, 1)  # shape=(B,1)
x_in_np = np.concatenate([imgs_np, ts_np], axis=1)  # => shape [B,17]
for i in range(B):
    x_in[i] = x_in_np[i]

# (3) 构造 v: analog of PyTorch => out_torch: (B,16)
out_np = out_torch.detach().cpu().numpy()
for i in range(B):
    v[i] = out_np[i]

# (4) Taichi 前向 + JVP
forward_mlp()  # 预热
taichi_jvp = taichi_compute_jvp()  # => shape [B,17]

# (5) 对比与 PyTorch 计算的 imgs_jvp
# PyTorch里 imgs_jvp.shape=[B,16]，而 Taichi 这里是 [B,17] (包含了对 ts 的梯度)
# 如果只关心对 imgs 的梯度, 取 taichi_jvp[:, :16]
taichi_jvp_imgs = taichi_jvp[:, :16]

print("\nTaichi's JVP w.r.t. x_in (imgs + ts) has shape:", taichi_jvp.shape)
print("  => w.r.t. imgs shape is", taichi_jvp_imgs.shape)

# 我们只对比 w.r.t. imgs 的 16 维
taichi_vs_pyt = taichi_jvp_imgs - imgs_jvp.detach().cpu().numpy()

max_diff = np.max(np.abs(taichi_vs_pyt))
print(f"Compare partial gradient w.r.t. 'imgs' => max abs diff: {max_diff:.3e}")
print("Taichi JVP sample0[:5]:", taichi_jvp_imgs[0,:5])
print("PyTorch JVP sample0[:5]:", imgs_jvp.detach().cpu().numpy()[0,:5])


Current PyTorch device: cpu
ToyMLP Structure:
 ToyMLP(
  (net): Sequential(
    (0): Linear(in_features=9, out_features=4, bias=True)
    (1): Tanh()
    (2): Linear(in_features=4, out_features=4, bias=True)
    (3): Tanh()
    (4): Linear(in_features=4, out_features=8, bias=True)
  )
)

JVP vector shape:  torch.Size([2, 8])
PyTorch final JVP shape:  torch.Size([2, 8])
[Taichi] version 1.7.2, llvm 15.0.5, commit 0131dce9, osx, python 3.9.21


[I 03/18/25 20:56:54.572 787052] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


[Taichi] Starting on arch=x64


    x = ti.field(ti.f32, (4, 9)).
 See https://docs.taichi-lang.org/docs/field#matrix-size for more details.
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/zhanglige/opt/anaconda3/envs/Flow/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_foreve

: 