Before anything, fit all PLAsTiCC data with SALT2 to see what survives, using the [script from the metrics paper](https://github.com/COINtoolbox/RESSPECT_metric/blob/main/code/01_SALT2_fit.py).

In [1]:
import pandas as pd
import numpy as np
import os
import glob

# Read SALT2 fit results

In [2]:
# choose between DDF or WFD
subsample = 'DDF'

# path to SALT2 fit results
input_dir = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + subsample + '/SALT2_fit/'

# list all available types
type_list = os.listdir(input_dir)

fitres_list = []

for name in type_list:
    
    flist = glob.glob(input_dir + name + '/fitres/master_fitres_*.fitres')
    
    for fname in flist:
        try:
            fitres_temp = pd.read_csv(fname, delim_whitespace=True, comment='#')
            fitres_list.append(fitres_temp)
        except ValueError:
            pass
        
fitres_all = pd.concat(fitres_list, ignore_index=True)

In [3]:
fitres_all.shape

(10228, 111)

Read all test metadata and confirm that all objects surviving SALT2 are on test

In [4]:
# read PLAsTiCC test
fname_test = '/media/RESSPECT/data/PLAsTiCC/PLAsTiCC_zenodo/plasticc_test_metadata.csv'
zenodo_test = pd.read_csv(fname_test)

In [5]:
flag_fitres = [item in zenodo_test['object_id'].values for item in fitres_all['CID'].values]
sum(flag_fitres)

10228

### Build validation, test and pool samples

- Validation set is used for code development and fine tunning
- Test data set is held back through the entire process and only used to produce results for final publication
- Pool sample are used to query during the active learning loop

In [6]:
# Set proportions for test and validation samples
frac_test_val = 0.2

val_test = fitres_all.sample(n= int(2 * frac_test_val * fitres_all.shape[0]), replace=False)
validation = val_test.sample(n=int(0.5 * val_test.shape[0]), replace=False)

flag_test = np.array([item not in validation['CID'].values for item in val_test['CID'].values])
test = val_test[flag_test]

flag_pool = np.array([item not in val_test['CID'].values for item in fitres_all['CID'].values])
pool = fitres_all[flag_pool]

print(' *** Sample sizes *** \n')
print('Survived SALT2: ', fitres_all.shape[0], '  ( 100 %)')
print('    Validation: ', validation.shape[0], '   (', int(100 * validation.shape[0]/fitres_all.shape[0]), ' %)')
print('          Test: ', test.shape[0], '   (', int(100 * test.shape[0]/fitres_all.shape[0]), ' %)')
print('          Pool: ', pool.shape[0], '   (', int(100 * pool.shape[0]/fitres_all.shape[0]), ' %)')

 *** Sample sizes *** 

Survived SALT2:  10228   ( 100 %)
    Validation:  2045    ( 19  %)
          Test:  2046    ( 20  %)
          Pool:  6137    ( 60  %)


In [8]:
# save ids to file
validation.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/initial_samples/' + \
                  subsample + '_validation_fitres.csv', 
                  index=False)

test.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/initial_samples/' + \
            subsample +'_test_fitres.csv', 
            index=False)

pool.to_csv('/media/RESSPECT/data/PLAsTiCC/for_pipeline/initial_samples/' + \
            subsample + '_pool_fitres.csv', 
            index=False)