In [25]:
%load_ext autoreload
%autoreload 2

from Shared.shared import *
from Shared.specific_CNB_sim import *

import pandas as pd

sim_name = f"Dopri5_1k"
sim_folder = f"sim_output/{sim_name}"
fig_folder = f"figures_local/{sim_name}"
Cl_folder = f"Shared/Cls"
nu_m_range = jnp.load(f"{sim_folder}/neutrino_massrange_eV.npy")
nu_m_picks = jnp.array([0.01, 0.05, 0.1, 0.2, 0.3])*Params.eV
simdata = SimData(sim_folder)

pix_dens_FD = jnp.load(f"{sim_folder}/pixel_densities_day1.npy")
tot_dens_FD = jnp.load(f"{sim_folder}/total_densities_day1.npy")
print(pix_dens_FD.shape)
print(tot_dens_FD.shape)
print(tot_dens_FD)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
(1, 5, 768)
(1, 50)
[[56.59817452 56.60432873 56.60716104 56.60837604 56.60885944 56.60903686
  56.60909659 56.60911493 56.60912002 56.6091213  56.60912158 56.60912164
  56.60912165 56.60912165 56.60912165 56.60912165 56.60912165 56.60912164
  56.60912164 56.60912164 56.60912164 56.60912164 56.60912163 56.60912163
  56.60912162 56.60912162 56.60912161 56.6091216  56.60912159 56.60912157
  56.60912155 56.60912153 56.6091215  56.60912146 56.60912142 56.60912137
  56.60912131 56.60912123 56.60912113 56.60912101 56.60912086 56.60912068
  56.60912045 56.60912012 56.60911983 56.60911941 56.60911889 56.60911825
  56.60911746 56.60911649]]


# Earth-Sun distances

In [8]:
# Load the Excel file into a pandas DataFrame
df = pd.read_excel('Data/Earth-Sun_distances.xlsx')


ES_distances = jnp.array(df.iloc[:, 1::2].apply(pd.to_numeric, errors='coerce')\
                                  .stack().reset_index(drop=True).tolist())[:-1]
ES_dists_kpc = ES_distances*Params.AU/Params.kpc

# print(ES_distances[jnp.array([0,-1])])
#? first and last day are the same day? since there are 366 elements?

days = jnp.arange(1,len(ES_distances)+1)

print(ES_distances.shape, days.shape)

print(ES_distances[0], ES_dists_kpc[0])

(365,) (365,)
0.9833098 4.7672204384173685e-09


# Solar gravity function

In [20]:
@jax.jit
def sun_gravity(x_i, com_sun, eps):

    x_i = x_i - com_sun
    r_i = jnp.linalg.norm(x_i) + eps

    gradient_sun = Params.G*Params.Msun*x_i/r_i**3

    # Acceleration is negative value of (grav. pot.) gradient.
    return -gradient_sun

com_sun = jnp.array([0.,0.,0.])

# init_coords = jnp.load(f"{sim_folder}/init_xyz_halo1.npy")*Params.kpc
# init_coords = jnp.load(f"{sim_folder}/init_xyz_modulation.npy")*Params.kpc
# print(init_coords/Params.kpc)
init_dis = ES_dists_kpc[0]*Params.kpc*10000
init_coords = np.array([float(init_dis), 0., 0.])

init_vels = jnp.load(f"{sim_folder}/initial_velocities.npy")*(Params.kpc/Params.s)
# print(init_vels[0,0,:]*(Params.kpc/Params.s)/(Params.km/Params.s))

eps = 350_000/3e16*Params.kpc
# eps = 0.
grad_sun = sun_gravity(init_coords, com_sun, eps)
print(grad_sun)

# Replace NaNs with zeros and apply cutoff
grad_sun = jnp.nan_to_num(grad_sun)
cutoff = 1e-35
grad_sun = jnp.where(jnp.abs(grad_sun) < cutoff, 0.0, grad_sun)

grad_sun /= (Params.kpc/Params.s**2)
print(grad_sun)

[-1.34655724e-34 -0.00000000e+00 -0.00000000e+00]
[-1.98760102e-30  0.00000000e+00  0.00000000e+00]


# Creation of integration steps (z and s)

In [None]:
from astropy.time import Time
from scipy.integrate import quad

def z_at_age(age):
    # Function to calculate redshift for a given age of the universe
    def age_diff(z):
        return cosmo.age(z).to(apu.year).value - age
    
    # Use root_scalar to find the redshift that gives the specified age
    sol = root_scalar(age_diff, bracket=[0, 10], method='brentq', xtol=1e-12, rtol=1e-12)
    return sol.root

def get_redshift_array():
    # Today's date
    today = Time.now()
    
    days = 365

    # Initialize an array to store redshift values with 64-bit precision
    redshifts = jnp.zeros(days, dtype=jnp.float64)
    
    # Age of the universe today
    age_today = cosmo.age(0).to(apu.year).value
    
    # Calculate redshift for each day in the past year
    for day in range(days):
        # Calculate the date for each day of the past year
        past_date = today - day * apu.day
        # Calculate the lookback time in years
        lookback_time = (today - past_date).to(apu.year).value
        # Calculate the age of the universe at this lookback time
        age_of_universe_then = age_today - lookback_time
        # Find the redshift corresponding to this age of the universe
        redshift = z_at_age(age_of_universe_then)
        # Store the redshift value in the array
        redshifts = redshifts.at[day].set(redshift)
    
    return redshifts


def s_of_z(z, args):
    """
    Convert redshift to time variable s with eqn. 4.1 in Mertsch et al.
    (2020), keeping only Omega_M and Omega_L in the Hubble eqn. for H(z).

    Args:
        z (float): redshift

    Returns:
        float: time variable s (in [seconds] if 1/H0 factor is included)
    """    

    def s_integrand(z):        

        # We need value of H0 in units of 1/s.
        H0_val = args.H0/(1/args.s)
        a_dot = np.sqrt(args.Omega_M*(1.+z)**3 + args.Omega_L)/(1.+z)*H0_val
        s_int = 1./a_dot

        return s_int

    s_of_z, _ = quad(s_integrand, 0., z)

    return jnp.float64(s_of_z)


# Get the redshift array for the past year
z_int_steps_1year = get_redshift_array()
print(z_int_steps_1year)

s_int_steps_1year = jnp.array([s_of_z(z, Params()) for z in z_int_steps_1year])
print(s_int_steps_1year.min(), s_int_steps_1year[2], s_int_steps_1year.max())
for i in range(5):
    dt0 = (s_int_steps_1year[i+2]+s_int_steps_1year[i+1])/10_000
    print(dt0)

jnp.save(f"{sim_folder}/z_int_steps_1year.npy", z_int_steps_1year)
jnp.save(f"{sim_folder}/s_int_steps_1year.npy", s_int_steps_1year)

In [None]:
s_int_steps_1year
# dt0 = (s_int_steps_1year[1])
# dt0

# Mini-Sim: 1 Pixel

In [None]:
# Simulation parameters.
with open(f'{sim_folder}/sim_parameters.yaml', 'r') as file:
    sim_setup = yaml.safe_load(file)

CPUs_sim = 128
neutrinos = 1000
init_dis = sim_setup['initial_haloGC_distance']
z_int_steps = jnp.load(f'{sim_folder}/z_int_steps_1year.npy')
s_int_steps = jnp.load(f'{sim_folder}/s_int_steps_1year.npy')
nu_massrange = jnp.load(f'{sim_folder}/neutrino_massrange_eV.npy')*Params.eV


@jax.jit
def EOMs_sun(s_val, y, args):

    # Unpack the input data
    s_int_steps, z_int_steps, kpc, s = args

    # Initialize vector.
    x_i, u_i = y

    # Switch to "numerical reality" here.
    x_i *= kpc
    u_i *= (kpc/s)

    # Find z corresponding to s via interpolation.
    z = Utils.jax_interpolate(s_val, s_int_steps, z_int_steps)

    # Compute gradient of sun.
    eps = 350_000/3e16*kpc
    grad_sun = SimExec.sun_gravity(x_i, jnp.array([0.,0.,0.]), eps)

    # Replace NaNs with zeros and apply cutoff
    grad_sun = jnp.nan_to_num(grad_sun)
    cutoff = 1e-100
    grad_sun = jnp.where(jnp.abs(grad_sun) < cutoff, 0.0, grad_sun)

    # Switch to "physical reality" here.
    grad_sun /= (kpc/s**2)
    x_i /= kpc
    u_i /= (kpc/s)

    # Hamilton eqns. for integration (global minus, s.t. we go back in time).
    dyds = -jnp.array([
        u_i, 1./(1.+z)**2 * grad_sun
    ])

    return dyds
    # return -u_i, -1./(1.+z)**2*grad_sun


@jax.jit
def backtrack_1_neutrino(init_vector, s_int_steps, z_int_steps, kpc, s):

    """
    Simulate trajectory of 1 neutrino. Input is 6-dim. vector containing starting positions and velocities of neutrino. Solves ODEs given by the EOMs function with an jax-accelerated integration routine, using the diffrax library. Output are the positions and velocities at each timestep, which was specified with diffrax.SaveAt. 
    """

    # Initial vector in correct shape for EOMs function
    y0 = init_vector.reshape(2,3)

    # ODE solver setup
    term = diffrax.ODETerm(EOMs_sun)
    t0 = s_int_steps[0]
    t1 = s_int_steps[-1]
    dt0 = (s_int_steps[2]) / 2
    

    ### ------------------ ###
    ### Integration Solver ###
    ### ------------------ ###
    solver = diffrax.Dopri8()
    stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
    # stepsize_controller = diffrax.ConstantStepSize()
    # solver = diffrax.Kvaerno3()
    # stepsize_controller = diffrax.PIDController(rtol=1e-1, atol=1e-3)
    # solver = diffrax.SemiImplicitEuler()

    args = (s_int_steps, z_int_steps, kpc, s)

    """
    tprev = t0
    tnext = t0 + dt0
    y = y0
    state = solver.init(term, tprev, tnext, y0, args)

    while tprev < t1:
        y, _, _, state, _ = solver.step(
            term, tprev, tnext, y, args, state, made_jump=False)
        print(f"At time {tnext} obtained value {y}")
        tprev = tnext
        tnext = min(tprev + dt0, t1)
    """

    # """
    # Specify timesteps where solutions should be saved
    saveat = diffrax.SaveAt(ts=jnp.array(s_int_steps))

    # Solve the coupled ODEs, i.e. the EOMs of the neutrino
    sol = diffrax.diffeqsolve(
        term, solver, 
        t0=t0, t1=t1, 
        dt0=dt0, 
        y0=y0, max_steps=10000,
        saveat=saveat, 
        stepsize_controller=stepsize_controller, 
        args=args)
    
    trajectory = sol.ys.reshape(365,6)

    # Only return the initial [0] and last [-1] positions and velocities
    return jnp.stack([trajectory[0], trajectory[-1]])
    # """


def simulate_neutrinos_1_pix(init_xyz, init_vels, common_args):

    """
    Function for the multiprocessing routine below, which simulates all neutrinos for 1 pixel on the healpix skymap.
    """

    # Neutrinos per pixel
    nus = init_vels.shape[0]

    # Make vector with same starting position but different velocities
    init_vectors = jnp.array(
        [jnp.concatenate((init_xyz, init_vels[k])) for k in range(nus)])


    trajectories = jnp.array([
        backtrack_1_neutrino(vec, *common_args) for vec in init_vectors])
    
    return trajectories  # shape = (neutrinos, 2, 6)


### ============== ###
### Run Simulation ###
### ============== ###

# File name ending
end_str = f'modulation'

# Initial position (Earth)
init_xyz = np.array([float(init_dis), 0., 0.])
jnp.save(f'{sim_folder}/init_xyz_{end_str}.npy', init_xyz)

print(f"*** Simulation for modulation ***")

sim_start = time.perf_counter()


init_vels = np.load(f'{sim_folder}/initial_velocities.npy')  
# shape = (Npix, neutrinos per pixel, 3)

common_args = (s_int_steps, z_int_steps, Params.kpc, Params.s)

# Simulate all neutrinos along 1 pixel, without multiprocessing
nu_vectors = simulate_neutrinos_1_pix(init_xyz, init_vels[0], common_args)

sim_time = time.perf_counter()-sim_start
print(f"Simulation time: {sim_time/60.:.2f} min, {sim_time/(60**2):.2f} h")

print(nu_vectors.shape)

# Analyze vectors

In [12]:
nu_vectors = jnp.load(f"{sim_folder}/vectors_day1.npy").reshape(-1, 2, 6)
print(nu_vectors.shape)

positions_z0 = nu_vectors[:, 0, :3]
positions_z4 = nu_vectors[:, 1, :3]
velocities_z0 = nu_vectors[:, 0, 3:]
velocities_z4 = nu_vectors[:, 1, 3:]

print(positions_z4.shape)

(768000, 2, 6)
(768000, 3)


# Convert vectors to number densities

In [24]:
nu_vectors = jnp.load(f"{sim_folder}/vectors_modulation.npy")

print(nu_vectors.reshape(-1,2,6)[0])

# # Compute individual number densities for each healpixel
# nu_allsky_masses = jnp.array([0.01, 0.05, 0.1, 0.2, 0.3])*Params.eV
# pix_dens = Physics.number_densities_all_sky(
#     v_arr=nu_vectors[..., 3:],
#     m_arr=nu_allsky_masses,
#     pix_sr=simdata.pix_sr,
#     args=Params())

# # Compute total number density, by using all neutrino vectors for integral
# tot_dens = Physics.number_densities_mass_range(
#     v_arr=nu_vectors.reshape(-1, 2, 6)[..., 3:], 
#     m_arr=nu_m_range, 
#     pix_sr=4*Params.Pi,
#     args=Params())



[[ 4.76722044e-09  0.00000000e+00  0.00000000e+00 -3.91294247e-18
  -3.91294247e-18 -5.40072969e-17]
 [            nan             nan             nan             nan
              nan             nan]]


# Plot outputs

In [None]:
def modulation(nu_dens_days=None):
    
    # 1 to 365 for each day of the year
    days = jnp.arange(1, 366)
    
    # Placeholder, actual data after finishing project
    if nu_dens_days is None:
        nu_dens_days = jnp.ones_like(days)*56

    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(days, nu_dens_days, label='FD_sim')

    # Customize x-axis to show ticks and labels only on specific dates
    tick_days = [305, 32, 122, 213]  # Corresponding to Nov 1, Feb 1, May 1, and Aug 1
    tick_labels = ['Nov 1', 'Feb 1', 'May 1', 'Aug 1']

    plt.xticks(tick_days, tick_labels)

    # Add labels and title
    plt.xlabel('Day of the Year')
    plt.ylabel('Number Density')
    plt.title('Number Densities Across the Year')
    plt.legend()

    # Show grid for better readability
    plt.grid(True, which="major", linestyle="dashed")

    # Display the plot
    plt.show()

modulation()