In [9]:
import torch
from einops import rearrange

In [14]:
def get_point_cloud_from_depth_torch(depth, intrinsic, depth_scalar=1):
    *_, channels, height, width = depth.shape

    if channels == 1 and width != 1:
        depth = rearrange(depth, '... c h w -> ... h w c')

    device = depth.device
    px, py = float(intrinsic[0, 2]), float(intrinsic[1, 2])
    fx, fy = float(intrinsic[0, 0]), float(intrinsic[1, 1])

    stacked_p = torch.tensor([[px, py],], dtype=torch.float32, device=device).unsqueeze(0)
    stacked_f = torch.tensor([[fx, fy],], dtype=torch.float32, device=device).unsqueeze(0)
    
    coordinates = torch.stack(torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device)), dim=-1)
    
    points = (((coordinates - stacked_p) * depth) / stacked_f)
    points = torch.concat([points, depth], dim=-1)
    points = rearrange(points, '... h w c -> ... (h w) c')

    return points

depth = torch.randn(512, 1, 1, 128, 128)
intrinsic = torch.randn(3,3)

get_point_cloud_from_depth_torch(depth, intrinsic)  # Call the function

tensor([12, 15])


tensor([[[[-1.0961e+00, -6.9444e-01,  9.7079e-01],
          [-3.3646e+00, -3.9396e+00,  2.9800e+00],
          [-1.4528e-01, -2.4817e-01,  1.2867e-01],
          ...,
          [-1.7149e+02, -5.9997e+01,  7.8377e-01],
          [ 8.0618e+01,  2.8429e+01, -3.6845e-01],
          [-1.0009e+01, -3.5572e+00,  4.5744e-02]]],


        [[[-3.3815e-01, -2.1424e-01,  2.9949e-01],
          [-2.8346e+00, -3.3191e+00,  2.5106e+00],
          [ 1.7915e-01,  3.0602e-01, -1.5867e-01],
          ...,
          [-5.5515e+01, -1.9422e+01,  2.5372e-01],
          [ 1.3516e+02,  4.7660e+01, -6.1771e-01],
          [ 1.6319e+02,  5.7999e+01, -7.4584e-01]]],


        [[[ 6.2079e-01,  3.9331e-01, -5.4983e-01],
          [-3.7141e+00, -4.3489e+00,  3.2896e+00],
          [-4.2219e-01, -7.2120e-01,  3.7393e-01],
          ...,
          [ 6.5385e+01,  2.2876e+01, -2.9883e-01],
          [ 3.6361e+01,  1.2822e+01, -1.6618e-01],
          [-3.0139e+02, -1.0712e+02,  1.3775e+00]]],


        ...,


        [[