In [None]:
import pandas as pd
import pickle 
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kde
from copy import deepcopy
import numpy as np
from collections import OrderedDict

In [None]:
all_shapes = {'SLSN': 'o',
              'SNIax': 's',
              'SNII': 'd',
              'SNIbc': 'X',
              'SLSN': 'v',
              'AGN': '^',
              'TDE': '<',
              'KN': '>',
              'CART': 't'}

alpha_dict = dict()
alpha_dict['wfd']= {          'perfect3000': 1, 
                              'fiducial3000': 1, 
                              'random3000': 1,
                              '75SNIa25SNII': 1, 
                              '90SNIa10SNII': 0.9,
                              '95SNIa5SNII': 0.75,
                              '98SNIa2SNII': 0.6,
                              '99SNIa1SNII': 0.45,
                              '90SNIa10SNIbc': 0.9,
                              '95SNIa5SNIbc': 0.75,
                              '98SNIa2SNIbc': 0.6,
                              '99SNIa1SNIbc': 0.45,
                              '75SNIa25SNIax': 1,
                              '90SNIa10SNIax': 0.9,
                              '95SNIa5SNIax': 0.75,
                              '98SNIa2SNIax': 0.6,
                              '99SNIa1SNIax': 0.45,
                              '98SNIa2CART': 0.6,
                              '99SNIa1CART': 0.45,
                              '98SNIa2SLSN': 0.6,
                              '99SNIa1SLSN': 0.45
                  }
alpha_dict['ddf'] = {
                          'perfect3000': 1, 
                          'fiducial3000': 1, 
                          'random3000': 1,
                          '90SNIa10SNII': 0.9,
                          '95SNIa5SNII': 0.75,
                          '98SNIa2SNII': 0.6,
                          '99SNIa1SNII': 0.45,
                          '95SNIa5SNIbc': 0.75,
                          '98SNIa2SNIbc': 0.6,
                          '99SNIa1SNIbc': 0.45,
                          '90SNIa10SNIax': 0.9,
                          '95SNIa5SNIax': 0.75,
                          '98SNIa2SNIax': 0.6,
                          '99SNIa1SNIax': 0.45}

def make_remap_dict_perc(file_extension):
    if 'WFD' == file_extension:
        remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
                              'random3000': 'Random',
                              '75SNIa25SNII': 'SN-II 25 %', 
                              '90SNIa10SNII': 'SN-II 10 %',
                              '95SNIa5SNII': 'SN-II 5 %',
                              '98SNIa2SNII': 'SN-II 2 %',
                              '99SNIa1SNII': 'SN-II 1 %',
                              '90SNIa10SNIbc': 'SN-Ibc 10 %',
                              '95SNIa5SNIbc': 'SN-Ibc 5 %',
                              '98SNIa2SNIbc': 'SN-Ibc 2 %',
                              '99SNIa1SNIbc': 'SN-Ibc 1 %',
                              '75SNIa25SNIax': 'SN-Iax 25 %',
                              '90SNIa10SNIax': 'SN-Iax 10 %',
                              '95SNIa5SNIax': 'SN-Iax 5 %',
                              '98SNIa2SNIax': 'SN-Iax 2 %',
                              '99SNIa1SNIax': 'SN-Iax 1 %',
                              '98SNIa2CART': 'CART 2 %',
                              '99SNIa1CART': 'CART 1 %',
                              '98SNIa2SLSN': 'SLSN 2 %',
                              '99SNIa1SLSN': 'SLSN 1 %'
                  })
    else:
        remap_dict = OrderedDict({
                          'perfect3000': 'Perfect', 
                          'fiducial3000': 'Fiducial', 
                          'random3000': 'Random',
                          '90SNIa10SNII': 'SN-II 10 %',
                          '95SNIa5SNII': 'SN-II 5 %',
                          '98SNIa2SNII': 'SN-II 2 %',
                          '99SNIa1SNII': 'SN-II 1 %',
                          '95SNIa5SNIbc': 'SN-Ibc 5 %',
                          '98SNIa2SNIbc': 'SN-Ibc 2 %',
                          '99SNIa1SNIbc': 'SN-Ibc 1 %',
                          '90SNIa10SNIax': 'SN-Iax 10 %',
                          '95SNIa5SNIax': 'SN-Iax 5 %',
                          '98SNIa2SNIax': 'SN-Iax 2 %',
                          '99SNIa1SNIax': 'SN-Iax 1 %',
              })
    return(remap_dict)

def make_remap_dict_small(file_extension):
    if 'WFD' == file_extension:
        remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
                              'random3000': 'Random',
                              '75SNIa25SNII': 'SN-II', 
                              '90SNIa10SNII': 'SN-II',
                              '95SNIa5SNII': 'SN-II',
                              '98SNIa2SNII': 'SN-II',
                              '99SNIa1SNII': 'SN-II',
                              '90SNIa10SNIbc': 'SN-Ibc',
                              '95SNIa5SNIbc': 'SN-Ibc',
                              '98SNIa2SNIbc': 'SN-Ibc',
                              '99SNIa1SNIbc': 'SN-Ibc',
                              '75SNIa25SNIax': 'SN-Iax',
                              '90SNIa10SNIax': 'SN-Iax',
                              '95SNIa5SNIax': 'SN-Iax',
                              '98SNIa2SNIax': 'SN-Iax',
                              '99SNIa1SNIax': 'SN-Iax',
                              '98SNIa2CART': 'CART',
                              '99SNIa1CART': 'CART',
                              '98SNIa2SLSN': 'SLSN',
                              '99SNIa1SLSN': 'SLSN'
                  })
    else:
        remap_dict = OrderedDict({
                          'perfect3000': 'Perfect', 
                          'fiducial3000': 'Fiducial', 
                          'random3000': 'Random',
                          '90SNIa10SNII': 'SN-II',
                          '95SNIa5SNII': 'SN-II',
                          '98SNIa2SNII': 'SN-II',
                          '99SNIa1SNII': 'SN-II',
                          '95SNIa5SNIbc': 'SN-Ib',
                          '98SNIa2SNIbc': 'SN-Ibc',
                          '99SNIa1SNIbc': 'SN-Ibc',
                          '90SNIa10SNIax': 'SN-Iax',
                          '95SNIa5SNIax': 'SN-Iax',
                          '98SNIa2SNIax': 'SN-Iax',
                          '99SNIa1SNIax': 'SN-Iax',
              })
    return(remap_dict)

a_file = open("colors.pkl", "rb")
contaminant_colors = pickle.load(a_file)
contaminant_colors['wfd']['perfect3000'] = 'black'
contaminant_colors['ddf']['perfect3000'] = 'black'
contaminant_colors['DDF'] = contaminant_colors['ddf']
contaminant_colors['WFD'] = contaminant_colors['wfd']


In [None]:
k = {}
k['DDF'] = [7, 7, 7]
k['WFD'] = [5, 5, 5]

groups = {}
groups['DDF'] = {}
groups['DDF'][0] = ['99SNIa1SNII', '99SNIa1SNIax', '99SNIa1SNIbc', 'perfect3000']
groups['DDF'][1] = ['98SNIa2SNII', '98SNIa2SNIax', '98SNIa2SNIbc', 'perfect3000']
groups['DDF'][2] = ['95SNIa5SNII', '95SNIa5SNIax', '95SNIa5SNIbc', 'perfect3000']

groups['WFD'] = {}
groups['WFD'][0] = ['99SNIa1SNII', '99SNIa1SNIax', '99SNIa1SNIbc', '99SNIa1CART', '99SNIa1SLSN', 'perfect3000']
groups['WFD'][1] = ['98SNIa2SNII', '98SNIa2SNIax', '98SNIa2SNIbc', '98SNIa2CART', '98SNIa2SLSN', 'perfect3000']
groups['WFD'][2] = ['95SNIa5SNII', '95SNIa5SNIax', '95SNIa5SNIbc', 'perfect3000']

data_temp = []

for field in ['WFD', 'DDF']:
    for i in range(3):
        for case in groups[field][i]:

            fname = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + field + '/results/v' + str(k[field][i]) + '/' + \
                    '3000/posteriors/csv/chains_' + case + '_lowz_withbias.csv.gz'
    
    
            remaps = make_remap_dict_perc(field)
            data = pd.read_csv(fname, index_col=False)
            data['case'] = case
            data['case_label'] = remaps[case]
            data['field'] = field
            data_temp.append(data.sample(n=2000, replace=False))
    
data_all = pd.concat(data_temp, ignore_index=True)

In [None]:
fs_axlabel = 22
fs_legend = 22
fs_ticks = 16

percs = ['1 %', '2 %', '5 %', '1 %', '2 %', '5 %']

legend_objs = []
legends_done = []
axs = {}

plt.figure(figsize=(24,10))

for j in range(6):    
    if j < 3:
        field = 'DDF'
    else:
        field = 'WFD'
        
    flag_field = data_all['field'].values == field
        
    axs[j] = plt.subplot(2,3,j + 1)
    axs[j].scatter([0.3], [-1], marker='*', s=80, color='black', zorder=4)
    
    for i in range(len(groups[field][j - (j // 3)*3])):
        flag_plot = np.logical_and(flag_field, data_all['case'].values == groups[field][j - (j//3) * 3][i])
        
        if 'perfect' in groups[field][j - (j//3)*3][i]:
            alpha = 0.25
            line = "--"
        else:
            alpha = 1
            line = '-'
            
        sns.kdeplot(
            data=data_all[flag_plot], x="om", y="w", levels=[0.05] , ax=axs[j], color=contaminant_colors[field][groups[field][j - (j//3)*3][i]], 
            linewidths=3, zorder=i+1, linestyles=line, alpha=alpha
        )
    axs[j].set_xlim(0.25, 0.35)
    axs[j].set_ylim(-1.5, -0.85)
    
    if j == 0 or j == 3:
        axs[j].set_ylabel('w',  fontsize=fs_axlabel)
        axs[j].set_yticks([-1.5, -1.3, -1.1, -0.9])
        axs[j].tick_params(axis='y', labelsize=fs_ticks)
    else:
        axs[j].set_yticks([])
        axs[j].set_ylabel('')
    if j < 3:    
        axs[j].set_xticks([])
        axs[j].set_xlabel('')
    else:
        axs[j].set_xticks([0.26, 0.28, 0.3, 0.32, 0.34])
        axs[j].tick_params(axis='x', labelsize=fs_ticks)
        axs[j].set_xlabel(r'$\Omega_m$', fontsize=fs_axlabel)
        
    axs[j].text(0.255, -1.45, field + ' - ' + percs[j], fontsize=fs_axlabel)
    
for field in ['DDF', 'WFD']:
    for k in range(3):
        for a in groups[field][k]:
            if a not in legends_done:
                if 'perfect' in a:
                    alpha = 0.25
                    line = "--"
                else:
                    alpha = 1
                    line = '-'
                
                legends_done.append(a)
                l1 = axs[2].plot([],[], color=contaminant_colors[field][a],
                                    label=make_remap_dict_small(field)[a], alpha=alpha, ls=line,lw=3)
                legend_objs.append(l1)
    l1 = axs[2].scatter([], [], marker='*', s=80, color='black', label=r'w=-1, $\Omega_m=0.3$', zorder=4)

handles, labels = axs[2].get_legend_handles_labels()
order = [labels.index('SN-II'), labels.index('SN-Iax'), labels.index('SN-Ibc'),labels.index('CART'), labels.index('SLSN'), labels.index('Perfect'), 
        labels.index(r'w=-1, $\Omega_m=0.3$')]

axs[2].legend([handles[idx] for idx in order],[labels[idx] for idx in order], frameon=False, bbox_to_anchor=(1.75, 1.), labelspacing=0.9, fontsize=fs_legend)

plt.subplots_adjust(hspace=0., wspace=0., right=0.75)
#plt.show()
plt.savefig('om_w.pdf', bbox_inches='tight')