In [1]:
# Copyright 2023 resspect software
# Author: Emille E. O. Ishida
#
# created on 17 January 2023
#
# Licensed MIT License;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://opensource.org/license/mit/
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Before anything, fit all PLAsTiCC data with SALT2 to see what survives, using the [script from the metrics paper](https://github.com/COINtoolbox/RESSPECT_metric/blob/main/code/01_SALT2_fit.py).

In [2]:
import pandas as pd
import numpy as np
import os
import glob

# Read SALT2 fit results

In [3]:
# choose between DDF or WFD
subsample = 'DDF'

# path to SALT2 fit results
input_dir = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + subsample + '/SALT2_fit/'

# list all available types
type_list = os.listdir(input_dir)

fitres_list = []

for name in type_list:
    
    flist = glob.glob(input_dir + name + '/fitres/master_fitres_*.fitres')
    
    for fname in flist:
        try:
            fitres_temp = pd.read_csv(fname, delim_whitespace=True, comment='#')
            fitres_temp['zHD'] = fitres_temp['SIM_ZCMB']
            fitres_temp.fillna(-99, inplace=True)
            fitres_list.append(fitres_temp)
        except ValueError:
            pass
        
fitres_all = pd.concat(fitres_list, ignore_index=True)
fitres_all.fillna(-99, inplace=True)

In [4]:
fitres_all.shape

(10228, 111)

Read all test metadata and confirm that all objects surviving SALT2 are on test

In [5]:
# read PLAsTiCC test
fname_test = '/media/RESSPECT/data/PLAsTiCC/PLAsTiCC_zenodo/plasticc_test_metadata.csv'
zenodo_test = pd.read_csv(fname_test)

In [6]:
flag_fitres = [item in zenodo_test['object_id'].values for item in fitres_all['CID'].values]
sum(flag_fitres)

10228

### Build validation, test and pool samples

- Validation set is used for code development and fine tunning
- Test data set is held back through the entire process and only used to produce results for final publication
- Pool sample are used to query during the active learning loop -- this does not need to survive SALT2

In [7]:
# Set proportions for test and validation samples
frac_test_val = 0.2

val_test = fitres_all.sample(n= int(2 * frac_test_val * fitres_all.shape[0]), replace=False)
validation = val_test.sample(n=int(0.5 * val_test.shape[0]), replace=False)

flag_test = np.array([item not in validation['CID'].values for item in val_test['CID'].values])
test = val_test[flag_test]

flag_subsample = np.array([zenodo_test['ddf_bool'].values[i] == int('DDF' == subsample) 
                           for i in range(zenodo_test.shape[0])])
test_ids = zenodo_test[flag_subsample]['object_id'].values
flag_pool = np.array([item not in val_test['CID'].values for item in test_ids])
pool = zenodo_test[flag_subsample][flag_pool]

print(' *** Sample sizes *** \n')
print('Survived SALT2: ', fitres_all.shape[0], '  ( 100 %)')
print('    Validation: ', validation.shape[0], '   (', int(100 * validation.shape[0]/fitres_all.shape[0]), ' %)')
print('          Test: ', test.shape[0], '   (', int(100 * test.shape[0]/fitres_all.shape[0]), ' %)')
print('\n')
print('Zenodo test: ', len(test_ids), '  ( 100 %)')
print('       Pool: ', pool.shape[0], '   (', int(100 * pool.shape[0]/len(test_ids)), ' %)')

 *** Sample sizes *** 

Survived SALT2:  10228   ( 100 %)
    Validation:  2045    ( 19  %)
          Test:  2046    ( 20  %)


Zenodo test:  32926   ( 100 %)
       Pool:  28835    ( 87  %)


# save ids to file
validation.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/' + subsample + '/initial_samples/' + \
                  subsample + '_validation_fitres.csv', 
                  index=False)

test.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/' + subsample + '/initial_samples/' + \
            subsample +'_test_fitres.csv', 
            index=False)

pool.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/' + subsample + '/initial_samples/' + \
            subsample + '_pool_metadata.csv', 
            index=False)

# Separate perfect samples for comparison