paper: Maximizing Parallelism in the Construction of BVHs, Octrees, and k-d Trees

https://research.nvidia.com/sites/default/files/pubs/2012-06_Maximizing-Parallelism-in/karras2012hpg_paper.pdf

In [2]:
import numpy as np

def print_bits(num: int, num_bits: int) -> None:
    """Print the binary representation of a number up to num_bits."""
    bits = bin(num)[2:].zfill(num_bits)
    print(bits)

def interleave_bits(x: int, y: int, z: int, num_bits_per_dim: int) -> int:
    """
    Interleave bits from x, y, and z coordinates into a single Morton code.
    
    Args:
        x, y, z: Integer coordinates
        num_bits: Number of bits to interleave from each coordinate
    
    Returns:
        Interleaved Morton code
    """    
    
    result = 0
    for i in range(num_bits_per_dim):
        result |= ((x & (1 << i)) << (2 * i)) | \
                 ((y & (1 << i)) << (2 * i + 1)) | \
                 ((z & (1 << i)) << (2 * i + 2))
    return result

def bit_deinterleave_bits(morton_code: int, num_bits_per_dim: int) -> tuple[int, int, int]:
    """
    Deinterleave bits from a Morton code into x, y, z coordinates.
    
    Args:
        morton_code: The Morton code to deinterleave
        num_bits_per_dim: Number of bits per dimension
    
    Returns:
        Tuple of (x, y, z) coordinates
    """    
    
    x = y = z = 0
    for i in range(num_bits_per_dim):
        x |= (morton_code & (1 << (3 * i))) >> (2 * i)
        y |= (morton_code & (1 << (3 * i + 1))) >> (2 * i + 1)
        z |= (morton_code & (1 << (3 * i + 2))) >> (2 * i + 2)
    return x, y, z

def fractional_to_binary(num: float, num_bits: int) -> int:
    """
    Convert a fractional number between 0 and 1 to binary representation.
    
    Args:
        num: Number between 0 and 1
        num_bits: Number of bits to use
    
    Returns:
        Binary representation as integer
    """
    assert 0 <= num < 1.0
    
    result = 0
    n = 0
    
    while num > 0 and n < num_bits:
        num *= 2
        if num >= 1.0:
            result |= 1 << (num_bits - 1 - n)
            num -= 1.0
        n += 1
    return result

def coordinate_to_morton_code(num: float, num_bits: int) -> int:
    """
    Convert a coordinate (0 to 1) to a Morton code.
    
    Args:
        num: Coordinate value between 0 and 1
        num_bits: Number of bits to use
    
    Returns:
        Morton code for the coordinate
    """
    assert 0 <= num < 1.0
    return fractional_to_binary(num, num_bits - 1)

def binary_to_fractional(binary: int, num_bits: int) -> float:
    """
    Convert a binary number to a fractional value between 0 and 1.
    
    Args:
        binary: Binary number
        num_bits: Number of bits to consider
    
    Returns:
        Fractional value between 0 and 1
    """
    assert num_bits <= 64
    
    result = 0.0
    for i in range(num_bits):
        result += (binary & 1) * 1.0 / (1 << (num_bits - i))
        binary >>= 1
    return result

def morton_code_to_coordinate(morton_code: int, num_bits: int) -> float:
    """
    Convert a Morton code back to a coordinate value.
    
    Args:
        morton_code: The Morton code
        num_bits: Number of bits used in the Morton code
    
    Returns:
        Coordinate value between 0 and 1
    """
    assert num_bits >= 1 and 3 * num_bits <= 64
    assert (morton_code & (1 << (num_bits - 1))) == 0
    
    return binary_to_fractional(morton_code, num_bits - 1)

In [None]:
# Example usage:
x, y, z = 5, 6, 7
num_bits = 4
morton = interleave_bits(x, y, z, num_bits)
print_bits(morton, num_bits * 3)  # Print the interleaved bits

# Convert back
x2, y2, z2 = bit_deinterleave_bits(morton, num_bits)
print(f"Original: ({x}, {y}, {z}), Recovered: ({x2}, {y2}, {z2})")

In [25]:
import numpy as np
def count_leading_zeros(x, num_bits=30):
    return num_bits - x.bit_length() if x != 0 else num_bits

def delta(morton_codes, i, j, num_bits=30):
    n = len(morton_codes)
    if j < 0 or j >= n:
        return -1  # Sentinel for out-of-range
    if morton_codes[i] == morton_codes[j]:
        # Same code ⇒ fallback on index difference
        return count_leading_zeros(i ^ j) + num_bits
    else:
        return count_leading_zeros(morton_codes[i] ^ morton_codes[j], num_bits)

def determine_range(morton_codes, i, num_bits=30):
    n = len(morton_codes)
    d = np.sign(delta(morton_codes, i, i+1, num_bits) - delta(morton_codes, i, i-1, num_bits))
    delta_min = delta(morton_codes, i, i - d, num_bits)

    # Exponential search to find other end
    lmax = 2
    while delta(morton_codes, i, i + lmax * d, num_bits) > delta_min:
        lmax *= 2

    # Binary search to refine
    l = 0
    t = lmax // 2
    while t >= 1:
        if delta(morton_codes, i, i + (l + t) * d, num_bits) > delta_min:
            l += t
        t //= 2

    j = i + l * d
    return (min(i, j), max(i, j))

def find_split(morton_codes, delta_node, first, last, num_bits=30):
    split = first
    stride = last - first

    while stride > 1:
        stride = (stride + 1) // 2
        mid = split + stride
        if mid < last and delta(morton_codes, first, mid, num_bits) > delta_node:
            split = mid
    return split

def find_range_3d(prefix, num_bits_prefix, num_bits = 30):
    assert(num_bits % 3 == 0)

    # create masks 
    mask1 = ((1 << num_bits_prefix) - 1) << (num_bits - num_bits_prefix)
    mask2 = mask1 ^ ((1 << num_bits) - 1)
    
    # Create the two numbers
    min_val = prefix & mask1
    max_val = prefix | mask2

    xmin, ymin, zmin = bit_deinterleave_bits(min_val, num_bits//3)
    xmax, ymax, zmax = bit_deinterleave_bits(max_val, num_bits//3)

    # print(f"xmin: {xmin}, ymin: {ymin}, zmin: {zmin}, xmax: {xmax}, ymax: {ymax}, zmax: {zmax}")
    return (xmin, ymin, zmin), (xmax, ymax, zmax)

class Node:
    def __init__(self, index, split, left, right, min_point, max_point):
        self.index = index
        self.split = split
        self.left = left
        self.right = right
        self.min_point = min_point
        self.max_point = max_point

def build_tree(morton_codes, num_bits = 30):
    n = len(morton_codes)
    nodes = [None] * (n - 1)

    for i in range(n - 1):
        first, last = determine_range(morton_codes, i, num_bits)
        delta_node = delta(morton_codes, first, last, num_bits)
        split = find_split(morton_codes, delta_node, first, last, num_bits)        
        min_point, max_point = find_range_3d(morton_codes[split], delta_node, num_bits)

        left = split if split == first else split + n
        right = split + 1 if split + 1 == last else split + 1 + n
        nodes[i] = Node(
            index=i+n,            
            split=split,
            left=left,
            right=right,
            min_point=min_point,
            max_point=max_point
        )
    return nodes


In [26]:
# Morton codes from the paper’s example (already sorted)
morton_codes = [0, 0b100, 0b001000000, 0b001000001, 0b111111111]
nodes = build_tree(morton_codes, 9)

# Print result clearly
print("LBVH Internal Nodes:")
n = len(morton_codes)
for node in nodes:
    print(f"Node {node.index} ({node.index - n}): min_point={node.min_point}, max_point={node.max_point}, "
          f"split={node.split}, left={node.left}, right={node.right}")

LBVH Internal Nodes:
Node 5 (0): min_point=(0, 0, 0), max_point=(7, 7, 7), split=3, left=8, right=4
Node 6 (1): min_point=(0, 0, 0), max_point=(1, 1, 1), split=0, left=0, right=1
Node 7 (2): min_point=(4, 0, 0), max_point=(5, 0, 0), split=2, left=2, right=3
Node 8 (3): min_point=(0, 0, 0), max_point=(7, 3, 3), split=1, left=6, right=7


In [35]:
def check_intersect(query_min, query_max, min_point, max_point):
    # Check intersection for each dimension
    return all(query_max[i] >= min_point[i] and query_min[i] <= max_point[i] for i in range(3))

def check_inside(morton_code, query_min, query_max, num_bits = 30):
    x, y, z = bit_deinterleave_bits(morton_code, num_bits//3)
    return query_min[0] <= x <= query_max[0] and query_min[1] <= y <= query_max[1] and query_min[2] <= z <= query_max[2]

def traverse_tree(nodes, morton_codes, query_min, query_max, num_bits = 30):
    results = []
    n = len(morton_codes)
    stack = []  # Stack to store nodes to visit
    
    # Start with root node    
    stack.append(n)
    
    while stack:
        node_idx = stack.pop()        
        
        # Handle leaf node
        if node_idx < n:
            if check_inside(morton_codes[node_idx], query_min, query_max, num_bits):
                results.append(node_idx)
        else:
            node = nodes[node_idx - n]
            if (check_intersect(query_min, query_max, node.min_point, node.max_point)):
                stack.append(node.left)
                stack.append(node.right)
    return sorted(results)

# Example usage:
query_min = (0,0,0)
query_max = (4,4,4)
result = traverse_tree(nodes, morton_codes, query_min, query_max, 9)
print(f"\nPoints in region [{query_min}, {query_max}]:")
print(f"Indices: {[int(x) for x in result]}")
print(f"Morton codes: {[morton_codes[i] for i in result]}")



Points in region [(0, 0, 0), (4, 4, 4)]:
Indices: [0, 1, 2]
Morton codes: [0, 4, 64]


In [141]:
import random

# Generate random points
num_bits = 4
max_coord = (1 << num_bits) - 1

num_points = 5000  # Adjust this number to control how many points to generate before removing duplicates
# random.seed(42)  # Set fixed random seed
points = [(random.randint(0,max_coord), random.randint(0,max_coord), random.randint(0,max_coord)) for _ in range(num_points)]

# Convert to numpy array and get unique points
unique_points = list(set(tuple(point) for point in points))
print(f"Generated {len(unique_points)} unique 3D grid points:")
# print(unique_points)

morton_codes = [interleave_bits(x, y, z, num_bits) for x, y, z in unique_points]
# print("before sort")
# print(morton_codes)
# print(unique_points)

# Sort morton_codes and unique_points together using zip
sorted_pairs = sorted(zip(morton_codes, unique_points))
morton_codes, unique_points = zip(*sorted_pairs)
morton_codes = list(morton_codes)
unique_points = list(unique_points)

# print("after sort")
# print(morton_codes)
# print(unique_points)
internal_nodes = build_tree(morton_codes, num_bits * 3)

# Generate random query window within valid range
x1, x2 = sorted([random.randint(0, max_coord), random.randint(0, max_coord)])
y1, y2 = sorted([random.randint(0, max_coord), random.randint(0, max_coord)])
z1, z2 = sorted([random.randint(0, max_coord), random.randint(0, max_coord)])
query_min = (x1, y1, z1)
query_max = (x2, y2, z2)
print(f"Query window: [{query_min}, {query_max}]")

result = traverse_tree(internal_nodes, morton_codes, query_min, query_max, num_bits * 3)
result = sorted(result)
print(f"number of points in query window: {len(result)}")
# print(f"Indices: {[int(x) for x in result]}")
# print(f"Morton codes: {[morton_codes[i] for i in result]}")



bench_result = []
for i in range(len(unique_points)):
    if (check_inside(morton_codes[i], query_min, query_max, num_bits * 3)):
        bench_result.append(i)

# print("\nComparing results:")
# print(f"Brute force result: {bench_result}")
# print(f"Tree traversal result: {result}")

if sorted(bench_result) == sorted(result):
    print("Results match!")
else:
    print("Results don't match!")
    print("Missing points:", set(bench_result) - set(result))
    print("Extra points:", set(result) - set(bench_result))





Generated 2892 unique 3D grid points:
Query window: [(10, 8, 11), (15, 14, 15)]
number of points in query window: 148
Results match!
