# Basic Creation Demo

This notebook shows how to load a normalizing flow from pzflow, wrap it for rail.creation, and draw galaxy samples with redshift posteriors.

In [None]:
from pzflow.examples import example_flow
from rail.creation import Creator, engines
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches

Let's load the example galaxy redshift flow from pzflow. To see the construction of this flow, look at this pzflow [example notebook](https://github.com/jfcrenshaw/pzflow/blob/main/examples/redshift_example.ipynb). This flow will be wrapped in a `FlowEngine`, which is a wrapper so that the methods of the flow match the methods expected by the `Creator` object

In [None]:
flow = engines.FlowEngine(example_flow())
creator = Creator(flow)

Now we can draw samples from the creator:

In [None]:
samples = creator.sample(n_samples=10000, seed=0)
samples

These samples don't include redshift posteriors. If we want posteriors, we need to set`include_pdf==True`:

In [None]:
samples_w_pdfs = creator.sample(n_samples=10000, seed=0, include_pdf=True)
samples_w_pdfs

The last column of the samples are the true redshift posteriors. You can also access the grid over which the posteriors were calculated:

In [None]:
samples_w_pdfs.attrs['pz_grid']

This is the default redshift grid, but we can also define our own grid:

In [None]:
grid = np.arange(0, 2.5, 0.5)
samples_w_pdfs2 = creator.sample(n_samples=10000, seed=0, include_pdf=True, pz_grid=grid)
samples_w_pdfs2

In [None]:
samples_w_pdfs2.attrs['pz_grid']

Let's plot a few of the pdf's

In [None]:
fig,axes = plt.subplots(4, 3, figsize=(15,20))

axes = axes.flatten()

for idx, ax in enumerate(axes):
    
    zs = samples_w_pdfs.attrs['pz_grid']
    redshift, u, g, r, i, z, y, pz = samples_w_pdfs.loc[idx]
    ax.plot(zs, pz)
    
    ax.axvline(redshift, c='C3', ls='--', zorder=0)
    
    # create legend with magnitudes
    handles = [mpl_patches.Rectangle((0, 0), 1, 1, fc="white", ec="white", lw=0, alpha=0)] * 6
    labels = []
    band_labels = ['u', 'g', 'r', 'i', 'z', 'y']
    for band,mag in zip(band_labels, [u,g,r,i,z,y]):
        labels.append(f'{band} = {mag:.2f}')
    ax.legend(handles, labels, loc='best', 
              fancybox=True, framealpha=0.7, 
              handlelength=0, handletextpad=0)
    
    ax.set_xlim(0,2)
    ax.set_ylim(0,30)
    ax.set_xlabel("Redshift")
    ax.set_ylabel("$p(z \,|\, ugrizy)$")

Let's look at the point estimates

In [None]:
z_map = []
for pz in samples_w_pdfs['pz_pdf']:
    z_map.append(samples_w_pdfs.attrs['pz_grid'][pz.argmax()])

In [None]:
fig,ax = plt.subplots(figsize=(4,4), constrained_layout=True)
ax.scatter(samples_w_pdfs['redshift'], z_map, s=1)
ax.set_xlabel('True redshift')
ax.set_ylabel('$z_\mathrm{MAP}$')
ax.set_xlim(0,2)
ax.set_ylim(0,2)