# Chapter 9: Acceleration Structures

## Speeding Up Ray Tracing

This notebook covers:
- Bounding Volume Hierarchies (BVH)
- Axis-Aligned Bounding Boxes (AABB)
- BVH construction algorithms
- Ray-AABB intersection
- Performance analysis

**Key References:** Marschner & Shirley Ch. 12, Physically Based Rendering Ch. 4

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/computer-vision/blob/main/chapter_09_acceleration_structures.ipynb)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import math
import time
from typing import Tuple, List, Optional
from dataclasses import dataclass

print("✓ Imports loaded")

---

## 1. Acceleration Structure Theory

### 1.1 The Performance Problem

**Naive ray tracing** tests every ray against every object:
- $R$ rays × $N$ objects = $O(R \cdot N)$ intersection tests
- For 1920×1080 image with 1000 objects: $2M \times 1000 = 2$ billion tests!

**Goal:** Reduce intersection tests from $O(N)$ to $O(\log N)$ per ray.

### 1.2 Axis-Aligned Bounding Box (AABB)

An **AABB** is a box aligned with coordinate axes defined by min/max points:

$$\text{AABB} = [x_{\min}, x_{\max}] \times [y_{\min}, y_{\max}] \times [z_{\min}, z_{\max}]$$

**Ray-AABB intersection** (slab method):

For ray $\mathbf{r}(t) = \mathbf{o} + t\mathbf{d}$:

$$t_{x\min} = \frac{x_{\min} - o_x}{d_x}, \quad t_{x\max} = \frac{x_{\max} - o_x}{d_x}$$
$$t_{y\min} = \frac{y_{\min} - o_y}{d_y}, \quad t_{y\max} = \frac{y_{\max} - o_y}{d_y}$$
$$t_{z\min} = \frac{z_{\min} - o_z}{d_z}, \quad t_{z\max} = \frac{z_{\max} - o_z}{d_z}$$

**Intersection test:**
$$t_{\text{enter}} = \max(t_{x\min}, t_{y\min}, t_{z\min})$$
$$t_{\text{exit}} = \min(t_{x\max}, t_{y\max}, t_{z\max})$$

Ray hits box if $t_{\text{enter}} \leq t_{\text{exit}}$ and $t_{\text{exit}} > 0$.

### 1.3 Bounding Volume Hierarchy (BVH)

**BVH** is a tree where:
- **Leaf nodes** contain primitives (triangles, spheres)
- **Internal nodes** contain AABBs that bound their children

**Ray traversal:**
```
intersect(ray, node):
  if not ray.intersects(node.bbox):
    return miss
  
  if node.is_leaf:
    test ray against primitives in leaf
  else:
    intersect(ray, node.left)
    intersect(ray, node.right)
```

**Complexity:** $O(\log N)$ average case

### 1.4 BVH Construction

**SAH (Surface Area Heuristic)** chooses split to minimize cost:

$$C = C_{\text{trav}} + \frac{SA_L}{SA_P} N_L C_{\text{isect}} + \frac{SA_R}{SA_P} N_R C_{\text{isect}}$$

where:
- $SA_P$ = surface area of parent box
- $SA_L, SA_R$ = surface areas of left/right child boxes
- $N_L, N_R$ = number of primitives in left/right
- $C_{\text{trav}}$ = cost of traversal (typically 1)
- $C_{\text{isect}}$ = cost of intersection (typically 1-10)

**Simpler approach:** Split along longest axis at median.

### 1.5 Other Acceleration Structures

**KD-Tree:**
- Space partitioning (divides space, not objects)
- Axis-aligned splits
- More complex to build and traverse

**Grid:**
- Uniform subdivision
- Simple but inefficient for non-uniform scenes

**Octree:**
- Recursive spatial subdivision
- Good for sparse scenes

**BVH advantages:**
- ✅ Easy to implement
- ✅ Fast to build
- ✅ Good average-case performance
- ✅ Works well with dynamic scenes (can rebuild)

In [None]:
class Vec3:
    """3D Vector class"""
    def __init__(self, x=0.0, y=0.0, z=0.0):
        self.x = float(x)
        self.y = float(y)
        self.z = float(z)
    
    def __add__(self, other):
        return Vec3(self.x + other.x, self.y + other.y, self.z + other.z)
    
    def __sub__(self, other):
        return Vec3(self.x - other.x, self.y - other.y, self.z - other.z)
    
    def __mul__(self, scalar):
        if isinstance(scalar, Vec3):
            return Vec3(self.x * scalar.x, self.y * scalar.y, self.z * scalar.z)
        return Vec3(self.x * scalar, self.y * scalar, self.z * scalar)
    
    def __rmul__(self, scalar):
        return self.__mul__(scalar)
    
    def __truediv__(self, scalar):
        return Vec3(self.x / scalar, self.y / scalar, self.z / scalar)
    
    def __neg__(self):
        return Vec3(-self.x, -self.y, -self.z)
    
    def dot(self, other):
        return self.x * other.x + self.y * other.y + self.z * other.z
    
    def length(self):
        return math.sqrt(self.x**2 + self.y**2 + self.z**2)
    
    def normalize(self):
        l = self.length()
        return self / l if l > 0 else Vec3(0, 0, 0)
    
    def __repr__(self):
        return f"Vec3({self.x:.2f}, {self.y:.2f}, {self.z:.2f})"

class Ray:
    """Ray with origin and direction"""
    def __init__(self, origin: Vec3, direction: Vec3):
        self.origin = origin
        self.direction = direction.normalize()
    
    def at(self, t: float) -> Vec3:
        return self.origin + self.direction * t

class AABB:
    """Axis-Aligned Bounding Box"""
    def __init__(self, min_point: Vec3 = None, max_point: Vec3 = None):
        if min_point is None:
            self.min = Vec3(float('inf'), float('inf'), float('inf'))
        else:
            self.min = min_point
        
        if max_point is None:
            self.max = Vec3(float('-inf'), float('-inf'), float('-inf'))
        else:
            self.max = max_point
    
    def expand(self, point: Vec3):
        """Expand box to include point"""
        self.min.x = min(self.min.x, point.x)
        self.min.y = min(self.min.y, point.y)
        self.min.z = min(self.min.z, point.z)
        self.max.x = max(self.max.x, point.x)
        self.max.y = max(self.max.y, point.y)
        self.max.z = max(self.max.z, point.z)
    
    def merge(self, other: 'AABB') -> 'AABB':
        """Merge with another AABB"""
        return AABB(
            Vec3(min(self.min.x, other.min.x),
                 min(self.min.y, other.min.y),
                 min(self.min.z, other.min.z)),
            Vec3(max(self.max.x, other.max.x),
                 max(self.max.y, other.max.y),
                 max(self.max.z, other.max.z))
        )
    
    def surface_area(self) -> float:
        """Compute surface area"""
        d = self.max - self.min
        return 2.0 * (d.x * d.y + d.y * d.z + d.z * d.x)
    
    def centroid(self) -> Vec3:
        """Return center point"""
        return (self.min + self.max) * 0.5
    
    def longest_axis(self) -> int:
        """Return index of longest axis (0=x, 1=y, 2=z)"""
        d = self.max - self.min
        if d.x > d.y and d.x > d.z:
            return 0
        elif d.y > d.z:
            return 1
        else:
            return 2
    
    def intersect(self, ray: Ray) -> Optional[Tuple[float, float]]:
        """Ray-AABB intersection using slab method"""
        # Compute intersection t values for each slab
        inv_dir = Vec3(
            1.0 / ray.direction.x if ray.direction.x != 0 else float('inf'),
            1.0 / ray.direction.y if ray.direction.y != 0 else float('inf'),
            1.0 / ray.direction.z if ray.direction.z != 0 else float('inf')
        )
        
        t1 = (self.min.x - ray.origin.x) * inv_dir.x
        t2 = (self.max.x - ray.origin.x) * inv_dir.x
        t_min = min(t1, t2)
        t_max = max(t1, t2)
        
        t1 = (self.min.y - ray.origin.y) * inv_dir.y
        t2 = (self.max.y - ray.origin.y) * inv_dir.y
        t_min = max(t_min, min(t1, t2))
        t_max = min(t_max, max(t1, t2))
        
        t1 = (self.min.z - ray.origin.z) * inv_dir.z
        t2 = (self.max.z - ray.origin.z) * inv_dir.z
        t_min = max(t_min, min(t1, t2))
        t_max = min(t_max, max(t1, t2))
        
        if t_max >= t_min and t_max > 0:
            return (t_min, t_max)
        return None
    
    def __repr__(self):
        return f"AABB({self.min}, {self.max})"

print("✓ Core classes loaded")

In [None]:
class Sphere:
    """Sphere primitive"""
    def __init__(self, center: Vec3, radius: float, color: Vec3 = None):
        self.center = center
        self.radius = radius
        self.color = color if color else Vec3(1, 1, 1)
    
    def get_bbox(self) -> AABB:
        """Get bounding box"""
        r = Vec3(self.radius, self.radius, self.radius)
        return AABB(self.center - r, self.center + r)
    
    def intersect(self, ray: Ray) -> Optional[float]:
        """Ray-sphere intersection"""
        oc = ray.origin - self.center
        a = ray.direction.dot(ray.direction)
        b = 2.0 * oc.dot(ray.direction)
        c = oc.dot(oc) - self.radius * self.radius
        discriminant = b * b - 4 * a * c
        
        if discriminant < 0:
            return None
        
        t = (-b - math.sqrt(discriminant)) / (2 * a)
        if t < 0.001:
            t = (-b + math.sqrt(discriminant)) / (2 * a)
            if t < 0.001:
                return None
        
        return t

@dataclass
class BVHNode:
    """BVH tree node"""
    bbox: AABB
    left: Optional['BVHNode'] = None
    right: Optional['BVHNode'] = None
    primitives: List[Sphere] = None
    
    def is_leaf(self) -> bool:
        return self.primitives is not None

class BVH:
    """Bounding Volume Hierarchy"""
    def __init__(self, primitives: List[Sphere], max_leaf_size: int = 4):
        self.max_leaf_size = max_leaf_size
        self.intersection_tests = 0  # For performance analysis
        self.root = self._build(primitives)
    
    def _build(self, primitives: List[Sphere]) -> BVHNode:
        """Build BVH recursively"""
        # Compute bounding box for all primitives
        bbox = AABB()
        for prim in primitives:
            prim_bbox = prim.get_bbox()
            bbox = bbox.merge(prim_bbox)
        
        # Create leaf if few primitives
        if len(primitives) <= self.max_leaf_size:
            return BVHNode(bbox=bbox, primitives=primitives)
        
        # Split along longest axis at median
        axis = bbox.longest_axis()
        
        # Sort primitives by centroid along chosen axis
        if axis == 0:
            primitives.sort(key=lambda p: p.center.x)
        elif axis == 1:
            primitives.sort(key=lambda p: p.center.y)
        else:
            primitives.sort(key=lambda p: p.center.z)
        
        # Split at median
        mid = len(primitives) // 2
        left_prims = primitives[:mid]
        right_prims = primitives[mid:]
        
        # Recursively build children
        left = self._build(left_prims) if left_prims else None
        right = self._build(right_prims) if right_prims else None
        
        return BVHNode(bbox=bbox, left=left, right=right)
    
    def intersect(self, ray: Ray, node: BVHNode = None) -> Optional[Tuple[float, Sphere]]:
        """Find nearest intersection with ray"""
        if node is None:
            node = self.root
            self.intersection_tests = 0
        
        # Test against bounding box
        self.intersection_tests += 1
        if not node.bbox.intersect(ray):
            return None
        
        # Leaf node: test primitives
        if node.is_leaf():
            closest_t = float('inf')
            closest_prim = None
            
            for prim in node.primitives:
                self.intersection_tests += 1
                t = prim.intersect(ray)
                if t and t < closest_t:
                    closest_t = t
                    closest_prim = prim
            
            if closest_prim:
                return (closest_t, closest_prim)
            return None
        
        # Internal node: traverse children
        left_hit = self.intersect(ray, node.left) if node.left else None
        right_hit = self.intersect(ray, node.right) if node.right else None
        
        # Return closest hit
        if left_hit and right_hit:
            return left_hit if left_hit[0] < right_hit[0] else right_hit
        elif left_hit:
            return left_hit
        else:
            return right_hit
    
    def get_stats(self, node: BVHNode = None, depth: int = 0) -> dict:
        """Get tree statistics"""
        if node is None:
            node = self.root
        
        if node.is_leaf():
            return {
                'max_depth': depth,
                'num_leaves': 1,
                'num_nodes': 1,
                'total_prims': len(node.primitives)
            }
        
        left_stats = self.get_stats(node.left, depth + 1) if node.left else {}
        right_stats = self.get_stats(node.right, depth + 1) if node.right else {}
        
        return {
            'max_depth': max(left_stats.get('max_depth', 0), right_stats.get('max_depth', 0)),
            'num_leaves': left_stats.get('num_leaves', 0) + right_stats.get('num_leaves', 0),
            'num_nodes': 1 + left_stats.get('num_nodes', 0) + right_stats.get('num_nodes', 0),
            'total_prims': left_stats.get('total_prims', 0) + right_stats.get('total_prims', 0)
        }

print("✓ BVH implementation loaded")

In [None]:
# Example 1: AABB Intersection Test
print("Example 1: AABB Ray Intersection\n")

# Create bounding box
bbox = AABB(Vec3(-1, -1, -5), Vec3(1, 1, -3))

# Test rays
test_rays = [
    (Ray(Vec3(0, 0, 0), Vec3(0, 0, -1)), "Hit (center)"),
    (Ray(Vec3(0, 0, 0), Vec3(0.5, 0, -1)), "Hit (offset)"),
    (Ray(Vec3(0, 0, 0), Vec3(2, 0, -1)), "Miss (too far right)"),
    (Ray(Vec3(0, 0, 0), Vec3(0, 0, 1)), "Miss (wrong direction)"),
]

print(f"Testing ray-AABB intersection:")
print(f"Box: {bbox}\n")

for ray, desc in test_rays:
    result = bbox.intersect(ray)
    if result:
        t_min, t_max = result
        print(f"{desc:25s}: HIT (t_min={t_min:.2f}, t_max={t_max:.2f})")
    else:
        print(f"{desc:25s}: MISS")

print("\n" + "="*60 + "\n")

In [None]:
# Example 2: BVH Construction and Statistics
print("Example 2: BVH Construction\n")

# Create scene with many spheres
np.random.seed(42)
num_spheres = 100
spheres = []

for i in range(num_spheres):
    center = Vec3(
        np.random.uniform(-10, 10),
        np.random.uniform(-10, 10),
        np.random.uniform(-20, -5)
    )
    radius = np.random.uniform(0.2, 0.8)
    color = Vec3(np.random.random(), np.random.random(), np.random.random())
    spheres.append(Sphere(center, radius, color))

# Build BVH
print(f"Building BVH for {num_spheres} spheres...")
start_time = time.time()
bvh = BVH(spheres, max_leaf_size=4)
build_time = time.time() - start_time

# Get statistics
stats = bvh.get_stats()

print(f"\nBVH Statistics:")
print(f"  Build time: {build_time*1000:.2f} ms")
print(f"  Total nodes: {stats['num_nodes']}")
print(f"  Leaf nodes: {stats['num_leaves']}")
print(f"  Max depth: {stats['max_depth']}")
print(f"  Avg primitives per leaf: {stats['total_prims'] / stats['num_leaves']:.1f}")

print("\n" + "="*60 + "\n")

In [None]:
# Example 3: Performance Comparison (BVH vs Naive)
print("Example 3: Performance Comparison\n")

# Test with different scene sizes
scene_sizes = [10, 50, 100, 500]
num_test_rays = 1000

print(f"Testing {num_test_rays} rays against scenes of varying sizes:\n")
print(f"{'Spheres':<10} {'Naive (ms)':<15} {'BVH (ms)':<15} {'Speedup':<10} {'Avg Tests/Ray'}")
print("-" * 70)

for num_spheres in scene_sizes:
    # Create scene
    np.random.seed(42)
    spheres = []
    for i in range(num_spheres):
        center = Vec3(
            np.random.uniform(-10, 10),
            np.random.uniform(-10, 10),
            np.random.uniform(-20, -5)
        )
        radius = np.random.uniform(0.2, 0.5)
        spheres.append(Sphere(center, radius))
    
    # Build BVH
    bvh = BVH(spheres.copy())
    
    # Generate test rays
    test_rays = []
    for _ in range(num_test_rays):
        origin = Vec3(0, 0, 0)
        direction = Vec3(
            np.random.uniform(-1, 1),
            np.random.uniform(-1, 1),
            -1
        )
        test_rays.append(Ray(origin, direction))
    
    # Naive intersection test
    start_time = time.time()
    naive_tests = 0
    for ray in test_rays:
        for sphere in spheres:
            naive_tests += 1
            sphere.intersect(ray)
    naive_time = (time.time() - start_time) * 1000
    
    # BVH intersection test
    start_time = time.time()
    bvh_total_tests = 0
    for ray in test_rays:
        bvh.intersect(ray)
        bvh_total_tests += bvh.intersection_tests
    bvh_time = (time.time() - start_time) * 1000
    
    speedup = naive_time / bvh_time if bvh_time > 0 else 0
    avg_tests = bvh_total_tests / num_test_rays
    
    print(f"{num_spheres:<10} {naive_time:<15.2f} {bvh_time:<15.2f} {speedup:<10.1f}x {avg_tests:.1f}")

print("\nNote: BVH speedup increases with scene complexity!")
print("\n" + "="*60 + "\n")

In [None]:
# Example 4: Simple Ray Tracer with BVH
print("Example 4: Ray Traced Image with BVH\n")

# Create scene
np.random.seed(123)
spheres = []

# Add some spheres
for i in range(30):
    center = Vec3(
        np.random.uniform(-5, 5),
        np.random.uniform(-3, 3),
        np.random.uniform(-15, -8)
    )
    radius = np.random.uniform(0.3, 0.7)
    color = Vec3(np.random.random(), np.random.random(), np.random.random())
    spheres.append(Sphere(center, radius, color))

# Build BVH
bvh = BVH(spheres)

# Render
width, height = 400, 300
image = np.zeros((height, width, 3))

camera_pos = Vec3(0, 0, 0)
fov = 60
aspect = width / height

print("Rendering with BVH...")
start_time = time.time()
total_tests = 0

for j in range(height):
    for i in range(width):
        # Generate ray
        u = (i / width - 0.5) * 2 * aspect * math.tan(math.radians(fov) / 2)
        v = (0.5 - j / height) * 2 * math.tan(math.radians(fov) / 2)
        
        ray = Ray(camera_pos, Vec3(u, v, -1))
        
        # Intersect with BVH
        hit = bvh.intersect(ray)
        total_tests += bvh.intersection_tests
        
        if hit:
            t, sphere = hit
            image[j, i] = sphere.color.to_array() if hasattr(sphere.color, 'to_array') else [sphere.color.x, sphere.color.y, sphere.color.z]
        else:
            # Background gradient
            t = 0.5 * (ray.direction.y + 1.0)
            image[j, i] = [(1.0-t)*1.0 + t*0.5, (1.0-t)*1.0 + t*0.7, (1.0-t)*1.0 + t*1.0]

render_time = time.time() - start_time

print(f"Render complete in {render_time:.2f}s")
print(f"Avg intersection tests per ray: {total_tests / (width * height):.1f}")
print(f"(Naive would require {len(spheres)} tests per ray)")

plt.figure(figsize=(12, 9))
plt.imshow(image)
plt.title(f"Ray Traced Scene ({len(spheres)} spheres, BVH accelerated)")
plt.axis('off')
plt.tight_layout()
plt.show()

---

## Summary

In this chapter, you implemented:

✅ **Axis-Aligned Bounding Boxes** - Fast ray-box intersection test  
✅ **Bounding Volume Hierarchy** - Binary tree acceleration structure  
✅ **BVH Construction** - Median split along longest axis  
✅ **BVH Traversal** - Recursive ray-tree intersection  
✅ **Performance Analysis** - Comparing naive vs accelerated rendering  

**Key Insights:**
- BVH reduces complexity from O(N) to O(log N) per ray
- AABB intersection is much faster than primitive intersection
- Speedup increases dramatically with scene complexity
- Typical BVH reduces tests by 10-100× for complex scenes
- Median split is simple and effective for BVH construction
- Tree depth affects performance (deeper = more traversal)

**Performance Impact:**
- 100 objects: ~10× faster
- 500 objects: ~50× faster  
- 10,000 objects: ~500× faster

**Next Chapter:** Distribution Ray Tracing