In [17]:
import torch

def compute_id_stats(residual: torch.Tensor):
    # tensor shape: N x H x C
    N, H, C = residual.shape
    
    # 生成对角线索引
    ids = (torch.arange(N, device=residual.device)[:, None] + 
           torch.arange(H, device=residual.device)[None, :])  # shape (N, H)
    
    # 展平并准备数据
    x_flat = residual.view(-1, C)  # [N*H, C]
    ids_flat = ids.view(-1, 1).expand(-1, C)  # [N*H, C]
    
    # 初始化结果张量
    result_shape = (N + H - 1, C)
    sum_per_id = torch.zeros(result_shape, dtype=residual.dtype, device=residual.device)
    sum_squared_per_id = torch.zeros_like(sum_per_id)
    
    # 计算总和和平方和
    sum_per_id.scatter_add_(0, ids_flat, x_flat)
    sum_squared_per_id.scatter_add_(0, ids_flat, (residual ** 2).view(-1, C))
    
    # 计算元素数目
    counts = torch.bincount(ids.view(-1), minlength=N+H-1).to(dtype=residual.dtype)
    counts = counts.unsqueeze(-1).expand(-1, C)
    
    # 计算均值和标准差
    mean = sum_per_id / counts
    std = torch.sqrt((sum_squared_per_id / counts) - mean.pow(2))
    
    return mean, std

# 示例用法
N, H, C = 3, 2, 2
X = torch.tensor([
    [[1, 7], [2, 8]],
    [[3, 9], [4, 10]],
    [[5, 11], [6, 12]]
], dtype=torch.float32)
mean, std = compute_id_stats(X)
print("均值:", mean)
print("标准差:", std)

均值: tensor([[ 1.0000,  7.0000],
        [ 2.5000,  8.5000],
        [ 4.5000, 10.5000],
        [ 6.0000, 12.0000]])
标准差: tensor([[0.0000, 0.0000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.0000, 0.0000]])
