# check the operation separately without the forward and backward pass

In [1]:
import numpy as np
import warp as wp
import torch
import matplotlib.pyplot as plt
import os
import os.path as osp

In [2]:
from datetime import datetime
today = datetime.today().strftime('%Y%m%d')
log_dir = f"logs/{today}"
os.makedirs(log_dir, exist_ok=True)

data_path = "/data/ruihan/projects/PhysDreamer/physdreamer/warp_mpm/logs/20250321"
device = "cuda:0"

In [3]:
# load the grid_v_in data from torch and warp respectively
grid_v_in_torch = np.load(osp.join(data_path, "grid_v_in_torch.npz"))["grid_v_in"]
grid_v_in_torch = torch.from_numpy(grid_v_in_torch).to(device)
print(f"check grid_v_in_torch shape {grid_v_in_torch.shape}, dtype {grid_v_in_torch.dtype}")


check grid_v_in_torch shape torch.Size([5, 5, 5, 3]), dtype torch.float64


In [25]:
wp.init()

@wp.kernel
def sum_grid_v_in(
    grid_v_in: wp.array(dtype=wp.vec3d),
    loss: wp.array(dtype=wp.float64)  
):
    """
    Compute how grid velocity `grid_v_in` contributes to the loss gradient.
    """

    i, j, k = wp.tid()  # Thread index for grid
    


    # Fetch the current grid velocity
    grid_v_tensor = grid_v_in[i, j, k]

    # Accumulate the loss gradient contribution
    print(i)
    wp.atomic_add(loss, 0, grid_v_tensor[0])
    wp.atomic_add(loss, 0, grid_v_tensor[1])
    wp.atomic_add(loss, 0, grid_v_tensor[2])

In [17]:
grid_v_in_warp = np.load(osp.join(data_path, "grid_v_in_wp.npz"))["grid_v_in"]
print(f"shape {grid_v_in_warp.shape}")
grid_size = (grid_v_in_warp.shape[0], grid_v_in_warp.shape[1], grid_v_in_warp.shape[2])
print(f"grid_size {grid_size}")
grid_v_in_warp = wp.from_torch(torch.from_numpy(grid_v_in_warp).to(device))
print(f"check grid_v_in_warp, dtype {grid_v_in_warp.dtype}, type {type(grid_v_in_warp)}")

shape (5, 5, 5, 3)
grid_size (5, 5, 5)
check grid_v_in_warp, dtype <class 'warp.types.float64'>, type <class 'warp.types.array'>


In [29]:
# compute torch sum
sum_torch = torch.sum(grid_v_in_torch)
# compute warp sum
sum_wp = wp.zeros(dtype=wp.float64)

tape = wp.Tape()
with tape:

    wp.launch(
        kernel=sum_grid_v_in,
        dim=grid_size,
        inputs=[grid_v_in_warp, sum_wp],
        device=device
    ) 
print(f"sum_torch {sum_torch}, sum_wp {sum_wp}")

sum_torch 4.263256414560601e-14, sum_wp []
