In [47]:
import torch
def get_reference_points(H, W, Z=8, num_points_in_pillar=1, dim='3d', bs=1, device='cpu', dtype=torch.float):
        """Get the reference points used in DCA and DSA.
        Args:
            H, W: spatial shape of bev.
            Z: hight of pillar.
            D: sample D points uniformly from each pillar.
            device (obj:`device`): The device where
                reference_points should be.
        Returns:
            Tensor: reference points used in decoder, has \
                shape (bs, num_keys, num_levels, 2).
        """

        # reference points in 3D space, used in spatial cross-attention (SCA)
        if dim == '3d':
            zs = torch.linspace(0.5, Z - 0.5, Z, dtype=dtype,
                                device=device).view(1, 1, Z).expand(H, W, Z) / Z
            xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
                                device=device).view(1, W, 1).expand(H, W, Z) / W
            ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
                                device=device).view(H, 1, 1).expand(H, W, Z) / H
            ref_3d = torch.stack((xs, ys, zs), -1)
            ref_3d = ref_3d.flatten(0,2)
            return ref_3d

TypeError: get_reference_points() missing 2 required positional arguments: 'H' and 'W'

torch.Size([1, 32768, 3])

In [29]:
num_points_in_pillar=1
dim='3d'
bs=1
device='cpu'
dtype=torch.float
zs = torch.linspace(0.5, 10 - 0.5, 10, dtype=dtype,
                                device=device).view(-1, 1, 1).expand(10,64, 64) / 10

In [2]:
import torch

def cartesian_to_spherical(x, y, z):
    r = torch.sqrt(x**2 + y**2 + z**2)
    theta = torch.atan2(y, x)  # azimuth
    phi = torch.asin(z / r)    # elevation
    return r, theta, phi

def get_reference_points_spherical(H=64, W=64, Z=10, bs=1, device='cpu', dtype=torch.float):
    # Cartesian coordinates in specified range
    grid_size = 0.8 / 2
    x_range = (0, 51.2)
    y_range = (-25.6, 25.6)
    z_range = (-5, 3)
    
    xs = torch.linspace(x_range[0]+grid_size, x_range[1]-grid_size, W, dtype=dtype, device=device)
    ys = torch.linspace(y_range[0]+grid_size, y_range[1]-grid_size, H, dtype=dtype, device=device)
    zs = torch.linspace(z_range[0]+grid_size, z_range[1]-grid_size, Z, dtype=dtype, device=device)
    
    # Create a meshgrid for 3D coordinates
    xs, ys, zs = torch.meshgrid(xs, ys, zs, indexing='ij')
    
    # Flatten and stack coordinates
    ref_3d = torch.stack((xs.flatten(), ys.flatten(), zs.flatten()), -1)
    
    # Convert to spherical coordinates
    rs, thetas, phis = cartesian_to_spherical(ref_3d[:, 0], ref_3d[:, 1], ref_3d[:, 2])
    
    # Scale spherical coordinates to fit into the target grid
    r_res = 1.68
    theta_res = 4 * (torch.pi / 180)  # Convert degrees to radians
    phi_res = 4 * (torch.pi / 180)    # Convert degrees to radians
    
    rs_index = (rs / r_res) / 44
    thetas_index = (((thetas) / theta_res) + 13.5) / 27
    phis_index = (((phis) / phi_res)  + 5)/ 10 
    ref_3d_spherical = torch.stack([rs_index,thetas_index,phis_index],-1)
    return ref_3d_spherical

# Usage
H, W, Z = 64, 64, 10
ref_3d_spherical = get_reference_points_spherical(H, W, Z, device='cpu', dtype=torch.float32)
ref_3d_spherical

tensor([[ 0.3466, -0.3249,  0.2414],
        [ 0.3448, -0.3249,  0.2856],
        [ 0.3434, -0.3249,  0.3303],
        ...,
        [ 0.7673,  0.7443,  0.5253],
        [ 0.7675,  0.7443,  0.5455],
        [ 0.7679,  0.7443,  0.5656]])

In [3]:
ref_3d_spherical.shape

torch.Size([40960, 3])