# Calculate ${\rm SFH}(t)$ for a population of galaxies selected from the catalog

This notebook shows how to use the Diffstar source code to calculate ${\rm SFH}(t)$ for a large galaxy population. 

### Read in diffmah and diffstar parameters from the catalog

We will read in the catalog, select a sample of red and blue galaxies and compute their star-formation histories

In [None]:
import sys
sys.path.insert(0, '/global/homes/k/kovacs/gcr-catalogs_diffsky_v0.1')
import numpy as np
import GCRCatalogs
gc = GCRCatalogs.load_catalog('diffsky_v0.1_p3765_vsmall')
native_quantities = sorted(gc.list_all_native_quantities())
from lsstdesc_diffsky.constants import MAH_PNAMES, MS_U_PNAMES, Q_U_PNAMES
print(MAH_PNAMES, MS_U_PNAMES, Q_U_PNAMES)
print(gc.cosmology)

#### Get params from catalog

In [None]:
cat_data = gc.get_quantities(MAH_PNAMES + MS_U_PNAMES + Q_U_PNAMES + ['redshift', 'mag_r', 'mag_i'])

## Define r-i color and define cuts for a red and blue sample

In [None]:
cat_data['r-i'] = cat_data['mag_r'] -  cat_data['mag_i']
red_mask = (cat_data['r-i'] > 1.3) &  (cat_data['redshift'] < 1.0)
print(np.count_nonzero(red_mask))
blue_mask = (cat_data['r-i'] < 0.1) &  (cat_data['redshift'] < 1.0)
print(np.count_nonzero(blue_mask))

### Build the JAX kernel and compute the SFHs

In this calculation there is a natural tradeoff between performance in memory and compute cycles. Different options for input kwargs `tobs_loop` and `galpop_loop` have different performance on CPUs/GPUs.

In [None]:
T0 = gc.cosmology.age(0.0).value
LGT0 = np.log10(T0)

Collect SFH parameters into arrays; define time array

In [None]:
mah_params = np.array([cat_data[_] for _ in MAH_PNAMES]).T
ms_u_params = np.array([cat_data[_] for _ in MS_U_PNAMES]).T
q_u_params = np.array([cat_data[_] for _ in Q_U_PNAMES]).T
#q_u_params = assemble_param_arrays(cat_data, Q_U_PNAMES)
print(mah_params.shape, ms_u_params.shape, q_u_params.shape)
times = np.linspace(0.1, T0, 100)

In [None]:
from lsstdesc_diffsky.photometry.get_SFH_from_params import get_sfh_from_params
sfh_red = get_sfh_from_params(mah_params[red_mask], ms_u_params[red_mask], q_u_params[red_mask], LGT0, times)
sfh_blue = get_sfh_from_params(mah_params[blue_mask], ms_u_params[blue_mask], q_u_params[blue_mask], LGT0, times)
print(sfh_red.shape, sfh_blue.shape)

### Plot a few example SFHs from our red and blue subsamples

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
fig, ax_all = plt.subplots(1, 2, figsize=(15, 5), sharey=True)

for ax, sfh, color in zip(ax_all.flat, [sfh_red, sfh_blue], ['r', 'blue']):
    
    for __ in range(25):
        iplot = np.random.randint(0, sfh.shape[0])
        __ = ax.plot(times, sfh[iplot, :], lw=0.5, color=color)
        
    ax.set_yscale('log')
    ax.set_xlabel('Time (Gyr')
    ax.set_ylabel('SFR ($M_\odot$/yr)')

fig.savefig('sfh.png')