In [None]:
import GCRCatalogs
import pandas as pd
import tables_io

In [None]:
import os
os.environ['GCR_CONFIG_SOURCE']="files"
GCRCatalogs.get_available_catalog_names(name_contains="rubin")

In [None]:
rubinsim = GCRCatalogs.load_catalog('roman_rubin_2023_v1.1.3_elais')

In [None]:
rubinsim.list_all_native_quantities()

In [None]:
rubinsim_relevantcols = [_col for _col in rubinsim.list_all_native_quantities() if ("LSST_obs" in _col or "ROMAN_obs" in _col) and "nodust" not in _col] + ['redshift']
rubinsim_quantities = rubinsim.get_quantities(rubinsim_relevantcols, native_filters=['healpix_pixel == 10552'])

In [None]:
rubinsim_quantities

In [None]:
rubinsim_df = pd.DataFrame(rubinsim_quantities)

In [None]:
rubinsim_df

In [None]:
from rail.core.data import PqHandle
from rail.core.stage import RailStage

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from jax import numpy as jnp

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
rename_dict = {
    "LSST_obs_u": "mag_u_lsst",
    "LSST_obs_g": "mag_g_lsst",
    "LSST_obs_r": "mag_r_lsst",
    "LSST_obs_i": "mag_i_lsst",
    "LSST_obs_z": "mag_z_lsst",
    "LSST_obs_y": "mag_y_lsst",
    "ROMAN_obs_R062": "mag_wfi_f062_roman",
    "ROMAN_obs_Z087": "mag_wfi_f087_roman",
    "ROMAN_obs_Y106": "mag_wfi_f106_roman",
    "ROMAN_obs_J129": "mag_wfi_f129_roman",
    "ROMAN_obs_W146": "mag_wfi_f146_roman",
    "ROMAN_obs_H158": "mag_wfi_f158_roman",
    "ROMAN_obs_F184": "mag_wfi_f184_roman",
    "ROMAN_obs_K213": "mag_wfi_f213_roman"
}

band_dict = { _key.split('_')[-1]: _val for _key, _val in rename_dict.items() if 'lsst' in _val }
band_dict.update({ _key.split('_')[-1][0]: _val for _key, _val in rename_dict.items() if 'roman' in _val })
band_dict

In [None]:
#dropped_cols = ["ROMAN_obs_W146", "ROMAN_obs_R062", "ROMAN_obs_Z087", "ROMAN_obs_K213"]
data_df = rubinsim_df.rename(columns=rename_dict).copy() #.drop(columns=dropped_cols).rename(columns=rename_dict)
data_df

In [None]:
data_truth = PqHandle('input')
data_truth.set_data(data_df)

## Degrader 1 : LSST & Rubin errors

Now, we will demonstrate the `LSSTErrorModel`, which adds photometric errors using a model similar to the model from [Ivezic et al. 2019](https://arxiv.org/abs/0805.2366) (specifically, it uses the model from this paper, without making the high SNR assumption. To restore this assumption and therefore use the exact model from the paper, set `highSNR=True`.)

Let's create an error model with the default settings:

In [None]:
from rail.creation.degraders.photometric_errors import LSSTErrorModel
from rail.creation.degraders.photometric_errors import RomanErrorModel

In [None]:
errorModel_lsst = LSSTErrorModel.make_stage(
    name="lsst_error_model",
    renameDict=band_dict,
    ndFlag=np.nan
) #, extendedType="auto")
samples_w_lssterrs = errorModel_lsst(data_truth) #errorModel_lsst(samples_conf_inc_mag)

In [None]:
samples_w_lssterrs()

In [None]:
errorModel_Roman = RomanErrorModel.make_stage(
    name="roman_error_model",
    renameDict=band_dict,
    ndFlag=np.nan
)
errorModel_Roman.config['m5']['Y'] = 27.0
errorModel_Roman.config['theta']['Y'] = 27.0

In [None]:
samples_w_romanerrs = errorModel_Roman(samples_w_lssterrs)

In [None]:
samples_w_romanerrs()

In [None]:
#samples_w_lssterrs_df = tables_io.convertObj(samples_w_lssterrs.data, tables_io.types.PD_DATAFRAME)
#samples_w_romanerrs_df = tables_io.convertObj(samples_w_romanerrs.data, tables_io.types.PD_DATAFRAME)

In [None]:
#samples_w_errs_df = (samples_w_lssterrs_df.drop(columns=["W", "R", "Z", "Y", "J", "H", "F", "K"])).merge(samples_w_romanerrs_df.drop(columns=["u", "g", "r", "i", "z", "y"]))

In [None]:
#samples_w_errs = PqHandle('input')
#samples_w_errs.set_data(samples_w_errs_df)

In [None]:
#samples_w_errs()

## Degrader 2 : Quantity Cut (magnitude)

Recall how the sample above has galaxies as dim as magnitude 30. This is well beyond the LSST 5-sigma limiting magnitudes, so it will be useful to apply cuts to the data to filter out these super-dim samples. We can apply these cuts using the `QuantityCut` degrader. This degrader will cut out any samples that do not pass all of the specified cuts.

Let's make and run degraders that first adds photometric errors, then cuts at i<25.3, which is the LSST gold sample.

If you look at the i column, you will see there are no longer any samples with i > 25.3. The number of galaxies returned has been nearly cut in half from the input sample and, unlike the LSSTErrorModel degrader, is not equal to the number of input objects.  Users should note that with degraders that remove galaxies from the sample the size of the output sample will not equal that of the input sample.

One more note: it is easy to use the QuantityCut degrader as a SNR cut on the magnitudes. The magnitude equation is $m = -2.5 \log(f)$. Taking the derivative, we have
$$
dm = \frac{2.5}{\ln(10)} \frac{df}{f} = \frac{2.5}{\ln(10)} \frac{1}{\mathrm{SNR}}.
$$
So if you want to make a cut on galaxies above a certain SNR, you can make a cut
$$
dm < \frac{2.5}{\ln(10)} \frac{1}{\mathrm{SNR}}.
$$
For example, an SNR cut on the i band would look like this: `QuantityCut({"i_err": 2.5/np.log(10) * 1/SNR})`.

In [None]:
from rail.creation.degraders.quantityCut import QuantityCut

In [None]:
mag_cut = QuantityCut.make_stage(name="cuts", cuts={"mag_i_lsst": 25.3})
samples_mag = mag_cut(samples_w_romanerrs)

In [None]:
samples_mag()

## Degrader 3 : Inv redshift incompleteness

Next, we will demonstrate the `InvRedshiftIncompleteness` degrader. It applies a selection function, which keeps galaxies with probability $p_{\text{keep}}(z) = \min(1, \frac{z_p}{z})$, where $z_p$ is the ''pivot'' redshift. We'll use $z_p = 0.8$.

In [None]:
from rail.creation.degraders.spectroscopic_degraders import InvRedshiftIncompleteness

In [None]:
inv_incomplete = InvRedshiftIncompleteness.make_stage(
    name="incompleteness", pivot_redshift=0.8
)
samples_incomplete_mag = inv_incomplete(samples_mag)

In [None]:
samples_incomplete_mag()

## Degrader 4: LineConfusion

`LineConfusion` is a degrader that simulates spectroscopic errors resulting from the confusion of different emission lines.

For this example, let's use the degrader to simulate a scenario in which which 2% of [OII] lines are mistaken as [OIII] lines, and 1% of [OIII] lines are mistaken as [OII] lines. (note I do not know how realistic this scenario is!)

In [None]:
from rail.creation.degraders.spectroscopic_degraders import LineConfusion

In [None]:
OII = 3727
OIII = 5007

lc_2p_0II_0III = LineConfusion.make_stage(
    name="lc_2p_0II_0III", true_wavelen=OII, wrong_wavelen=OIII, frac_wrong=0.02
)
lc_1p_0III_0II = LineConfusion.make_stage(
    name="lc_1p_0III_0II", true_wavelen=OIII, wrong_wavelen=OII, frac_wrong=0.01
)
samples_conf_inc_mag = lc_1p_0III_0II(
    lc_2p_0II_0III(samples_incomplete_mag)
)

In [None]:
samples_conf_inc_mag()

## Check Data

In [None]:
fig, ax = plt.subplots(figsize=(5, 4), dpi=100)

zmin = 0
zmax = 3.1

hist_settings = {
    "bins": 50,
    "range": (zmin, zmax),
    "density": True,
    "histtype": "step",
}

ax.hist(data_truth()["redshift"], label="Roman-Rubin sample", **hist_settings)
ax.hist(samples_mag()["redshift"], label="Mag. cut", **hist_settings)
ax.hist(
    samples_incomplete_mag()["redshift"],
    label="Incomplete Mag. Cut",
    **hist_settings
)
ax.hist(
    samples_conf_inc_mag()["redshift"],
    label="Confused Incomplete Mag. Cut",
    **hist_settings
)
ax.legend(title="Sample")
ax.set(xlim=(zmin, zmax), xlabel="Redshift", ylabel="Galaxy density")
plt.show()

## Try and run BPZ on the catalog

In [None]:
from rail.tools.table_tools import ColumnMapper, TableConverter
import tables_io

rename_dict_bpz = { _key: '_'.join([_key.split('_')[0], _key.split('_')[-1]]+_key.split('_')[1:-1]) for _key in samples_conf_inc_mag.data.keys() if "err" in _key }
rename_dict_bpz

In [None]:
col_remapper = ColumnMapper.make_stage(
    name="col_remapper",
    columns=rename_dict_bpz,
)

table_conv = TableConverter.make_stage(
    name="table_conv",
    output_format="numpyDict",
)

data_colmap = col_remapper(samples_conf_inc_mag)
data_bpz = table_conv(data_colmap)

In [None]:
data_bpz()

In [None]:
data_bpz_df = tables_io.convertObj(data_bpz.data, tables_io.types.PD_DATAFRAME)
data_bpz_df

In [None]:
train_data_df = data_bpz_df.sample(5000)
train_data_df

In [None]:
from rail.core.data import TableHandle

train_data = DS.add_data("train_data", train_data_df, TableHandle)

table_conv_train = TableConverter.make_stage(
    name="table_conv_train",
    output_format="numpyDict",
)

train_data_conv = table_conv_train(train_data)

In [None]:
train_data_conv()

In [None]:
bands = ["u", "g", "r", "i", "z", "y"]
lsst_bands = []
lsst_errs = []
lsst_filts = []
for band in bands:
    lsst_bands.append(f"mag_{band}_lsst")
    lsst_errs.append(f"mag_err_{band}_lsst")
    lsst_filts.append(f"DC2LSST_{band}")
print(lsst_bands)
print(lsst_filts)

In [None]:
robands = ["wfi_f106", "wfi_f129", "wfi_f158", "wfi_f184"]
roman_bands = [f"mag_{band}_roman" for band in robands]
roman_errs = [f"mag_err_{band}_roman" for band in robands]
roman_filts = [f"roman_{band}" for band in robands]
print(roman_bands)
print(roman_filts)

In [None]:
from rail.estimation.algos.bpz_lite import BPZliteInformer, BPZliteEstimator

from rail.core.data import ModelHandle

RAILDIR = "/global/u2/j/jcheval/rail_base/src"

cosmospriorfile = os.path.join(RAILDIR, "rail/examples_data/estimation_data/data/COSMOS31_HDFN_prior.pkl")
cosmosprior = DS.read_file("cosmos_prior", ModelHandle, cosmospriorfile)
sedfile = "COSMOS_seds.list" #os.path.join(RAILDIR, "rail/examples_data/estimation_data/data/SED/COSMOS_seds.list")

inform_bpz = BPZliteInformer.make_stage(
    name="inform_bpz",
    nondetect_val=np.nan,
    model="bpz.pkl",
    hdf5_groupname="photometry",
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data"
)

inform_bpz.inform(train_data_conv)

In [None]:
estimate_bpz = BPZliteEstimator.make_stage(
    name="estimate_bpz",
    hdf5_groupname="",
    nondetect_val=np.nan,
    model=inform_bpz.get_handle("model"),
    no_prior=True,
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data"
)

In [None]:
bpz_estimated = estimate_bpz.estimate(data_bpz)

In [None]:
z_phot = bpz_estimated.data.mode(grid=np.linspace(0.01, 3.+4.*0.15, 301, endpoint=True))

In [None]:
z_phot.shape

In [None]:
z_true = data_bpz()['redshift']

In [None]:
import matplotlib.pyplot as plt
f, a = plt.subplots(1,1, figsize=(6,6))
zs = np.linspace(0.01, 3., 100)
a.scatter(z_true, z_phot, alpha=0.1, s=2, label='BPZ, no prior, LSST filters\nDefault SED templates set')
a.plot(zs, zs, 'k:')
a.plot(zs, zs+(1+zs)*0.15, 'k-')
a.plot(zs, zs-(1+zs)*0.15, 'k-')
a.set_xlabel('z_spec')
a.set_ylabel('z_phot')
a.set_xlim(-0.01, 3.1)
a.set_ylim(-0.01, 3.1)
a.set_aspect('equal', 'box')
a.grid()
a.legend()

### Try with non-default SEDs but just LSST bands

In [None]:
cosmos_dict = dict(
    hdf5_groupname="photometry",
    output="bpz_results_COSMOS_SEDs_LSST.hdf5",
    spectra_file=sedfile,
    bands=lsst_bands,
    err_bands=lsst_errs,
    filter_list=lsst_filts,
    prior_band="mag_i_lsst",
    no_prior=True
)

In [None]:
run_newseds = BPZliteEstimator.make_stage(
    name="bpz_newseds_lsst",
    model=cosmosprior,
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data", **cosmos_dict
)

In [None]:
newseds_bpz_estimated = run_newseds.estimate(data_bpz)

In [None]:
z_phot_new = newseds_bpz_estimated.data.mode(grid=np.linspace(0.01, 3.+4.*0.15, 301, endpoint=True))

In [None]:
import matplotlib.pyplot as plt
f, a = plt.subplots(1,1, figsize=(6,6))
zs = np.linspace(0.01, 3., 100)
a.scatter(z_true, z_phot_new, alpha=0.2, s=4, label='BPZ, no prior, LSST filters\nPolletta et al 2007 and BC03 SED templates set')
a.plot(zs, zs, 'k:')
a.plot(zs, zs+(1+zs)*0.15, 'k-')
a.plot(zs, zs-(1+zs)*0.15, 'k-')
a.set_xlabel('z_spec')
a.set_ylabel('z_phot')
a.set_xlim(-0.01, 3.1)
a.set_ylim(-0.01, 3.1)
a.set_aspect('equal', 'box')
a.grid()
a.legend()

### Try with non default SEDs and all bands

In [None]:
cosmos_roman_dict = dict(
    hdf5_groupname="photometry",
    output="bpz_results_COSMOS_SEDs_LSST_ROMAN.hdf5",
    spectra_file=sedfile,
    bands=lsst_bands+roman_bands,
    err_bands=lsst_errs+roman_errs,
    filter_list=lsst_filts+roman_filts,
    mag_limits={_band: 28 for _band in lsst_bands+roman_bands},
    zp_errors=np.full(len(lsst_bands+roman_bands), 0.01),
    prior_band="mag_i_lsst",
    no_prior=True
)

In [None]:
run_newseds_roman = BPZliteEstimator.make_stage(
    name="bpz_newseds_lsstRoman",
    model=cosmosprior,
    data_path="/global/u2/j/jcheval/rail_base/src/rail/examples_data/estimation_data/data", **cosmos_roman_dict
)

In [None]:
roman_bpz_estimated = run_newseds_roman.estimate(data_bpz)

In [None]:
z_phot_roman = roman_bpz_estimated.data.mode(grid=np.linspace(0.01, 3.+4.*0.15, 301, endpoint=True))

In [None]:
import matplotlib.pyplot as plt
f, a = plt.subplots(1,1, figsize=(6,6))
zs = np.linspace(0.01, 3., 100)
a.scatter(z_true, z_phot_roman, alpha=0.2, s=4, label='BPZ, no prior, LSST+Roman filters\nPolletta et al 2007 and BC03 SED templates set')
a.plot(zs, zs, 'k:')
a.plot(zs, zs+(1+zs)*0.15, 'k-')
a.plot(zs, zs-(1+zs)*0.15, 'k-')
a.set_xlabel('z_spec')
a.set_ylabel('z_phot')
a.set_xlim(-0.01, 3.1)
a.set_ylim(-0.01, 3.1)
a.set_aspect('equal', 'box')
a.grid()
a.legend()

## Save as appropriate input for process_fors2.photoZ

In [None]:
'''
rerename_dict = {
    "u": "mag_lsst_u",
    "g": "mag_lsst_g",
    "r": "mag_lsst_r",
    "i": "mag_lsst_i",
    "z": "mag_lsst_z",
    "y": "mag_lsst_y",
    "u_err": "mag_err_lsst_u",
    "g_err": "mag_err_lsst_g",
    "r_err": "mag_err_lsst_r",
    "i_err": "mag_err_lsst_i",
    "z_err": "mag_err_lsst_z",
    "y_err": "mag_err_lsst_y",
    "R": "mag_roman_wfi_f062",
    "Z": "mag_roman_wfi_f087",
    "Y": "mag_roman_wfi_f106",
    "J": "mag_roman_wfi_f129",
    "W": "mag_roman_wfi_f146",
    "H": "mag_roman_wfi_f158",
    "F": "mag_roman_wfi_f184",
    "K": "mag_roman_wfi_f213",
    "R_err": "mag_err_roman_wfi_f062",
    "Z_err": "mag_err_roman_wfi_f087",
    "Y_err": "mag_err_roman_wfi_f106",
    "J_err": "mag_err_roman_wfi_f129",
    "W_err": "mag_err_roman_wfi_f146",
    "H_err": "mag_err_roman_wfi_f158",
    "F_err": "mag_err_roman_wfi_f184",
    "K_err": "mag_err_roman_wfi_f213",
    "redshift": "z_spec"
}
'''
#rerename_dict = {"redshift": "z_spec"}
#rerename_dict.update(
#    {_key: '_'.join([_key.split('_')[0], _key.split('_')[-1], _key.split('_')[-2]]) for _key in samples_conf_inc_mag_w_errs.data.keys() if "lsst" in _key and not "err" in _key}
#)
rerename_dict = {_key: '_'.join([_key.split('_')[0], _key.split('_')[-1], _key.split('_')[-3], _key.split('_')[-2]]) for _key in samples_conf_inc_mag.data.keys() if "lsst" in _key and "err" in _key}

rerename_dict.update(
    {_key: '_'.join([_key.split('_')[0], _key.split('_')[-1], _key.split('_')[-3], _key.split('_')[-2]]) for _key in samples_conf_inc_mag.data.keys() if "roman" in _key and not "err" in _key}
)
rerename_dict.update(
    {_key: '_'.join([_key.split('_')[0], _key.split('_')[-1], _key.split('_')[-2], _key.split('_')[-4], _key.split('_')[-3]]) for _key in samples_conf_inc_mag.data.keys() if "roman" in _key and "err" in _key}
)
rerename_dict

In [None]:
from rail.tools.table_tools import ColumnMapper

col_remapper_proF2 = ColumnMapper.make_stage(
    name="col_remapper_proF2",
    columns=rerename_dict,
)

cat_for_processf2 = col_remapper_proF2(samples_conf_inc_mag)
cat_for_processf2()

In [None]:
final_cat_df = tables_io.convertObj(cat_for_processf2.data, tables_io.types.PD_DATAFRAME)
final_cat_df

In [None]:
cols_to_drop = [
    _col for _col in final_cat_df.columns if "mag" in _col and "err" not in _col and "mag_err_"+"_".join(_col.split("_")[1:]) not in final_cat_df.columns
]
cols_to_drop

In [None]:
final_cat_df.drop(columns=cols_to_drop, inplace=True)

In [None]:
final_cat_df

In [None]:
if True : final_cat_df.to_hdf('magszgalaxies_lsstroman_gold_hp10552.h5', key='photometry')

## Test distribution manips

In [None]:
import qp

In [None]:
zmin = 0.01
zmax = 3.1
hcounts, hbins = np.histogram(final_cat_df['z_spec'], bins=301, range=(zmin, zmax))
hbins, hcounts

In [None]:
ens_h = qp.Ensemble(qp.hist, data=dict(bins=hbins, pdfs=np.row_stack([hcounts for i in range(10)])))
grid = np.linspace(zmin, zmax, 302, endpoint=True)

In [None]:
ens_h.plot(xlim=(zmin, zmax))

In [None]:
ztrue=final_cat_df['z_spec'].values[:10]

In [None]:
hPIT = qp.metrics.pit.PIT(ens_h, ztrue)

In [None]:
meta_metr=hPIT.calculate_pit_meta_metrics()

In [None]:
meta_metr