# Computing photometry of Diffsky galaxies

This notebook demonstrates how to compute photometry through arbitrary bandpasses of a population of Diffsky galaxies. 

In [None]:
# ! wget -q https://portal.nersc.gov/project/hacc/aphearin/DSPS_data/ssp_data_fsps_v3.2_lgmet_age.sparse.h5

## Load the mock data

This cells show how to load the mock natively produced by the diffsky source code. For purposes of this demo, we will just work with a single lightcone patch of the natively-generated hdf5 files, which will be downloaded in the next cell.

#### Using the OpenCosmo toolkit
The hdf5 files produced by diffsky are later ingested by the [OpenCosmo](https://opencosmo.readthedocs.io/en/stable/) toolkit, which enables efficient querying, map-making, and other features. See [Accessing and Working With Diffsky Mock Galaxy Catalogs](https://github.com/ArgonneCPAC/opencosmo-examples/blob/main/03-Diffsky/demo_diffmah_diffstar.ipynb) for a tutorial on how to load diffsky mocks with OpenCosmo.

In [None]:
import os
from diffsky.data_loaders import load_flat_hdf5

drn_mock = "/Users/aphearin/work/DATA/random_data/1122/smdpl_dr1"
basename_list = ["lc_cores-411.0.diffsky_gals.hdf5", ]
mock_collector = []
for bn in basename_list:
    fn_mock = os.path.join(drn_mock, bn)
    mock_this_bn = load_flat_hdf5(fn_mock, dataset='data')
    mock_collector.append(mock_this_bn)
    
mock_data = {key: np.concatenate([mock[key] for mock in mock_collector]) for key in mock_collector[0].keys()}

In [None]:
from diffsky.data_loaders.hacc_utils import load_lc_mock as llcm

mock_info = llcm.load_lc_mock_info(fn_mock)
print(mock_info.keys())

In [None]:
ssp_mag_table, wave_eff_table = llcm.get_ssp_phot_tables(
    mock_info['tcurves'], mock_info['z_phot_table'], mock_info['ssp_data'], mock_info['sim_info'])

In [None]:
phot_data = llcm.compute_mock_photometry(
    mock_data, mock_info, mock_info['tcurves'], ssp_mag_table, wave_eff_table)
phot_data.keys()

#### Validate recalculation of LSST magnitudes

The next two cells verify that when we recompute photometry through LSST bands, we get the same result stored as column data in the mock.

In [None]:
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 4))
__=ax0.scatter(mock_data['lsst_u'], phot_data['obs_mags'][:, 0], s=1)
__=ax1.hist(phot_data['obs_mags'][:, 0]-mock_data['lsst_u'], alpha=0.7)
xlabel = ax0.set_xlabel('lsst_u mock')
ylabel = ax0.set_ylabel('lsst_u recomputed')
xlabel = ax1.set_xlabel('recomputed - mock')

In [None]:
assert np.allclose(mock_data['lsst_u'], phot_data['obs_mags'][:, 0], rtol=1e-4)

In [None]:
mock_info['tcurves']._fields

## Computing photometry in other bands

To compute photometry in some other band, you can repeat the calculation above, but using a different choice for transmission curves. To use your own transmission curves, you need to format each transmission curve as a `namedtuple` with two fields: `wave` and `transmission`. Then your collection of transmission curves should be bundled into a namedtuple with one entry per bandpass. Below we demonstrate how to do this using a few transmission curves associated with the COSMOS-20 dataset that are provided as part of DSPS.

In [None]:
# ! wget -q https://portal.nersc.gov/project/hacc/aphearin/DSPS_data/filters/J_uv_transmission.h5
# ! wget -q https://portal.nersc.gov/project/hacc/aphearin/DSPS_data/filters/H_uv_transmission.h5
# ! wget -q https://portal.nersc.gov/project/hacc/aphearin/DSPS_data/filters/K_uv_transmission.h5

In [None]:
import h5py
from collections import namedtuple

TransmissionCurve = namedtuple("TransmissionCurve", ("wave", "transmission"))

filter_nicknames = ("J_uv", "H_uv", "K_uv")
bname_list = [s + "_transmission.h5" for s in filter_nicknames]
tcurve_collector = []
for bname in bname_list:
    with h5py.File(bname, 'r') as hdf:
        tcurve = TransmissionCurve(hdf['wave'][:], hdf['transmission'][:])
    tcurve_collector.append(tcurve)

tcurves = namedtuple("Tcurves", filter_nicknames)(*tcurve_collector)

print(tcurves._fields)

In [None]:
ssp_mag_table, wave_eff_table = llcm.get_ssp_phot_tables(
    tcurves, mock_info['z_phot_table'], mock_info['ssp_data'], mock_info['sim_info'])

In [None]:
# phot_data2 = llcm.compute_mock_photometry(mock_data, mock_info, ssp_mag_table, wave_eff_table)

In [None]:
mock_data['delta_scatter_ms'].shape

In [None]:
from diffsky.experimental import lc_phot_kern

wave_eff_galpop = lc_phot_kern.interp_vmap2(
    mock_data["redshift_true"], mock_info["z_phot_table"], wave_eff_table
)
wave_eff_galpop.shape

In [None]:
from diffsky.ssp_err_model import ssp_err_model
ssp_err_model.LAMBDA_REST

In [None]:
from diffsky import phot_utils
rest_wave_eff = phot_utils.get_wave_eff_from_tcurves(tcurves, 0.0)
rest_wave_eff

In [None]:
from jax import vmap
interp_vmap = jjit(vmap(jnp.interp, in_axes=(None, None, 0)))
# interp_vmap(wave_eff_galpop, ssp_err_model.LAMBDA_REST, mock_data['delta_scatter_ms']).shape

delta_scatter_obs_table = interp_vmap(rest_wave_eff, ssp_err_model.LAMBDA_REST, mock_data['delta_scatter_ms'])

In [None]:
jnp.interp(3.0, ssp_err_model.LAMBDA_REST, ssp_err_model.LAMBDA_REST*5)