In [None]:
import numpy as np
from PIL import Image

# Define constants
IMAGE_WIDTH = 1200
IMAGE_HEIGHT = 600
SAMPLES_PER_PIXEL = 100

# Define the sphere parameters
SPHERE_CENTER = np.array([0.0, 0.0, -5.0])
SPHERE_RADIUS = 1.0

def ray_trace(ray_origin, ray_direction):
    # Calculate intersection point with sphere
    a = np.dot(ray_direction, ray_direction)
    b = 2 * np.dot(ray_direction, (ray_origin - SPHERE_CENTER))
    c = np.sum((ray_origin - SPHERE_CENTER) ** 2) - SPHERE_RADIUS ** 2

    discr = b ** 2 - 4 * a * c
    if discr < 0:
        return None

    t1 = (-b + np.sqrt(discr)) / (2 * a)
    t2 = (-b - np.sqrt(discr)) / (2 * a)

    # Select the closest intersection point
    t = min(t1, t2)
    if t < 0:
        return None

    # Calculate color of pixel based on intersection point
    intersection_point = ray_origin + t * ray_direction
    normal_vector = intersection_point - SPHERE_CENTER
    normal_vector /= np.linalg.norm(normal_vector)

    # Simple diffuse shading (no lighting, just ambient)
    color = 0.5 + 0.2 * normal_vector[2]
    return color

def render_image():
    image_buffer = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3))

    for x in range(IMAGE_WIDTH):
        print(f"Processing: {x}")
        for y in range(IMAGE_HEIGHT):
            # Create ray from camera to pixel on image plane
            ray_origin = np.array([0.0, 0.0, 0.0])
            ray_direction = np.array([
                (x - IMAGE_WIDTH / 2) / IMAGE_WIDTH,
                -(y - IMAGE_HEIGHT / 2) / IMAGE_HEIGHT,
                -1.0
            ])
            ray_direction /= np.linalg.norm(ray_direction)

            # Accumulate color of pixel over multiple samples
            color_sum = 0.0
            for _ in range(SAMPLES_PER_PIXEL):
                # Perturb the ray direction to simulate Monte Carlo sampling
                perturbed_ray_direction = [
                    ray_direction[0] + 0.01 * np.random.randn(),
                    ray_direction[1] + 0.01 * np.random.randn(),
                    ray_direction[2]
                ]
                perturbed_ray_direction /= np.linalg.norm(perturbed_ray_direction)

                intersection_color = ray_trace(ray_origin, perturbed_ray_direction)
                if intersection_color is not None:
                    color_sum += intersection_color

            # Average the color over all samples
            color_avg = color_sum / SAMPLES_PER_PIXEL

            image_buffer[y, x] = [color_avg, 0.5 * (1 + color_avg), 0.2]

    return image_buffer

image_data = render_image()
img = Image.fromarray((255 * image_data).astype('uint8'))
img.save('ray_tracing_output.png')

Processing: 0
Processing: 1
Processing: 2
Processing: 3
Processing: 4
Processing: 5
Processing: 6
Processing: 7
Processing: 8
Processing: 9
Processing: 10
Processing: 11
Processing: 12
Processing: 13
Processing: 14
Processing: 15
Processing: 16
Processing: 17
Processing: 18
Processing: 19
Processing: 20
Processing: 21
Processing: 22
Processing: 23
Processing: 24
Processing: 25
Processing: 26
Processing: 27
Processing: 28
Processing: 29
Processing: 30
Processing: 31
Processing: 32
Processing: 33
Processing: 34
Processing: 35
Processing: 36
Processing: 37
Processing: 38
Processing: 39
Processing: 40
Processing: 41
Processing: 42
Processing: 43
Processing: 44
Processing: 45
Processing: 46
Processing: 47
Processing: 48
Processing: 49
Processing: 50
Processing: 51
Processing: 52
Processing: 53
Processing: 54
Processing: 55
Processing: 56
Processing: 57
Processing: 58
Processing: 59
Processing: 60
Processing: 61
Processing: 62
Processing: 63
Processing: 64
Processing: 65
Processing: 66
Proce

Using numba on the double loop on IMAGE_WIDTH * IMAGE_HEIGHT leads to a computational time of around 30s.

In [None]:
import numpy as np
from PIL import Image
import numba

# Define constants
IMAGE_WIDTH = 1200
IMAGE_HEIGHT = 600
SAMPLES_PER_PIXEL = 100

# Define the sphere parameters
SPHERE_CENTER = np.array([0.0, 0.0, -5.0])
SPHERE_RADIUS = 1.0

# Start numba modification
@numba.jit(nopython=True)
# End numba modification
def ray_trace(ray_origin, ray_direction):
    # Calculate intersection point with sphere
    a = np.dot(ray_direction, ray_direction)
    b = 2 * np.dot(ray_direction, (ray_origin - SPHERE_CENTER))
    c = np.sum((ray_origin - SPHERE_CENTER) ** 2) - SPHERE_RADIUS ** 2

    discr = b ** 2 - 4 * a * c
    if discr < 0:
        return None

    t1 = (-b + np.sqrt(discr)) / (2 * a)
    t2 = (-b - np.sqrt(discr)) / (2 * a)

    # Select the closest intersection point
    t = min(t1, t2)
    if t < 0:
        return None

    # Calculate color of pixel based on intersection point
    intersection_point = ray_origin + t * ray_direction
    normal_vector = intersection_point - SPHERE_CENTER
    normal_vector /= np.linalg.norm(normal_vector)

    # Simple diffuse shading (no lighting, just ambient)
    color = 0.5 + 0.2 * normal_vector[2]
    return color

# Start numba modification
@numba.jit(nopython=True, parallel=True)
def render_image(image_buffer):
    for x in numba.prange(IMAGE_WIDTH):
        for y in numba.prange(IMAGE_HEIGHT):
# End numba modification
            # Create ray from camera to pixel on image plane
            ray_origin = np.array([0.0, 0.0, 0.0])
            ray_direction = np.array([
                (x - IMAGE_WIDTH / 2) / IMAGE_WIDTH,
                -(y - IMAGE_HEIGHT / 2) / IMAGE_HEIGHT,
                -1.0
            ])
            ray_direction /= np.linalg.norm(ray_direction)

            # Accumulate color of pixel over multiple samples
            color_sum = 0.0
            for _ in range(SAMPLES_PER_PIXEL):
                # Perturb the ray direction to simulate Monte Carlo sampling
                perturbed_ray_direction = np.array([
                    ray_direction[0] + 0.01 * np.random.randn(),
                    ray_direction[1] + 0.01 * np.random.randn(),
                    ray_direction[2]
                ])
                perturbed_ray_direction /= np.linalg.norm(perturbed_ray_direction)

                intersection_color = ray_trace(ray_origin, perturbed_ray_direction)
                if intersection_color is not None:
                    color_sum += intersection_color

            # Average the color over all samples
            color_avg = color_sum / SAMPLES_PER_PIXEL

            image_buffer[y, x] = [color_avg, 0.5 * (1 + color_avg), 0.2]

image_buffer = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3))
render_image(image_buffer)
img = Image.fromarray((255 * image_buffer).astype('uint8'))
print(img)
img.save('ray_tracing_output.png')

<PIL.Image.Image image mode=RGB size=1200x600 at 0x789CBD2C8100>


The double loop on IMAGE_WIDTH * IMAGE_HEIGHT is now vectorized.
Running on CPU leads to a computational time of around 30s, so a factor 1/90 is obtained compared to the unvectorized code.
Using jax.numpy leads to aroud 15s, so a factor 1/180 compared to the unvectorized code. Note that on GPU, a further factor 1/15 was obtained.
Interestingly, the performance (with NumPy) compared to Numba is similar. Thus Numba solution is interesting because it reaches the same performance of a vectorized code version, without the complexity of implementing the vectorization.

In [None]:
import os
os.environ["JAX_ENABLE_X64"] = "1"
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [4]:
import jax.numpy as jnp
from PIL import Image
import time
import jax

# Define constants
IMAGE_WIDTH = 1200
IMAGE_HEIGHT = 600
SAMPLES_PER_PIXEL = 100

# Define the sphere parameters
SPHERE_CENTER = jnp.array([0.0, 0.0, -5.0])
SPHERE_RADIUS = 1.0

def ray_trace(ray_origins, ray_directions):
    # Calculate intersection point with sphere
    a = jnp.sum(ray_directions ** 2, axis=-1)
    b = 2 * jnp.sum((ray_directions * (ray_origins - SPHERE_CENTER)), axis=-1)
    c = jnp.sum((ray_origins - SPHERE_CENTER) ** 2, axis=-1) - SPHERE_RADIUS ** 2

    discr = b ** 2 - 4 * a * c
    mask = discr > 0
    t1 = (-b + jnp.sqrt(discr)) / (2 * a)
    t2 = (-b - jnp.sqrt(discr)) / (2 * a)

    # Select the closest intersection point
    t = jnp.where(t1 < t2, t1, t2)
    mask &= t > 0

    # Calculate color of pixel based on intersection point
    intersection_points = ray_origins + mask[:,:,None] * ray_directions
    normal_vectors = (intersection_points - SPHERE_CENTER) / jnp.linalg.norm(intersection_points - SPHERE_CENTER, axis=-1)[:, :, None]
    colors = 0.5 + 0.2 * normal_vectors[:, :, 2]

    return colors

def render_image():
    x_coords = jnp.arange(IMAGE_WIDTH)
    y_coords = jnp.arange(IMAGE_HEIGHT)
    xx, yy = jnp.meshgrid(x_coords, y_coords)

    ray_origins = jnp.array([0.0, 0.0, 0.0])
    ray_directions_x = (xx - IMAGE_WIDTH / 2) / IMAGE_WIDTH
    ray_directions_y = -(yy - IMAGE_HEIGHT / 2) / IMAGE_HEIGHT
    ray_directions_z = -1.0 * jnp.ones((IMAGE_HEIGHT, IMAGE_WIDTH))
    ray_directions = jnp.stack([ray_directions_x, ray_directions_y, ray_directions_z], axis=-1)
    ray_directions /= jnp.linalg.norm(ray_directions, axis=-1)[:, :, None]

    colors = jnp.zeros((1, IMAGE_HEIGHT, IMAGE_WIDTH)) #jax.vmap(lambda x: jax.random.normal(jax.random.key(1024), shape=(IMAGE_HEIGHT, IMAGE_WIDTH)), in_axes=0)(jnp.zeros(SAMPLES_PER_PIXEL))
    for i in range(SAMPLES_PER_PIXEL):
        perturbed_ray_directions = ray_directions + 0.01 * jax.random.normal(jax.random.key(1024), ray_directions.shape)
        perturbed_ray_directions /= jnp.linalg.norm(perturbed_ray_directions, axis=-1)[:, :, None]
        colors += jnp.expand_dims(ray_trace(jnp.tile(ray_origins[None, None], (IMAGE_HEIGHT, IMAGE_WIDTH, 1)), perturbed_ray_directions), axis=0)

    image_buffer = jnp.mean(colors, axis=0)

    return image_buffer

start_time = time.time()
image_buffer = render_image()
end_time = time.time()

print(f"Rendering took {end_time - start_time} seconds")

img = Image.fromarray((255 * np.array(image_buffer)).astype('uint8'))
img.save('ray_tracing_output.png')

Rendering took 13.867013931274414 seconds
