# validating $p(z | photometry)$ for ELAsTiCC

_Alex Malz (~~GCCL@RUB~~LINCC@CMU)_

The goal here is to validate realistically complex (not necessarily "good") mock photo-$z$ posteriors for host galaxies. 

In [None]:
import bisect
import corner
import math
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
random.seed = 42
import scipy.stats as sps
import subprocess
import sys
eps = sys.float_info.epsilon

In [None]:
import pzflow
from pzflow import Flow
from pzflow.bijectors import Chain, ColorTransform, InvSoftplus, StandardScaler, RollingSplineCoupling, ShiftBounds
from pzflow.distributions import Uniform, Joint, Normal

In [None]:
import qp

## Host photometry

Let's pick one hostlib for now.
This is somewhat slow because the files are ~GB.

In [None]:
hl_heads = {'SNIa': 64,
            'SNII': 128, 
            'SNIbc': 128, 
            # 'SNIbc_Pt2': None, 
            'UNMATCHED_KN_SHIFT': 64,
            'UNMATCHED_COSMODC2': 64}

In [None]:
pick_one = 1
which_hl = list(hl_heads.keys())[pick_one]
in_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/unzip/'+which_hl+'_GHOST.HOSTLIB_RESTOREDMAG1YR'#.gz'
# in_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/sandbox/SNII_GHOST_NOHEADER.HOSTLIB'
hl_head = 0#int(subprocess.check_output(f"zcat {in_path} | cat -n | sed -n '/VARNAMES/ {{ p; q }}'  | awk '{{print $1-1}}'", shell=True))

In [None]:
df = pd.read_csv(in_path, skiprows=hl_head, delimiter=' ', header=0)
# df.set_index('GALID')
nhost = len(df)
# nhost = 100
# df = df[:nhost]
print(nhost)

In [None]:
df.columns

In [None]:
# df[['u_obs', 'g_obs', 'r_obs', 'i_obs', 'z_obs',
#        'Y_obs', 'u_obs_err', 'g_obs_err',
#        'r_obs_err', 'i_obs_err', 'z_obs_err', 'Y_obs_err']].isnull().sum()

In [None]:
true_locs = df['ZTRUE'].values.reshape((nhost, 1))

## Photo-$z$ model

In [None]:
hl_df_in = df.rename(columns={'ZTRUE':'redshift',
                           'Y_obs':'y', 
                           'r_obs':'r', 
                           'u_obs':'u', 
                           'g_obs':'g', 
                           'z_obs':'z', 
                           'i_obs':'i',#})[['redshift','u', 'g', 'r', 'i', 'z', 'y']]
                            'Y_obs_err':'y_err', 
                           'r_obs_err':'r_err', 
                           'u_obs_err':'u_err', 
                           'g_obs_err':'g_err', 
                           'z_obs_err':'z_err', 
                           'i_obs_err':'i_err'})[['redshift', 
                                                  'u', 'g', 'r', 'i', 'z', 'y', 
                                                  'u_err', 'g_err', 'r_err', 'i_err', 'z_err', 'y_err']]

quantities = hl_df_in.columns

# convert magnitudes to colors
hl_df_colors = hl_df_in.copy()[['redshift', 'r', 'r_err']]
for i in range(len(quantities)-2-6):
    hl_df_colors[quantities[i+1]+'-'+quantities[i+2]] = hl_df_in[quantities[i+1]] - hl_df_in[quantities[i+2]]
    hl_df_colors[quantities[i+1]+'-'+quantities[i+2]+'_err'] = np.sqrt(hl_df_in[quantities[6+i+1]]**2 + hl_df_in[quantities[6+i+2]]**2)
      
hl_df = hl_df_colors[['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r', 'u-g_err', 'g-r_err', 'r-i_err', 'i-z_err', 'z-y_err', 'r_err']][:nhost]
df_subset = hl_df[['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']]

flow = Flow(file='../data_files/model_photo-zs_uniform_splbin64_epoch100_flow.pkl')# this path will not change any time soon
flow.latent = Uniform((-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5))

In [None]:
# plt.hist(hl_df[])

In [None]:
nsamp = 10000

In [None]:
samples = flow.sample(nsamp, seed=0)

fig = plt.figure(figsize=(12,12))

# ranges = [(-0.1,2.4), (19.5,33), (19,32), (19,29), (19,29), (19,28), (19,28)]

corner.corner(samples, fig=fig, color='r', bins=20, hist_bin_factor=2, data_kwargs={'ms':3}, contour_kwargs={'linewidths':2}, label='model', labels=['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r'])

corner.corner(df_subset.sample(nsamp), fig=fig, bins=20, hist_bin_factor=2, color='b', data_kwargs={'ms':3}, show_titles=True, label='inputs', labels=['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']);

# corner.corner(res_plot.sample(nsamp), fig=fig, bins=20, hist_bin_factor=2, color='g', data_kwargs={'ms':3}, show_titles=True, label='combo', labels=['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']);

fig.legend()
plt.savefig('./debug_1yr.png')

In [None]:
# hardmax = {}

In [None]:
# hardmax['r'] = max(samples['r'])

In [None]:
# hardmax['g'] = max(samples['g-r']) + hardmax['r']

In [None]:
# hardmax['u'] = max(samples['u-g']) + hardmax['g']

In [None]:
# hardmax['i'] = hardmax['r'] - min(samples['r-i'])

In [None]:
# hardmax['z'] = hardmax['i'] - min(samples['i-z'])

In [None]:
# hardmax['y'] = hardmax['z'] - min(samples['z-y'])

In [None]:
# hardmax

## Estimated photo-$z$ information

just one chunk for now

In [None]:
chunks = hl_heads[which_hl]
idx = np.random.randint(chunks)
print(idx)

In [None]:
batch_size = 50
batch_factor = math.ceil(nhost / (chunks * batch_size))
chunk_size = batch_size * batch_factor

In [None]:
# idx = 86

glob_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/zquants/1yr/'
in_name = 'pz'+which_hl+'batch'+str(idx)+'_'+str(batch_size * batch_factor)+'chunks'+str(batch_size)+'magcap'
in_pdfs = qp.read(glob_path+in_name+'.fits')

# pdf_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/unzip/'+which_hl+'_GHOST_AIMALZPZS.HOSTLIB'
# res = pd.read_csv(pdf_path, skiprows=hl_head, delimiter=' ', header=0)

In [None]:
# df_subset = df[df['GALID'].isin(in_pdfs.ancil['GALID'])]

In [None]:
df_subset['GALID'] = df['GALID'].copy()
df_subset[['u-g_err', 'g-r_err', 'r-i_err', 'i-z_err', 'z-y_err', 'r_err']] = hl_df[['u-g_err', 'g-r_err', 'r-i_err', 'i-z_err', 'z-y_err', 'r_err']].copy()
df_subset = df_subset[idx * chunk_size : min((idx+1) * chunk_size, nhost)]

In [None]:
# df_subset

### Point estimates (median)

In [None]:
med_ind = np.where(in_pdfs.metadata()['quants'][0] == 0.5)[0][0]

In [None]:
np.sum(np.isnan(in_pdfs.objdata()['locs'].T[med_ind])) / chunk_size

In [None]:
mask = ~np.isnan(in_pdfs.objdata()['locs'].T[med_ind])

In [None]:
zgrid = np.logspace(-3., np.log10(3.), 300)

In [None]:
fig = plt.figure(figsize=(5, 5))
plt.hist2d(df_subset['redshift'].values[mask], in_pdfs.objdata()['locs'][mask].T[med_ind], 
          bins=zgrid, cmap='gray_r', density=True);
plt.xlabel('true redshift')
plt.ylabel('median photo-z')

In [None]:
# res_plot = res.rename(columns={'ZTRUE':'redshift',
#                            'Y_obs':'y', 
#                            'r_obs':'r', 
#                            # 'u_obs':'u', 
#                            'g_obs':'g', 
#                            'z_obs':'z', 
#                            'i_obs':'i'})
# res_plot_colors = res_plot.copy()[['redshift', 'r']]
# res_plot['u'] = 0
# for i in range(len(quantities)-2-6):
#     res_plot_colors[quantities[i+1]+'-'+quantities[i+2]] = res_plot[quantities[i+1]] - res_plot[quantities[i+2]]
#     # hl_df_colors[quantities[i+1]+'-'+quantities[i+2]+'_err'] = np.sqrt(hl_df[quantities[6+i+1]]**2 + hl_df[quantities[6+i+2]]**2)
      
# res_plot = res_plot_colors[['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']]

In [None]:
# len(df)

In [None]:
# len(res)

In [None]:
# plt.hist(df['Y_obs'], bins=100);
# plt.hist(res['Y_obs'], bins=100);
# plt.semilogy()

In [None]:
# plt.hist2d(df['Y_obs']-df['z_obs'], res['Y_obs']-res['z_obs'], bins=(100,100));

In [None]:
# plt.hist(df['ZTRUE'], bins=zgrid, alpha=0.5);
# plt.hist(res['ZTRUE'], bins=zgrid, alpha=0.5);
# plt.hist(res['ZPHOT'], bins=zgrid, alpha=0.5);
# plt.xlabel()

In [None]:
# wherenans = ~np.sum(np.isnan(in_pdfs.objdata()['locs']), axis=1)
# plt.hist2d(df_subset['redshift'][wherenans], in_pdfs.objdata()['locs'][wherenans].T[6])

In [None]:
# truth = df_subset['redshift'].to_numpy()

In [None]:
# medians = in_pdfs.objdata()['locs'][:, 4]

In [None]:
truth = df_subset['redshift'].values[mask]
medians = in_pdfs.objdata()['locs'][mask].T[med_ind]

In [None]:
fig = plt.figure(figsize=(5, 5))
plt.scatter(truth, medians, s=1, alpha=0.1);
plt.xlabel('true redshift')
plt.ylabel('median photo-z')

In [None]:
plt.hist(truth, alpha=0.5, label='truth', bins=zgrid, density=True);
plt.hist(medians, alpha=0.5, label='medians', bins=zgrid, density=True);
plt.legend()
plt.xlabel('z')

### Point estimate metrics

In [None]:
bias_mlg = (medians - truth) / (1 + medians)
bias_desc = (medians - truth) / (1 + truth)

print(np.mean(bias_mlg))
print(np.mean(bias_desc))

In [None]:
plt.hist(bias_mlg, bins=np.arange(-3., 1., 0.1), alpha=0.5, label='robust bias');
plt.hist(bias_desc, bins=np.arange(-3., 1., 0.1), alpha=0.5, label='canonical bias');
plt.legend()
plt.xlabel('bias')

In [None]:
plt.scatter(truth, medians, s=1, cmap=mpl.cm.viridis_r, c=np.abs(bias_mlg), vmin=0, vmax=max(np.abs(bias_mlg)), alpha=0.1)
plt.colorbar()
plt.xlabel('true redshift')
plt.ylabel('median photo-z')

In [None]:
scatter_mlg = in_pdfs.ancil['iqr'][:,0][mask] * 1.349
scatter_desc = scatter_mlg / (1 + truth)

print(np.mean(scatter_mlg))
print(np.mean(scatter_desc))

In [None]:
plt.hist(scatter_mlg, bins=zgrid, alpha=0.5, label='robust scatter');
plt.hist(scatter_desc, bins=zgrid, alpha=0.5, label='canonical scatter');
plt.legend()
# plt.semilogx()
plt.xlabel('scatter')

In [None]:
plt.scatter(truth, medians, s=1, cmap=mpl.cm.viridis_r, c=scatter_mlg, alpha=0.1, vmin=0, vmax=max(scatter_mlg))
plt.colorbar()
plt.xlabel('true redshift')
plt.ylabel('median photo-z')

In [None]:
# plt.scatter(np.abs(bias_mlg), 3. * scatter_mlg, s=1, cmap=mpl.cm.viridis_r, c=truth, vmin=0, vmax=3, alpha=0.1)
# plt.plot([0, 3], [0, 3], c='k')
# plt.colorbar()
# plt.xlabel('bias')
# plt.ylabel('outlier threshold')

In [None]:
thresh = np.where(3 * scatter_mlg < 0.06, 0.06, 3 * scatter_mlg)

In [None]:
is_outlier = np.where(np.abs(bias_mlg) > thresh)[0]

In [None]:
len(is_outlier) / len(thresh)

### Quantiles

In [None]:
# quants = np.linspace(0., 1., 11)
# quants[0] += eps
# quants[-1] -= eps

plot a few

In [None]:
plot_inds = random.sample(range(in_pdfs.npdf), 10)
print(plot_inds)

In [None]:
galids = in_pdfs.ancil['GALID'][plot_inds]
print(galids)

In [None]:
to_plot = df_subset.iloc[plot_inds]
print(to_plot['GALID'].values)

In [None]:
# df['GALID'].iloc[to_plot.index]

In [None]:
# truth = df_subset['redshift'].values[mask]
# medians = in_pdfs.objdata()['locs'][mask].T[med_ind]

In [None]:
mag_caps = {'u': 41.0, 'g': 36.1, 'r': 34.8, 'i': 34.9, 'z': 34.9, 'y': 34.9}

In [None]:
fig, ax = plt.subplots(len(to_plot), 1, figsize=(5, 20))
evalled = flow.posterior(to_plot, column='redshift', grid=zgrid, err_samples=10)
for i, galind in enumerate(to_plot.index):
    flagged = 0
    for col in mag_caps.keys():
        if hl_df_in[col].iloc[to_plot.index[i]] > mag_caps[col]:
            flagged += 1
    print((i, galind, flagged))
    ax[i].plot(zgrid, evalled[i], c='b', label='model PDF')
    in_pdfs[[plot_inds[i]]].plot(axes=ax[i], color='g', label='quantile PDF')
    ax[i].vlines(to_plot['redshift'].iloc[i], 0, max(example), color='r',  label='true redshift')
    ax[i].vlines(in_pdfs.objdata()['locs'][plot_inds[i]], 0, max(example)/5., color='k', linewidth=.75, label='raw quantiles')
    # ax[i].set_ylim(0, 15)
    ax[i].text(2, 5, str(galids[i]))
    ax[i].set_ylabel('$p(z)$')
    if i == len(to_plot)-1:
        ax[i].set_xlabel('$z$')
    else:
        ax[i].set_xticklabels([])
    ax[i].legend(loc='upper right')
fig.subplots_adjust(wspace=0, hspace=0)

# scratch below here

In [None]:
in_pdfs.objdata()['locs']

In [None]:
plt.hist(in_pdfs.objdata()['locs'].T[0])

In [None]:
is_outlier

In [None]:
in_pdfs.objdata()['locs'][35]

In [None]:
rando = 26
plt.title(str(in_pdfs.ancil['GALID'][rando]))
plt.vlines(in_pdfs.objdata()['locs'][rando], -1, 1, color='k', label='quantiles')
plt.plot(zgrid, evalled[rando], label='orig. grid')
plt.plot(zgrid_lo, evalled_lo[rando], label='lower-res grid')
plt.plot(zgrid_hi, evalled_hi[rando], label='higher-res grid')
plt.legend()
plt.xlim(0, 3)

In [None]:
rando = 35
plt.title(str(in_pdfs.ancil['GALID'][rando]))
plt.vlines(in_pdfs.objdata()['locs'][rando], -1, 1, color='k', label='quantiles')
plt.plot(zgrid, evalled[rando], label='orig. grid')
plt.plot(zgrid_lo, evalled_lo[rando], label='lower-res grid')
plt.plot(zgrid_hi, evalled_hi[rando], label='higher-res grid')
plt.legend()
plt.xlim(0, 3)

In [None]:
is_broken = np.where(in_pdfs.objdata()['locs'][:,-1] == zgrid[-1])[0]

In [None]:
len(is_outlier) / len(thresh) - len(is_broken) / len(thresh)

In [None]:
plt.hist(truth[is_outlier], bins=np.arange(0., 3.01, 0.01), alpha=0.5, label='truth');
plt.hist(medians[is_outlier], bins=np.arange(0., 3.01, 0.01), alpha=0.5, label='median');
plt.legend()

In [None]:
np.max(medians[is_outlier])

In [None]:
quantlabs = ['ZPHOT_Q000', 'ZPHOT_Q010', 'ZPHOT_Q020', 'ZPHOT_Q030', 'ZPHOT_Q040', 'ZPHOT_Q050', 'ZPHOT_Q060', 'ZPHOT_Q070', 'ZPHOT_Q080', 'ZPHOT_Q090', 'ZPHOT_Q100']

### Reconstructed PDFs

plot one just to see what the quantiles look like for a given PDF

In [None]:
# plot_one = random.sample(range(nhost), 1)
# plt.plot(zgrid, sps.truncnorm(a=pos_min[plot_one][0], b=pos_max[plot_one][0], 
#                               loc=obs_locs[plot_one][0], scale=sigma*(1+obs_locs[plot_one][0])).pdf(zgrid))
# plt.vlines(df[quantlabs].loc[plot_one], -1, 1, linestyle='--', color='k')
# # plt.xlim(obs_locs[plot_one][0]-5*sigma*(1+obs_locs[plot_one][0]), 
# #          obs_locs[plot_one][0]+5*sigma*(1+obs_locs[plot_one][0]))
# plt.xlim(df['ZPHOT_Q000'].loc[plot_one].values[0]-0.01, df['ZPHOT_Q100'].loc[plot_one].values[0]+0.01)
# plt.text(obs_locs[plot_one][0], 2, str(df['GALID'].loc[plot_one[0]]))
# plt.xlabel('$z$')
# plt.ylabel('$p(z)$')

plot histogram of first and last quantile, since these aren't really well-defined

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].hist(df['ZPHOT_Q000'], bins=300)
ax[0].set_xlabel('$z_{0\%}$')
ax[1].hist(df['ZPHOT_Q100'], bins=300)
ax[1].set_xlabel('$z_{100\%}$')

also save $p(z_{median})$ and [inter-quartile range](https://en.wikipedia.org/wiki/Interquartile_range)

unfortunately these are quite slow now!

In [None]:
df['IQR_ZPHOT'] = posterior.ppf(0.75) - posterior.ppf(0.25)

In [None]:
df['P_ZPHOT'] = posterior.pdf(posterior.median())

save a file

In [None]:
out_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/zquants/'+which_hl+'_dummy_pz.csv'
df.to_csv(out_path, index=False, sep=' ')

## reading in quantiles

In [None]:
nhost = len(in_pdfs.objdata()['locs'])
zgrid = np.logspace(-2.5, np.log10(3.), 250)
to_plot = random.sample(range(nhost), 10)

zgrid_lo = np.logspace(-3., np.log10(3.), 100)
zgrid_hi = np.logspace(-3., np.log10(3.), 300)

In [None]:
quants = in_pdfs.metadata()['quants'][0]

In [None]:
plot_one = 3

In [None]:
zq_vals = in_pdfs[plot_one].objdata()['locs'][0]

In [None]:
len(zq_vals)

In [None]:
len(quants)

In [None]:
# evalled = in_pdfs.pdf(zgrid)
# evalled_lo = in_pdfs.pdf(zgrid_lo)
# evalled_hi = in_pdfs.pdf(zgrid_hi)

In [None]:
# rando = 1
# plt.title(str(in_pdfs.ancil['GALID'][rando]))
# plt.vlines(in_pdfs.objdata()['locs'][rando], -1, 1, color='k', label='quantiles')
# plt.plot(zgrid, evalled[rando], label='orig. grid')
# plt.plot(zgrid_lo, evalled_lo[rando], label='lower-res grid')
# plt.plot(zgrid_hi, evalled_hi[rando], label='higher-res grid')
# plt.legend()
# # plt.xlim(0, 1)

## reconstructing a pdf from quantiles

read from a generic hostlib file

In [None]:
# df = pd.read_csv(out_path, delimiter=' ', header=0)

pick one for demonstration

In [None]:
# plot_one = random.sample(range(nhost), 1)[0]
# zq_vals = df[quantlabs].loc[plot_one].values

This is the reconstruction algorithm from [ye olde qp](https://github.com/aimalz/qp/blob/master/qp/pdf.py#L554).

Note: The aforementioned code exhibits unexpected behavior in Python 3; 
you must (unfortunately) run it in Python 2 for consistency with [Malz & Marshall+ 2017](http://stacks.iop.org/1538-3881/156/i=1/a=35).

Another note: You can save yourself one float by replacing the redshifts where $CDF=0$ and $CDF=1$ with $p(z_{q})$ for any of the saved quantiles $q$, at the cost of some inaccuracy in the tails.

In [None]:
# derivative = (quants[1:] - quants[:-1]) / (zq_vals[1:] - zq_vals[:-1])
# derivative = np.insert(derivative, 0, 0.)
# derivative = np.append(derivative, 0.)

# def pdf_inside(zgrid):
#     pdf = np.zeros_like(zgrid)
#     for n in range(len(zgrid)):
#         ind = bisect.bisect_left(zq_vals, zgrid[n])
#         pdf[n] = derivative[ind]
#     return(pdf)

q = quants
z = zq_vals

derivative = (q[1:] - q[:-1]) / (z[1:] - z[:-1])
derivative = np.insert(derivative, 0, eps)
derivative = np.append(derivative, eps)
def pdf_inside(xf):
    nx = len(xf)
    yf = np.ones(nx) * eps
    for n in range(nx):
        i = bisect.bisect_left(z, xf[n])
        yf[n] = derivative[i]
    return(yf)

eval_pdf = pdf_inside(zgrid)

show difference between original and reconstruction

In [None]:
# plt.plot(zgrid, sps.truncnorm(a=pos_min[plot_one][0], b=pos_max[plot_one][0], 
#                               loc=obs_locs[plot_one][0], scale=sigma*(1+obs_locs[plot_one][0])).pdf(zgrid),
#                               label='original PDF')
# plt.vlines(df[quantlabs].loc[plot_one], -1, 1, linestyle='--')
plt.vlines(zq_vals, -1, 1, linestyle='--')
# plt.xlim(zq_vals[0] - 0.01, zq_vals[-1] + 0.01)
plt.plot(zgrid, eval_pdf, label='reconstructed from '+str(len(quants))+' quantiles')
# plt.text(obs_locs[plot_one][0], 2, str(df['GALID'].loc[plot_one]))
plt.legend(loc='upper right')
plt.xlabel('$z$')
plt.ylabel('$p(z)$')

## Troubleshooting

In [None]:
hl_df = df_subset.rename(columns={'ZTRUE':'redshift',
                           'Y_obs':'y', 
                           'r_obs':'r', 
                           'u_obs':'u', 
                           'g_obs':'g', 
                           'z_obs':'z', 
                           'i_obs':'i',#})[['redshift','u', 'g', 'r', 'i', 'z', 'y']]
                            'Y_obs_err':'y_err', 
                           'r_obs_err':'r_err', 
                           'u_obs_err':'u_err', 
                           'g_obs_err':'g_err', 
                           'z_obs_err':'z_err', 
                           'i_obs_err':'i_err'})[['redshift', 
                                                  'u', 'g', 'r', 'i', 'z', 'y', 
                                                  'u_err', 'g_err', 'r_err', 'i_err', 'z_err', 'y_err']]

quantities = hl_df.columns

hl_df_colors = hl_df.copy()[['redshift', 'r', 'r_err']]
for i in range(len(quantities)-2-6):
    hl_df_colors[quantities[i+1]+'-'+quantities[i+2]] = hl_df[quantities[i+1]] - hl_df[quantities[i+2]]
    hl_df_colors[quantities[i+1]+'-'+quantities[i+2]+'_err'] = np.sqrt(hl_df[quantities[6+i+1]]**2 + hl_df[quantities[6+i+2]]**2)
print(hl_df_colors.columns)

# hl_df = hl_df_colors[['redshift', 'u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r', 'u-g_err', 'g-r_err', 'r-i_err', 'i-z_err', 'z-y_err', 'r_err']][:nhost]

In [None]:
flow = Flow(file='../data_files/model_photo-zs_uniform_splbin64_epoch100_flow.pkl')
# this path will not change any time soon
flow.latent = Uniform((-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5), (-5, 5))

In [None]:
# hl_df_colors.iloc[is_outlier[0]]

In [None]:
flow_z = flow.posterior(hl_df_colors.iloc[is_broken[:10]],#[['u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']], 
                            column='redshift', grid=zgrid, err_samples=1000)#, batch_size=min(batch_size, len(hl_subset)))

In [None]:
hl_df_colors.iloc[is_broken[:10]]

In [None]:
df.iloc[is_broken[:10]]

In [None]:
for i in [2]:
    # flow_z = flow.posterior(hl_df_colors.iloc[i],#[['u-g', 'g-r', 'r-i', 'i-z', 'z-y', 'r']], 
    #                         column='redshift', grid=zgrid, err_samples=100)
    plt.figure(size=(5,5))
    plt.plot(zgrid, flow_z[i], label=df_subset['GALID'].iloc[i])
# for i in range(10):
    plt.vlines(df_subset['ZTRUE'].iloc[is_outlier[i]], -1., 100, color='k')
    plt.vlines(medians[is_outlier[i]], -1., 100, color='r')
    plt.legend()
    plt.close()

## SCRATCH: scaling calculations

In [None]:
#which_hl, nhost, ncompleted, nremain, nredo
scaling = {'SNIa': (2141261, 871000, 1270261, 424984),
                 'SNII': (2449022, 76600, 2372422, 36194),
                 'SNIbc': (3354141, 0, 3354141, 0),
                 'UNMATCHED_KN_SHIFT': (1907364, 835800, 1071564, 443893),
                 'UNMATCHED_COSMODC2': (1907364, 537300, 1370064, 285367)}

In [None]:
# redoing bad magnitude errors

for key in scaling.keys():
    print(f'{key} needs to redo {scaling220622[key][3]}')
    print(f'in {math.ceil(scaling220622[key][3] / 50)} sub-batches of 50')
    print(f'so {math.ceil(scaling220622[key][3] / 50)*0.5/60.}hours at 30s per sub-batch')
    print(f'so {math.ceil(scaling220622[key][3] / 50)*0.5/60./32}hours per job')

In [None]:
# not yet done

for key in scaling220622.keys():
    print(f'{key} still needs to do {scaling220622[key][2]}')
    print(f'in {math.ceil(scaling220622[key][0] / 50)} sub-batches of 50')
    # print(f'so {math.ceil(scaling220622[key][2] / 50)*0.5/60.} hours at 30s per sub-batch')
    print(f'so {math.ceil(scaling220622[key][0] / 50)*0.5/60./64} hours per job')

In [None]:
hl_heads = {'SNIa': (10, 2141291), # 2.75 hours
            'SNII': (19, 2449022), # 3.25 hours
            'SNIbc': (19, 3354171), # 4.25 hours
            'UNMATCHED_KN_SHIFT': (19, 1907394), #2.5 hours
            'UNMATCHED_COSMODC2': (18, 1907393)} # 2.5 hours
# for key in hl_heads.keys():
    # testpath = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/debug/'+key+'_BADMAGS.gz'
    # print((key, str(math.ceil(hl_heads[key][1]/(25*100)/32)), "number of 30-minute batches per core"))

In [None]:
lengths = {}
for which_hl in hl_heads.keys():
    # if which_hl != 'SNIbc':
    print((which_hl, os.path.getsize('/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/'+which_hl+'_GHOST.HOSTLIB.gz')/(1024**3)))
    hl_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/'+which_hl+'_GHOST.HOSTLIB.gz'
        # hl_path = '/global/cfs/cdirs/lsst/groups/TD/SN/SNANA/SURVEYS/LSST/ROOT/PLASTICC_DEV/HOSTLIB/magerr/debug/'+which_hl+'_GHOST.HOSTLIB_BADMAGS.gz'
    hl_head = int(subprocess.check_output(f"zcat {hl_path} | cat -n | sed -n '/VARNAMES/ {{ p; q }}'  | awk '{{print $1-1}}'", shell=True))
    df = pd.read_csv(hl_path, skiprows=hl_head, delimiter=' ', header=0)
    lengths[which_hl] = len(df)
    print(len(df))