In [3]:
import torch
def scatter_nd_pytorch(indices, updates, shape):
    # Ensure indices are long tensor for indexing
    indices = indices.long()
    
    # Create a zero tensor of the specified shape
    result = torch.zeros(shape, dtype=updates.dtype, device=updates.device)
    
    # Unpack the dimensions for scattering
    dim_length = indices.shape[-1]
    idx_expanded = tuple(indices[:, :, i] for i in range(dim_length))
    
    # Scatter the updates to the result tensor
    result[idx_expanded] = updates
    return result


# Example usage with your provided shapes and sizes
N_rays = 8192
N_samples = 42
N_obj = 1
N_samples_obj = 6

# Dummy data for demonstration
id_z_vals_bckg = torch.randint(0, N_samples + N_obj*N_samples_obj, (N_rays, 6, 2))
raw_bckg = torch.randn(N_rays, 6, 4)

# Shape of the raw tensor
raw_sh = [N_rays, N_samples + N_obj*N_samples_obj, 4]

# Create the raw tensor
raw = torch.zeros(raw_sh)

# Perform the scatter operation
raw += scatter_nd_pytorch(id_z_vals_bckg, raw_bckg, raw_sh)

In [7]:
raw

tensor([[[-0.4114,  0.4762,  0.2824, -0.4915],
         [-0.9814, -0.0852,  0.5516,  0.0407],
         [ 1.6542,  0.2371, -0.9970,  0.5172],
         ...,
         [ 0.3117,  0.4094, -1.5334, -0.4685],
         [-1.5857, -0.1623,  1.4708, -0.6366],
         [ 0.5398, -0.5206, -0.7659,  0.2284]],

        [[ 1.5041,  0.5396, -0.3422, -0.8521],
         [-0.2654,  0.7120, -0.6745,  1.3746],
         [ 1.1241,  1.4274, -0.2997,  0.0702],
         ...,
         [-1.6938, -0.3264, -0.8954, -0.5679],
         [-0.2659, -0.7734, -0.6231,  0.4535],
         [ 0.3490,  0.8625,  0.3576,  1.1135]],

        [[ 1.8997,  0.5756, -1.0827, -0.6513],
         [ 1.0553, -0.2906, -1.2896, -0.4517],
         [-1.1076,  0.6554, -1.4548,  0.4433],
         ...,
         [ 0.8763, -1.4792, -0.6201, -0.8161],
         [ 1.4393, -1.1915,  0.3240,  1.8210],
         [ 1.5799,  0.7605,  0.0423,  0.0083]],

        ...,

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]

: 

In [6]:
id_z_vals_bckg

tensor([[[31, 32],
         [ 1,  4],
         [11,  0],
         [43, 11],
         [ 7,  4],
         [14, 34]],

        [[13, 37],
         [13, 25],
         [ 1,  5],
         [43, 29],
         [13, 28],
         [ 9, 11]],

        [[44, 25],
         [13, 23],
         [31, 43],
         [44,  9],
         [43, 44],
         [10,  8]],

        ...,

        [[10, 21],
         [ 0,  3],
         [35, 18],
         [43, 39],
         [ 2, 27],
         [31,  3]],

        [[45, 45],
         [21, 11],
         [42, 14],
         [ 1,  4],
         [18,  1],
         [30,  9]],

        [[16, 34],
         [35, 22],
         [29, 22],
         [13, 10],
         [32, 29],
         [42, 36]]])