In [1]:
%load_ext lab_black
%cd ..

/root/data2/jaehyeok/dev/ddpm/space-filling-pytorch


In [2]:
import torch as th

In [3]:
th.cuda.set_device(9)

In [4]:
# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
from typing import Optional, Union


class KeyLUT:
    def __init__(self):
        r256 = torch.arange(256, dtype=torch.int64)
        r512 = torch.arange(512, dtype=torch.int64)
        zero = torch.zeros(256, dtype=torch.int64)
        device = torch.device("cpu")

        self._encode = {
            device: (
                self.xyz2key(r256, zero, zero, 8),
                self.xyz2key(zero, r256, zero, 8),
                self.xyz2key(zero, zero, r256, 8),
            )
        }
        self._decode = {device: self.key2xyz(r512, 9)}

    def encode_lut(self, device=torch.device("cpu")):
        if device not in self._encode:
            cpu = torch.device("cpu")
            self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
        return self._encode[device]

    def decode_lut(self, device=torch.device("cpu")):
        if device not in self._decode:
            cpu = torch.device("cpu")
            self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
        return self._decode[device]

    def xyz2key(self, x, y, z, depth):
        key = torch.zeros_like(x)
        for i in range(depth):
            mask = 1 << i
            key = key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0))
        return key

    def key2xyz(self, key, depth):
        x = torch.zeros_like(key)
        y = torch.zeros_like(key)
        z = torch.zeros_like(key)
        for i in range(depth):
            x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
            y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
            z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
        return x, y, z


_key_lut = KeyLUT()


def xyz2key(
    x: torch.Tensor,
    y: torch.Tensor,
    z: torch.Tensor,
    b: Optional[Union[torch.Tensor, int]] = None,
    depth: int = 16,
):
    r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
    based on pre-computed look up tables. The speed of this function is much
    faster than the method based on for-loop.

    Args:
      x (torch.Tensor): The x coordinate.
      y (torch.Tensor): The y coordinate.
      z (torch.Tensor): The z coordinate.
      b (torch.Tensor or int): The batch index of the coordinates, and should be
          smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
          :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
      depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
    """

    EX, EY, EZ = _key_lut.encode_lut(x.device)
    x, y, z = x.long(), y.long(), z.long()

    mask = 255 if depth > 8 else (1 << depth) - 1
    key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
    if depth > 8:
        mask = (1 << (depth - 8)) - 1
        key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
        key = key16 << 24 | key

    if b is not None:
        b = b.long()
        key = b << 48 | key

    return key


def key2xyz(key: torch.Tensor, depth: int = 16):
    r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
    and the batch index based on pre-computed look up tables.

    Args:
      key (torch.Tensor): The shuffled key.
      depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
    """

    DX, DY, DZ = _key_lut.decode_lut(key.device)
    x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)

    b = key >> 48
    key = key & ((1 << 48) - 1)

    n = (depth + 2) // 3
    for i in range(n):
        k = key >> (i * 9) & 511
        x = x | (DX[k] << (i * 3))
        y = y | (DY[k] << (i * 3))
        z = z | (DZ[k] << (i * 3))

    return x, y, z, b

In [8]:
from typing import Tuple

import torch as th
import triton
import triton.language as tl
from torchtyping import TensorType


@triton.autotune(
    configs=[
        # triton.Config({"BLOCK_SIZE": 32}),
        # triton.Config({"BLOCK_SIZE": 64}),
        # triton.Config({"BLOCK_SIZE": 128}),
        triton.Config({"BLOCK_SIZE": 256}),
        # triton.Config({"BLOCK_SIZE": 512}),
        # triton.Config({"BLOCK_SIZE": 1024}),
    ],
    key=["BN"],
)
@triton.jit
def point_to_zorder_3d_depth16_fp32_kernel(
    xyz_ptr,
    distance_ptr,
    BN,
    space_size,
    x_offset,
    y_offset,
    z_offset,
    BLOCK_SIZE: tl.constexpr,
):
    # load data
    pid = tl.program_id(0)
    idx_bn = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = idx_bn < BN
    # TODO no coalescing
    fx = tl.load(xyz_ptr + idx_bn * 3 + x_offset, mask=mask)
    fy = tl.load(xyz_ptr + idx_bn * 3 + y_offset, mask=mask)
    fz = tl.load(xyz_ptr + idx_bn * 3 + z_offset, mask=mask)
    x = ((fx + 1) / 2 * space_size).to(tl.int64)
    y = ((fy + 1) / 2 * space_size).to(tl.int64)
    z = ((fz + 1) / 2 * space_size).to(tl.int64)
    x = tl.minimum(tl.maximum(x, 0), space_size - 1)
    y = tl.minimum(tl.maximum(y, 0), space_size - 1)
    z = tl.minimum(tl.maximum(z, 0), space_size - 1)

    # calculate z-order
    ret = 0
    for i in tl.static_range(0, 16):
        q = 1 << i
        ret |= (x & q) << (2 * i + 2)
        ret |= (y & q) << (2 * i + 1)
        ret |= (z & q) << (2 * i + 0)

    # write results
    tl.store(distance_ptr + idx_bn, ret, mask=mask)


def point_to_zorder_3d_depth16_fp32(
    xyz: TensorType["b", "n", 3, th.float32],
    space_size: int,
    x_offset: int = 0,
    y_offset: int = 1,
    z_offset: int = 2,
):
    B, N = xyz.shape[:2]
    distance = xyz.new_empty(B, N, dtype=th.int64)
    grid = lambda meta: (triton.cdiv(B * N, meta["BLOCK_SIZE"]),)
    point_to_zorder_3d_depth16_fp32_kernel[grid](xyz, distance, B * N, space_size, x_offset, y_offset, z_offset)
    return distance

In [6]:
xyz = th.rand(16, 3, device="cuda") * 2 - 1
grid_size = 2**16

In [15]:
z_order_triton = point_to_zorder_3d_depth16_fp32(xyz[None], grid_size, 0, 1, 2)
z_order_triton

tensor([[142258343134507, 234905236643413, 266121751766580, 229605166709668,
         274726867874873, 221983309916557,  34985663610472, 198169605513854,
         120898085593575,  54730846400195, 253166080228275, 191542425668609,
         249175721684259, 149382059561983, 124312971898722, 247139660076827]],
       device='cuda:9')

In [14]:
grid_coord = ((xyz + 1) / 2 * grid_size).long()
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
z_order_ptv3 = xyz2key(x, y, z)
z_order_ptv3

tensor([142258343134507, 234905236643413, 266121751766580, 229605166709668,
        274726867874873, 221983309916557,  34985663610472, 198169605513854,
        120898085593575,  54730846400195, 253166080228275, 191542425668609,
        249175721684259, 149382059561983, 124312971898722, 247139660076827],
       device='cuda:9')

In [16]:
th.allclose(z_order_triton, z_order_ptv3)

True

# Large Scale Test
---

In [17]:
xyz = th.rand(32768, 3, device="cuda") * 2 - 1
grid_size = 2**16

In [18]:
z_order_triton = point_to_zorder_3d_depth16_fp32(xyz[None], grid_size, 0, 1, 2)
z_order_triton

tensor([[138784506524412, 224138053161915, 185899124151734,  ...,
           3557457093838,  77484789105525, 249790088739659]], device='cuda:9')

In [19]:
grid_coord = ((xyz + 1) / 2 * grid_size).long()
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
z_order_ptv3 = xyz2key(x, y, z)
z_order_ptv3

tensor([138784506524412, 224138053161915, 185899124151734,  ...,
          3557457093838,  77484789105525, 249790088739659], device='cuda:9')

In [20]:
th.allclose(z_order_triton, z_order_ptv3)

True