In [1]:
import os
import sys
import numpy as np
import swyft
import pickle
import matplotlib.pyplot as plt
import torch
import importlib
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
torch.set_float32_matmul_precision('medium')
device_notebook = "cuda" if torch.cuda.is_available() else "cpu"
import wandb
import copy
from torch.multiprocessing import Pool
torch.multiprocessing.set_start_method('spawn',force=True)
torch.set_num_threads(28)
import itertools
import subprocess

In [2]:
main_dir = "ALPs_with_SWYFT"
thesis_figs = os.getcwd().split(main_dir)[0]+"/"+main_dir+"/thesis_figures/"

In [3]:
def _get_HDI_thresholds(x, cred_level=[0.68268, 0.95450, 0.99730]):
    x = x.flatten()
    x = np.sort(x)[::-1]  # Sort backwards
    total_mass = x.sum()
    enclosed_mass = np.cumsum(x)
    idx = [np.argmax(enclosed_mass >= total_mass * f) for f in cred_level]
    levels = np.array(x[idx])
    return levels
    
def generate_expected_limits(samples,
                             prior_samples,
                             bounds,
                             net = None,  
                             trainer = None,
                             contour_matrix = None,
                             predictions = None,
                             ax=None,
                             limit_credibility=0.9973,
                             levels = [0.003,0.05,0.34,0.682,0.95,0.9973,1],
                             fill=True,
                             bins=50,
                             batch_size = 1024,
                             param_names = ['m','g'],
                             colors = ['r','#FFA500','y','g','b','k'],
                             alpha = 0.5,
                             alpha_variable = False,
                            ):
    
    if isinstance(samples,int):
        n_limits = samples
    else:
        n_limits = len(samples)
        
    if isinstance(prior_samples, int):
        n_prior_samples = prior_samples
    else:
        n_prior_samples = len(prior_samples)

    
    if not np.any(predictions) and not np.any(contour_matrix):
        repeat = n_prior_samples // batch_size + (n_prior_samples % batch_size > 0)
        
        predictions = trainer.infer(
            net,
            samples.get_dataloader(batch_size=1,repeat=repeat),
            prior_samples.get_dataloader(batch_size=batch_size)
        )
        
    if not np.any(contour_matrix):
        for i in range(n_limits):
    
            predictions_i = copy.deepcopy(predictions)
    
            predictions_i[0].logratios = predictions[0].logratios[i*n_prior_samples:(i+1)*n_prior_samples]
            predictions_i[0].params = predictions[0].params[i*n_prior_samples:(i+1)*n_prior_samples]
            predictions_i[1].logratios = predictions[1].logratios[i*n_prior_samples:(i+1)*n_prior_samples]
            predictions_i[1].params = predictions[1].params[i*n_prior_samples:(i+1)*n_prior_samples]
    
            counts, _ = swyft.get_pdf(
                predictions_i,
                param_names,
                bins = bins,
            )
    
            if i==0:
                X,Y = np.meshgrid(np.linspace(0,counts.shape[0]-1,counts.shape[0]),np.linspace(0,counts.shape[1]-1,counts.shape[1]))
                matrix_total = np.zeros(X.shape)
    
            plt.figure('dummy')
            levels_limits=sorted(_get_HDI_thresholds(counts,cred_level=[0,limit_credibility]))
            limit_contour = plt.contourf(counts.T,levels=levels_limits)
            # plt.clf()
    
            matrix_i = np.ones(X.shape)
    
            for collection in limit_contour.collections:
                for path in collection.get_paths():
                    mask = path.contains_points(np.vstack((X.flatten(), Y.flatten())).T,radius=1e-9)
                    mask = mask.reshape(X.shape)
                    matrix_i[mask] = 0
            
            matrix_total += matrix_i
    
    else:
        matrix_total = contour_matrix

    if not ax:
        fig = plt.figure()
        fig.add_subplot(1,1,1)
        ax = fig.axes[0]
    
    for li in range(len(levels)-1):
        ax.contourf(matrix_total,
                    levels=[levels[li]*n_limits,levels[li+1]*n_limits],
                    extent=[bounds[0][0], bounds[0][1], bounds[1][0], bounds[1][1]],
                    colors = colors[li],
                    alpha = (li+2)*alpha/len(levels) if alpha_variable else alpha,
                   )

    plt.close('dummy')
    
    if not ax:
        return matrix_total, fig
    else:
        return matrix_total,ax

In [68]:
def convert_pair_to_index(pair,n_indices):
    pair = sorted(pair)
    return (pair[0]+1)*(n_indices-1+n_indices-pair[0]-1)/2 - n_indices + pair[1]

In [69]:
convert_pair_to_index([(4,2),(3,4)],5)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [46]:
4*(5-1+5-3-1)/2 - 5 + 4

9.0

In [None]:
4 + 3 + 2 + 1

In [74]:
priors['agnostic3']['predictions'][1].logratios[:,[2,4,7]].shape

torch.Size([100000, 3])

In [79]:
priors['agnostic3']['predictions'][1].params[:,[2,4,7],:]

tensor([[[-1.6014,  3.8724],
         [-0.8914, -8.1186],
         [-8.1186,  3.8724]],

        [[ 2.2802,  3.5610],
         [ 0.8959, -9.7811],
         [-9.7811,  3.5610]],

        [[-0.0467,  0.6703],
         [-0.0634, -8.7696],
         [-8.7696,  0.6703]],

        ...,

        [[-1.8627,  3.1752],
         [-1.1117, -8.4786],
         [-8.4786,  3.1752]],

        [[ 2.9359,  2.0395],
         [-1.5558, -9.7220],
         [-9.7220,  2.0395]],

        [[-0.9742,  1.0217],
         [ 0.8667, -9.2012],
         [-9.2012,  1.0217]]], dtype=torch.float64)

In [89]:
eval('(3,4))'[:-1])

(3, 4)

In [85]:
eval()

AttributeError: type object 'str' has no attribute 'concat'

In [81]:
'(3,4),(1,2)'.split('),')

['(3,4', '(1,2)']

In [None]:
sys.getsizeof(priors['agnostic3']['predictions'][1].logratios[:,0].to(torch.float32))

In [None]:
sys.getsizeof(priors['confident2']['predictions'])

In [6]:
import pympler

ModuleNotFoundError: No module named 'pympler'

In [100]:
priors['confident2']['predictions'][1].logratios[:,[3,4]].float()[0][0]

tensor(0.0537)

In [97]:
priors['confident2']['predictions'][1].logratios[:,[3,4]]

tensor([[ 0.0537,  0.2366],
        [ 0.0043,  0.2580],
        [ 0.0275,  0.2439],
        ...,
        [ 0.0499,  0.2460],
        [-0.0132,  0.2781],
        [ 0.0459,  0.2443]], dtype=torch.float64)

In [5]:
names = ['agnostic3','confident2']
colors_priors = ['r','#FFA500','y','g','b', ]

priors = {}
for ip, name in enumerate(names):

    priors[name] = {'name': name}

    priors[name]['results_path'] = '/home/gertwk/ALPs_with_SWYFT/cluster_runs/analysis_results/'+name

    priors[name]['config_vars'] = priors[name]['results_path'] +'/config_variables.pickle'

    priors[name]['config_phys'] = priors[name]['results_path'] +'/physics_variables.pickle'
    
    priors[name]['truncation_record'] = priors[name]['results_path'] +'/truncation_record.pickle'
    
    sys.path.append(priors[name]['results_path'])
    import param_function
    import ALP_quick_sim
    with open(priors[name]['config_vars'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    with open(priors[name]['config_phys'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    with open(priors[name]['truncation_record'], 'rb') as file: config_objects = pickle.load(file)
    for key in config_objects.keys(): priors[name][key] = config_objects[key]
    sys.path.remove(priors[name]['results_path'])
    sys.path.append(priors[name]['results_path']+'/train_output/net')
    import network
    sys.path.remove(priors[name]['results_path']+'/train_output/net')
    
    count = 0
    for combo in itertools.product(*priors[name]['hyperparams'].values()):
        if count == priors[name]['which_grid_point']:
            hyperparams_point = {}
            for i, key in enumerate(priors[name]['hyperparams'].keys()):
                hyperparams_point[key]=combo[i]
        count +=1
        
    priors[name]['net_path'] = (priors[name]['results_path'] + '/train_output/net/trained_network_round_'
                                +str(priors[name]['which_truncation'])+'_gridpoint_'+str(priors[name]['which_grid_point'])+'.pt')

    priors[name]['net'] = network.NetworkCorner(nbins=priors[name]['A'].nbins, marginals=priors[name]['POI_indices'], 
                                                param_names=priors[name]['A'].param_names, **hyperparams_point)

    priors[name]['net'].load_state_dict(torch.load(priors[name]['net_path']))

    with open('/home/gertwk/ALPs_with_SWYFT/cluster_runs/analysis_results/'+'test3'+'/explim_predictions.pickle', 'rb') as file:
        priors[name]['predictions'] = pickle.load(file)
    
    del sys.modules['param_function']
    del sys.modules['ALP_quick_sim']
    del sys.modules['network']




In [None]:
fig = plt.figure(figsize=(12,12))
for i in range(4):
    fig.add_subplot(2,2,i+1)
    name = names[i%2]
    fig.axes[-1].set_title(name)
    _ = generate_expected_limits(10, 10_000, bounds = [priors[name]['bounds'][0], priors[name]['bounds'][1]],
                                 predictions=priors[name]['predictions'], ax=fig.axes[-1])
    

fig.savefig(thesis_figs+'expected_limits_4_priors.pdf')

In [101]:
with open('/home/gertwk/ALPs_with_SWYFT/cluster_runs/analysis_results/'+'flare0_agnostic_mock2'+'/explim_predictions.pickle', 'rb') as file:
        preds = pickle.load(file)

In [106]:
preds[1].params.shape

torch.Size([1000000, 2, 2])