In [1]:
import os
import sys
import numpy as np
import swyft
import pickle
import matplotlib.pyplot as plt
import torch
import importlib
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
torch.set_float32_matmul_precision('medium')
device_notebook = "cuda" if torch.cuda.is_available() else "cpu"
import wandb
import copy
from torch.multiprocessing import Pool
torch.multiprocessing.set_start_method('spawn',force=True)
torch.set_num_threads(28)
import itertools
import subprocess
from tqdm.auto import tqdm
sys.path.append('/home/gertwk/ALPs_with_SWYFT/analysis_scripts/ALP_sim')
from explim_functions import generate_expected_limits
import sympy as sy
from scipy.stats import norm, lognorm
stdnorm = norm()
from swyft.plot.mass import _get_jefferys_interval as interval
import random
import matplotlib.gridspec as gridspec
from matplotlib.colors import to_rgb
from matplotlib.ticker import MaxNLocator

In [2]:
main_dir = "ALPs_with_SWYFT"
thesis_figs = os.getcwd().split(main_dir)[0]+"/"+main_dir+"/thesis_figures/"

In [3]:
names = ['flare0_informed',]
colors_priors = ['r','#FFA500','y','g','b', ]

priors = {}
for ip, name in enumerate(names):

    priors[name] = {'name': name}

    priors[name]['results_path'] = '/home/gertwk/ALPs_with_SWYFT/cluster_runs/analysis_results/'+name

    priors[name]['store_path'] = priors[name]['results_path']+"/sim_output/store"

    priors[name]['config_vars'] = priors[name]['results_path'] +'/config_variables.pickle'

    priors[name]['config_phys'] = priors[name]['results_path'] +'/physics_variables.pickle'
    
    priors[name]['truncation_record'] = priors[name]['results_path'] +'/truncation_record.pickle'

    removed_ALP_sim=0
    try:
        sys.path.remove('/home/gertwk/ALPs_with_SWYFT/analysis_scripts/ALP_sim')
        removed_ALP_sim=1
    except ValueError:
        pass
    try:
        del sys.modules['ALP_quick_sim']
    except KeyError:
        pass
    sys.path.append(priors[name]['results_path'])
    import param_function
    import ALP_quick_sim
    with open(priors[name]['config_vars'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    with open(priors[name]['config_phys'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    with open(priors[name]['truncation_record'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    sys.path.remove(priors[name]['results_path'])
    sys.path.append(priors[name]['results_path']+'/train_output/net')
    import network
    sys.path.remove(priors[name]['results_path']+'/train_output/net')
    
    count = 0
    for combo in itertools.product(*priors[name]['hyperparams'].values()):
        if count == priors[name]['which_grid_point']:
            hyperparams_point = {}
            for i, key in enumerate(priors[name]['hyperparams'].keys()):
                hyperparams_point[key]=combo[i]
        count +=1

    priors[name]['net_path'] = {}
    priors[name]['net'] = {}
    for rnd in range(priors[name]['which_truncation']+1):
        round = 'round_'+str(rnd)
        priors[name]['net_path'][round] = (priors[name]['results_path'] + '/train_output/net/trained_network_'
                                                         +round+'_gridpoint_'+str(priors[name]['which_grid_point'])+'.pt')
        priors[name]['net'][round] = network.NetworkCorner(
            nbins=priors[name]['A'].nbins,
            marginals=priors[name]['POI_indices'],
            param_names=priors[name]['A'].param_names,
            **hyperparams_point,
        )
        priors[name]['net'][round].load_state_dict(torch.load(priors[name]['net_path'][round]))

    with open(priors[name]['results_path']+'/explim_predictions.pickle', 'rb') as file:
        priors[name]['predictions'] = pickle.load(file)

    if priors[name]['which_truncation'] > 0:
        store = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name']+"_round_"+str(priors[name]['which_truncation'])+"_gridpoint_"+str(priors[name]['which_grid_point']))
        store_explim = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name']+"_explim_round_"+str(priors[name]['which_truncation'])+"_gridpoint_"+str(priors[name]['which_grid_point']))
        store_prior = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name']+"_prior_round_"+str(priors[name]['which_truncation'])+"_gridpoint_"+str(priors[name]['which_grid_point']))
    else:
        store = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name'])
        store_explim = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name']+"_explim")
        store_prior = swyft.ZarrStore(priors[name]['store_path'] + "/" + priors[name]['store_name']+"_prior")
    priors[name]['samples'] = store.get_sample_store()
    priors[name]['samples_explim'] = store_explim.get_sample_store()
    priors[name]['samples_prior'] = store_prior.get_sample_store()
 
    del sys.modules['param_function']
    del sys.modules['ALP_quick_sim']
    del sys.modules['network']
    if removed_ALP_sim: sys.path.append('/home/gertwk/ALPs_with_SWYFT/analysis_scripts/ALP_sim')




In [4]:
trainer = swyft.SwyftTrainer(accelerator = 'cuda', precision = 64,logger=False,)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
sys.path.append('/home/gertwk/ALPs_with_SWYFT/analysis_scripts/ALP_sim')

In [6]:
def convert_pair_to_index(pair,n_indices):
    pair = sorted(pair)
    return int((pair[0]+1)*(n_indices-1+n_indices-pair[0]-1)/2 - n_indices + pair[1])

In [7]:
def weight(exp,n_bins):
    x = np.linspace(-1,1,n_bins)
    return 0.5+0.5*np.cos(np.pi*np.sign(x)*np.abs(x)**exp)

In [8]:
def blend(color1, color2, amount=0.5):
    return tuple(np.array(color1)*amount + np.array(color2)*(1-amount))

In [9]:
def p_to_z(x):
    return stdnorm.ppf(0.5+x/2)

def z_to_p(x):
    return stdnorm.cdf(x)-stdnorm.cdf(-x)

In [10]:
try:
    del sys.modules['DRP_test']
except KeyError:
    pass
try:
    del sys.modules['reference_functions']
except KeyError:
    pass
from DRP_test import get_drp_coverage, get_drp_coverage_torch, draw_DRP_samples_fast
from reference_functions import References
R = References()
references2D = R.references2D

In [11]:
round_colors = ['r','y','g','b']

In [12]:
storage_location = '/storage/gertwk/ALPs_with_SWYFT/notebooks/thesis_results'
storage_identifier = 'DRP_plots_new_references-Copy1-6'

In [13]:
%%time
n_samps = 500_000
n_prior_samps=20_173
n_refs = 100
which_truncation = 3

name = names[0]
samples = priors[name]['samples'][-n_samps:]
prior_samples = priors[name]['samples_prior'][:n_prior_samps]
which_truncation = priors[name]['which_truncation']
which_grid_point = priors[name]['which_grid_point']
POIs = priors[name]['POI_indices']
A = priors[name]['A']
bounds = np.array(priors[name]['bounds_rounds'][which_grid_point][which_truncation])
len_samps = len(samples)
len_prior_samps = len(prior_samples)

CPU times: user 27.2 ms, sys: 0 ns, total: 27.2 ms
Wall time: 223 µs


In [14]:
%%time
overwrite_DRP_storage = 0
if not os.path.exists(storage_location+'/'+storage_identifier): os.mkdir(storage_location+'/'+storage_identifier)
for rnd in range(which_truncation+1):
    round = 'round_'+str(rnd)
    filename = storage_location+'/'+storage_identifier+'/DRP_samples_'+str(len_samps)+'_'+str(len_prior_samps)+'_'+round
    if not os.path.exists(filename) or overwrite_DRP_storage:
        DRP_coverage_samples = {}
        DRP_coverage_samples['draws1d'],DRP_coverage_samples['draws2d'],DRP_coverage_samples['weights1d'],DRP_coverage_samples['weights2d'] = draw_DRP_samples_fast(
            priors[name]['net'][round],
            trainer,
            samples,
            prior_samples,
            batch_size = 1024*4,
        )
        with open(filename,'wb') as file: pickle.dump(DRP_coverage_samples,file)  
        del DRP_coverage_samples

Predicting: 0it [00:00, ?it/s]

TypeError: 'NoneType' object is not subscriptable

In [None]:
%%time
overwrite_HPD_storage = 0
if not os.path.exists(storage_location+'/'+storage_identifier): os.mkdir(storage_location+'/'+storage_identifier)
for rnd in range(which_truncation+1):
    round = 'round_'+str(rnd)
    filename = storage_location+'/'+storage_identifier+'/HPD_samples_'+str(len_samps)+'_'+str(len_prior_samps)+'_'+round
    if not os.path.exists(filename) or overwrite_HPD_storage:
        HPD_coverage_samples = trainer.test_coverage(priors[name]['net'][round], samples, prior_samples)
        with open(filename,'wb') as file: pickle.dump(HPD_coverage_samples,file)  
        del HPD_coverage_samples



Predicting: 0it [00:00, ?it/s]

In [None]:
%%time
references_1d = [
    references2D(samples)[0][:,[0]] for _ in range(n_refs)
]

references_2d = [
    references2D(samples)[0] for _ in range(n_refs)
]

In [None]:
%%time

overwrite_DRP_draws = 0

ecp_pp = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
alpha_pp = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
ecp_zz = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
alpha_zz = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
f_pp = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
f_zz = { 'round_'+str(rnd) : [[{} for ref_list in references_1d],[{} for ref_list in references_2d]] for rnd in range(which_truncation+1) }
validation_sums = { 'round_'+str(rnd) : [{},{}] for rnd in range(which_truncation+1) }
rows = len(POIs)

for rnd in range(which_truncation+1):
    round = 'round_'+str(rnd)
    len_samps = len(samples)
    len_prior_samps = len(prior_samples)
    filename = storage_location+'/'+storage_identifier+'/DRP_samples_'+str(len_samps)+'_'+str(len_prior_samps)+'_'+round
    with open(filename,'rb') as file: coverage_samples = pickle.load(file)
    keys_1d = list(coverage_samples['draws1d'].keys())
    keys_2d = list(coverage_samples['draws2d'].keys())
    filename_draws = storage_location+'/'+storage_identifier+'/DRP_draws_'+str(len_samps)+'_'+str(len_prior_samps)+'_'+round

    
    if os.path.exists(filename_draws) and not overwrite_DRP_draws: 
        with open(filename_draws,'rb') as file: DRP_draws = pickle.load(file)
        ecp_pp[round] = DRP_draws['ecp_pp']
        ecp_zz[round] = DRP_draws['ecp_zz']
        alpha_pp[round] = DRP_draws['alpha_pp']
        alpha_zz[round] = DRP_draws['alpha_zz']
        f_pp[round] = DRP_draws['f_pp']
        f_zz[round] = DRP_draws['f_zz']
        validation_sums[round] = DRP_draws['f_score']
    else:
        for i, key in enumerate(keys_1d):
            draws = coverage_samples['draws1d'][key]
            samps = samples['params'][:,[POIs[i]]]
            weights = coverage_samples['weights1d'][key]
            
            for ref_i in range(len(references_1d)):
                random_indices = random.sample(list(np.arange(len(samples))), n_samps)
                random_prior_indices = random.sample(list(np.arange(len(prior_samples))), n_prior_samps)
        
                ecp_pp[round][0][ref_i][key], alpha_pp[round][0][ref_i][key], ecp_zz[round][0][ref_i][key], alpha_zz[round][0][ref_i][key], f_pp[round][0][ref_i][key], f_zz[round][0][ref_i][key], f_score, _ = get_drp_coverage_torch(
                    draws[random_prior_indices][:,random_indices,:],
                    samps[random_indices],
                    weights = weights[random_prior_indices][:,random_indices],
                    theta_names=A.param_names[POIs[i]],
                    bounds = np.array(bounds)[[POIs[i]]],
                    references = references_1d[ref_i][random_indices],
                    device='cuda',
                    intermediate_figures=False,
                )
    
                if ref_i == 0: validation_sums[round][0][key] = 0
                validation_sums[round][0][key] += f_score/n_refs
                
        row = 0
        column = 0
        for i, key in enumerate(keys_2d):
            row+=1
            if row >= rows:
                column+=1
                row = 1+column  
        
            draws = coverage_samples['draws2d'][key]
            samps = samples['params'][:,[column,row]]
            weights = coverage_samples['weights2d'][key]
        
            for ref_i in range(len(references_2d)): 
                random_indices = random.sample(list(np.arange(len(samples))), n_samps)
                random_prior_indices = random.sample(list(np.arange(len(prior_samples))), n_prior_samps)
            
                ecp_pp[round][1][ref_i][key], alpha_pp[round][1][ref_i][key], ecp_zz[round][1][ref_i][key], alpha_zz[round][1][ref_i][key], f_pp[round][1][ref_i][key],f_zz[round][0][ref_i][key], f_score, _ = get_drp_coverage_torch(
                    draws[random_prior_indices][:,random_indices,:],
                    samps[random_indices],
                    weights = weights[random_prior_indices][:,random_indices],
                    theta_names=np.array(A.param_names)[[column,row]],
                    bounds = np.array(bounds)[[column,row]],
                    references = references_2d[ref_i][random_indices],
                    device='cuda',
                    intermediate_figures=False,
                )
    
                if ref_i == 0: validation_sums[round][1][key] = 0
                validation_sums[round][1][key] += f_score/n_refs

        DRP_draws = {
            'ecp_pp': ecp_pp[round],
            'ecp_zz': ecp_zz[round],
            'alpha_pp': alpha_pp[round],
            'alpha_zz': alpha_zz[round],
            'f_pp': f_pp[round],
            'f_zz': f_zz[round],
            'f_score': validation_sums[round],
        }
        with open(filename_draws,'wb') as file: pickle.dump(DRP_draws,file)

for i, key in enumerate(keys_1d):
    print(key)
    for rnd in range(which_truncation+1):
        print(validation_sums['round_'+str(rnd)][0][key])
    print()

for i, key in enumerate(keys_2d):
    print(key)
    for rnd in range(which_truncation+1):
        print(validation_sums['round_'+str(rnd)][1][key])
    print()



In [None]:
%%time

# CONFIGURATION
label_size = 15
label_pad = 0
tick_size = 10
x_tick_pad = 1
y_tick_pad = -1
x_tick_rotation = 45
y_tick_rotation = 45
max_y_ticks = 4
max_z =3
significance1 = 5/n_samps
significance2 = 1/n_samps
blend_amount = 0.5
opacity = 0.25
HPD_residuals = True
legend_loc = (0,2.1)

automatic_rel_ticks = True

x_axis_pp = "$1-\\alpha$"
y_axis_pp = 'EC' #"$\mathrm{ECP}$"
ticks_pp = [0.5,0.68,0.95]

x_axis_zz = "$Z_{1-\\alpha}$"
y_axis_zz = "$Z_\mathrm{EC}$"
ticks_zz = [1,2,3]

x_axis_rel = "$1-\\alpha$"
y_axis_rel = "$\mathrm{EC}$ residuals"
x_ticks_rel = ticks_pp
y_ticks_rel = [-0.1,1,0.1]

x_axis_rel_zz = "$Z_{1-\\alpha}$"
y_axis_rel_zz = "$Z_\mathrm{EC}$ residuals"
x_ticks_rel_zz = ticks_zz
y_ticks_rel_zz = [-1,0,1]

POI_names = ['$m_a$', '$g_{a \\gamma}$', 'Spectral Amplitude', 'Spectral Index', 'Cut-off Energy']

x_axis_list = [x_axis_pp,x_axis_zz,x_axis_rel,x_axis_rel_zz,x_axis_pp,x_axis_zz]
y_axis_list = [y_axis_pp,y_axis_zz,y_axis_rel,y_axis_rel_zz,y_axis_pp,y_axis_zz]
x_tick_list = [ticks_pp,ticks_zz,ticks_pp,ticks_zz,ticks_pp,ticks_zz]

if not HPD_residuals:
    y_tick_list = [ticks_pp,ticks_zz,y_ticks_rel,y_ticks_rel_zz,y_ticks_rel,y_ticks_rel_zz]
else:
    y_tick_list = [ticks_pp,ticks_zz,y_ticks_rel,y_ticks_rel_zz,x_ticks_rel,x_ticks_rel_zz]

adjusted_colors = [blend(to_rgb(col),(1,1,1),amount=blend_amount) for col in round_colors[:-1]]
adjusted_colors.append(round_colors[-1])

rows = len(POIs)


DRP_fig_pp = plt.figure(figsize = (12, 12))
DRP_fig_zz = plt.figure(figsize = (12, 12))
DRP_fig_rel = plt.figure(figsize = (12, 12))
DRP_fig_rel_zz = plt.figure(figsize = (12, 12))
DRP_fig_HPD_pp = plt.figure(figsize = (12, 12))
DRP_fig_HPD_zz = plt.figure(figsize = (12, 12))

fig_list = [DRP_fig_pp,DRP_fig_zz,DRP_fig_rel,DRP_fig_rel_zz, DRP_fig_HPD_pp,DRP_fig_HPD_zz]


DRP_fig_pp.subplots_adjust(hspace=0.1, wspace=0.1)
DRP_fig_zz.subplots_adjust(hspace=0.1, wspace=0.1)
DRP_fig_rel.subplots_adjust(hspace=0.1, wspace=0.25)
DRP_fig_rel_zz.subplots_adjust(hspace=0.1, wspace=0.25)
if HPD_residuals:
    DRP_fig_HPD_pp.subplots_adjust(hspace=0.1, wspace=0.25)
    DRP_fig_HPD_zz.subplots_adjust(hspace=0.1, wspace=0.25)
else:
    DRP_fig_HPD_pp.subplots_adjust(hspace=0.1, wspace=0.1)
    DRP_fig_HPD_zz.subplots_adjust(hspace=0.1, wspace=0.1)

#LOADING HPD COVERAGE SAMPLES
for rnd in range(which_truncation+1):
    round = 'round_'+str(rnd)
    filename = storage_location+'/'+storage_identifier+'/HPD_samples_'+str(len_samps)+'_'+str(len_prior_samps)+'_'+round
    with open(filename,'rb') as file: coverage_samples[round] = pickle.load(file)

#ITERATION OVER FIGURES AND SUBFIGURES
row = -1
column = 0
index_1d=-1
index_2d=-1
make_extra_plot = False
gs = gridspec.GridSpec(5,5)
i = -1
while i < len(keys_1d)+len(keys_2d)-1:
    i+=1
    if not make_extra_plot:
        row+=1
        if row == rows:
            column+=1
            row = column
        if row == column:
            dims = 0
            index_1d += 1
            key = keys_1d[index_1d]
        else:
            dims = 1
            index_2d += 1
            key = keys_2d[index_2d]

        for fig_i, fig in enumerate(fig_list): fig.add_subplot(rows, rows, rows*row+column+1)
            
    else:
        for fig_i, fig in enumerate(fig_list): fig.add_subplot(gs[:-3,3:])

    print(str(i)+'/'+str(len(keys_1d)+len(keys_2d))+' ('+str(row)+','+str(column)+')',flush=True,end='\r')

    # PLOTTING OF COVERAGES
    
    for rnd in range(which_truncation+1):
        round = 'round_'+str(rnd)
        for ref_i in range(len(references_2d)):
            ecp_ex_pp = np.zeros(len(ecp_pp[round][dims][ref_i][key])+1)
            alpha_ex_pp = np.zeros(len(alpha_pp[round][dims][ref_i][key])+1)
            ecp_ex_zz = np.zeros(len(ecp_zz[round][dims][ref_i][key])+1)
            alpha_ex_zz_orig = np.zeros(len(alpha_zz[round][dims][ref_i][key])+1)
            ecp_ex_pp[1:] = ecp_pp[round][dims][ref_i][key]
            alpha_ex_pp[1:] = alpha_pp[round][dims][ref_i][key]
            ecp_ex_zz[1:] = ecp_zz[round][dims][ref_i][key]
            alpha_ex_zz_orig[1:] = alpha_zz[round][dims][ref_i][key]
            alpha_ex_zz = alpha_ex_zz_orig[alpha_ex_zz_orig<=max_z]
            ecp_ex_zz = ecp_ex_zz[alpha_ex_zz_orig<=max_z]
            if rnd < which_truncation:
                label = 'Range after truncation #' + str(rnd) if ref_i == 0 else None
                DRP_fig_pp.axes[-1].fill_between(alpha_ex_pp,ecp_ex_pp,alpha_ex_pp, color=adjusted_colors[rnd],label=label)
                DRP_fig_zz.axes[-1].fill_between(alpha_ex_zz,ecp_ex_zz,alpha_ex_zz, color=adjusted_colors[rnd],label=label)
                DRP_fig_rel.axes[-1].fill_between(alpha_ex_pp,(ecp_ex_pp-alpha_ex_pp),np.zeros(len(alpha_ex_pp)), color=adjusted_colors[rnd],label=label)
                DRP_fig_rel_zz.axes[-1].fill_between(alpha_ex_zz,(ecp_ex_zz-alpha_ex_zz),np.zeros(len(alpha_ex_zz)), color=adjusted_colors[rnd],label=label)  
            else:
                label = 'Final coverages' if ref_i == 0 else None
                DRP_fig_pp.axes[-1].plot(alpha_ex_pp, ecp_ex_pp, round_colors[rnd],alpha=opacity,label=label)
                DRP_fig_zz.axes[-1].plot(alpha_ex_zz, ecp_ex_zz, round_colors[rnd],alpha=opacity,label=label)
                DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (ecp_ex_pp-alpha_ex_pp), round_colors[rnd],alpha=opacity,label=label)
                DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (ecp_ex_zz-alpha_ex_zz), round_colors[rnd],alpha=opacity,label=label)
        label = 'Coverage with error after truncation #' + str(rnd)
        if dims == 0:
            swyft.plot_pp(coverage_samples[round], key,ax = DRP_fig_HPD_pp.axes[-1],color=adjusted_colors[rnd],interval_opacity=0.25,label=label,interval_label=None,x_label=None,y_label=None,diagonal_color=None,residuals=True)
            swyft.plot_zz(coverage_samples[round], key,ax = DRP_fig_HPD_zz.axes[-1],color=adjusted_colors[rnd],interval_opacity=0.25,sigma_color=None,label=label,interval_label=None,x_label=None,y_label=None,diagonal_color=None,residuals=True,z_max=max_z)
        else:
            swyft.plot_pp(coverage_samples[round], eval(key.replace(' ',',')),ax = DRP_fig_HPD_pp.axes[-1],color=adjusted_colors[rnd],interval_opacity=0.25,label=label,interval_label=None,x_label=None,y_label=None,diagonal_color=None,residuals=HPD_residuals)
            swyft.plot_zz(coverage_samples[round], eval(key.replace(' ',',')),ax = DRP_fig_HPD_zz.axes[-1],color=adjusted_colors[rnd],interval_opacity=0.25,sigma_color=None,label=label,interval_label=None,x_label=None,y_label=None,diagonal_color=None,residuals=HPD_residuals,z_max=max_z)
    
    
    #CONFIGURATION OF INDIVIDUAL AXES
    
    for fig_i, fig in enumerate(fig_list): 
        fig.axes[-1].set_xticks(x_tick_list[fig_i])
        fig.axes[-1].set_xticklabels([])
        fig.axes[-1].set_yticks(y_tick_list[fig_i])
        fig.axes[-1].set_yticklabels([])

    if automatic_rel_ticks:
        DRP_fig_rel.axes[-1].yaxis.set_major_locator(MaxNLocator(nbins=max_y_ticks))
        DRP_fig_rel_zz.axes[-1].yaxis.set_major_locator(MaxNLocator(nbins=max_y_ticks))
        DRP_fig_HPD_pp.axes[-1].yaxis.set_major_locator(MaxNLocator(nbins=max_y_ticks))
        DRP_fig_HPD_zz.axes[-1].yaxis.set_major_locator(MaxNLocator(nbins=max_y_ticks))

        DRP_fig_rel.axes[-1].yaxis.set_major_formatter(plt.ScalarFormatter(None))
        DRP_fig_rel_zz.axes[-1].yaxis.set_major_formatter(plt.ScalarFormatter(None))
        DRP_fig_HPD_pp.axes[-1].yaxis.set_major_formatter(plt.ScalarFormatter(None))
        DRP_fig_HPD_zz.axes[-1].yaxis.set_major_formatter(plt.ScalarFormatter(None))

        DRP_fig_pp.axes[-1].set_yticks(ticks_pp)
        DRP_fig_zz.axes[-1].set_yticks(ticks_zz)
        
    if row==column:
        for fig_i, fig in enumerate(fig_list): fig.axes[-1].set_title(POI_names[row])
  
    if row==rows-1:
        for fig_i, fig in enumerate(fig_list): 
            fig.axes[-1].tick_params(axis='x',labelsize=tick_size,pad=x_tick_pad,rotation=x_tick_rotation)
            # fig.axes[-1].set_xticks(x_tick_list[fig_i])
            fig.axes[-1].set_xticklabels(x_tick_list[fig_i])
            fig.axes[-1].set_xlabel(x_axis_list[fig_i], fontsize=label_size, labelpad=label_pad)
        
    if column == 0:
        DRP_fig_pp.axes[-1].set_yticklabels(ticks_pp)
        DRP_fig_zz.axes[-1].set_yticklabels(ticks_zz)
        for fig_i, fig in enumerate(fig_list): 
            fig.axes[-1].set_ylabel(y_axis_list[fig_i], fontsize=label_size, labelpad=label_pad)
            if not automatic_rel_ticks: fig.axes[-1].set_yticklabels(y_tick_list[fig_i])
            
    for fig_i, fig in enumerate(fig_list):
        fig.axes[-1].tick_params(axis='y',labelsize=tick_size,rotation=y_tick_rotation, pad=y_tick_pad)

    uncertainty1 = interval((alpha_ex_pp*n_samps).astype(int),n_samps,alpha = significance1)
    upper_uncertainty1 = uncertainty1[:,0]
    lower_uncertainty1 = uncertainty1[:,1]
    uncertainty2 = interval((alpha_ex_pp*n_samps).astype(int),n_samps,alpha = significance2)
    upper_uncertainty2 = uncertainty2[:,0]
    lower_uncertainty2 = uncertainty2[:,1]
    uncertainty1_zz = interval((z_to_p(alpha_ex_zz)*n_samps).astype(int),n_samps,alpha=significance1)
    upper_uncertainty1_zz = p_to_z(uncertainty1_zz[:,1])
    lower_uncertainty1_zz = p_to_z(uncertainty1_zz[:,0])
    uncertainty2_zz = interval((z_to_p(alpha_ex_zz)*n_samps).astype(int),n_samps,alpha=significance2)
    upper_uncertainty2_zz = p_to_z(uncertainty2_zz[:,1])
    lower_uncertainty2_zz = p_to_z(uncertainty2_zz[:,0])

    DRP_fig_pp.axes[-1].plot(alpha_ex_pp, alpha_ex_pp,'k-',label='$ECP=1-\\alpha$')
    DRP_fig_pp.axes[-1].plot(alpha_ex_pp, upper_uncertainty1,'k--', label='Significance = {:1g}'.format(significance1))
    DRP_fig_pp.axes[-1].plot(alpha_ex_pp, lower_uncertainty1,'k--')
    DRP_fig_pp.axes[-1].plot(alpha_ex_pp, upper_uncertainty2,'k:', label='Significance = {:1g}'.format(significance2))
    DRP_fig_pp.axes[-1].plot(alpha_ex_pp, lower_uncertainty2,'k:')

    DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (alpha_ex_pp-alpha_ex_pp), 'k-',label='$ECP=1-\\alpha$')
    DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (upper_uncertainty1-alpha_ex_pp), 'k--',label='Significance '+str(significance1))
    DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (lower_uncertainty1-alpha_ex_pp), 'k--')
    DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (upper_uncertainty2-alpha_ex_pp), 'k:',label='Significance '+str(significance2))
    DRP_fig_rel.axes[-1].plot(alpha_ex_pp, (lower_uncertainty2-alpha_ex_pp), 'k:')

    DRP_fig_zz.axes[-1].plot(alpha_ex_zz, alpha_ex_zz, 'k-',label='$ECP=1-\\alpha$')
    DRP_fig_zz.axes[-1].plot(alpha_ex_zz, upper_uncertainty1_zz,'k--',label='Significance = {:1g}'.format(significance1))
    DRP_fig_zz.axes[-1].plot(alpha_ex_zz, lower_uncertainty1_zz,'k--')
    DRP_fig_zz.axes[-1].plot(alpha_ex_zz, upper_uncertainty2_zz,'k:',label='Significance = {:1g}'.format(significance2))
    DRP_fig_zz.axes[-1].plot(alpha_ex_zz, lower_uncertainty2_zz,'k:')
    # DRP_fig_zz.axes[-1].set_ylim([0,max_z+1])
    
    DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (alpha_ex_zz-alpha_ex_zz), 'k-',label='$ECP=1-\\alpha$')
    DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (upper_uncertainty1_zz-alpha_ex_zz), 'k--',label='Significance = {:1g}'.format(significance1))
    DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (lower_uncertainty1_zz-alpha_ex_zz), 'k--')
    DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (upper_uncertainty2_zz-alpha_ex_zz), 'k:',label='Significance = {:1g}'.format(significance2))
    DRP_fig_rel_zz.axes[-1].plot(alpha_ex_zz, (lower_uncertainty2_zz-alpha_ex_zz), 'k:')
    # DRP_fig_rel_zz.axes[-1].set_ylim([-0.5*max_z,0.5*max_z])
  
    if HPD_residuals:
        DRP_fig_HPD_pp.axes[-1].plot(alpha_ex_pp, (alpha_ex_pp-alpha_ex_pp), 'k-',label='$ECP=1-\\alpha$')
        DRP_fig_HPD_zz.axes[-1].plot(alpha_ex_zz, (alpha_ex_zz-alpha_ex_zz), 'k-',label='$ECP=1-\\alpha$')
        DRP_fig_HPD_zz.axes[-1].set_ylim([-0.5*max_z,0.5*max_z])
    else:
        DRP_fig_HPD_pp.axes[-1].plot(alpha_ex_pp, alpha_ex_pp, 'k-',label='$ECP=1-\\alpha$')
        DRP_fig_HPD_zz.axes[-1].plot(alpha_ex_zz, alpha_ex_zz, 'k-',label='$ECP=1-\\alpha$')
        DRP_fig_HPD_zz.axes[-1].set_ylim([0,max_z])


    # DEALING WITH EXTRA AXIS OUTSIDE CORNER PLOT
    
    if row==1 and column==0:
        if not make_extra_plot:
            i -= 1
            make_extra_plot=True
        else:
            make_extra_plot=False

            extra_title = 'Combined coverage for ('+POI_names[0]+','+POI_names[1]+')'
            extra_title_size = 15
            extra_legend_size = 10
            extra_legend_loc = legend_loc

            for fig_i, fig in enumerate(fig_list):
                fig.axes[-1].set_xticks(x_tick_list[fig_i])
                fig.axes[-1].set_xticklabels(x_tick_list[fig_i])
                fig.axes[-1].tick_params(axis='x',labelsize=tick_size,pad=x_tick_pad)
                fig.axes[-1].tick_params(axis='y',labelsize=tick_size,pad=y_tick_pad)
                fig.axes[-1].set_xlabel(x_axis_list[fig_i], fontsize=label_size, labelpad=label_pad)
                fig.axes[-1].set_ylabel(y_axis_list[fig_i], fontsize=label_size, labelpad=label_pad)
                fig.axes[-1].set_title(extra_title, fontsize=extra_title_size)
                
           
            # DRP_fig_pp.axes[-1].set_yticks(ticks_pp)
            DRP_fig_pp.axes[-1].set_yticklabels(ticks_pp)
            # DRP_fig_zz.axes[-1].set_yticks(ticks_zz)
            DRP_fig_zz.axes[-1].set_yticklabels(ticks_zz)
            if not HPD_residuals:
                DRP_fig_HPD_pp.axes[-1].set_yticklabels(ticks_pp)
                DRP_fig_HPD_zz.axes[-1].set_yticklabels(ticks_zz)
    
    elif row==1 and column==1:
        extra_title_size = 15
        extra_legend_size = 10
        for fig_i, fig in enumerate(fig_list):
            fig.axes[-1].legend(prop={'size': extra_legend_size}, loc = "upper left", bbox_to_anchor=extra_legend_loc)
  

In [None]:
testfig = plt.figure(figsize = (12, 6))
testfig.add_subplot(1,2,1)
testfig.axes[0].plot([0,1],[0,1], color = (1,0.5,0), linestyle='--')
testfig.axes[0].fill_between([0,1],[0,1],[0,0.5], label = 'yoop')
testfig.axes[0].fill_between([0,1],[0,-1])
testfig.axes[0].set_xlabel("Empirical $1-\\alpha$", fontsize=10,labelpad=0)
testfig.axes[0].set_xlabel(None)
testfig.axes[0].set_ylabel("$\mathrm{ECP}\;/\;\Delta \mathrm{ECP}$", fontsize=10, labelpad=0)
testfig.axes[0].tick_params(axis='both',labelsize=10,pad=1)
testfig.axes[0].tick_params(axis='y',rotation=45, pad=-1)
testfig.axes[0].legend(loc='upper left', bbox_to_anchor=(0,1.1))
testfig.add_subplot(1,2,2)
for rnd in range(which_truncation+1):
    round = 'round_'+str(rnd)
    # swyft.plot_pp(coverage_samples[round], eval(keys_2d[0].replace(' ',',')),ax = testfig.axes[1])
    swyft.plot_pp(coverage_samples[round], keys_1d[0],ax = testfig.axes[1],diagonal_color='g',color=adjusted_colors[rnd])
testfig.tight_layout(w_pad = -1)
testfig.axes[0].set_yticks([0,0.2,0.4])
testfig.axes[0].set_yticklabels([0,0.2,0.4])
testfig.axes[0].yaxis.set_major_locator(plt.AutoLocator())
testfig.axes[0].yaxis.set_major_formatter(plt.ScalarFormatter(None))

In [None]:
single_title = extra_title
single_title_size = 30
single_label_size = 25
single_tick_size = 20
single_x_tick_pad = 5
single_y_tick_pad = 1
legend_size=15

for ax in DRP_fig_pp.axes:
    ax.set_visible(False)
DRP_fig_pp.axes[1].set_visible(True)
DRP_fig_pp.axes[1].set_xticks([0.5,0.682,0.954])
DRP_fig_pp.axes[1].set_xticklabels([0.5,0.68,0.95])
DRP_fig_pp.axes[1].tick_params(axis='x',labelsize=single_tick_size,pad=single_x_tick_pad)
DRP_fig_pp.axes[1].tick_params(axis='y',labelsize=single_tick_size,pad=single_y_tick_pad)
DRP_fig_pp.axes[1].set_xlabel(x_axis_pp, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_pp.axes[1].set_ylabel(y_axis_pp, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_pp.axes[1].set_title(single_title, fontsize=single_title_size)
DRP_fig_pp.axes[1].legend(prop={'size': legend_size})
DRP_fig_pp.set_figheight(12*5)
DRP_fig_pp.set_figwidth(12*5)
DRP_fig_pp

In [None]:
for ax in DRP_fig_zz.axes:
    ax.set_visible(False)
DRP_fig_zz.axes[1].set_visible(True)
DRP_fig_zz.axes[1].set_xticks([1,2,3])
DRP_fig_zz.axes[1].set_xticklabels([1,2,3])
DRP_fig_zz.axes[1].set_yticks([1,2,3])
DRP_fig_zz.axes[1].set_yticklabels([1,2,3])
DRP_fig_zz.axes[1].tick_params(axis='x',labelsize=single_tick_size,pad=single_x_tick_pad)
DRP_fig_zz.axes[1].tick_params(axis='y',labelsize=single_tick_size,pad=single_y_tick_pad)
DRP_fig_zz.axes[1].set_xlabel(x_axis_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_zz.axes[1].set_ylabel(y_axis_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_zz.axes[1].set_title(single_title, fontsize=single_title_size)
DRP_fig_zz.axes[1].legend(prop={'size': legend_size})
DRP_fig_zz.set_figheight(12*5)
DRP_fig_zz.set_figwidth(12*5)
DRP_fig_zz

In [None]:
for ax in DRP_fig_rel.axes:
    ax.set_visible(False)
DRP_fig_rel.axes[1].set_visible(True)
DRP_fig_rel.axes[1].tick_params(axis='x',labelsize=single_tick_size,pad=single_x_tick_pad)
DRP_fig_rel.axes[1].tick_params(axis='y',labelsize=single_tick_size,pad=single_y_tick_pad)
DRP_fig_rel.axes[1].set_xlabel(x_axis_rel, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_rel.axes[1].set_ylabel(y_axis_rel, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_rel.axes[1].set_title(single_title, fontsize=single_title_size)
DRP_fig_rel.axes[1].legend(prop={'size': legend_size})
DRP_fig_rel.set_figheight(12*5)
DRP_fig_rel.set_figwidth(12*5)
DRP_fig_rel

In [None]:
for ax in DRP_fig_rel_zz.axes:
    ax.set_visible(False)
DRP_fig_rel_zz.axes[1].set_visible(True)
DRP_fig_rel_zz.axes[1].tick_params(axis='x',labelsize=single_tick_size,pad=single_x_tick_pad)
DRP_fig_rel_zz.axes[1].tick_params(axis='y',labelsize=single_tick_size,pad=single_y_tick_pad)
DRP_fig_rel_zz.axes[1].set_xlabel(x_axis_rel_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_rel_zz.axes[1].set_ylabel(y_axis_rel_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_rel_zz.axes[1].set_title(single_title, fontsize=single_title_size)
DRP_fig_rel_zz.axes[1].legend(prop={'size': legend_size})
DRP_fig_rel_zz.set_figheight(12*5)
DRP_fig_rel_zz.set_figwidth(12*5)
DRP_fig_rel_zz

In [None]:
for ax in DRP_fig_HPD_zz.axes:
    ax.set_visible(False)
DRP_fig_HPD_zz.axes[1].set_visible(True)
DRP_fig_HPD_zz.axes[1].tick_params(axis='x',labelsize=single_tick_size,pad=single_x_tick_pad)
DRP_fig_HPD_zz.axes[1].tick_params(axis='y',labelsize=single_tick_size,pad=single_y_tick_pad)
DRP_fig_HPD_zz.axes[1].set_xlabel(x_axis_rel_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_HPD_zz.axes[1].set_ylabel(y_axis_rel_zz, fontsize=single_label_size, labelpad=label_pad)
DRP_fig_HPD_zz.axes[1].set_title(single_title, fontsize=single_title_size)
DRP_fig_HPD_zz.axes[1].legend(prop={'size': legend_size})
DRP_fig_HPD_zz.set_figheight(12*5)
DRP_fig_HPD_zz.set_figwidth(12*5)
DRP_fig_HPD_zz.axes[1].set_ylim([-0.75,0.75])
DRP_fig_HPD_zz