In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import tables_io

In [2]:
from rail.estimation.algos.tpz_lite import TPZliteInformer

In [3]:
from rail.core.utils import RAILDIR

In [4]:
RAILDIR

'/Users/sam/anaconda3/envs/mlztest/lib/python3.10/site-packages'

In [5]:
datafile = os.path.join(RAILDIR,"rail/examples_data/testdata/test_dc2_training_9816.hdf5")

In [6]:
import rail
import qp
from rail.core.data import TableHandle
from rail.core.stage import RailStage

In [7]:
DS = RailStage.data_store
DS.__class__.allow_overwrite = True

In [8]:
training_data = DS.read_file("training_data", TableHandle, datafile)


In [9]:
training_data()

OrderedDict([('photometry',
              OrderedDict([('id',
                            array([8062500000, 8062500062, 8062500124, ..., 8082681636, 8082693813,
                                   8082707059])),
                           ('mag_err_g_lsst',
                            array([0.00500126, 0.00508365, 0.00505737, ..., 0.01664717, 0.03818999,
                                   0.05916394], dtype=float32)),
                           ('mag_err_i_lsst',
                            array([0.00500074, 0.00507535, 0.00501555, ..., 0.0153863 , 0.03277681,
                                   0.04307469], dtype=float32)),
                           ('mag_err_r_lsst',
                            array([0.00500058, 0.00504773, 0.00501542, ..., 0.0122792 , 0.02692565,
                                   0.03255744], dtype=float32)),
                           ('mag_err_u_lsst',
                            array([0.00504562, 0.00955173, 0.01114765, ..., 0.20123477, 0.7962344 ,
         

In [10]:
tpz_dict = dict(zmin=0.0, zmax=3.0, nzbins=301, hdf5_groupname='photometry', ref_band="mag_i_lsst", nrandom=3)

In [11]:
pz_train = TPZliteInformer.make_stage(name='inform_TPZ', model='demo_tpz.pkl', **tpz_dict)

In [12]:
pz_train.inform(training_data)

creating 3 random realizations...
making a total of 15 trees for 3 random realizations * 5 bootstraps
making 1 of 15...
making 2 of 15...
making 3 of 15...
making 4 of 15...
making 5 of 15...
making 6 of 15...
making 7 of 15...
making 8 of 15...
making 9 of 15...
making 10 of 15...
making 11 of 15...
making 12 of 15...
making 13 of 15...
making 14 of 15...
making 15 of 15...
Inserting handle into data store.  model_inform_TPZ: inprogress_demo_tpz.pkl, inform_TPZ


<rail.core.data.ModelHandle at 0x142848250>

# try estimate!

In [13]:
from rail.estimation.algos.tpz_lite import TPZliteEstimator

In [14]:
testfile = os.path.join(RAILDIR,"rail/examples_data/testdata/test_dc2_validation_9816.hdf5")

In [15]:
test_data = DS.read_file("test_data", TableHandle, testfile)


In [16]:
#test_data()

In [17]:
test_dict = dict(hdf5_groupname='photometry', ref_band="mag_i_lsst")

In [18]:
test_runner = TPZliteEstimator.make_stage(name="test_tpz", model=pz_train.get_handle('model'), **test_dict)

In [19]:
res = test_runner.estimate(test_data)

Process 0 running estimator on chunk 0 - 10000


  ef = ((e68 - e681) / (area - area1)) * (0.68 - area1) + e681


Inserting handle into data store.  output_test_tpz: inprogress_output_test_tpz.hdf5, test_tpz
Process 0 running estimator on chunk 10000 - 20000
Process 0 running estimator on chunk 20000 - 20449


# test opening the model file

In [20]:
import pickle

In [21]:
with open("demo_tpz.pkl", "rb") as f:
    model = pickle.load(f)

In [22]:
model.keys()

dict_keys(['trainkeys', 'treedict', 'use_atts', 'zmin', 'zmax', 'nzbins', 'att_dict', 'keyatt', 'nrandom', 'ntrees', 'minleaf', 'natt', 'sigmafactor', 'bands', 'rmsfactor'])

In [23]:
# save a tree with np.save to match what TPZ does
extree = model['treedict']['tree_2']

In [24]:
model['treedict']

{'tree_0': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x145474c10>,
 'tree_1': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x146d93ca0>,
 'tree_2': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x145626c20>,
 'tree_3': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x144a2efe0>,
 'tree_4': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x144ca0670>,
 'tree_5': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x1470ffcd0>,
 'tree_6': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x147475ba0>,
 'tree_7': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x1478966e0>,
 'tree_8': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x147ce5f30>,
 'tree_9': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x152f05090>,
 'tree_10': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x153337c40>,
 'tree_11': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x1577655d0>,
 'tree_12': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x157b8f820>,
 'tree_13': <rail.estimation.algos.ml_codes.TPZ.Rtree at 0x16970d300>,
 'tree_14': <rai

In [25]:
np.save("testtree.npy", extree, allow_pickle=True)

In [26]:
loadtree = np.load("testtree.npy", allow_pickle=True)

In [27]:
loadtree

array(<rail.estimation.algos.ml_codes.TPZ.Rtree object at 0x144f116c0>,
      dtype=object)

In [28]:
xx = loadtree.item()

In [29]:
xx

<rail.estimation.algos.ml_codes.TPZ.Rtree at 0x144f116c0>

In [30]:
xdd = xx.dict_dim

In [31]:
xdd

'all'

In [32]:
# pull out data same way it is done in code

In [33]:
from rail.estimation.algos.mlz_utils import data

In [34]:
inputdata = test_data()['photometry']
Ng_temp = np.array(list(inputdata.values()))
testkeys = list(test_data()['photometry'].keys())

In [35]:
Ng_temp

array([[8.06250000e+09, 8.06250003e+09, 8.06250006e+09, ...,
        8.08269386e+09, 8.08270064e+09, 8.08270714e+09],
       [5.05747600e-03, 5.04860934e-03, 5.01400605e-03, ...,
        3.20556425e-02, 3.61963399e-02, 3.40182148e-02],
       [5.01721911e-03, 5.02923271e-03, 5.00793941e-03, ...,
        3.60456705e-02, 3.88294198e-02, 3.82704698e-02],
       ...,
       [1.90898361e+01, 1.96379852e+01, 1.85592480e+01, ...,
        2.52381458e+01, 2.49651814e+01, 2.50031528e+01],
       [1.91976585e+01, 1.96890240e+01, 1.86532288e+01, ...,
        2.47688522e+01, 2.49558201e+01, 2.49272480e+01],
       [2.30460936e-02, 2.18762275e-02, 4.41930996e-02, ...,
        3.02101440e+00, 2.98104019e+00, 2.95916868e+00]])

In [36]:
testkeys

['id',
 'mag_err_g_lsst',
 'mag_err_i_lsst',
 'mag_err_r_lsst',
 'mag_err_u_lsst',
 'mag_err_y_lsst',
 'mag_err_z_lsst',
 'mag_g_lsst',
 'mag_i_lsst',
 'mag_r_lsst',
 'mag_u_lsst',
 'mag_y_lsst',
 'mag_z_lsst',
 'redshift']

In [37]:
Ng = np.array(Ng_temp, 'i')

In [38]:
Ng

array([[-527434591, -527434560, -527434529, ..., -507240737, -507233948,
        -507227451],
       [         0,          0,          0, ...,          0,          0,
                 0],
       [         0,          0,          0, ...,          0,          0,
                 0],
       ...,
       [        19,         19,         18, ...,         25,         24,
                25],
       [        19,         19,         18, ...,         24,         24,
                24],
       [         0,          0,          0, ...,          3,          2,
                 2]], dtype=int32)

In [39]:
Ng_temp.T.shape

(20449, 14)

In [40]:
class objfromdict(object):
    def __init__(self, d):
        for k, v in d.items():
            setattr(self, k, v)

In [41]:
modelatt = objfromdict(model)

In [42]:
test_att_dict={'mag_g_lsst': {'type': 'real', 'ind': 1, 'eind': 7},
                   'mag_i_lsst': {'type': 'real', 'ind': 2, 'eind': 8},
                   'mag_r_lsst': {'type': 'real', 'ind': 3, 'eind': 9},
                   'mag_u_lsst': {'type': 'real', 'ind': 4, 'eind': 10},
                   'mag_y_lsst': {'type': 'real', 'ind': 5, 'eind': 11},
                   'mag_z_lsst': {'type': 'real', 'ind': 6, 'eind': 12},
                   'redshift': {'type': 'real', 'ind': -1, 'eind': -1}}

In [43]:
Test = data.catalog(modelatt, Ng_temp.T, testkeys, modelatt.use_atts, test_att_dict)

In [44]:
Test.get_XY()

In [45]:
Test.X

array([[1.12391384e-02, 5.05747600e-03, 5.01634413e-03, 5.01721911e-03,
        5.03180083e-03, 5.10173244e-03],
       [7.49955885e-03, 5.04860934e-03, 5.02237631e-03, 5.02923271e-03,
        5.06408280e-03, 5.24003245e-03],
       [5.64152608e-03, 5.01400605e-03, 5.00636036e-03, 5.00793941e-03,
        5.01549384e-03, 5.04636532e-03],
       ...,
       [2.66208248e+01, 3.20556425e-02, 2.37318054e-02, 3.60456705e-02,
        6.79195374e-02, 2.40162894e-01],
       [4.04541135e-01, 3.61963399e-02, 2.55193245e-02, 3.88294198e-02,
        8.05819854e-02, 1.86846241e-01],
       [1.19701362e+00, 3.40182148e-02, 2.62609888e-02, 3.82704698e-02,
        7.85020068e-02, 1.93484738e-01]])

In [46]:
Test.nobj

20449