# Loading pre-trained models

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import numpy as jnp

import diffmahnet

In [None]:
# View available pre-trained models
available_names = diffmahnet.pretrained_model_names
available_names

In [None]:
centrals_model = diffmahnet.load_pretrained_model("cenflow_v2_0.eqx")
satellites_model = diffmahnet.load_pretrained_model("satflow_v2_0.eqx")

## Generate diffmah parameters
- Condition on a few values of $u = (\log M_{\rm obs}, t_{\rm obs})$
- For each conditional value, generate a sample of 1000 MAHs

In [None]:
n_sample = 1000
m_grid = jnp.array([11.0, 12.5, 14.0])
t_grid = jnp.array([2.0, 7.5, 13.0])
m_vals, t_vals = [jnp.repeat(x.flatten(), n_sample)
                  for x in jnp.meshgrid(m_grid, t_grid)]
print(m_vals)
print(t_vals)

## Create functions roughly equivalent to `mc_diffmah_*pop()` for our trained models

In [None]:
# Note a few differences from mc_diffmah_cenpop:
# - Only returns a single set of DiffmahParams
# - Does not depend on lgt0 or t_peak
mc_diffmahnet_cenpop = centrals_model.make_mc_diffmahnet()
mc_diffmahnet_satpop = satellites_model.make_mc_diffmahnet()

randkey = jax.random.key(0)
keys = jax.random.split(randkey, 2)
cenflow_diffmahparams = mc_diffmahnet_cenpop(
    centrals_model.get_params(), m_vals, t_vals, keys[0])
satflow_diffmahparams = mc_diffmahnet_satpop(
    satellites_model.get_params(), m_vals, t_vals, keys[1])

In [None]:
# Plot mass accretion histories from the predicted diffmah parameters
tgrid = jnp.linspace(0.1, t_vals, 100).T
cen_mah = diffmahnet.log_mah_kern(
    cenflow_diffmahparams, tgrid, np.log10(13.8))
sat_mah = diffmahnet.log_mah_kern(
    satflow_diffmahparams, tgrid, np.log10(13.8))

In [None]:
# Plot the MAH of every 200th halo (5 per set of {M_obs, t_obs})
plt.plot([], [], label="centrals", color="C0")
plt.plot([], [], label="satellites", color="C1")
plt.plot(tgrid[::200].T, cen_mah[::200].T, color="C0", alpha=0.5)
plt.plot(tgrid[::200].T, sat_mah[::200].T, color="C1", alpha=0.5)
plt.legend(frameon=False)
plt.xlabel("$\\rm t \\; [Gyr]$")
plt.ylabel("$\\rm \\log(M_h(t)/M_\\odot)$")
plt.show()