In [4]:
import taichi as ti
import torch
from typing import *

ti.init(arch=ti.gpu)  # 使用GPU进行加速

# 定义一个大的预定义大小
MAX_DAYS = 1000
MAX_STOCKS = 1000

# 定义数据字段
x_field = ti.field(dtype=ti.f32, shape=(MAX_DAYS, MAX_STOCKS))
n_field = ti.field(dtype=ti.i32, shape=(MAX_DAYS,))
mask_field = ti.field(dtype=ti.i32, shape=(MAX_DAYS, MAX_STOCKS))  # 使用整数代替布尔值
mean_field = ti.field(dtype=ti.f32, shape=(MAX_DAYS,))
std_field = ti.field(dtype=ti.f32, shape=(MAX_DAYS,))


@ti.kernel
def compute_mean_std(num_days: int, num_stocks: int):
    for i in range(num_days):
        sum_x = 0.0
        sum_x2 = 0.0
        for j in range(num_stocks):
            if not mask_field[i, j]:
                sum_x += x_field[i, j]
                sum_x2 += x_field[i, j] ** 2
        mean_val = sum_x / n_field[i]
        mean_field[i] = mean_val
        std_field[i] = ti.sqrt(sum_x2 / n_field[i] - mean_val**2)


def masked_mean_std_taichi(
    x: torch.Tensor,
    n: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
):
    num_days, num_stocks = x.shape
    if mask is None:
        mask = torch.isnan(x)
    if n is None:
        n = (~mask).sum(dim=1)

    # 将数据复制到taichi字段中
    x_field[:num_days, :num_stocks] = x.numpy()
    n_field[:num_days] = n.numpy()
    mask_field[:num_days, :num_stocks] = mask.numpy().astype(int)

    compute_mean_std(num_days, num_stocks)

    # 从taichi字段中获取结果
    mean = mean_field[:num_days].to_numpy()
    std = std_field[:num_days].to_numpy()

    return torch.tensor(mean), torch.tensor(std)


# 测试
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mean, std = masked_mean_std_taichi(x)
print(mean, std)

[Taichi] Starting on arch=cuda


TypeError: write_float(): incompatible function arguments. The following argument types are supported:
    1. (self: taichi._lib.core.taichi_python.SNode, arg0: List[int], arg1: float) -> None

Invoked with: <taichi._lib.core.taichi_python.SNode object at 0x7f9358bbb6f0>, (slice(None, 2, None), slice(None, 3, None), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), array([[1., 2., 3.],
       [4., 5., 6.]], dtype=float32)