In [None]:
import numpy as np
import jax
from matplotlib import pyplot as plt
from jax import numpy as jnp
from jax import jit as jjit
from jax import vmap

mred = u'#d62728' 
mgreen = u'#2ca02c'
mblue = u'#1f77b4' 

## Delay time distribution (DTD) Eq. 6 of https://academic.oup.com/mnras/article/506/3/3330/6318383?login=false

In [None]:
@jjit
## tau    (Gyr)
## tp      (Gyr)
## A   (1e-13 M^{-1}_{sun} yr^{-1}
##
##  Nominal values A=2.11; beta=-1.13; tp=0.04

def DTD(tau, A, beta, tp):
    return jax.lax.cond(tau < tp, lambda a: 0., lambda a : a[0]*jnp.power(a[1],a[2]), [A, tau, beta])

In [None]:
DTD_vmap = jjit(vmap(DTD, in_axes=(0, None, None, None)))

In [None]:
A=2.11; beta=-1.13; tp=0.04

In [None]:
tau = jnp.linspace(0, 1, 100)
fig, ax = plt.subplots(1, 1)
__=ax.plot(tau, DTD_vmap(tau, A, beta, tp))

## Placeholder Main Sequence of Star Formation with stupid t-z relation Eq. 8 of  https://academic.oup.com/mnras/article/506/3/3330/6318383?login=false

In [None]:
@jjit
## M    (1e10)
## t Gyr
def SFH(t, M):
    d = t * 1e9 * 3.16e7 * 3e8 /3.09e16/1e6 # Mpc
    z= 68.*d/3e5
    return jnp.power(M,0.7)*(jnp.exp(1.9*z)/(jnp.exp(1.7*(z-2))+jnp.exp(0.2*(z-2))))

In [None]:
SFH_vmap = jjit(vmap(SFH, in_axes=(0, None)))

In [None]:
M=1.

In [None]:
t = jnp.linspace(0, 10, 100)
fig, ax = plt.subplots(1, 1)
__=ax.plot(t, SFH_vmap(t, M))

## DIFFSTAR SFH
Use an example galaxy that has been parameterized by DIFFSTAR

In [None]:
from diffstar.stars import _get_bounded_sfr_params
from diffstar.quenching import _get_bounded_q_params
from diffstar.stars import calculate_sm_sfr_fstar_history_from_mah

In [None]:
output_data={'halo_id': '1251',
 'lgmcrit': '1.15884e+01',
 'lgy_at_mcrit': '1.17262e-01',
 'indx_lo': '6.05819e-01',
 'indx_hi': '-7.67132e-01',
 'tau_dep': '-9.95171e+00',
 'qt': '1.06138e+00',
 'qs': '-4.43739e+00',
 'q_drop': '-1.09065e+00',
 'q_rejuv': '-4.21483e+00',
 'loss': '2.96707e-02',
 'success': '1'}

In [None]:
colnames = list(output_data.keys())
sfr_colnames = colnames[1:6]
q_colnames = colnames[6:10]
u_sfr_fit_params = np.array([output_data[key] for key in sfr_colnames]).astype(float)
u_q_fit_params = np.array([output_data[key] for key in q_colnames]).astype(float)

## Supernova Rate

In [None]:
@jjit
# in log(tau) units
def SNR_kernel(logtau, t0, M, A, beta, tp):
    tau = jnp.exp(logtau)
    return DTD(tau-t0, A, beta, tp)*SFH(tau, M)*tau

In [None]:
SNR_kernel_vmap = jjit(vmap(SNR_kernel, in_axes=(0, None, None, None, None, None)))

In [None]:
@jjit
def SNR(t0, M, A, beta, tp):
    logtf=jnp.log(12.)
    logtaus = jnp.linspace(jnp.log(t0), logtf, 1000)
    kernel = SNR_kernel_vmap(logtaus, t0, M, A, beta, tp)
    return jnp.trapz(kernel,logtaus)

In [None]:
SNR_vmap = jjit(vmap(SNR, in_axes=(0, None, None, None, None)))

In [None]:
t0=jnp.linspace(0.1, 12, 20)

In [None]:
SNR_vmap(t0, M, A, beta, tp)

In [None]:
fig, ax = plt.subplots(1, 1)
__=ax.plot(t0, SNR_vmap(t0, M, A, beta, tp))