# Calculating SFH with Diffstar

This notebook gives two basic illustrations of how to use diffstar to model the SFHs of individual and populations of galaxies.

### SFH of an individual diffstar galaxy

In the cell below, we'll grab the default diffmah and diffstar parameters, then we'l use the `sfh_singlegal` function to calculate the SFH.

In [None]:
import numpy as np
from diffstar.defaults import DEFAULT_MAH_PARAMS
from diffstar.defaults import DEFAULT_MS_PARAMS
from diffstar.defaults import DEFAULT_Q_PARAMS

today_gyr = 13.8 
tarr = np.linspace(0.1, today_gyr, 100)

In [None]:
from diffstar import sfh_singlegal

sfh_gal = sfh_singlegal(
    tarr, DEFAULT_MAH_PARAMS, DEFAULT_MS_PARAMS, DEFAULT_Q_PARAMS)

### SFHs of a population of diffstar galaxies

For purposes of this toy demonstration, we'll first create a small diffstar population by randomly adding noise to the default diffstar parameters.

In [None]:
n_gals = 5

mah_params_galpop = np.tile(DEFAULT_MAH_PARAMS, n_gals)
mah_params_galpop = mah_params_galpop.reshape((n_gals, -1))

ms_params_galpop = np.tile(DEFAULT_MS_PARAMS, n_gals)
ms_params_galpop = ms_params_galpop.reshape((n_gals, -1))

q_params_galpop = np.tile(DEFAULT_Q_PARAMS, n_gals)
q_params_galpop = q_params_galpop.reshape((n_gals, -1))

ms_noise = np.random.normal(
    loc=0, scale=0.25, size=(n_gals, DEFAULT_MS_PARAMS.size))
ms_params_galpop = ms_params_galpop + ms_noise

q_noise = np.random.normal(
    loc=0, scale=0.1, size=(n_gals, DEFAULT_Q_PARAMS.size))
q_params_galpop = q_params_galpop + q_noise

The `sfh_galpop` calculates the SFH of an entire population at once. This calculation is vectorized with `jax.vmap` and so will be more efficient than a loop over successive calls to `sfh_singlegal`.

In [None]:
from diffstar import sfh_galpop

sfh_pop = sfh_galpop(
    tarr, mah_params_galpop, ms_params_galpop, q_params_galpop)

In [None]:
from matplotlib import pyplot as plt

fig, ax = plt.subplots(1, 1)
ylim = ax.set_ylim(1e-3, 50)
yscale = ax.set_yscale('log')

__=ax.plot(tarr, sfh_gal, '--', color='k')

for igal in range(n_gals):
    __=ax.plot(tarr, sfh_pop[igal, :])


xlabel = ax.set_xlabel(r'${\rm cosmic\ time\ [Gyr]}$')
ylabel = ax.set_ylabel(r'${\rm SFR\ [M_{\odot}/yr]}$')