paper: Maximizing Parallelism in the Construction of BVHs, Octrees, and k-d Trees

https://research.nvidia.com/sites/default/files/pubs/2012-06_Maximizing-Parallelism-in/karras2012hpg_paper.pdf

In [4]:
import numpy as np

def count_leading_zeros(x, bits=32):
    return bits - x.bit_length() if x != 0 else bits

def delta(morton_codes, i, j):
    n = len(morton_codes)
    if j < 0 or j >= n:
        return -1  # Sentinel for out-of-range
    if morton_codes[i] == morton_codes[j]:
        # Same code ⇒ fallback on index difference
        return count_leading_zeros(i ^ j) + 32
    else:
        return count_leading_zeros(morton_codes[i] ^ morton_codes[j])

def determine_range(morton_codes, i):
    n = len(morton_codes)
    d = np.sign(delta(morton_codes, i, i+1) - delta(morton_codes, i, i-1))
    delta_min = delta(morton_codes, i, i - d)

    # Exponential search to find other end
    lmax = 2
    while delta(morton_codes, i, i + lmax * d) > delta_min:
        lmax *= 2

    # Binary search to refine
    l = 0
    t = lmax // 2
    while t >= 1:
        if delta(morton_codes, i, i + (l + t) * d) > delta_min:
            l += t
        t //= 2

    j = i + l * d
    return (min(i, j), max(i, j))

def find_split(morton_codes, first, last):
    delta_node = delta(morton_codes, first, last)
    split = first
    stride = last - first

    while stride > 1:
        stride = (stride + 1) // 2
        mid = split + stride
        if mid < last and delta(morton_codes, first, mid) > delta_node:
            split = mid
    return split

class Node:
    def __init__(self, index, range_tuple, split, left, right):
        self.index = index
        self.range = range_tuple
        self.split = split
        self.left = left
        self.right = right

def build_tree(morton_codes):
    n = len(morton_codes)
    nodes = [None] * (n - 1)

    for i in range(n - 1):
        first, last = determine_range(morton_codes, i)
        split = find_split(morton_codes, first, last)
        left = split if split == first else split + n
        right = split + 1 if split + 1 == last else split + 1 + n
        nodes[i] = Node(
            index=i+n,
            range_tuple=(first, last),
            split=split,
            left=left,
            right=right
        )
    return nodes


In [7]:
# Morton codes from the paper’s example (already sorted)
morton_codes = [1, 2, 4, 5, 19, 24, 25, 30]
nodes = build_tree(morton_codes)

# Print result clearly
print("LBVH Internal Nodes:")
n = len(morton_codes)
for node in nodes:
    print(f"Node {node.index} ({node.index - n}): range=({node.range[0]}, {node.range[1]}), "
          f"split={node.split}, left={node.left}, right={node.right}")

LBVH Internal Nodes:
Node 8 (0): range=(0, 7), split=3, left=11, right=12
Node 9 (1): range=(0, 1), split=0, left=0, right=1
Node 10 (2): range=(2, 3), split=2, left=2, right=3
Node 11 (3): range=(0, 3), split=1, left=9, right=10
Node 12 (4): range=(4, 7), split=4, left=4, right=13
Node 13 (5): range=(5, 7), split=6, left=14, right=7
Node 14 (6): range=(5, 6), split=5, left=5, right=6


In [8]:
def traverse_tree_iterative(nodes, morton_codes, query_min, query_max):
    results = []
    n = len(morton_codes)    
    stack = []  # Stack to store nodes to visit
    
    # Start with root node    
    stack.append(n)
    
    while stack:
        node_idx = stack.pop()        
        
        # Handle leaf node
        if node_idx < n:
            if query_min <= morton_codes[node_idx] <= query_max:
                results.append(node_idx)
        else:
            node = nodes[node_idx - n]
            if (query_min > morton_codes[node.range[1]] or query_max < morton_codes[node.range[0]]):
                continue
            else:
                stack.append(node.left)
                stack.append(node.right)
    return sorted(results)

# Example usage:
query_min = 4
query_max = 19
result = traverse_tree_iterative(nodes, morton_codes, query_min, query_max)
print(f"\nPoints in region [{query_min}, {query_max}]:")
print(f"Indices: {result}")
print(f"Morton codes: {[morton_codes[i] for i in result]}")



Points in region [4, 19]:
Indices: [2, 3, 4]
Morton codes: [4, 5, 19]
