In [None]:
%matplotlib inline

import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as colors
import matplotlib.cm as cmx
from mpl_toolkits.axes_grid1 import make_axes_locatable
from collections import OrderedDict

from resspect import cosmo_metric_utils as cmu
from copy import deepcopy
import arviz as az

In [None]:
# list of numbered files 
nobjs = '3000'
field = 'WFD'

files = ['/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + field + '/results/v' + str(i) + '/' + nobjs + \
             '/summary_stats.csv' for i in range(10)]

In [None]:
if field == 'WFD':
    remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
                              'random3000': 'Random',
                              '75SNIa25SNII': 'SN-II 25 %', 
                              '90SNIa10SNII': 'SN-II 10 %',
                              '95SNIa5SNII': 'SN-II 5 %',
                              '98SNIa2SNII': 'SN-II 2 %',
                              '99SNIa1SNII': 'SN-II 1 %',
                              '90SNIa10SNIbc': 'SN-Ibc 10 %',
                              '95SNIa5SNIbc': 'SN-Ibc 5 %',
                              '98SNIa2SNIbc': 'SN-Ibc 2 %',
                              '99SNIa1SNIbc': 'SN-Ibc 1 %',
                              '75SNIa25SNIax': 'SN-Iax 25 %',
                              '90SNIa10SNIax': 'SN-Iax 10 %',
                              '95SNIa5SNIax': 'SN-Iax 5 %',
                              '98SNIa2SNIax': 'SN-Iax 2 %',
                              '99SNIa1SNIax': 'SN-Iax 1 %',
                              '98SNIa2CART': 'CART 2 %',
                              '99SNIa1CART': 'CART 1 %',
                              '98SNIa2SLSN': 'SLSN 2 %',
                              '99SNIa1SLSN': 'SLSN 1 %'
                  })
else:
    remap_dict = OrderedDict({
                              'perfect3000': 'Perfect', 
                              'fiducial3000': 'Fiducial', 
                              'random3000': 'Random',
                              '90SNIa10SNII': 'SN-II 10 %',
                              '95SNIa5SNII': 'SN-II 5 %',
                              '98SNIa2SNII': 'SN-II 2 %' ,
                              '99SNIa1SNII': 'SN-II 1 %',
                              '95SNIa5SNIbc': 'SN-Ibc 5 %',
                              '98SNIa2SNIbc': 'SN-Ibc 2 %',
                              '99SNIa1SNIbc': 'SN-Ibc 1 %',
                              '90SNIa10SNIax': 'SN-Iax 10 %',
                              '95SNIa5SNIax': 'SN-Iax 5 %',
                              '98SNIa2SNIax': 'SN-Iax 2 %',
                              '99SNIa1SNIax': 'SN-Iax 1 %'
                  })
    

all_shapes = {'SLSN': 'o',
              'SNIax': 's',
              'SN-Iax': 's',
              'SNII': 'd',
              'SN-II': 'd',
              'SNIbc': 'X',
              'SN-Ibc': 'X',
              'AGN': '^',
              'CART': 'v',
              'perfect': 'P',
              'fiducial': 'p',
              'random': 'H' }

if field == 'WFD':
    color_nums = np.array([1, 1, 1,            # Special
                       25, 10, 5, 2, 1,        # II
                           10, 5, 2, 1,        # Ibc
                       25, 10, 5, 2, 1,        # Iax
                       2, 1,                  # SLSN
                        2, 1                # CART
                       ]) 
else:
    color_nums = np.array([1, 1, 1,                # Special
                           10, 5, 2, 1,        # II
                               5, 2, 1,        # Ibc
                           10, 5, 2, 1,        # Iax
                       ]) 
    
# Color map
rainbow = cm = plt.get_cmap('plasma_r')
cNorm  = colors.LogNorm(vmin=1, vmax=30) #colors.Normalize(vmin=0, vmax=50)
scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=rainbow)
color_map = scalarMap.to_rgba(np.arange(1, 30))

In [None]:
summary_dict = {}

summary_dict['fom3'] = {}
summary_dict['KLD'] = {}
summary_dict['Wasserstein'] = {}
summary_dict['FM'] = {}
summary_dict['wfit'] = {}
summary_dict['wfit_std'] = {}
summary_dict['stan'] = {}
summary_dict['stan_std'] = {}


# j -> index from 0 to number of cases
# a -> case
# c -> contamination percentage. Cases with more than one contaminante are set to 1

for j, (a, c) in enumerate(zip(remap_dict, color_nums)):

    summary_dict['fom3'][a] = []
    summary_dict['KLD'][a] = []
    summary_dict['Wasserstein'][a] = []
    summary_dict['FM'][a] = []
    summary_dict['wfit'][a] = []
    summary_dict['wfit_std'][a] = []
    summary_dict['stan'][a] = []
    summary_dict['stan_std'][a] = []
    
    # f -> list of numbered files
    for f in files:
        
        df_t = pd.read_csv(f)
        
        if a in df_t['case'].values:
            summary_dict['fom3'][a].append(df_t['fom3'].loc[df_t['case'] == a].values[0])
            kld = df_t['KLD'].loc[df_t['case'] == a].values[0]
            summary_dict['KLD'][a].append(np.log10(kld))
            summary_dict['Wasserstein'][a].append(df_t['EMD'].loc[df_t['case'] == a].values[0])
            summary_dict['wfit'][a].append(df_t['wfit_w_lowz'].loc[df_t['case'] == a].values[0])
            summary_dict['wfit_std'][a].append(df_t['wfit_wsig_lowz'].loc[df_t['case'] == a].values[0])
            summary_dict['stan'][a].append(df_t['stan_w_lowz'].loc[df_t['case'] == a].values[0])
            summary_dict['stan_std'][a].append(df_t['stan_wsig_lowz'].loc[df_t['case'] == a].values[0])
        
    for i in range(10):
        # Fisher
        dist_loc_base = '/media/RESSPECT/data/PLAsTiCC/for_metrics/final_data3/' + field + '/results/v' + str(i) + '/' + nobjs +'/stan_input/' 
        
        df_fisher = pd.read_csv(dist_loc_base + '/stan_input_salt2mu_lowz_withbias_perfect' + nobjs + '.csv')
        sig_perf = cmu.fisher_results(df_fisher['z'].values, df_fisher['muerr'].values)[0]

        df_ = pd.read_csv(dist_loc_base + '/stan_input_salt2mu_lowz_withbias_'+ a + '.csv')
        sig = cmu.fisher_results(df_['z'].values, df_['muerr'].values)[0]
        summary_dict['FM'][a].append((sig[1]-sig_perf[1])/sig_perf[1])

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(22,10), sharey=True)

# wfit
ax1 = axes[2]

# Bayes
ax2 = axes[3]

ax1.axvline(-1, color='saddlebrown', alpha=0.65)
ax2.axvline(-1, color='saddlebrown', alpha=0.65)

#### wfit perfect
if field == 'DDF':
    jj = 3
    wfit_perf_mean = summary_dict['wfit']['perfect3000'][jj]
    wfit_perf_std = summary_dict['wfit_std']['perfect3000'][jj]
else:
    wfit_perf_mean = np.mean(summary_dict['wfit']['perfect3000'])
    wfit_perf_std = max(summary_dict['wfit_std']['perfect3000'])

ax1.axvline(wfit_perf_mean, color='k', ls='-.')
ax1.axvspan(wfit_perf_mean - wfit_perf_std, 
            wfit_perf_mean +  wfit_perf_std, 
            alpha=0.15, color='grey')

#### stan perfect
if field == 'DDF':
    stan_perf_mean =  summary_dict['stan']['perfect3000'][jj]
    stan_perf_std = summary_dict['stan_std']['perfect3000'][jj]
else:
    stan_perf_mean =  np.mean(summary_dict['stan']['perfect3000'])
    stan_perf_std = max(summary_dict['stan_std']['perfect3000'])

ax2.axvline(stan_perf_mean, color='k', ls='-.')
ax2.axvspan(stan_perf_mean - stan_perf_std, 
            stan_perf_mean + stan_perf_std, 
            alpha=0.15, color='grey')

#### FOM3 perfect
ax3 = axes[0]
fom_perf_mean =  np.mean(summary_dict['fom3']['perfect3000'])
fom_perf_std = np.std(summary_dict['fom3']['perfect3000'])
ax3.axvline(fom_perf_mean, color='k', ls='-.')
ax3.axvspan(fom_perf_mean - fom_perf_std, 
            fom_perf_mean +  fom_perf_std, 
            alpha=0.15, color='grey')

##### Fisher perfect
ax4 = axes[1]
fisher_perf_mean =  np.mean(summary_dict['FM']['perfect3000'])
fisher_perf_std = np.std(summary_dict['FM']['perfect3000'])
ax4.axvline(fisher_perf_mean, color='k', ls='-.')
ax4.axvspan(fisher_perf_mean - fisher_perf_std, 
            fisher_perf_mean +  fisher_perf_std, 
            alpha=0.15, color='grey')

# Wasserstein
ax5 = axes[5]
ax5.axvline(0, color='k', ls='-.')

# KLD
ax6 = axes[4]
#ax6.axvline(1, color='k', ls='-.')

i = 0
tick_lbls = []
i_list = []
for j, (a, c) in enumerate(zip(remap_dict, color_nums)):

    # wfit
    if field == 'DDF':
        wfw = np.array([summary_dict['wfit'][a][jj]])
        wfw_sig = np.array([summary_dict['wfit_std'][a][jj]])
    else:
        wfw = np.array([np.mean(summary_dict['wfit'][a])])
        wfw_sig = np.array([max(summary_dict['wfit_std'][a])])

    class_ = str.split(remap_dict[a])[0]
        
    # Fisher
    fm = np.array(np.nanmean(summary_dict['FM'][a]))
    fm_st_min = np.min(summary_dict['FM'][a])
    fm_st_max = np.max(summary_dict['FM'][a])
        
    # Wasserstein
    if field == 'DDF':
        wsd = np.array([summary_dict['Wasserstein'][a][jj]])
    else:
        wsd = np.array([np.mean(summary_dict['Wasserstein'][a])])
        wsd_st_min = np.min(summary_dict['Wasserstein'][a])
        wsd_st_max = np.max(summary_dict['Wasserstein'][a])
        
    # fom3
    fom3 =  np.array([np.mean(summary_dict['fom3'][a])])
    fom3_sig = np.array([np.std(summary_dict['fom3'][a], ddof=1)])
        
    # KLD
    if field == 'DDF':
        kld = np.array(summary_dict['KLD'][a][jj])
    else:
        kld = np.array(np.mean(summary_dict['KLD'][a]))
        kld_st_min = np.min(summary_dict['KLD'][a])
        kld_st_max = np.max(summary_dict['KLD'][a])

    bad_data=False
    if wfw[0] < -1.2:
        wfw[0] = -1.3
        bad_data=True
        xuplims=[-1.3]
        
        
    if field == 'WFD':
        mfc = "none"
    else:
        mfc = color_map[c]

    if 'fiducial' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'c'
        if bad_data:
            ax1.errorbar(wfw, [-i], xerr=[0.03], marker=all_shapes['fiducial'],color='c',
                         xuplims=xuplims, markersize=10, mfc=mfc2)
        else:
            ax1.plot(wfw, -i, all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
            ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='c', ms=10)
    elif 'random' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'magenta'
        if bad_data:
            ax1.errorbar(wfw, [-i], xerr=[0.03], marker=all_shapes['random'],color='magenta',
                         xuplims=xuplims, markersize=10, mfc=mfc2)
        else:
            ax1.plot(wfw, -i, all_shapes['random'], color='magenta', ms=12, mfc=mfc2)
            ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='magenta', ms=10)
    elif 'perfect' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'k'
        ax1.plot(wfw, -i, all_shapes['perfect'], color='k', ms=12, mfc=mfc2)
        ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color='k', ms=10 )
    else:
        if bad_data:
            ax1.errorbar(wfw, [-i], xerr=[0.03], marker=all_shapes[class_],color=color_map[c],
                         xuplims=xuplims, markersize=10, mfc=mfc)
        else:
            ax1.plot(wfw, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
            ax1.plot([wfw - wfw_sig, wfw + wfw_sig], [-i, -i], "|-", color=color_map[c], ms=10)

    # Stan/Bayes
    if field == 'DDF':
        bw = np.array([summary_dict['stan'][a][jj]])
        bw_sig = summary_dict['stan_std'][a][jj]
    else:
        bw = np.array([np.mean(summary_dict['stan'][a])])
        bw_sig = max(summary_dict['stan_std'][a])
    
    bad_data=False
    if bw[0] < -1.3:
        bw[0] = -1.3
        bad_data=True
        xuplims=[-1.3]
        
    if 'fiducial' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'c'
        if bad_data:
            ax2.errorbar(bw, [-i], xerr=[0.03], marker=all_shapes['fiducial'],color='c',
                             xuplims=xuplims, markersize=10, mfc=mfc2)
        else:
            ax2.plot(bw, -i, color='c', ms=10, marker=all_shapes['fiducial'], mfc=mfc2)
            ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='c', ms=10)
        ax3.plot(fom3, -i, all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
        ax3.plot([fom3 - fom3_sig, fom3 + fom3_sig], [-i, -i], "|-", color='c', ms=10)
        ax4.plot(fm, -i, all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
        ax4.plot([fm_st_min, fm_st_max], [-i, -i], color='c', ms=10)
        if field == 'DDF':
            ax5.plot(wsd, np.array([-i]), all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
        else:
            ax5.plot([wsd_st_min, wsd_st_max], [-i, -i], color='c', ms=10)
            ax5.plot(wsd, np.array([-i]), marker=all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
        ax6.plot(kld, np.array([-i]), marker=all_shapes['fiducial'], color='c', ms=10, mfc=mfc2)
        if field == 'WFD':
            ax6.plot([kld_st_min, kld_st_max], [-i, -i], color='c', ms=10)
      
    elif 'random' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'magenta'
            
        if bad_data:
            ax2.errorbar(bw, [-i], xerr=[0.03], marker=all_shapes['random'],color='magenta',
                             xuplims=xuplims, markersize=10, mfc=mfc2)
        else:
            ax2.plot(bw, -i, all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
            ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='magenta', ms=10)
        ax3.plot(fom3, -i, all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
        ax3.plot([fom3 - fom3_sig, fom3 + fom3_sig], [-i, -i], "|-", color='magenta', ms=10)
        ax4.plot([fm_st_min, fm_st_max], [-i, -i], color='magenta', ms=10)
        ax4.plot(fm, np.array([-i]), marker=all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
        if wsd > 0.35:
            ax5.errorbar(0.35, np.array([-i]), xerr=[0.03], marker=all_shapes['random'],color='magenta',
                         xlolims=0.75, markersize=10, mfc=mfc2)
        elif field == 'DDF':
            ax5.plot(wsd, np.array([-i]), marker=all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
        elif field == 'WFD':
            ax5.plot([wsd_st_min, wsd_st_max], [-i, -i], color='magenta', ms=10)
            ax5.plot(wsd, np.array([-i]), marker=all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
        if kld > 4.35:
            ax6.errorbar(4.35, [-i], xerr=[0.5], marker=all_shapes['random'],color='magenta',
                         xlolims=0.5, markersize=10, mfc=mfc2)
        else:
            ax6.plot(kld, -i, all_shapes['random'], color='magenta', ms=10, mfc=mfc2)
            if field == 'WFD':
                ax6.plot([kld_st_min, kld_st_max], [-i, -i], color='magenta', ms=10)
                
    elif 'perfect' in a:
        if field == 'WFD':
            mfc2 = "none"
        else:
            mfc2 = 'k'

        ax2.plot(bw, -i, color='k', ms=10, marker=all_shapes['perfect'], mfc=mfc2)
        ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color='k', ms=10 )
        ax3.plot(fom3, -i, all_shapes['perfect'], color='k', ms=10, mfc=mfc2)
        ax3.plot([fom3 - fom3_sig, fom3 + fom3_sig], [-i, -i], "|-", color='k', ms=10 )
        ax4.plot(0, -i, all_shapes['perfect'], color='k', ms=10, mfc=mfc2)
        ax5.plot(0, -i, all_shapes['perfect'], color='k', ms=10, mfc=mfc2)
        #ax6.plot(1, -i, all_shapes['perfect'], color='k', ms=10, mfc=mfc2)
    else:
        if bad_data:
            ax2.errorbar(bw, [-i], xerr=[0.03], marker=all_shapes[class_],color=color_map[c],
                         xuplims=xuplims, markersize=10, mfc=mfc)
        else:
            ax2.plot(bw, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
            ax2.plot([bw - bw_sig, bw + bw_sig], [-i, -i], "|-", color=color_map[c], ms=10)
        ax3.plot(fom3, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
        ax3.plot([fom3 - fom3_sig, fom3 + fom3_sig], [-i, -i], "|-", color=color_map[c], ms=10)
        ax4.plot([fm_st_min, fm_st_max], [-i, -i], color=color_map[c], ms=10)
        ax4.plot(fm, np.array([-i]), color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
        if wsd > 0.35:
            ax5.errorbar(0.35, np.array([-i]), xerr=[0.03], marker=all_shapes[class_],color=color_map[c],
                         xlolims=0.75, markersize=10, mfc=mfc)
        elif field == 'DDF':
            ax5.plot(wsd, np.array([-i]), color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
        elif field == 'WFD':
            ax5.plot([wsd_st_min, wsd_st_max], [-i, -i], color=color_map[c], ms=10)
            ax5.errorbar(wsd, np.array([-i]), color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
        if kld > 4.35:
            ax6.errorbar(4.35, [-i], xerr=[0.5],  marker=all_shapes[class_], color=color_map[c],
                                 xlolims=0.5,
                                 markersize=10, mfc=mfc)
            
        else:
            ax6.plot(kld, -i, color=color_map[c], ms=10, marker=all_shapes[class_], mfc=mfc)
            if field == 'WFD':
                ax6.plot([kld_st_min, kld_st_max], [-i, -i], color=color_map[c], ms=10)
                #ax6.errorbar([kld],[-i], xerr=kld_st, color=color_map[c], ms=10)
            
    tick_lbls.append(remap_dict[a])
    i_list.append(-i)
    i +=0.8
        
    if 'random' in a or '99SNIa1' in a:
        i_list.append(-i)
        i += 0.8
        tick_lbls.append('')
        
fs = 18    # size of x-axis labels
ts = 14   # size of x-axis ticks

tick_locs = i_list[::-1]
ax1.set_yticks(tick_locs)
ax1.set_yticklabels(tick_lbls[::-1])
ax1.tick_params(labelsize=ts)

ax1.set_ylim(i_list[-1]-0.5, i_list[0]+0.5)

ax1.set_xlabel(r'wfit', fontsize=fs)
ax2.set_xlabel(r'StanIa', fontsize=fs)
ax2.tick_params(labelsize=ts)
ax3.set_xlabel('FOM3', fontsize=fs)
ax3.tick_params(labelsize=ts)
ax4.set_xlabel('FM Fractional Difference', fontsize=fs)
ax4.tick_params(labelsize=ts)
ax5.set_xlabel('EMD', fontsize=fs)
ax5.tick_params(labelsize=ts)
ax6.set_xlabel('log(KLD)', fontsize=fs)
ax6.tick_params(labelsize=ts)

plt.subplots_adjust(bottom=0.15, wspace=0.3) # wspace=0.05

if field == 'WFD':
    ax1.set_xlim(-1.35, -0.95)
    ax2.set_xlim(-1.35, -0.95)
    ax3.set_xlim(-0.05, 1.05)
    ax4.set_xlim(-0.075, 0.35)
    ax5.set_xlim(-0.01, 0.4)
    ax6.set_xlim(-0.1, 5)
    
    ticks = [-4, -10, -15, -21]
    
    for ax in axes:
        for t in ticks:
            yticks = ax.yaxis.get_major_ticks()
            yticks[t].set_visible(False)
        ax.axvspan(-10, 2e10, ymin=0.65, ymax=0.87, alpha=0.08, color='tab:purple')
        ax.axvspan(-10, 2e10, ymin=0.25, ymax=0.45, alpha=0.08, color='tab:purple')
        ax.axvspan(-10, 2e10, ymin=0.04, ymax=0.13, alpha=0.08, color='tab:purple')
    
else:
    ax1.set_xlim(-1.35, -0.95)
    ax2.set_xlim(-1.35, -0.9)
    ax3.set_xlim(-0.05, 1.05)
    ax4.set_xlim(-0.075, 0.35)
    ax5.set_xlim(-0.01, 0.4)
    ax6.set_xlim(-0.1, 5)
    
    ticks = [-4, -9, -13, -18]
    
    for ax in axes:
        for t in ticks:
            yticks = ax.yaxis.get_major_ticks()
            yticks[t].set_visible(False)
        ax.axvspan(-10, 5e10, ymin=0.5, ymax=0.815, alpha=0.08, color='tab:purple')
        ax.axvspan(-10, 5e10, ymin=0.05, ymax=0.275, alpha=0.08, color='tab:purple')

    
#plt.savefig('combined_metrics_' + field + '_all_versions.pdf', bbox_inches='tight')
plt.show()