<a href="https://colab.research.google.com/github/CharlesPlusC/RadMapJAX/blob/main/Ray2ForceConfig.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp
from jax import grad, value_and_grad, jit
from jax import random
from jax import lax

from collections import namedtuple

import os
import time

import plotly.express as px
import plotly.graph_objects as go

def normalize(x):
  return x / jnp.linalg.norm(x, axis = -1)[...,jnp.newaxis]

def visualize_rays(positions, directions, length=0.1, max_rays=1000):
    end_points = positions + length * directions

    positions = positions[100:200]
    end_points = end_points[100:200]

    fig = go.Figure()
    for start, end in zip(positions, end_points):
        fig.add_trace(go.Scatter3d(x=[start[0], end[0]],
                                   y=[start[1], end[1]],
                                   z=[start[2], end[2]],
                                   mode='lines',
                                   line=dict(color='blue', width=2)))

    fig.update_layout(title="Ray Visualization",
                      scene=dict(aspectmode='data'),
                      showlegend=False)
    fig.show()

Ray = namedtuple("Ray", "position direction")

def spherical_to_carthesian(spherical):
  x = -jnp.cos(spherical[:,:,0]) * jnp.cos(spherical[:,:,1])
  y = -jnp.cos(spherical[:,:,0]) * jnp.sin(spherical[:,:,1])
  z = -jnp.sin(spherical[:,:,0])
  return jnp.stack((x, y, z), axis = -1)

def grid(shape):
  u = jnp.linspace(0, 1, shape[1])
  v = jnp.linspace(0, 1, shape[0])
  theta, phi = jnp.meshgrid(v, u)
  return jnp.stack((theta, phi)).T

def sphere(shape):
  spherical = grid(shape) * jnp.array([jnp.pi, 2 * jnp.pi]) - jnp.array([-jnp.pi/2, 0])
  return spherical_to_carthesian(spherical)

class Mesh:
  def __init__(self, filename):
    self.vertices = ()
    self.faces = ()
    for line in open(filename, "r"):
      if(line.startswith("v ")):
        values = line.split()
        self.vertices = self.vertices + (tuple(float(x) for x in values[1:4]),)
      if(line.startswith("f ")):
        values = line.split()
        face = ()
        for v in values[1:4]:
          values2 = v.split('/')
          face = face + (int(values2[0]) - 1,)
        self.faces = self.faces + (face,)

    self.vertices = jnp.array(self.vertices)
    self.faces = jnp.array(self.faces)

    box_min = jnp.min(self.vertices, axis = 0)
    box_max = jnp.max(self.vertices, axis = 0)
    box_size = box_max - box_min
    box_center = (box_min + box_max) / 2
    self.vertices = 1 * (self.vertices - box_center) / box_size

    self.build_area_sampling()

  def heron(self, sides):
    semiperimeter = (sides[0] + sides[1] + sides[2])/2
    result = 0
    result += semiperimeter * (semiperimeter - sides[0])
    result += semiperimeter * (semiperimeter - sides[1])
    result += semiperimeter * (semiperimeter - sides[2])
    return jnp.sqrt(result)

  def build_area_sampling(self):
    self.proto_size = 100_000

    vertex0 = self.vertices[self.faces[...,0]]
    vertex1 = self.vertices[self.faces[...,1]]
    vertex2 = self.vertices[self.faces[...,2]]
    sides = jnp.array([vertex0 - vertex1, vertex1 - vertex2, vertex2 - vertex0])
    sides = sides ** 2
    sides = jnp.sum(sides, axis = -1)
    areas = self.heron(sides)
    cummulative_area = jnp.cumsum(areas)
    pdf = cummulative_area / cummulative_area[-1]

    proto = jnp.linspace(0, 1, self.proto_size)
    self.proto_indices = jnp.searchsorted(pdf, proto)

  def get_normal(self, face):
    vertex0 = self.vertices[face[...,0]]
    vertex1 = self.vertices[face[...,1]]
    vertex2 = self.vertices[face[...,2]]
    v0v1 = vertex1 - vertex0
    v0v2 = vertex2 - vertex0
    return jnp.cross(v0v1, v0v2)

  def intersect(self, rays, t_min, t_max):

    def intersect_tri(face, best_ts, best_normals):
      vertex0 = self.vertices[face[0]]
      vertex1 = self.vertices[face[1]]
      vertex2 = self.vertices[face[2]]

      v0v1 = vertex1 - vertex0
      v0v2 = vertex2 - vertex0

      plane_normal = jnp.cross(v0v1, v0v2)

      plane_d = -jnp.sum(vertex0 * plane_normal, axis = -1)

      normal_dot_origin = jnp.sum(rays.position * plane_normal, axis = -1) + plane_d
      normal_dot_direction = jnp.sum(rays.direction * plane_normal, axis = -1)
      tri_ts = -normal_dot_origin / normal_dot_direction

      hit = rays.position + tri_ts[...,jnp.newaxis] * rays.direction

      edge0 = vertex1 - vertex0;
      edge1 = vertex2 - vertex1;
      edge2 = vertex0 - vertex2;
      vp0 = hit - vertex0;
      vp1 = hit - vertex1;
      vp2 = hit - vertex2;
      c0 = jnp.cross(edge0, vp0)
      c1 = jnp.cross(edge1, vp1)
      c2 = jnp.cross(edge2, vp2)
      f0 = jnp.sum(c0 * plane_normal, axis = -1)
      f1 = jnp.sum(c1 * plane_normal, axis = -1)
      f2 = jnp.sum(c2 * plane_normal, axis = -1)
      tri_ts = jnp.where(f0 > 0, tri_ts, float('nan'))
      tri_ts = jnp.where(f1 > 0, tri_ts, float('nan'))
      tri_ts = jnp.where(f2 > 0, tri_ts, float('nan'))

      skip = (jnp.isnan(tri_ts)) | (tri_ts > best_ts) | (tri_ts < t_min)
      best_ts = jnp.where(skip, best_ts, tri_ts)
      best_normals = jnp.where(skip[...,jnp.newaxis], best_normals, plane_normal)
      return best_ts, best_normals

    def intersect_for_i(index, state):
      best_ts = state[0]
      best_normals = state[1]
      state = intersect_tri(self.faces[index], best_ts, best_normals)
      return state

    best_ts = jnp.full(rays.position.shape[:-1], float('nan'))
    best_normals = jnp.zeros_like(rays.position)

    startTime = time.time()
    state = (best_ts, best_normals)
    state = lax.fori_loop(0, self.faces.shape[0], intersect_for_i, state)
    time_taken = time.time() - startTime

    ray_count = (rays.position.size / rays.position.shape[-1]) / (1000 * 1000)
    intersection_count = self.faces.shape[0] * ray_count
    rays_per_second = intersection_count / time_taken
    #print("{:0.1f} Mrays/s ({:0.1f}M rays in {:0.3f}s)".format(rays_per_second, intersection_count, time_taken))

    return state

class Log:

  def flatten(self, vertices):
    vertices = jnp.reshape(vertices, (-1, vertices.shape[2]))
    return vertices

  def init(self, vertices):
    vertices = self.flatten(vertices)
    self.path = jnp.empty((0,) + (vertices.shape[0],) + (vertices.shape[-1],))

  def add(self, new_vertices):
    new_vertices = self.flatten(new_vertices)
    self.path = jnp.append(self.path, new_vertices[jnp.newaxis,...], axis = 0)

  def plot(self):
    depths = jnp.arange(0, self.path.shape[0])
    depths = jnp.repeat(depths, self.path.shape[1])
    index = jnp.arange(0, self.path.shape[1])
    index = jnp.tile(index, self.path.shape[0])

    flat = jnp.reshape(self.path, (-1, self.path.shape[2]))

    data_frame = {
      "x": flat[:,0],
      "y": flat[:,1],
      "z": flat[:,2],
      "depth": depths,
      "index": index
    }

    scene = dict(
      aspectratio = dict(x = 1, y = 1, z = 1),
      aspectmode = "manual",
      xaxis = dict(range=[-1,1]),
      yaxis = dict(range=[-1,1]),
      zaxis = dict(range=[-1,1]),
    )

    fig1 = px.scatter_3d(data_frame, x = "x", y = "y", z = "z", color = "depth")
    fig1.update_layout(scene=scene)
    fig2 = px.line_3d(data_frame, x = "x", y = "y", z = "z", color = "index")
    fig2.update_layout(scene=scene)

    fig3 = go.Figure(data=fig1.data + fig2.data)
    fig3.show()

class NullLog(Log):
  def init(self, vertices):
    pass

  def add(self, new_vertices):
    pass

  def plot(self):
    pass

class World:

  def __init__(self):
    self.position_key = random.PRNGKey(0)
    self.index_key = random.PRNGKey(1)
    self.direction_key = random.PRNGKey(2)
    self.mesh = Mesh("GPS2F_v6_full.obj")
    #self.mesh = Mesh("box.obj")
    self.light_position = jnp.array([0.9, -0.3, 0.4])

  def sample_directions(self, shape):
    self.direction_key, subkey = random.split(self.direction_key)
    return normalize(2 * random.uniform(subkey, (shape[0], 3)) - 1)

  def sample_camera(self, sample_shape):
    a = 2 * grid(sample_shape) - jnp.array([1,1])
    b = -5 * jnp.ones(sample_shape)
    positions = jnp.concatenate((a, b[...,jnp.newaxis]), axis = -1)
    directions = jnp.ones_like(positions) * jnp.array([0,0,1])
    return Ray(positions, directions)

  def sample_light(self, sample_shape):
    self.position_key, subkey = random.split(self.position_key)
    xi = random.uniform(subkey, sample_shape + (3+3,))
    directions = normalize(sphere(sample_shape) + 0 * xi[...,0:3])
    positions = 0.01 * (xi[...,3:] - 0.5) - 3 * directions
    # visualize_rays(positions, directions)
    return Ray(positions, directions)

  def sample_surface(self, sample_shape):
      self.index_key, subkey = random.split(self.index_key)
      xi = random.uniform(subkey, sample_shape + (3,))

      face_indices = self.mesh.proto_indices[jnp.array(xi[...,0] * self.mesh.proto_size, dtype = int)]

      sampled_faces = self.mesh.faces[face_indices]
      sampled_vertex0 = self.mesh.vertices[sampled_faces[...,0]]
      sampled_vertex1 = self.mesh.vertices[sampled_faces[...,1]]
      sampled_vertex2 = self.mesh.vertices[sampled_faces[...,2]]

      edge0 = sampled_vertex1 - sampled_vertex0
      edge1 = sampled_vertex2 - sampled_vertex0

      bounded_xi = jnp.where(xi[...,1,jnp.newaxis] + xi[...,2,jnp.newaxis] < 1, xi, 1 - xi)
      positions = sampled_vertex0 + bounded_xi[...,1,jnp.newaxis] * edge0 + bounded_xi[...,2,jnp.newaxis] * edge1
      directions = sphere(sample_shape)
      normals = normalize(jnp.cross(edge0, edge1))

      return Ray(positions, directions), normals

  def intersect(self, rays, t_min = 0.1, t_max = 100):
    ts, normals = self.mesh.intersect(rays, t_min, t_max)
    return ts, normalize(normals)

def ray_to_force(incoming_direction, outgoing_direction, normal, specular=1.0, reflectivity=1.0, normal_only=False):
    speed_of_light = 1

    if normal_only:
        magnitude = jnp.linalg.norm(incoming_direction, axis=-1)
        magnitude_reshaped = magnitude[..., jnp.newaxis]
        force = magnitude_reshaped * normal / speed_of_light
        return force

    if not 0 <= specular <= 1 or not 0 <= reflectivity <= 1:
        raise ValueError("Specular and reflectivity values must be between 0 and 1")

    diffuse = 1 - specular

    specular_force = (incoming_direction + reflectivity*(outgoing_direction)) * specular
    diffuse_force = (incoming_direction + reflectivity*(1.0 / 3.0 * normal)) * diffuse

    total_force = (specular_force + diffuse_force) / speed_of_light

    return total_force

def reflect(d, n):
  return normalize(d - 2 * jnp.sum(normalize(d) * normalize(n), axis = -1)[...,jnp.newaxis] * n)

def estimate_backward(sample_shape, world, number_of_bounces, log, force_config):

  rays, normals = world.sample_surface(sample_shape)

  log.init(rays.position)
  log.add(rays.position)

  incoming_direction = rays.direction
  outgoing_direction = reflect(rays.direction, normals)
  force = ray_to_force(incoming_direction, outgoing_direction, normals, **force_config)

  for bounce_index in range(0, number_of_bounces):
    best_ts, best_normals = world.intersect(rays)

    new_position = rays.position + best_ts[...,jnp.newaxis] * rays.direction
    new_direction = reflect(rays.direction, -best_normals)
    rays = Ray(new_position, new_direction)

    log.add(rays.position)

  is_back = jnp.sum(jnp.nan_to_num(normalize(incoming_direction), 0) * normals, axis = -1) > 0
  is_hit = jnp.any(jnp.isnan(rays.position), axis = -1)
  force = jnp.where(jnp.logical_or(is_hit, is_back)[...,jnp.newaxis], force, 0)

  log.plot()

  return force

def estimate_forward(sample_shape, world, number_of_bounces, log):

  rays = world.sample_light(sample_shape)

  log.init(rays.position)
  log.add(rays.position)

  for bounce_index in range(0, number_of_bounces):
    best_ts, best_normals = world.intersect(rays)
    rays = Ray(rays.position + best_ts[...,jnp.newaxis] * rays.direction, rays.direction)
    log.add(rays.position)
    rays = Ray(rays.position, normalize(reflect(normalize(rays.direction), -normalize(best_normals))))

  log.plot()

  #radiance = jnp.where(jnp.any(jnp.isnan(rays.direction), axis = -1)[...,jnp.newaxis], 0, rays.direction)
  #radiance = rays.direction
  #radiance = jnp.nan_to_num(rays.direction, 0)
  radiance = jnp.nan_to_num(best_normals, 0)

  return radiance

def reconstruct(radiance):
  return radiance

world = World()

# def downsample(image):
#   result = 0
#   result += jnp.roll(image, (0, 0), axis=(0,1))
#   result += jnp.roll(image, (1, 0), axis=(0,1))
#   result += jnp.roll(image, (0, 1), axis=(0,1))
#   result += jnp.roll(image, (1, 1), axis=(0,1))
#   result = result[::2,::2]
#   return result / 4

def downsample(image):
    from scipy.ndimage import zoom
    # Calculate the necessary zoom factors for each dimension
    height, width, _ = image.shape
    zoom_factor = (181 / height, 361 / width, 1)

    # Resize the image using the zoom factor
    resized_image = zoom(image, zoom_factor, order=1)  # Using bilinear interpolation (order=1)

    return resized_image

def calculate_rms_difference(image1, image2):
    # Ensure the images have the same shape
    if image1.shape != image2.shape:
        print(f"image shape 1{image1.shape};image shape 2{image2.shape} ")
        raise ValueError("Images must have the same shape for RMS calculation.")

    # Calculate the squared difference
    diff_squared = jnp.square(image1 - image2)

    # Calculate the mean of the squared differences
    mean_diff_squared = jnp.mean(diff_squared)

    # Return the square root of the mean
    return jnp.sqrt(mean_diff_squared)

def load_and_display_jnp_file(file_path):
    # Load the JNP file
    rgb_array = jnp.load(file_path)

    flipped_image = rgb_array[::-1, :, :]

    # Print the shape of the array
    print(f"Shape of the array: {rgb_array.shape}")

    return flipped_image

solution_shape = (1000, 2000)

def build_force_map(sample_count, estimator, force_config, number_of_bounces=2, profile_interval=1, save_path="/content/images"):
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    import jax.numpy as jnp
    import time

    solution = jnp.zeros(solution_shape + (3,))
    total_time = 0

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    for i in range(0, sample_count):
        startTime = time.time()
        radiance = estimator(solution_shape, world, number_of_bounces, NullLog(), force_config)
        solution = solution + reconstruct(radiance)
        time_taken = time.time() - startTime
        total_time += time_taken

        if i % profile_interval == 0:
          sample_count = (solution.size) / (1000 * 1000)
          samples_per_second = sample_count / time_taken
          print("Step {} [total time {:0.1f}s]: {:0.1f} Msamples/s ({:0.1f}M samples in {:0.3f}s)".format(i, total_time, samples_per_second, sample_count, time_taken))

        def tonemap(radiance):
            return (radiance - radiance.min()) / (radiance.max() - radiance.min())

        image = solution / (i + 1)
        image = downsample(image)
        image = tonemap(image)
        fig = px.imshow(image)
        fig.show()

        # Save the image
        image_np = np.array(image)
        plt.imshow(image_np)
        plt.axis('off')  # Turn off axis
        image_filename = os.path.join(save_path, f"image_{i:04d}.png")
        plt.savefig(image_filename, bbox_inches='tight', pad_inches=0)
        plt.close()

    force_map_filename = "force_map"
    for key, value in force_config.items():
        force_map_filename += f"_{key}{value}"
    force_map_filename += ".npy"
    force_map_full_path = os.path.join(save_path, force_map_filename)

    np.save(force_map_full_path, np.array(solution))
    print(f"Force map saved to {force_map_filename}")



def create_gif_from_images(images_path, gif_path, duration=0.5):
    import imageio
    images = []
    for filename in sorted(os.listdir(images_path)):
        if filename.endswith('.png'):
            img_path = os.path.join(images_path, filename)
            images.append(imageio.imread(img_path))

    imageio.mimsave(gif_path, images, duration=duration, palettesize=256, subrectangles=True)

    print(f"GIF created at {gif_path}")

experiments = [
    {  # Surface Normals
        'specular': 1.0,
        'reflectivity': 0.0,
        'normal_only': True
    },
    {  # Only Specular
        'specular': 1.0,
        'reflectivity': 1.0
    },
    {  # Only Diffuse
        'specular': 0.0,
        'reflectivity': 1.0
    },
    {  # No reflections
        'specular': 0.0,
        'reflectivity': 0.0
    },
    {  # GPS2F Gold MLI Material
        'specular': 0.5,
        'reflectivity': 0.7
    }
]

for force_conf in experiments:
    build_force_map(100, estimate_backward, force_conf)
    print(f"built force map for: {force_conf}")

# create_gif_from_images('/content/images', '/content/force_map.gif')
# print("Done")


Step 0 [total time 19.6s]: 0.3 Msamples/s (6.0M samples in 19.606s)


Step 1 [total time 38.9s]: 0.3 Msamples/s (6.0M samples in 19.336s)


Step 2 [total time 58.5s]: 0.3 Msamples/s (6.0M samples in 19.591s)


Step 3 [total time 77.9s]: 0.3 Msamples/s (6.0M samples in 19.365s)


Step 4 [total time 97.3s]: 0.3 Msamples/s (6.0M samples in 19.388s)


Step 5 [total time 116.6s]: 0.3 Msamples/s (6.0M samples in 19.355s)


Step 6 [total time 136.0s]: 0.3 Msamples/s (6.0M samples in 19.352s)
