# Using a Creator to Calculate True Posteriors for a Galaxy Sample

author: John Franklin Crenshaw, Sam Schmidt, Eric Charles, others...

last run successfully: April 26, 2023

This notebook demonstrates how to use a RAIL Engine to calculate true posteriors for galaxy samples drawn from the same Engine. Note that this notebook assumes you have already read through `degradation-demo.ipynb`.

Calculating posteriors is more complicated than drawing samples, because it requires more knowledge of the engine that underlies the Engine. In this example, we will use the same engine we used in Degradation demo: `FlowEngine` which wraps a normalizing flow from the [pzflow](https://github.com/jfcrenshaw/pzflow) package.

This notebook will cover three scenarios of increasing complexity:
1. [**Calculating posteriors without errors**](#NoErrors)
2. [**Calculating posteriors while convolving errors**](#ErrConv)
3. [**Calculating posteriors with missing bands**](#MissingBands)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pzflow.examples import get_example_flow
from rail.creation.engines.flowEngine import FlowCreator, FlowPosterior
from rail.creation.degradation import (
    InvRedshiftIncompleteness,
    LineConfusion,
    LSSTErrorModel,
    QuantityCut,
)
from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.core.utilStages import ColumnMapper

In [None]:
import pzflow
import os
flow_file = os.path.join(os.path.dirname(pzflow.__file__), 'example_files', 'example-flow.pzflow.pkl')

We'll start by setting up the Rail data store.  RAIL uses [ceci](https://github.com/LSSTDESC/ceci), which is designed for pipelines rather than interactive notebooks, the data store will work around that and enable us to use data interactively.  See the `rail/examples/goldenspike/goldenspike.ipynb` example notebook for more details on the Data Store.

In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

<a id="NoErrors"></a>
## 1. Calculating posteriors without errors

For a basic first example, let's make a Creator with no degradation and draw a sample.

Note that the `FlowEngine.sample` method is handing back a `DataHandle`.  When talking to rail stages we can use this as though it were the underlying data and pass it as an argument.  This allows the rail stages to keep track of where their inputs are coming from.

In [None]:
n_samples=6
# create the FlowCreator
flowCreator = FlowCreator.make_stage(name='truth', model=flow_file, n_samples=n_samples)
# draw a few samples
samples_truth = flowCreator.sample(6, seed=0)

Now, let's calculate true posteriors for this sample. Note the important fact here: these are literally the true posteriors for the sample because pzflow gives us direct access to the probability distribution from which the sample was drawn!

When calculating posteriors, the Engine will always require `data`, which is a pandas DataFrame of the galaxies for which we are calculating posteriors (in out case the `samples_truth`). Because we are using a `FlowEngine`, we also must provide `grid`, because `FlowEngine` calculates posteriors over a grid of redshift values.

Let's calculate posteriors for every galaxy in our sample:

In [None]:
flow_post = FlowPosterior.make_stage(name='truth_post', 
                                             column='redshift',
                                             grid = np.linspace(0, 2.5, 100),
                                             marg_rules=dict(flag=np.nan, 
                                                             u=lambda row: np.linspace(25, 31, 10)),
                                             flow=flow_file)

In [None]:
pdfs = flow_post.get_posterior(samples_truth, column='redshift')

Note that Creator returns the pdfs as a [qp](https://github.com/LSSTDESC/qp) Ensemble:

In [None]:
pdfs.data

Let's plot these pdfs: 

In [None]:
fig, axes = plt.subplots(2, 3, constrained_layout=True, dpi=120)

for i, ax in enumerate(axes.flatten()):

    # plot the pdf
    pdfs.data[i].plot_native(axes=ax)

    # plot the true redshift
    ax.axvline(samples_truth.data["redshift"][i], c="k", ls="--")

    # remove x-ticks on top row
    if i < 3:
        ax.set(xticks=[])
    # set x-label on bottom row
    else:
        ax.set(xlabel="redshift")
    # set y-label on far left column
    if i % 3 == 0:
        ax.set(ylabel="p(z)")

The true posteriors are in blue, and the true redshifts are marked by the vertical black lines.

<a id="ErrConv"></a>
## 2. Calculating posteriors while convolving errors
Now, let's get a little more sophisticated.

Let's recreate the Engine/Degredation we were using at the end of the Degradation demo. 

I will make one change however:
the LSST Error Model sometimes results in non-detections for faint galaxies.
These non-detections are flagged with NaN.
Calculating posteriors for galaxies with missing magnitudes is more complicated, so for now, I will add one additional QuantityCut to remove any galaxies with missing magnitudes.
To see how to calculate posteriors for galaxies with missing magnitudes, see [Section 3](#MissingBands).

Now let's draw a degraded sample:

In [None]:
# set up the error model

n_samples=20
# create the FlowEngine
flowEngine_degr = FlowCreator.make_stage(name='degraded', flow_file=flow_file, n_samples=n_samples)
# draw a few samples
samples_degr = flowEngine_degr.sample(20, seed=0)
errorModel = LSSTErrorModel.make_stage(name='lsst_errors', input='xx')
quantityCut = QuantityCut.make_stage(name='gold_cut',  input='xx', cuts={band: np.inf for band in "ugrizy"})
inv_incomplete = InvRedshiftIncompleteness.make_stage(name='incompleteness', pivot_redshift=0.8)

OII = 3727
OIII = 5007

lc_2p_0II_0III = LineConfusion.make_stage(name='lc_2p_0II_0III', 
                                          true_wavelen=OII, wrong_wavelen=OIII, frac_wrong=0.02)
lc_1p_0III_0II = LineConfusion.make_stage(name='lc_1p_0III_0II',
                                          true_wavelen=OIII, wrong_wavelen=OII, frac_wrong=0.01)
detection = QuantityCut.make_stage(name='detection', cuts={"i": 25.3})

data = samples_degr
for degr in [errorModel, quantityCut, inv_incomplete, lc_2p_0II_0III, lc_1p_0III_0II, detection]:
    data = degr(data)

In [None]:
samples_degraded_wo_nondetects = data.data
samples_degraded_wo_nondetects

This sample has photometric errors that we would like to convolve in the redshift posteriors, so that the posteriors are fully consistent with the errors. We can perform this convolution by sampling from the error distributions, calculating posteriors, and averaging.

`FlowEngine` has this functionality already built in - we just have to provide `err_samples` to the `get_posterior` method.

Let's calculate posteriors with a variable number of error samples.

In [None]:
grid = np.linspace(0, 2.5, 100)
def get_degr_post(key, data, **kwargs):
    flow_degr_post = FlowPosterior.make_stage(name=f'degr_post_{key}', **kwargs) 
    return flow_degr_post.get_posterior(data, column='redshift')

In [None]:
degr_kwargs = dict(column='redshift', flow_file=flow_file, 
                   marg_rules=dict(flag=np.nan, u=lambda row: np.linspace(25, 31, 10)),
                   grid=grid, seed=0, batch_size=2)
pdfs_errs_convolved = {
    err_samples: get_degr_post(f'{str(err_samples)}', data,
                               err_samples=err_samples, **degr_kwargs)
    for err_samples in [1, 10, 100, 1000]
}

In [None]:
fig, axes = plt.subplots(2, 3, dpi=120)

for i, ax in enumerate(axes.flatten()):

    # set dummy values for xlim
    xlim = [np.inf, -np.inf]

    for pdfs_ in pdfs_errs_convolved.values():

        # plot the pdf
        pdfs_.data[i].plot_native(axes=ax)

        # get the x value where the pdf first rises above 2
        xmin = grid[np.argmax(pdfs_.data[i].pdf(grid)[0] > 2)]
        if xmin < xlim[0]:
            xlim[0] = xmin
            
        # get the x value where the pdf finally falls below 2   
        xmax = grid[-np.argmax(pdfs_.data[i].pdf(grid)[0, ::-1] > 2)]
        if xmax > xlim[1]:
            xlim[1] = xmax

    # plot the true redshift
    #z_true = samples_degraded_wo_nondetects["redshift"][i]
    #ax.axvline(z_true, c="k", ls="--")

    # set x-label on bottom row
    if i >= 3:
        ax.set(xlabel="redshift")
    # set y-label on far left column
    if i % 3 == 0:
        ax.set(ylabel="p(z)")

    # set the x-limits so we can see more detail
    xlim[0] -= 0.2
    xlim[1] += 0.2
    ax.set(xlim=xlim, yticks=[])

# create the legend
axes[0, 1].plot([], [], c="C0", label=f"1 sample")
for i, n in enumerate([10, 100, 1000]):
    axes[0, 1].plot([], [], c=f"C{i+1}", label=f"{n} samples")
axes[0, 1].legend(
    bbox_to_anchor=(0.5, 1.3), 
    loc="upper center",
    ncol=4,
)

plt.show()

You can see the effect of convolving the errors. In particular, notice that without error convolution (1 sample), the redshift posterior is often totally inconsistent with the true redshift (marked by the vertical black line). As you convolve more samples, the posterior generally broadens and becomes consistent with the true redshift.

Also notice how the posterior continues to change as you convolve more and more samples. This suggests that you need to do a little testing to ensure that you have convolved enough samples.

Let's plot these same posteriors with even more samples to make sure they have converged:

**WARNING**: Running the next cell on your computer may exhaust your memory

In [None]:
pdfs_errs_convolved_more_samples = {
    err_samples: get_degr_post(f'{str(err_samples)}', data,
                               err_samples=err_samples, **degr_kwargs)
#    for err_samples in [1000, 2000, 5000, 10000]
    for err_samples in [12]
}

In [None]:
fig, axes = plt.subplots(2, 3, dpi=120)

for i, ax in enumerate(axes.flatten()):

    # set dummy values for xlim
    xlim = [np.inf, -np.inf]

    for pdfs_ in pdfs_errs_convolved_more_samples.values():

        # plot the pdf
        pdfs_.data[i].plot_native(axes=ax)

        # get the x value where the pdf first rises above 2
        xmin = grid[np.argmax(pdfs_.data[i].pdf(grid)[0] > 2)]
        if xmin < xlim[0]:
            xlim[0] = xmin
            
        # get the x value where the pdf finally falls below 2
        xmax = grid[-np.argmax(pdfs_.data[i].pdf(grid)[0, ::-1] > 2)]
        if xmax > xlim[1]:
            xlim[1] = xmax

    # plot the true redshift
    #z_true = samples_degraded_wo_nondetects["redshift"][i]
    #ax.axvline(z_true, c="k", ls="--")

    # set x-label on bottom row
    if i >= 3:
        ax.set(xlabel="redshift")
    # set y-label on far left column
    if i % 3 == 0:
        ax.set(ylabel="p(z)")

    # set the x-limits so we can see more detail
    xlim[0] -= 0.2
    xlim[1] += 0.2
    ax.set(xlim=xlim, yticks=[])

# create the legend
for i, n in enumerate([1000, 2000, 5000, 10000]):
    axes[0, 1].plot([], [], c=f"C{i}", label=f"{n} samples")
axes[0, 1].legend(
    bbox_to_anchor=(0.5, 1.3), 
    loc="upper center",
    ncol=4,
)

plt.show()

Notice that two of these galaxies may take upwards of 10,000 samples to converge (convolving over 10,000 samples takes 0.5 seconds / galaxy on my computer)

<a id="MissingBands"></a>
## 3. Calculating posteriors with missing bands

Now let's finally tackle posterior calculation with missing bands.

First, lets make a sample that has missing bands. Let's use the same degrader as we used above, except without the final QuantityCut that removed non-detections:

In [None]:
samples_degraded = DS['output_lc_1p_0III_0II']

You can see that galaxy 3 has a non-detection in the u band. `FlowEngine` can handle missing values by marginalizing over that value. By default, `FlowEngine` will marginalize over NaNs in the u band, using the grid `u = np.linspace(25, 31, 10)`. This default grid should work in most cases, but you may want to change the flag for non-detections, use a different grid for the u band, or marginalize over non-detections in other bands. In order to do these things, you must supply `FlowEngine` with marginalization rules in the form of the `marg_rules` dictionary.

Let's imagine we want to use a different grid for u band marginalization. In order to determine what grid to use, we will create a histogram of non-detections in u band vs true u band magnitude (assuming year 10 LSST errors). This will tell me what are reasonable values of u to marginalize over.

In [None]:
# get true u band magnitudes
true_u = DS['output_degraded'].data["u"].to_numpy()
# get the observed u band magnitudes
obs_u = DS['output_lsst_errors'].data["u"].to_numpy()

# create the figure
fig, ax = plt.subplots(constrained_layout=True, dpi=100)
# plot the u band detections
ax.hist(true_u[~np.isnan(obs_u)], bins="fd", label="detected")
# plot the u band non-detections
ax.hist(true_u[np.isnan(obs_u)], bins="fd", label="non-detected")

ax.legend()
ax.set(xlabel="true u magnitude")

plt.show()

Based on this histogram, I will marginalize over u band values from 27 to 31. Like how I tested different numbers of error samples above, here I will test different resolutions for the u band grid.

I will provide our new u band grid in the `marg_rules` dictionary, which will also include `"flag"` which tells `FlowEngine` what my flag for non-detections is.
In this simple example, we are using a fixed grid for the u band, but notice that the u band rule takes the form of a function - this is because the grid over which to marginalize can be a function of any of the other variables in the row. 
If I wanted to marginalize over any other bands, I would need to include corresponding rules in `marg_rules` too.

For this example, I will only calculate pdfs for galaxy 3, which is the galaxy with a non-detection in the u band. Also, similarly to how I tested the error convolution with a variable number of samples, I will test the marginalization with varying resolutions for the marginalized grid.

In [None]:
from rail.core.utilStages import RowSelector
# dict to save the marginalized posteriors
pdfs_u_marginalized = {}

row3_selector = RowSelector.make_stage(name='select_row3', start=3, stop=4)
row3_degraded = row3_selector(samples_degraded)

degr_post_kwargs = dict(grid=grid, err_samples=10000, seed=0, flow_file=flow_file, column='redshift')

# iterate over variable grid resolution
for nbins in [10, 20, 50, 100]:

    # set up the marginalization rules for this grid resolution
    marg_rules = {
        "flag": errorModel.config["ndFlag"],
        "u": lambda row: np.linspace(27, 31, nbins)
    }

    
    # calculate the posterior by marginalizing over u and sampling
    # from the error distributions of the other galaxies
    pdfs_u_marginalized[nbins] = get_degr_post(f'degr_post_nbins_{nbins}',
                                               row3_degraded,
                                               marg_rules=marg_rules,
                                               **degr_post_kwargs)

    

In [None]:
fig, ax = plt.subplots(dpi=100)
for i in [10, 20, 50, 100]:
    pdfs_u_marginalized[i]()[0].plot_native(axes=ax, label=f"{i} bins")
ax.axvline(samples_degraded().iloc[3]["redshift"], label="True redshift", c="k")
ax.legend()
ax.set(xlabel="Redshift")
plt.show()

Notice that the resolution with only 10 bins is sufficient for this marginalization.

In this example, only one of the bands featured a non-detection, but you can easily marginalize over more bands by including corresponding rules in the `marg_rules` dict. For example, let's artificially force a non-detection in the y band as well:

In [None]:
row3_degraded()

In [None]:
sample_double_degraded = row3_degraded().copy()
sample_double_degraded.iloc[0, 11:] *= np.nan
DS.add_data('double_degr', sample_double_degraded, TableHandle)
sample_double_degraded

In [None]:
# set up the marginalization rules for u and y marginalization
marg_rules = {
    "flag": errorModel.config.ndFlag,
    "u": lambda row: np.linspace(27, 31, 10),
    "y": lambda row: np.linspace(21, 25, 10),
}

# calculate the posterior by marginalizing over u and y, and sampling
# from the error distributions of the other galaxies
#pdf_double_marginalized = get_degr_post('double_degr', 
#                                        DS['double_degr'],
#                                        flow_file=flow_file,
#                                        input='xx',
#                                        column='redshift',
#                                        grid=grid, 
#                                        err_samples=10000, 
#                                        seed=0, 
#                                        marg_rules=marg_rules)

In [None]:
#fig, ax = plt.subplots(dpi=100)
#pdf_double_marginalized.data[0].plot_native(axes=ax)
#ax.axvline(sample_double_degraded.iloc[0]["redshift"], label="True redshift", c="k")
#ax.legend()
#ax.set(xlabel="Redshift")
#plt.show()

Note that marginalizing over multiple bands quickly gets expensive