In [6]:
from Shared.specific_CNB_sim import *


total_start = time.perf_counter()
init_dis =8.178
angle_momentum_decay = jnp.load(f'neutrino_angle_momentum_decay.npy') 
decayed_neutrinos_z = jnp.load(f'decayed_neutrinos_z.npy')
z_array = jnp.load(f'sim_output/no_gravity/z_int_steps.npy')
neutrino_momenta = jnp.load(f'sim_output/no_gravity/neutrino_momenta.npy')
s_int_steps=jnp.load(f'sim_output/no_gravity/s_int_steps.npy')
z_int_steps=jnp.load(f'sim_output/no_gravity/z_int_steps.npy')



In [2]:
@jax.jit
def EOMs_no_gravity(s_val, y, Nr):
    Nr_index = Nr 

    def find_nearest(array, value):
        idx = jnp.argmin(jnp.abs(array - value))
        return array[idx]

    # Find z corresponding to s via interpolation.
    z = jnp.interp(s_val, s_int_steps, z_int_steps)
    z_nearest = find_nearest(z_array, z)
    # check index of z in our z array
    z_index = jnp.where(z_nearest == z_array)[0]
    neutrino_number = decayed_neutrinos_z[int(z_index)][Nr_index]
    prev_neutrino_number = decayed_neutrinos_z[int(z_index) - 1][Nr_index]

    if neutrino_number == 0 and prev_neutrino_number == 1:
        _, u_i_p = y  # choose a random angle between 0 and 180
        p_i = find_nearest(neutrino_momenta, 0.06 * ((u_i_p[0]) ** 2 + (u_i_p[1]) ** 2 + (u_i_p[2]) ** 2))
        p_index = jnp.where(neutrino_momenta == p_i)[0]
        angle_decay_theta = random.randint(random.PRNGKey(0), (0, 180))
        angle_decay_phi = random.randint(random.PRNGKey(0), (0, 360))
        momentum_daughter = angle_momentum_decay[angle_decay_theta][int(p_index)]
        u_i = [(1 / 0.05) * momentum_daughter * jnp.sin(angle_decay_theta) * jnp.cos(angle_decay_phi),
               (1 / 0.05) * momentum_daughter * jnp.sin(angle_decay_theta) * jnp.sin(angle_decay_phi),
               (1 / 0.05) * momentum_daughter * jnp.cos(angle_decay_theta)]
    else:
        # Initialize vector.
        _, u_i = y

    dyds = -jnp.array([
        u_i, jnp.zeros(3)
    ])

    return dyds

In [7]:
@jax.jit
def backtrack_1_neutrino(init_vector, s_int_steps):

    """
    Simulate trajectory of 1 neutrino. Input is 6-dim. vector containing starting positions and velocities of neutrino. Solves ODEs given by the EOMs function with an jax-accelerated integration routine, using the diffrax library. Output are the positions and velocities at each timestep, which was specified with diffrax.SaveAt. 
    """
    y0_r, Nr = init_vector[0:-1], init_vector[-1]
    # Initial vector in correct shape for EOMs function
    y0 = y0_r.reshape(2,3)

    # ODE solver setup
    term = diffrax.ODETerm(EOMs_no_gravity)
    t0 = s_int_steps[0]
    t1 = s_int_steps[-1]
    dt0 = (s_int_steps[0] + s_int_steps[1]) / 1000
    

    ### ------------- ###
    ### Dopri5 Solver ###
    ### ------------- ###
    solver = diffrax.Dopri5()
    stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    # note: no change for tighter rtol and atol, e.g. rtol=1e-5, atol=1e-9


    # Specify timesteps where solutions should be saved
    saveat = diffrax.SaveAt(ts=jnp.array(s_int_steps))
    
    # Solve the coupled ODEs, i.e. the EOMs of the neutrino
    sol = diffrax.diffeqsolve(
        term, solver, 
        t0=t0, t1=t1, dt0=dt0, y0=y0, max_steps=10000,
        saveat=saveat, stepsize_controller=stepsize_controller, args=(Nr,))
    
    trajectory = sol.ys.reshape(100,6)

    # Only return the initial [0] and last [-1] positions and velocities
    return jnp.stack([trajectory[0], trajectory[-1]])



In [13]:
@jax.jit
#@partial(jit, static_argnums=1)
def simulate_neutrinos_1_pix(init_xyz, init_vels, s_int_steps):
    
    """
    Function for the multiprocessing routine below, which simulates all neutrinos for 1 pixel on the healpix skymap.
    """
   
    # Neutrinos per pixel
    nus = init_vels.shape[0]

    # Make vector with same starting position but different velocities
    init_vectors_0 = jnp.array(
        [jnp.concatenate((init_xyz, init_vels[k])) for k in range(nus)])
    
    Nr_column = jnp.arange(1000).reshape(-1, 1)
    # Concatenate the additional column to the original array
    init_vectors = jnp.hstack((init_vectors_0, Nr_column))

    trajectories = jnp.array([
        backtrack_1_neutrino(vec, s_int_steps) for vec in init_vectors])

    return trajectories  # shape = (neutrinos, 2, 3)



In [14]:
# File name ending
end_str = f'halo1'

# Initial position (Earth)
init_xyz = jnp.array([float(init_dis), 0., 0.])
jnp.save(f'sim_output/no_gravity/init_xyz_{end_str}.npy', init_xyz)


### ============== ###
### Run Simulation ###
### ============== ###

print(f"*** Simulation for no_gravity ***")

sim_start = time.perf_counter()

with open(f'sim_output/no_gravity/sim_parameters.yaml', 'r') as file:
    sim_setup = yaml.safe_load(file)

init_vels = jnp.load(f'sim_output/no_gravity/initial_velocities.npy')  
Nr = jnp.arange(192000)

trajectories_1_pix = simulate_neutrinos_1_pix(init_xyz, init_vels[0], s_int_steps)


*** Simulation for no_gravity ***
(Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/0)>,)


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function _fn at /home/fzimmer/.conda/envs/neutrino_clustering/lib/python3.10/site-packages/equinox/_make_jaxpr.py:36 for make_jaxpr. This concrete value was not available in Python because it depends on the value of the argument _dynamic_flat[3].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError