In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

In [3]:
from iactrace import (
    Telescope, Cylinder, MCIntegrator, HexagonalSensor, SquareSensor, 
    hexshow, plot_telescope_geometry
)

# Import HESS I telescope configuration:

In [4]:
telescope = Telescope.from_csv("../configs/HESS/hess_mirrors.dat",
                               focal_length = 15.0, dish_radius=15.0, surface_type='spherical')

In [5]:
integrator = MCIntegrator(n_samples=256, roughness=10.0)
#sensor = HexagonalSensor(hex_centers=hex_centers)
sensor = SquareSensor(100,100)

In [6]:
print("Compiling simulation...")
sim = telescope.compile(
    integrator=integrator,
    sensor=sensor,
    source_type='infinity',
    sensor_plane=(jnp.array([0., 0., 15.]), jnp.array([0., 0., -1.])),
    sampling_key=jax.random.key(42),
    alignment_key=jax.random.key(11)
)

print("Simulation compiled!")

Compiling simulation...
Simulation compiled!


# Simulating Star field:

In [7]:
# Generate star field
n_stars = 1000
key = jax.random.key(143)
key1, key2 = jax.random.split(key)

# Small angular region (5 degrees field of view)
fov_deg = 5
fov_rad = fov_deg * jnp.pi / 180

x = jax.random.uniform(key1, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
y = jax.random.uniform(key2, (n_stars,), minval=-fov_rad/2, maxval=fov_rad/2)
z = -jnp.ones(n_stars)

stars = jnp.stack([x, y, z], axis=1)
stars = stars / jnp.linalg.norm(stars, axis=1, keepdims=True)

In [8]:
%%time
result = sim(stars, 'infinity').block_until_ready()

IndexError: Too many indices: 1-dimensional array indexed with 2 regular indices.

In [9]:
plt.imshow(result)

NameError: name 'result' is not defined

# Simulating Point Source:

In [None]:
N_points = 1

key = jax.random.key(12)
key1, key2 = jax.random.split(key)

x = jax.random.uniform(key1, N_points, minval=-1, maxval=1)
y = jax.random.uniform(key2, N_points, minval=-1, maxval=1)
z = jnp.ones(N_points) * 120

points = jnp.array([x,y,z]).T

In [None]:
%%time
result = sim(points, 'point').block_until_ready()

In [None]:
plt.imshow(result)
plt.colorbar()

In [None]:
380*64