In [None]:
import os
import numpy as np
import rail
from rail.creation.degradation import LSSTErrorModel, InvRedshiftIncompleteness, LineConfusion, QuantityCut
from rail.creation.engines.flowEngine import FlowEngine, FlowPosterior
from rail.core.data import TableHandle
from rail.core.stage import RailStage
from rail.core.utilStages import ColumnMapper, TableConverter

from rail.estimation.algos.bpz_lite import BPZ_lite
from rail.estimation.algos.trainZ import Train_trainZ, TrainZ
from rail.estimation.algos.sklearn_nn import Train_SimpleNN, SimpleNN
from rail.estimation.algos.randomPZ import RandomPZ
from rail.estimation.algos.flexzboost import Train_FZBoost, FZBoost

from rail.evaluation.evaluator import Evaluator


In [None]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [None]:
RAIL_DIR = os.path.join(os.path.dirname(rail.__file__), '..')
flow_file = os.path.join(RAIL_DIR, 'examples/goldenspike/data/pretrained_flow.pkl')
bands = ['u','g','r','i','z','y']
band_dict = {band:f'mag_{band}_lsst' for band in bands}
rename_dict = {f'mag_{band}_lsst_err':f'mag_err_{band}_lsst' for band in bands}

In [None]:
flow_engine_test = FlowEngine.make_stage(name='flow_engine_test', 
                                         flow_file=flow_file, n_samples=50)
      
lsst_error_model_test = LSSTErrorModel.make_stage(name='lsst_error_model_test',
                                                  bandNames=band_dict)
                
col_remapper_test = ColumnMapper.make_stage(name='col_remapper_test', hdf5_groupname='',
                                            columns=rename_dict)

flow_post_test = FlowPosterior.make_stage(name='flow_post_test',
                                          column='redshift', flow_file=flow_file,
                                          grid=np.linspace(0., 5., 21))

table_conv_test = TableConverter.make_stage(name='table_conv_test', output_format='numpyDict', 
                                            seed=12345)


In [None]:
test_data_orig = flow_engine_test.sample(50, 12345)
test_data_errs = lsst_error_model_test(test_data_orig)
test_data_pq = col_remapper_test(test_data_errs)
test_data_post = flow_post_test.get_posterior(test_data_pq, 'redshift', err_samples=None)
test_data = table_conv_test(test_data_pq)

In [None]:
flow_engine_train = FlowEngine.make_stage(name='flow_engine_train', 
                                          flow_file=flow_file, n_samples=50,
                                          seed=12345)

lsst_error_model_train = LSSTErrorModel.make_stage(name='lsst_error_model_train',
                                                   bandNames=band_dict)

inv_redshift = InvRedshiftIncompleteness.make_stage(name='inv_redshift',
                                                    pivot_redshift=1.0)

line_confusion = LineConfusion.make_stage(name='line_confusion', 
                                          true_wavelen=5007., wrong_wavelen=3727., frac_wrong=0.05)

quantity_cut = QuantityCut.make_stage(name='quantity_cut',    
                                      cuts={'mag_i_lsst': 25.3})

col_remapper_train = ColumnMapper.make_stage(name='col_remapper_train', columns=rename_dict)
   
table_conv_train = TableConverter.make_stage(name='table_conv_train', output_format='numpyDict')
 


In [None]:
train_data_orig = flow_engine_train.sample(50, 12345)
train_data_errs = lsst_error_model_train(train_data_orig)
train_data_inc = inv_redshift(train_data_errs)
train_data_conf = line_confusion(train_data_inc)
train_data_cut = quantity_cut(train_data_conf)
train_data_pq = col_remapper_train(train_data_cut)
train_data = table_conv_train(train_data_pq)

In [None]:
train_trainZ = Train_trainZ.make_stage(name='train_trainZ', input='inprogress_output_table_conv_train.hdf5', 
                                       model_file='trainZ.pkl', hdf5_groupname='')

train_simpleNN = Train_SimpleNN.make_stage(name='train_simpleNN', input='inprogress_output_table_conv_train.hdf5', 
                                           model_file='simpleNN.pkl', hdf5_groupname='')

train_fzboost = Train_FZBoost.make_stage(name='train_FZBoost', input='inprogress_output_table_conv_train.pq', 
                                         model_file='fzboost.pkl', hdf5_groupname='')

In [None]:
train_data.data.keys()

In [None]:
train_trainZ.inform(train_data)
#train_simpleNN.inform(train_data)
#train_fzboost.inform(train_data)

In [None]:
test_bpz = BPZ_lite.make_stage(name='test_bpz', model_file='None',
                               hdf5_groupname='', columns_file='../estimation/configs/test_bpz.columns')

test_trainZ = TrainZ.make_stage(name='test_trainZ', hdf5_groupname='', model_file='inprogress_trainZ.pkl')

test_randomPZ = RandomPZ.make_stage(name='test_randomZ', hdf5_groupname='', model_file='None')

#test_simpleNN = SimpleNN.make_stage(name='test_simpleNN', 
#                                    model_file='simpleNN.pkl')

#test_fzboost = FZBoost.create(name='test_FZBoost', 
#                              model_file='fzboost.pkl', 
#                              aliases=dict(input='test_data', output='fzboost_estim'))

In [None]:
bpz_estim = test_bpz.estimate(test_data)
trainZ_estim = test_trainZ.estimate(test_data)
randomPZ_estim = test_randomPZ.estimate(test_data)

In [None]:
eval_dict = dict(bpz=bpz_estim, trainZ=trainZ_estim)
truth = test_data_orig

result_dict = {}
for key, val in eval_dict.items():
    the_eval = Evaluator.make_stage(name=f'{key}_eval', truth=truth)
    result_dict[key] = the_eval.evaluate(val, truth)

In [None]:
import tables_io
results_tables = {key:tables_io.convertObj(val.data, tables_io.types.PD_DATAFRAME) for key,val in result_dict.items()}

In [None]:
results_tables['bpz']

In [None]:
results_tables['trainZ']