In [1]:
import sys
sys.path.append('..')

from tools.siren import *
from tools.table import *

from tools.simulation import create_siren_grid
from tools.generate import new_differentiable_get_rays

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import time

table = Table('cprof_mu_train_10000ev.h5')
table_data = create_siren_grid(table)
siren_model, model_params = load_siren_jax('siren_cprof_mu.pkl')
origin = jnp.array([0.5, 0.0, -0.5])
direction = jnp.array([1.0, -1.0, 0.2])
Nphot = 1000000
energy =500
key = random.PRNGKey(0)

In [2]:
# import time
# import jax

# import sys
# sys.path.append('..')

# from tools.generate import new_differentiable_get_rays

# Nphot = 1000000
# energy =500

# def profile_function(f, *args, **kwargs):
#     # Compile the function first
#     results = f(*args, **kwargs)
#     jax.block_until_ready(results)  # Use jax.block_until_ready on the tuple
    
#     # Measure runtime
#     times = []
#     for _ in range(3):  # Number of iterations
#         start = time.perf_counter()
#         results = f(*args, **kwargs)
#         jax.block_until_ready(results)  # Block on all results
#         end = time.perf_counter()
#         times.append(end - start)
    
#     print(f"Average runtime: {sum(times)/len(times)*1000:.2f} ms")
#     print(f"Min runtime: {min(times)*1000:.2f} ms")
#     print(f"Max runtime: {max(times)*1000:.2f} ms")

# # Use it

In [3]:
import time
import jax
from tools.generate import differentiable_get_rays

Nphot = 100_000

def profile_function(f, *args, **kwargs):
    # Compile the function first
    results = f(*args, **kwargs)
    jax.block_until_ready(results)  # Use jax.block_until_ready on the tuple
    
    # Measure runtime
    times = []
    for _ in range(20):  # Number of iterations
        start = time.perf_counter()
        results = f(*args, **kwargs)
        jax.block_until_ready(results)  # Block on all results
        end = time.perf_counter()
        times.append(end - start)
    
    print(f"Average runtime: {sum(times)/len(times)*1000:.2f} ms")
    print(f"Min runtime: {min(times)*1000:.2f} ms")
    print(f"Max runtime: {max(times)*1000:.2f} ms")

In [4]:
# The old generate function
profile_function(differentiable_get_rays, origin, direction, 40, Nphot, key)

Average runtime: 2.79 ms
Min runtime: 2.53 ms
Max runtime: 3.06 ms


In [5]:
# The new generate function
profile_function(new_differentiable_get_rays, origin, direction, energy, Nphot, table_data, model_params, key)

Average runtime: 301.59 ms
Min runtime: 298.48 ms
Max runtime: 314.67 ms


In [6]:
import jax
import jax.numpy as jnp

def create_random_points(M):
    """
    Create M random points in 3D space where:
    - X and Y coordinates are randomly sampled from [-1, 1]
    - Z coordinate is fixed to 0
    
    Args:
        M (int): Number of points
        
    Returns:
        jnp.ndarray: Array of shape (M, 3) containing the 3D points
    """
    key = jax.random.PRNGKey(0)
    
    # Generate random X and Y coordinates
    xy = jax.random.uniform(key, shape=(M, 2), minval=-1, maxval=1)
    
    # Create zeros for Z coordinate
    z = jnp.zeros((M, 1))
    
    # Concatenate to get final points
    points = jnp.concatenate([xy, z], axis=1)
    
    return points

random_points = create_random_points(Nphot)

In [7]:
import sys
sys.path.append('..')

from tools.siren import *
from tools.table import *

from tools.simulation import create_siren_grid
from tools.generate import new_differentiable_get_rays

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

siren_model, model_params = load_siren_jax('siren_cprof_mu.pkl')

# Initialize SIREN model
model = SIREN(
    hidden_features=256,
    hidden_layers=3,
    out_features=1,
    outermost_linear=True
)

photon_weights, _ = model.apply(model_params, random_points)

In [8]:
profile_function(model.apply, model_params, random_points)

Average runtime: 408.20 ms
Min runtime: 398.53 ms
Max runtime: 439.48 ms
