# Rubin LSST DESC DC2: Accessing variable truth information (for stars and SNe) with GCRCatalogs

**Authors**: Yao-Yuan Mao (@yymao)

This notebook will illustrate the basics of accessing the variable truth information, such as the light curves for stars and SNe.

**Prerequisite (optional)**: The Object Table tutorial provides an introduction to `GCRCatalogs`. It is not strictly needed in this tutorial, but if you have questions about the use of `GCRCatalogs`, please check out the Object Table tutorial.

**Learning objectives**: After going through this notebook, you should be able to:
  1. Know the difference of the various truth catalogs related variable truth information.
  2. Being able to load the light curves for any SNe or stars that are of interset. 

## Before you start

Make sure you have followed the instructions on the [DESC Data Portal](https://lsstdesc-portal.nersc.gov/) to 
download the data files, install `GCRCatalogs`, and set up `root_dir` for `GCRCatalogs`.

In this example notebook, the following files will be needed:
- `sn_truth_summary.parquet`
- `sn_variability_truth.parquet`
- `star_lc_stats_int_id.parquet`
- `star_truth_summary_int_id.parquet`
- `star_variability_truth_int_id.parquet`

## Import necessary packages

In [None]:
# TODO: Remove this cell in final version
import sys
sys.path.insert(0, "../../../gcr-catalogs")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import pandas as pd
from astropy.coordinates import SkyCoord

In [None]:
import GCRCatalogs
from GCRCatalogs import GCRQuery
from GCRCatalogs.utils import first

In [None]:
def flux2mag(flux):
    """
    A convenience function to convert flux to magnitude
    """
    with np.errstate(divide="ignore"):
        return 22.5 - 2.5 * np.log10(flux)

In [None]:
# TODO: Remove this cell in final version
GCRCatalogs.set_root_dir("/global/cfs/cdirs/lsst/gsharing")

## Supernova light curves

We will start with finding light curves for supernovae.  When accessing DC2 variable truth information, one important concept is that the light curves are stored separetely from a "summary table". 

In the summary table, each row corresponds to one object (supernova), and it contains "summary" truth information such as coordinates, redshift, and time-averaged fluxes, but not the light curves. 

We can use the summary table to figure out the IDs of supernovae that we are interested in (e.g., in a certain sky area, in certain redshift ranges), these IDs will later be used to extract light curves.

### Loading supernova summary table

We first load the summary table into `sn_summary_cat` and check out what columns are available.

In [None]:
sn_summary_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_sn_truth_summary")

In [None]:
sorted(sn_summary_cat.list_all_quantities())

We will now actually load the data into memory. For this tutorial, we will only need the ID (`id`), coordinates (`ra`, `dec`), redshift, and extinction corrected fluxes (`flux_*_noMW`). 

Recall that `get_quantities` returns a Python dictionary. We can use `pd.DataFrame` to turn it into a Pandas DataFrame.

In [None]:
d_summary = pd.DataFrame(sn_summary_cat.get_quantities(["id", "ra", "dec", "redshift"] + [f"flux_{band}_noMW" for band in "ugrizy"]))
d_summary.head()

The fluxes (`flux_*` and `flux_*_noMW`) in the summary table are time-averaged fluxes over *infinite* time. For supernovae, those averages will always be zero. We can check if that's the case:

In [None]:
d_summary[[f"flux_{band}_noMW" for band in "ugrizy"]].max()

Now that we have the summary table, we can use it to select, say supernovae near a certain coordinate or in a certain redshift range. In this tutorial, we will find all supernovae that are within 2 arcmin of (RA, Dec) = (60.5, -36.6) and within z < 0.8. We will store these supernovae in `objects_to_plot` which we will use later to extract and plot the their light curves.

In [None]:
selected = SkyCoord(d_summary["ra"].values, d_summary["dec"].values, unit="deg").separation(SkyCoord(60.5, -36.6, unit="deg")).arcmin < 2
selected &= d_summary["redshift"] < 0.8

objects_to_plot = d_summary.iloc[np.flatnonzero(selected)]
objects_to_plot

### Loading supernova light curve (variability) table

Now that we have a list of supernovae that we want to inspect further, we can load the "variability table", which contains the light curves. 

We will load the "variability table" into `sn_light_curve_cat` and check what columns are available. 

In [None]:
sn_light_curve_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_sn_truth_variability")
sn_light_curve_cat.list_all_quantities()

Unlike the summary table, in the variability table, each row corresponds to one "measurement" the flux of a given object, at a given time, with a given filter. Hence, a single supernovae may correspond to many rows. 

Here, let's peek the beginning of the variability table. Note that the variability table is a much larger table, and you probably don't want to load the whole thing in the memory. Here we set `return_iterator=True` so that we can look at just the first chunk of the table. 

In [None]:
d_lc = pd.DataFrame(first(sn_light_curve_cat.get_quantities(sn_light_curve_cat.list_all_quantities(), return_iterator=True)))
d_lc.head()

You can see that from above that the first five rows are all different measurements (at different time or with different filters) for a single object (`MS_9177_2005`). 

Now that we have a basic idea of how the variability table is structured, we can start to search for the light curves for the supernovae that we identified earlier. In short, we need to extract all rows that have IDs overlapping with the IDs in `objects_to_plot`. 

Because the variability table is a large table, we don't want to load the whole thing and then do the search. By specifying `filters` in `get_quantities`, the backend code will do the search on each smaller chunk, which will save the runtime memory. 

The filter (i.e., search query) we will use here looks a bit complicated, but it really just means find all rows whose `id` is in `objects_to_plot["id"].values`: 
```python
GCRQuery((lambda x: np.isin(x, objects_to_plot["id"].values), "id"))
```
Even with this filter, the cell below will take a while to run. If runtime is a concern, you may want to use the SQL database that we also provide to access these information instead.

In [None]:
d_lc = pd.DataFrame(sn_light_curve_cat.get_quantities(
    ['id', 'MJD', 'bandpass', 'delta_flux'],
    filters=[GCRQuery((lambda x: np.isin(x, objects_to_plot["id"].values), "id"))],
))
d_lc.head()

By now, this `d_lc` DataFrame has all the data for everything in `objects_to_plot`. We would want to split the table by both the object IDs and by bandpass (filter). Pandas' `groupby` function comes in handy!

In [None]:
d_lc_grouped = d_lc.groupby(["id", "bandpass"]).groups

In [None]:
fig, ax = plt.subplots(ncols=len(objects_to_plot), figsize=(len(objects_to_plot)*4, 4))

for obj, ax_this in zip(objects_to_plot.itertuples(), ax):
    for i, band in enumerate("ugrizy"):
        indices = d_lc_grouped.get((obj.id, band))
        if indices is not None:
            ax_this.plot(d_lc["MJD"][indices], flux2mag(getattr(obj, f"flux_{band}_noMW") + d_lc["delta_flux"][indices]), 'o:', color=f"C{i}", label=f"${band}$")
    ax_this.legend(ncol=3)
    ax_this.set_title(obj.id)
    ax_this.set_xlabel("MJD")
    ax_this.set_ylabel("magnitude")
    ax_this.set_ylim(25, 10)

fig.tight_layout()        

Now we have the light curves! You might notice that when we plot the flux, we add the time-averaged `flux_X_noMW` to `delta_flux`. This is not necessary for SNe because, as we mentioned earlier, the infinite time-averaged fluxes are zero. But we keep our code here more generic (as we will see later, this addition is needed for stars). 


## Variable star light curves

Now we are going to repeat everything we have done so far, but with variable stars. Just like what we did earlier, we will start with getting the summary information, and use that to identify some variable stars of interest. 

### Loading supernova summary table

In [None]:
star_summary_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_star_truth_summary")

In [None]:
sorted(star_summary_cat.list_all_quantities())

You'll notice that there are a few extra columns like `stdev_X`. This is because not all stars in this summary table are "variable"; some are rather static. 

In fact,only those stars whose `stdev` in at least one band is larger than 0.001 are stored in the variability table. So here we will also grab `stdev_r` when we are looking for interseting variable stars. 

In [None]:
d_summary = pd.DataFrame(star_summary_cat.get_quantities(["id", "ra", "dec", "stdev_r"] + [f"flux_{band}_noMW" for band in "ugrizy"]))

Another important difference between stars and supernovae is that for stars, their infinite time-averaged fluxes are usualy non-zero! We can verify that:

In [None]:
d_summary[[f"flux_{band}_noMW" for band in "ugrizy"]].min()

Like earlier, we will now use the summary table to select variable stars near a certain coordinate. Here, we further select bright stars and stars with `stdev_r` > 0.001 (to ensure that they appear in the  variability table). We will store the selected stars in `objects_to_plot` which we will use later to extract and plot the their light curves.

In [None]:
selected = SkyCoord(d_summary["ra"].values, d_summary["dec"].values, unit="deg").separation(SkyCoord(60.5, -36.6, unit="deg")).arcmin < 2
selected &= d_summary["flux_r_noMW"] > 1.0e3
selected &= d_summary["stdev_r"] > 1.0e-3

objects_to_plot = d_summary.iloc[np.flatnonzero(selected)]
objects_to_plot

### Loading supernova light curve (variability) table

We will now load the "variability table" for stars.

In [None]:
star_lc_cat = GCRCatalogs.load_catalog("desc_dc2_run2.2i_star_truth_variability")

Again, recall that in the "variability table", each row corresponds to one measurement, and we need to collect all the rows that have `id` values overlap with `objects_to_plot["id"].values`. 

Also remember that the variability table is a large table (even larger in the star case). We will use the same `filters` trick again to reduce the memory footprint, but the cell below will still result in ~4 GB at peak memory usage. It will take a while to run. If runtime is a concern, you may want to use the SQL database that we also provide to access these information instead.

In [None]:
d_lc = pd.DataFrame(star_lc_cat.get_quantities(
    ['id', 'MJD', 'bandpass', 'delta_flux'],
    filters=[GCRQuery((lambda x: np.isin(x, objects_to_plot["id"].values), "id"))],
))
d_lc.head()

Once we collect the light curve infomation in `d_lc`. We can use the same `groupby` trick to split the data frame by ID and bandpass, and proceed to make light curve plot.

In [None]:
d_lc_grouped = d_lc.groupby(["id", "bandpass"]).groups

In [None]:
fig, ax = plt.subplots(ncols=len(objects_to_plot), figsize=(len(objects_to_plot)*4, 4))

for obj, ax_this in zip(objects_to_plot.itertuples(), ax.flat):
    for i, band in enumerate("ugrizy"):
        indices = d_lc_grouped.get((obj.id, band))
        if indices is not None:
            ax_this.plot(d_lc["MJD"][indices], flux2mag(getattr(obj, f"flux_{band}_noMW") + d_lc["delta_flux"][indices]), 'o:', color=f"C{i}", label=f"${band}$")
    ax_this.legend(ncol=3)
    ax_this.set_title(obj.id)
    ax_this.set_xlabel("MJD")
    ax_this.set_ylabel("magnitude")

fig.tight_layout()        

It's important to note that here we must add the time-averaged `flux_X_noMW` to `delta_flux` to get reasonable magnitudem. It is expected that the variability of the stars is much smaller than the SNe. 

We can also plot the variability by subtracting the time-average magnidute:

In [None]:
fig, ax = plt.subplots(ncols=len(objects_to_plot), figsize=(len(objects_to_plot)*4, 4))

for obj, ax_this in zip(objects_to_plot.itertuples(), ax.flat):
    for i, band in [(2,"r")]:
        indices = d_lc_grouped.get((obj.id, band))
        if indices is not None:
            ax_this.plot(d_lc["MJD"][indices], flux2mag(getattr(obj, f"flux_{band}_noMW") + d_lc["delta_flux"][indices]) - flux2mag(getattr(obj, f"flux_{band}_noMW")), 'o:', color=f"C{i}", label=f"${band}$")
    ax_this.legend(ncol=3)
    ax_this.set_title(obj.id)
    ax_this.set_xlabel("MJD")
    ax_this.set_ylabel("$\Delta$ magnitude")
fig.tight_layout()        