In [None]:
import numpy as np
import matplotlib.pylab as plt
import pandas as pd
from glob import glob

In [None]:
# translate SNANA types
types_names = {90:'Ia', 67: '91bg', 52:'Iax', 42:'II', 62:'Ibc', 
               95: 'SLSN', 15:'TDE', 64:'KN', 88:'AGN', 92:'RRL', 65:'M-dwarf',
               16:'EB',53:'Mira', 6:'MicroL', 991:'MicroLB', 992:'ILOT', 
               993:'CART', 994:'PISN',995:'MLString'}

SNANA_types = {90:11, 62:{1:3, 2:13}, 42:{1:2, 2:12, 3:14},
               67:41, 52:43, 64:51, 95:60, 994:61, 992:62,
               993:63, 15:64, 88:70, 92:80, 65:81, 16:83,
               53:84, 991:90, 6:{1:91, 2:93}}

SNANA_names = {11: 'Ia', 3:'Ibc', 13: 'Ibc', 2:'II', 12:'II', 14:'II',
               41: '91bg', 43:'Iax', 51:'KN', 60:'SLSN', 61:'PISN', 62:'ILOT',
               63:'CART', 64:'TDE', 70:'AGN', 80:'RRL', 81:'M-dwarf', 83:'EB',
               84:'Mira', 90:'MicroLB', 91:'MicroL', 93:'MicroL'}

# DDF

In [None]:
# only types used in the paper
sntypes_ddf = ['Ia', 'II', 'Ibc', 'Iax']

# read zenodo metadata
fname = '/media/RESSPECT/data/PLAsTiCC/PLAsTiCC_zenodo/plasticc_test_metadata.csv'
test_metadata = pd.read_csv(fname)

# separate fields
ddf_flag = test_metadata['ddf_bool'].values == 1

# separate types
raw_z_ddf = {}
for snclass in [42, 62, 52, 90]:
    flag_type_temp = test_metadata['true_target'].values == snclass 
    flag_ddf_temp = np.logical_and(flag_type_temp, ddf_flag)
    raw_z_ddf[types_names[snclass]] = test_metadata[flag_ddf_temp]['true_z'].values
    
    
raw_z_ddf['all'] = []
for key in sntypes_ddf:
    raw_z_ddf['all'] = raw_z_ddf['all'] + list(raw_z_ddf[key])

In [None]:
print('Total number objects in raw data:', test_metadata.shape[0])
print('    of which: ', sum(ddf_flag), '   from DDF (', round(100 * sum(ddf_flag)/test_metadata.shape[0], 4),'%)')
print('         and: ', sum(~ddf_flag), ' from WFD (', round(100*sum(~ddf_flag)/test_metadata.shape[0], 4),'%)')

In [None]:
# read all objects that survived SALT2 used in the paper
ddf_fitres_dict = {}
all_ddf_z = []
for name in sntypes_ddf:
    flist_ddf = glob('/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/DDF/SALT2_fit/' + \
                          name + '/fitres/master_*.fitres')

    ddf_fitres_list = []
    for fname in flist_ddf:
        try:
            fitres_ddf_temp = pd.read_csv(fname, delim_whitespace=True, comment='#')
            ddf_fitres_list = ddf_fitres_list + list(fitres_ddf_temp['SIM_ZCMB'].values)
        except ValueError:
            pass
    
    if len(ddf_fitres_list) > 0:                  
        ddf_fitres_dict[name] = ddf_fitres_list
        all_ddf_z = all_ddf_z + ddf_fitres_list
        
ddf_fitres_dict['all'] = all_ddf_z

In [None]:
# read all objects that survived SALT2 from classes that did not make it to the paper plots
ddf_fitres_dict2 = {}
all_ddf_z2 = []

sntypes_ddf2 = list(types_names.values())
for name in sntypes_ddf:
    if name in sntypes_ddf2:
        sntypes_ddf2.remove(name)
    
for name2 in sntypes_ddf2:
    flist_ddf2 = glob('/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/DDF/SALT2_fit/' + \
                          name2 + '/fitres/master_*.fitres')

    ddf_fitres_list2 = []
    for fname in flist_ddf2:
        try:
            fitres_ddf_temp2 = pd.read_csv(fname, delim_whitespace=True, comment='#')
            ddf_fitres_list2 = ddf_fitres_list2 + list(fitres_ddf_temp2['SIM_ZCMB'].values)
        except ValueError:
            pass
    
    if len(ddf_fitres_list2) > 0:                  
        ddf_fitres_dict2[name2] = ddf_fitres_list2
        all_ddf_z2 = all_ddf_z2 + ddf_fitres_list2
        
ddf_fitres_dict2['all'] = all_ddf_z2

# WFD

In [None]:
# only types used in the paper
sntypes_wfd = ['Ia', 'II', 'Ibc', 'Iax', 'CART', 'SLSN']

# separate types
raw_z_wfd = {}
for snclass in [90, 42, 62, 52, 993, 95]:
    flag_type_temp = test_metadata['true_target'].values == snclass 
    flag_wfd_temp = np.logical_and(flag_type_temp, ~ddf_flag)
    raw_z_wfd[types_names[snclass]] = test_metadata[flag_wfd_temp]['true_z'].values
    
    
raw_z_wfd['all'] = []
for key in sntypes_wfd:
    raw_z_wfd['all'] = raw_z_wfd['all'] + list(raw_z_wfd[key])

In [None]:
# read all objects that survived SALT2
wfd_fitres_dict = {}
all_wfd_z = []
for name in sntypes_wfd:
    flist_wfd = glob('/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/WFD/SALT2_fit/' + \
                          name + '/fitres/master_*.fitres')

    wfd_fitres_list = []
    for fname in flist_wfd:
        try:
            fitres_wfd_temp = pd.read_csv(fname, delim_whitespace=True, comment='#')
            wfd_fitres_list = wfd_fitres_list + list(fitres_wfd_temp['SIM_ZCMB'].values)
        except ValueError:
            pass
    
    if len(wfd_fitres_list) > 0:     
        print(name, len(list(wfd_fitres_list)))
        wfd_fitres_dict[name] = wfd_fitres_list
        all_wfd_z = all_wfd_z + wfd_fitres_list
        
wfd_fitres_dict['all'] = all_wfd_z

In [None]:
# read all objects that survived SALT2 from classes that did not make it to the paper plots
wfd_fitres_dict2 = {}
all_wfd_z2 = []

sntypes_wfd2 = list(types_names.values())
for name in sntypes_wfd:
    if name in sntypes_wfd2:
        sntypes_wfd2.remove(name)
    
for name2 in sntypes_wfd2:
    flist_wfd2 = glob('/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/WFD/SALT2_fit/' + \
                          name2 + '/fitres/master_*.fitres')

    wfd_fitres_list2 = []
    for fname2 in flist_wfd2:
        try:
            fitres_wfd_temp2 = pd.read_csv(fname2, delim_whitespace=True, comment='#')
            wfd_fitres_list2 = wfd_fitres_list2 + list(fitres_wfd_temp2['SIM_ZCMB'].values)
        except ValueError:
            pass
    
    if len(wfd_fitres_list2) > 0:    
        wfd_fitres_dict2[name2] = wfd_fitres_list2
        all_wfd_z2 = all_wfd_z2 + wfd_fitres_list2
        
wfd_fitres_dict2['all'] = all_wfd_z2

In [None]:
n_salt2fit = sum([len(wfd_fitres_dict['all']), len(ddf_fitres_dict['all']),
                  len(wfd_fitres_dict2['all']), len(ddf_fitres_dict2['all'])])
n_salt2fit_ddf = len(ddf_fitres_dict['all'])+  len(ddf_fitres_dict2['all'])
n_salt2fit_wfd = len(wfd_fitres_dict['all'])+ len(wfd_fitres_dict2['all'])

print('After SALT2 fit: ', n_salt2fit, ' (', round(100*n_salt2fit/test_metadata.shape[0], 4), '%)')
print('    of which: ',  n_salt2fit_ddf, ' (', round(100*n_salt2fit_ddf/n_salt2fit,4),'% of SALT2fit pop from DDF)')
print('         and: ', n_salt2fit_wfd, ' (', round(100*n_salt2fit_wfd/n_salt2fit,4),'% of SALT2fit pop from WFD)')

In [None]:
sntypes_ddf = ['all'] + sntypes_ddf
sntypes_wfd = ['all'] + sntypes_wfd

In [None]:
norm = True

plt.figure(figsize=(16,30))

plt_indx = [1,3,5,7,9]

for i in range(1,15):
    if i in [1,3,5,7,9]:
        plt.subplot(7,2,i)
        plt.hist(raw_z_ddf[sntypes_ddf[i // 2]], label='original PLAsTiCC', bins=50, density=norm, alpha=0.5)
        plt.hist(ddf_fitres_dict[sntypes_ddf[i // 2]], label='after SALT2 fit', bins=50, density=norm, alpha=0.5)
        plt.xlabel('true_z', fontsize=14)
        plt.ylabel('N', fontsize=14)
        leg1 = plt.legend(title='DDF - ' + sntypes_ddf[i // 2], fontsize=12)
        leg1.get_title().set_fontsize('12')
        plt.xlim(0,2)

    elif i not in [11,13]:
        plt.subplot(7,2,i)
        plt.hist(raw_z_wfd[sntypes_wfd[i // 2 - 1]], label='original PLAsTiCC', bins=50, density=norm, alpha=0.5)
        plt.hist(wfd_fitres_dict[sntypes_wfd[i // 2 - 1]], label='after SALT2 fit', bins=50, density=norm, alpha=0.5)
        plt.xlabel('true_z', fontsize=14)
        plt.ylabel('N', fontsize=14)
        leg1 = plt.legend(title='WFD - ' + sntypes_wfd[i // 2 - 1], fontsize=12)
        leg1.get_title().set_fontsize('12')
        plt.xlim(0,2)

if norm:
    plt.savefig('redshit_normalized.png')
else:
    plt.savefig('redshit_not_normalized.png')