In [None]:
from __future__ import print_function, division
import os

from IPython import display
import sys
import importlib

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import math
import copy
import pickle
import corner
import jax

import matplotlib as mpl
from matplotlib import rcParams
from matplotlib.colors import Normalize
import matplotlib.cm as cm
from matplotlib.colors import LogNorm
from matplotlib import ticker
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.lines import Line2D

rcParams.update({'figure.autolayout': True})
plt.rc('font', family='serif')

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'png')
plt.rcParams['savefig.dpi'] = 75

plt.rcParams['figure.autolayout'] = False
plt.rcParams['figure.figsize'] = 10, 6
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 15
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
mpl.rc('axes',edgecolor='k')
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['ytick.color'] = 'k'

plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'

plt.rcParams['text.usetex'] = True

In [None]:
import sys
sys.path.append("..")

from utils import ed_fcts_amarel as ef
from utils import create_mask
from utils import ed_plotting as eplt

In [None]:
from astropy.io import fits
import astropy.units as u
import astropy
import healpy as hp

In [None]:
# load GPU
gpu_id = '2'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

from models.poissonian_gp import EbinPoissonModel

In [None]:
ebinmodel = EbinPoissonModel(
        # important parameters
        is_gp = False,
        data_file = 'fermi_data_sum',

        # default parameters
        nfw_gamma = 1.,
        blg_names = [ef.gen_blg_name_(i)[0] for i in range(5)],
        dif_names = ['gceNNo'],
        )

nside = ebinmodel.nside
ebin = ie = ebinmodel.ebin
data = ebinmodel.counts[ie]
mask_p = ebinmodel.mask_roi_arr[ie]

# load masks
mask_p, mask, outer_mask, total_mask, mask_40 = pickle.load(open('/data/edr76/gce-gp/figures/data/masks.p', 'rb'))

In [None]:
theta_list,phi_list = hp.pix2ang(nside, range(hp.nside2npix(nside)))
phi_list[phi_list>np.pi] = phi_list[phi_list>np.pi]-2*np.pi
theta_list = theta_list-np.pi/2

In [None]:
#xsize = int(np.sqrt(hp.nside2npix(nside)))+1
#ysize = int(xsize)

xsize = 800
ysize = int(xsize/2)
theta = np.linspace(np.pi, 0, ysize)
phi   = np.linspace(-np.pi, np.pi, xsize)
longitude = np.radians(np.linspace(-180, 180, xsize))
longitude = longitude[::-1]
latitude = np.radians(np.linspace(-90, 90, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI, THETA = np.meshgrid(phi, theta)

In [None]:
grid_pix = hp.ang2pix(nside, THETA, PHI)
grid_map = data[grid_pix]
vmax = 4.5

# Figure 1

In [None]:
fig = plt.figure(figsize=(12, 6), dpi= 120)
ax = fig.add_subplot(111,projection='mollweide')

ax.set_title('Data')
t = np.linspace(0, 2 * np.pi, 100)
l_list = 20./180*np.pi * np.cos(t)
b_list = 20./180*np.pi * np.sin(t)

ret = ax.pcolormesh(longitude, latitude, np.log10(grid_map),cmap='viridis',
              vmin=0,vmax=vmax)
ax.plot(l_list,b_list, color="k", ls = "-", lw = 1.)
ax.set_xlabel('$\ell~(^\circ)$')
ax.set_ylabel('$b~(^\circ)$')
ax.tick_params(axis='x', colors='k')
ax.set_longitude_grid(60)
ax.set_latitude_grid(30)
ax.grid()
ax.set_xticklabels([r'$120^\circ$',r'$60^\circ$',r'$0^\circ$',r'$-60^\circ$',r'$-120^\circ$'])

cb = fig.colorbar(ret, orientation='horizontal', shrink=.4, pad=0.10, ticks=[0, vmax])
cb.ax.xaxis.set_label_text(r'$\log_{10}({\rm counts})$')
cb.ax.xaxis.labelpad = -8

plt.savefig('figures/fig_allsky_data.png',format='png',bbox_inches='tight', dpi=300)

In [None]:
#xsize = int(np.sqrt(hp.nside2npix(nside)))+1
#ysize = int(xsize)

xsize = 400
ysize = int(xsize/2)
theta = np.linspace(np.pi*(0.5+20/180), np.pi*(0.5-20/180), ysize)
phi   = np.linspace(-np.pi/180*20, np.pi/180*20, xsize)
longitude = np.radians(np.linspace(-20, 20, xsize))
longitude = longitude[::-1]
latitude = np.radians(np.linspace(-20, 20, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI, THETA = np.meshgrid(phi, theta)

In [None]:
grid_pix = hp.ang2pix(nside, THETA, PHI)
grid_map = data[grid_pix]
vmax = 4.5

In [None]:
full_array = np.zeros(hp.nside2npix(nside))
full_array[~mask] = data[~mask]
grid_map = full_array[grid_pix]

In [None]:
full_array_ps = np.zeros(hp.nside2npix(nside))
full_array_ps[~mask_p] = data[~mask_p]
grid_map_ps = full_array_ps[grid_pix]

In [None]:
xsize = 400
ysize = int(xsize/2)
theta40 = np.linspace(np.pi*(0.5+40/180), np.pi*(0.5-40/180), ysize)
phi40   = np.linspace(-np.pi/180*40, np.pi/180*40, xsize)
longitude40 = np.radians(np.linspace(-40, 40, xsize))
longitude40 = longitude40[::-1]
latitude40 = np.radians(np.linspace(-40, 40, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI40, THETA40 = np.meshgrid(phi40, theta40)

In [None]:
grid_pix40 = hp.ang2pix(nside, THETA40, PHI40)
full_array_complete = np.zeros(hp.nside2npix(nside))
full_array_complete[~total_mask] = data[~total_mask]
grid_map_complete = full_array_complete[grid_pix40]

In [None]:
from matplotlib import text as mtext
from matplotlib import patches

class CurvedText(mtext.Text):
    """                                                                                                                                     
    A text object that follows an arbitrary curve.                                                                                          
    """
    def __init__(self, x, y, text, axes, **kwargs):
        super(CurvedText, self).__init__(x[0],y[0],' ', **kwargs)

        axes.add_artist(self)

        ##saving the curve:                                                                                                                 
        self.__x = x
        self.__y = y
        self.__zorder = self.get_zorder()

        ##creating the text objects                                                                                                         
        self.__Characters = []
        for c in text:
            if c == ' ':
                ##make this an invisible 'a':                                                                                               
                t = mtext.Text(0,0,'a')
                t.set_alpha(0.0)
            else:
                t = mtext.Text(0,0,c, **kwargs)

            #resetting unnecessary arguments                                                                                                
            t.set_ha('center')
            t.set_rotation(0)
            t.set_zorder(self.__zorder +1)

            self.__Characters.append((c,t))
            axes.add_artist(t)


    ##overloading some member functions, to assure correct functionality                                                                    
    ##on update                                                                                                                             
    def set_zorder(self, zorder):
        super(CurvedText, self).set_zorder(zorder)
        self.__zorder = self.get_zorder()
        for c,t in self.__Characters:
            t.set_zorder(self.__zorder+1)

    def draw(self, renderer, *args, **kwargs):
        """                                                                                                                                 
        Overload of the Text.draw() function. Do not do                                                                                     
        do any drawing, but update the positions and rotation                                                                               
        angles of self.__Characters.                                                                                                        
        """
        self.update_positions(renderer)

    def update_positions(self,renderer):
        """                                                                                                                                 
        Update positions and rotations of the individual text elements.                                                                     
        """        #preparations                                                                                                                       

        ##determining the aspect ratio:                                                                                                     
        ##from https://stackoverflow.com/a/42014041/2454357                                                                                 

        ##data limits                                                                                                                       
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        ## Axis size on figure                                                                                                              
        figW, figH = self.axes.get_figure().get_size_inches()
        ## Ratio of display units                                                                                                           
        _, _, w, h = self.axes.get_position().bounds
        ##final aspect ratio                                                                                                                
        aspect = ((figW * w)/(figH * h))*(ylim[1]-ylim[0])/(xlim[1]-xlim[0])

        #points of the curve in figure coordinates:                                                                                         
        x_fig,y_fig = (
            np.array(l) for l in zip(*self.axes.transData.transform([
            (i,j) for i,j in zip(self.__x,self.__y)
            ])))
                #point distances in figure coordinates                                                                                              
        x_fig_dist = (x_fig[1:]-x_fig[:-1])
        y_fig_dist = (y_fig[1:]-y_fig[:-1])
        r_fig_dist = np.sqrt(x_fig_dist**2+y_fig_dist**2)

        #arc length in figure coordinates                                                                                                   
        l_fig = np.insert(np.cumsum(r_fig_dist),0,0)
        #angles in figure coordinates                                                                                                       
        rads = np.arctan2((y_fig[1:] - y_fig[:-1]),(x_fig[1:] - x_fig[:-1]))
        degs = np.rad2deg(rads)


        rel_pos = 10
        for c,t in self.__Characters:
            #finding the width of c:                                                                                                        
            t.set_rotation(0)
            t.set_va('center')
            bbox1  = t.get_window_extent(renderer=renderer)
            w = bbox1.width
            h = bbox1.height

            #ignore all letters that don't fit:                                                                                             
            if rel_pos+w/2 > l_fig[-1]:
                t.set_alpha(0.0)
                rel_pos += w
                continue

            elif c != ' ':
                t.set_alpha(1.0)
            #finding the two data points between which the horizontal                                                                       
            #center point of the character will be situated                                                                                 
            #left and right indices:                                                                                                        
            il = np.where(rel_pos+w/2 >= l_fig)[0][-1]
            ir = np.where(rel_pos+w/2 <= l_fig)[0][0]

            #if we exactly hit a data point:                                                                                                
            if ir == il:
                ir += 1

            #how much of the letter width was needed to find il:                                                                            
            used = l_fig[il]-rel_pos
            rel_pos = l_fig[il]

            #relative distance between il and ir where the center                                                                           
            #of the character will be                                                                                                       
            fraction = (w/2-used)/r_fig_dist[il]

            ##setting the character position in data coordinates:  
            ##interpolate between the two points:                                                                                           
            x = self.__x[il]+fraction*(self.__x[ir]-self.__x[il])
            y = self.__y[il]+fraction*(self.__y[ir]-self.__y[il])

            #getting the offset when setting correct vertical alignment                                                                     
            #in data coordinates                                                                                                            
            t.set_va(self.get_va())
            bbox2  = t.get_window_extent(renderer=renderer)

            bbox1d = self.axes.transData.inverted().transform(bbox1)
            bbox2d = self.axes.transData.inverted().transform(bbox2)
            dr = np.array(bbox2d[0]-bbox1d[0])

            #the rotation/stretch matrix                                                                                                    
            rad = rads[il]
            rot_mat = np.array([
                [math.cos(rad), math.sin(rad)*aspect],
                [-math.sin(rad)/aspect, math.cos(rad)]
            ])

            ##computing the offset vector of the rotated character
            drp = np.dot(dr,rot_mat)

            #setting final position and rotation:                                                                                           
            t.set_position(np.array([x,y])+drp)
            t.set_rotation(degs[il])

            t.set_va('center')
            t.set_ha('center')

            #updating rel_pos to right edge of character                                                                                    
            rel_pos += w-used            

In [None]:
# Bottom: Data Processing 
fig, axes = plt.subplots(figsize=(12.5, 4.25), dpi= 120, nrows = 1, ncols = 3)

mesh_lat,mesh_lon = np.meshgrid(longitude,latitude)
ring_mask = (mesh_lat**2+mesh_lon**2)*(180/np.pi)**2>20**2

axes[0].set_title('Data Within 20$^{\circ}$ of GC')
axes[0].set_facecolor('grey')
axes[0].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi, np.ma.masked_array(np.log10(grid_map),mask=ring_mask),cmap='viridis',
              vmin=0,vmax=vmax)

t = np.linspace(0, 2 * np.pi, 100)
l_list = np.cos(t)
b_list = np.sin(t)

axes[0].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.)
axes[0].set_xlabel('$\ell~(^\circ)$')
axes[0].set_ylabel('$b~(^\circ)$')
axes[0].set_xticks([-20,-10,0,10,20])
axes[0].xaxis.set_inverted (True)

axes[1].set_title('Data With Point Source and Disk Masks')
axes[1].set_facecolor('grey')

circle1 = plt.Circle((0, 0), 20, color='w',fill=True)
        
axes[1].add_patch(circle1)

axes[1].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi, np.ma.masked_array(np.log10(grid_map_ps),mask=ring_mask),cmap='viridis',
              vmin=0,vmax=vmax)

axes[1].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.)
axes[1].set_xlabel('$\ell~(^\circ)$')
axes[1].set_ylabel('$b~(^\circ)$')
axes[1].set_xticks([-20,-10,0,10,20])
axes[1].xaxis.set_inverted (True)

                  
axes[2].set_title('Complete ROI')
axes[2].set_frame_on(False)
#axes[2].set_facecolor('grey')

#circle1 = plt.Circle((0, 0), 20, color='w',fill=True)
#circle2 = plt.Circle((0, 0), 40, color='w',fill=True)
#circle3 = plt.Circle((0, 0), 30, color='grey',fill=True)

#axes[2].add_patch(circle2)
#axes[2].add_patch(circle3)
#axes[2].add_patch(circle1)

axes[2].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, np.log10(grid_map_complete),cmap='viridis',
              vmin=0,vmax=vmax,zorder=-20)

axes[2].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.)
axes[2].plot(30*l_list,30*b_list, color="k", ls = "-", lw = 2.)
axes[2].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2.)
axes[2].set_xlabel('$\ell~(^\circ)$')
axes[2].set_ylabel('$b~(^\circ)$')
axes[2].set_xlim(40,-40) # added since annuluses extend beyond 40 deg
axes[2].set_ylim(-40,40)
axes[2].set_xticks([-40,-20,0,20,40])

CurvedText(x = 20*l_list[8:], y = 20*b_list[8:], text = 'Inner', color = 'black', va = 'bottom', axes = axes[2])
CurvedText(x = -30*l_list[60:], y = 30*b_list[60:], text = 'Buffer', color = 'red', va = 'bottom', axes = axes[2])
CurvedText(x = 40*l_list[8:], y = 40*b_list[8:], text = 'Outer', color = 'black', va = 'bottom', axes = axes[2])

# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

axes[2].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = -10)

annulus_roi_x = [outer_roi_high_x, outer_roi_lim_x[::-1]]
annulus_roi_y = [outer_roi_high_y, outer_roi_lim_y[::-1]]

axes[2].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = -10)

#axes[2].text(0,-1.9,s=r'${\rm Inner~ROI}$',horizontalalignment='center')
#axes[2].text(0,22,s=r'${\rm Buffer}$',horizontalalignment='center')
#axes[2].text(0,32,s=r'${\rm \bf{Outer~ROI}}$',horizontalalignment='center',color='r')

axes[2].set_frame_on(True) # adds border frame


dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[0].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)
for label in axes[1].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)
    
plt.tight_layout()

plt.savefig('figures/fig_allsky_data_inset.png',format='png',bbox_inches='tight', dpi=300)

# Figure 2

In [None]:
ordered_known_templates = ['iso', 'psc', 'bub', 'pib', 'ics']

In [None]:
# load data
data, temp_dict = pickle.load(open('/data/edr76/gce-gp/figures/data/figdata_1_3.p', 'rb'))

In [None]:
xsize = 400
ysize = int(xsize/2)
theta40 = np.linspace(np.pi*(0.5+40/180), np.pi*(0.5-40/180), ysize)
phi40   = np.linspace(-np.pi/180*40, np.pi/180*40, xsize)
longitude40 = np.radians(np.linspace(-40, 40, xsize))
longitude40 = longitude40[::-1]
latitude40 = np.radians(np.linspace(-40, 40, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI40, THETA40 = np.meshgrid(phi40, theta40)

In [None]:
grid_pix40 = hp.ang2pix(nside, THETA40, PHI40)
iso_map= np.zeros(hp.nside2npix(nside))
iso_map[~mask_40] = temp_dict['iso']
grid_iso = iso_map[grid_pix40]

psc_map= np.zeros(hp.nside2npix(nside))
psc_map[~mask_40] = temp_dict['psc']
grid_psc = psc_map[grid_pix40]

bub_map= np.zeros(hp.nside2npix(nside))
bub_map[~mask_40] = temp_dict['bub']
grid_bub = bub_map[grid_pix40]

pib_map= np.zeros(hp.nside2npix(nside))
pib_map[~mask_40] = temp_dict['pib']
grid_pib = pib_map[grid_pix40]

ics_map= np.zeros(hp.nside2npix(nside))
ics_map[~mask_40] = temp_dict['ics']
grid_ics = ics_map[grid_pix40]

In [None]:
fig, axes = plt.subplots(figsize=(12, 4.7), dpi= 120, nrows = 1, ncols = 3) 

t = np.linspace(0, 2 * np.pi, 100)
l_list = np.cos(t)
b_list = np.sin(t)

m = np.log10(grid_iso/grid_iso.max())

im = axes[0].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, m,cmap='viridis',
              vmin=-0.1,vmax=0)

#axes[0].set_facecolor('grey')
axes[0].set_title(r'\texttt{iso}')
axes[0].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[0].set_xlabel('$\ell~(^\circ)$')
axes[0].set_ylabel('$b~(^\circ)$')
#axes[0].set_xticks([-40,-20,0,20,40])
axes[0].xaxis.set_inverted (True)
divider = make_axes_locatable(axes[0])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')

m = np.log10(grid_psc/grid_psc.max())

im = axes[1].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, m,cmap='viridis',
              vmin=-10,vmax=0)

#axes[1].set_facecolor('grey')
axes[1].set_title(r'\texttt{psc}')
axes[1].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[1].set_xlabel('$\ell~(^\circ)$')
axes[1].set_ylabel('$b~(^\circ)$')
axes[1].set_xticks([-40,-20,0,20,40])
#axes[1].xaxis.set_inverted (True)
divider = make_axes_locatable(axes[1])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')

m = np.log10(grid_bub/grid_bub.max())

im = axes[2].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, m,cmap='viridis',
              vmin=-3,vmax=0)

#axes[2].set_facecolor('grey')
axes[2].set_title(r'\texttt{bub}')
axes[2].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[2].set_xlabel('$\ell~(^\circ)$')
axes[2].set_ylabel('$b~(^\circ)$')
axes[2].set_xticks([-40,-20,0,20,40])
#axes[2].xaxis.set_inverted (True)
divider = make_axes_locatable(axes[2])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')


# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [outer_roi_high_x, outer_roi_lim_x[::-1]]
annulus_roi_y = [outer_roi_high_y, outer_roi_lim_y[::-1]]

for i in range(3):
    axes[i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
    axes[i].set_frame_on(True)
    axes[i].set_xlim(-40,40) # added since annuluses extend beyond 40 deg
    axes[i].set_ylim(-40,40) # added since annuluses extend beyond 40 deg
    axes[i].xaxis.set_inverted (True)

plt.tight_layout()
plt.savefig('figures/fig_templates_known_top.png',format='png',bbox_inches='tight', dpi=300)

In [None]:
fig, axes = plt.subplots(figsize=(12/3*2, 4.7), dpi= 120, nrows = 1, ncols = 2) 

t = np.linspace(0, 2 * np.pi, 100)
l_list = np.cos(t)
b_list = np.sin(t)

m = np.log10(grid_pib/grid_pib.max())

im = axes[0].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, m,cmap='viridis',
              vmin=-3,vmax=0)

#axes[0].set_facecolor('grey')
axes[0].set_title(r'\texttt{pib}')
axes[0].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[0].set_xlabel('$\ell~(^\circ)$')
axes[0].set_ylabel('$b~(^\circ)$')
axes[0].set_xticks([-40,-20,0,20,40])
#axes[0].xaxis.set_inverted (True)
divider = make_axes_locatable(axes[0])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')

m = np.log10(grid_ics/grid_ics.max())

im = axes[1].pcolormesh(-longitude40*180/np.pi, latitude40*180/np.pi, m,cmap='viridis',
              vmin=-2,vmax=0)

#axes[1].set_facecolor('grey')
axes[1].set_title(r'\texttt{ics}')
axes[1].plot(40*l_list,40*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[1].set_xlabel('$\ell~(^\circ)$')
axes[1].set_ylabel('$b~(^\circ)$')
axes[1].set_xticks([-40,-20,0,20,40])
#axes[1].xaxis.set_inverted (True)
divider = make_axes_locatable(axes[1])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')

for i in range(2):
    axes[i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
    axes[i].set_frame_on(True)
    axes[i].set_xlim(-40,40) # added since annuluses extend beyond 40 deg
    axes[i].set_ylim(-40,40) # added since annuluses extend beyond 40 deg
    axes[i].xaxis.set_inverted (True)

plt.tight_layout()
plt.savefig('figures/fig_templates_known_bottom.png',format='png',bbox_inches='tight', dpi=300)

# Figure 3

In [None]:
for n in range(5):
    print(ef.gen_blg_name_(n)[0])

In [None]:
fig, axes = plt.subplots(figsize=(12, 2*4.7), dpi= 120, nrows = 2, ncols = 3) 

t = np.linspace(0, 2 * np.pi, 100)
l_list = np.cos(t)
b_list = np.sin(t)

grid_pix = hp.ang2pix(nside, THETA, PHI)
nfw_map= np.zeros(hp.nside2npix(nside))
nfw_map[~mask_40] = temp_dict['nfw']
grid_nfw = nfw_map[grid_pix]

m = np.log10(grid_nfw/grid_nfw.max())


mesh_lat,mesh_lon = np.meshgrid(longitude,latitude)
ring_mask = (mesh_lat**2+mesh_lon**2)*(180/np.pi)**2>20**2

im = axes[0,0].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.ma.masked_array(m,mask=ring_mask),cmap='viridis',
              vmin=-3,vmax=0)

#axes[0,0].set_facecolor('grey')
axes[0,0].set_title(r'\texttt{nfw}')
axes[0,0].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[0,0].set_xlabel('$\ell~(^\circ)$')
axes[0,0].set_ylabel('$b~(^\circ)$')
axes[0,0].set_xticks([-20,-10,0,10,20])

# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

axes[0,0].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
axes[0,0].set_frame_on(True)
axes[0,0].set_xlim(-20,20) # added since annuluses extend beyond 40 deg
axes[0,0].set_ylim(-20,20)
axes[0,0].xaxis.set_inverted (True)

dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[0,0].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)
divider = make_axes_locatable(axes[0,0])
cax = divider.append_axes('bottom', size='5%', pad=0.6)
cb=fig.colorbar(im, cax=cax, orientation='horizontal')
cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')

for n in range(5):
    i = n+1
    xspot = i%3
    yspot = int(i/3)
    name = ef.gen_blg_name_(n)[0]
    string = r'\texttt{'+name+'}'
    bulge_map= np.zeros(hp.nside2npix(nside))
    bulge_map[~mask_40] = temp_dict['blg'][n]
    grid_bulge = bulge_map[grid_pix]

    m = np.log10(grid_bulge/grid_bulge.max())

    im = axes[yspot,xspot].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi, np.ma.masked_array(m,mask=ring_mask),cmap='viridis',
                  vmin=-3,vmax=0)

    #axes[yspot,xspot].set_facecolor('grey')
    axes[yspot,xspot].set_title(string)
    axes[yspot,xspot].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
    axes[yspot,xspot].set_xlabel('$\ell~(^\circ)$')
    axes[yspot,xspot].set_ylabel('$b~(^\circ)$')
    axes[yspot,xspot].set_xticks([-20,-10,0,10,20])
    for label in axes[yspot,xspot].xaxis.get_majorticklabels()[4:5]:
        label.set_transform(label.get_transform() + offset)
    divider = make_axes_locatable(axes[yspot,xspot])
    cax = divider.append_axes('bottom', size='5%', pad=0.6)
    cb=fig.colorbar(im, cax=cax, orientation='horizontal')
    cb.ax.xaxis.set_label_text(r'$\log_{10}(\lambda/\lambda_{\rm max})$')
    
    
    axes[yspot,xspot].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
    axes[yspot,xspot].set_frame_on(True)
    axes[yspot,xspot].set_xlim(-20,20) # added since annuluses extend beyond 40 deg
    axes[yspot,xspot].set_ylim(-20,20)
    axes[yspot,xspot].xaxis.set_inverted (True)
    
plt.tight_layout()
plt.savefig('figures/fig_templates_gce.png',format='png',bbox_inches='tight', dpi=300)

# Figure 4

In [None]:
ed_data_location = '/data/edr76/gce-gp/figures/'

In [None]:
# load data
data, corner_samples, temp_sample_dict, temp_sample_dict_cmask, temp_dict = pickle.load(open(ed_data_location+'data/figdata_4_6.p', 'rb'))

# load cartesian data
mask_map_cart, exp_gp_samples_cart, gp_true, tot_samples_cart, sim_samples, model_residuals_cart, data_residuals_cart = pickle.load(open(ed_data_location+'data/figdata_4_6_cart.p', 'rb'))

In [None]:
names = list(corner_samples.keys())
#labels = ef.gen_labels(names)

In [None]:
names_dressed = [r'$S_{\texttt{bub}}$',
                r'$S_{\texttt{ics}}$',
                r'$S_{\texttt{iso}}$',
                r'$S_{\texttt{pib}}$',
                r'$S_{\texttt{psc}}$',
                r'$S_{\texttt{gp}}$']

In [None]:
template_sample_array = np.zeros((len(names), len(corner_samples[names[0]])))
for i in range(len(names)):
    name = names[i]
    template_sample_array[i] = corner_samples[name]

In [None]:
fig = corner.corner(template_sample_array.T, labels=names_dressed, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})

N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))
for i in range(len(names)):
    ax = axes[i,i]
    name = names[i]
#    ax.axvline(np.mean(corner_samples[name]), color='red', linestyle='--')
    
N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline(temp_dict[name], color='red', linestyle='--')

for yi in range(len(names)):
    for xi in range(yi):
        ax = axes[yi,xi]
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]

        ax.axvline(value_x, color='red', linestyle='--')
        ax.axhline(value_y, color='red', linestyle='--')
        ax.plot(value_x, value_y, "sr")

#plt.tight_layout()
plt.savefig('figures/fig_sim_corner_plot.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

In [None]:
temp_names_sim = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'nfw']

In [None]:
fig = plt.figure(figsize=(12, 4), dpi= 120) ; ax = fig.add_subplot(111)
bins = np.logspace(3.,5.,150)

all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
names = list(temp_sample_dict.keys())

temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

ordered_names = [name for name in all_temp_names if name in names]
for k in range(len(ordered_names)):
    name = ordered_names[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    # if name == 'gp':
    #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
    #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
    # else:
    temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
    ax.hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)
    
names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
ordered_names_sim = [name for name in all_temp_names if name in names_sim]
      
for k in range(len(ordered_names_sim)):
    name = ordered_names_sim[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    if ordered_names_sim[k] == 'gp':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        ax.axvline(temp_sum_sim, linestyle='--', c = ccode)
    elif ordered_names_sim[k] == 'nfw':
        continue
    elif ordered_names_sim[k] == 'blg':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        ax.axvline(temp_sum_sim, linestyle='--', c = ccode)
    else:
        temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
        ax.axvline(temp_sum_sim, linestyle='--', c = ccode)  

handles, labels = ax.get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

ax.legend(handles=new_handles, labels=labels,frameon=False,loc=2)
ax.set_xscale('log')
ax.set_xlabel(r'$\mathrm{Counts}$')
ax.set_ylabel(r'$\mathrm{Density}$')

#plt.tight_layout()
plt.savefig('figures/fig_sim_log_counts_hist.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

# Figure 5

In [None]:
n_pixels = 160

In [None]:
slice_val = 3.2  # y-value of slice

fig, axes = plt.subplots(figsize=(12, 6 - 0.125), dpi= 120, nrows = 1, ncols = 2)

# 1d slice of total rate map
q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)
sim_cart = ef.healpix_to_cart(sim_samples, mask, n_pixels = n_pixels, nside = 128)
raw_cart = ef.healpix_to_cart(data[~mask], mask, n_pixels = n_pixels, nside = 128)

plt.axes(axes[0])
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_val = slice_val, 
    ylim = [20., 90.],
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,)

# 1d slice of total rate residuals
q = np.percentile(data_residuals_cart, [2.5,16,50,84,97.5], axis = 0)
sim_cart = np.zeros((n_pixels,n_pixels))
raw_cart = None

plt.axes(axes[1])
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = '$(y - y_\mathrm{pred})$', q_color = 'purple', line_color = 'k', ls = '--')

axes[0].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
axes[1].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])

fig.tight_layout(pad = 0.2)
fig.savefig('figures/fig_sim_tot_slice.pdf', bbox_inches='tight')

# Figure 6

In [None]:
fig, axes = plt.subplots(figsize=(12/3*2, 4.), dpi= 120, nrows = 1, ncols = 2) 

q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
sim_cart = ef.healpix_to_cart(gp_true, mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
raw_cart = None

plt.axes(axes[0], aspect = 'equal')
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_dir = 'horizontal', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

plt.axes(axes[1], aspect = 'equal')
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_dir = 'vertical', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

axes[0].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
plt.tight_layout()
plt.savefig('figures/fig_sim_gp_samples_slices.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

In [None]:
xsize = 400
ysize = int(xsize/2)
theta = np.linspace(np.pi*(0.5+20/180), np.pi*(0.5-20/180), ysize)
phi   = np.linspace(-np.pi/180*20, np.pi/180*20, xsize)
longitude = np.radians(np.linspace(-20, 20, xsize))
longitude = longitude[::-1]
latitude = np.radians(np.linspace(-20, 20, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI, THETA = np.meshgrid(phi, theta)

In [None]:
fig, axes = plt.subplots(figsize=(12/0.85, 4), dpi= 120, nrows = 1, ncols = 3) 

t = np.linspace(0, 2 * np.pi, 100)
l_list = np.cos(t)
b_list = np.sin(t)

for i in range(3):
    grid_pix = hp.ang2pix(nside, THETA, PHI)
    temp_map_0= np.zeros(hp.nside2npix(nside))
    temp_map_0[~mask] = temp_sample_dict_cmask['gp'][i]
    grid_0 = temp_map_0[grid_pix]

    im = axes[i].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                  vmin=-0.7,vmax=1.75)

    axes[i].set_facecolor('grey')
    #axes[0].set_title(r'\texttt{nfw}')
    axes[i].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.,zorder=11)
    axes[i].set_xlabel('$\ell~(^\circ)$')
    axes[i].set_ylabel('$b~(^\circ)$')
    axes[i].set_xticks([-20,-10,0,10,20])
#    axes[i].xaxis.set_inverted(True)
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[i].xaxis.get_majorticklabels()[4:5]:
        label.set_transform(label.get_transform() + offset)
        
        # load boundary and fill shapes
    _, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

    annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
    annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

    axes[i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
    axes[i].set_frame_on(True)
    axes[i].set_xlim(-20,20) # added since annuluses extend beyond 40 deg
    axes[i].set_ylim(-20,20)
    axes[i].xaxis.set_inverted (True)

plt.tight_layout()
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.81, 0.2, 0.015, 0.72])
fig.colorbar(im, cax=cbar_ax,label=r'$\log_{10}(\lambda)$')
#cbar_ax.yaxis.set_label_text(r'$\log_{10}({\bf y})$')

plt.savefig('figures/fig_sim_gp_samples_draws.png',format='png',bbox_inches='tight', dpi=300)

# Figure 7

In [None]:
# load data
data, temp_sample_dict_list, temp_sample_dict_cmask_list, ll_samples_dict, x_range = pickle.load(open(ed_data_location+'data/figdata_7_8.p', 'rb'))

# load cartesian data
mask_map_cart, exp_gp_samples_cart_list, tot_samples_cart_list, model_residuals_cart_list = pickle.load(open(ed_data_location+'data/figdata_7_8_cart.p', 'rb'))
raw_cart = ef.healpix_to_cart(data[~mask], mask, n_pixels = n_pixels, nside = 128)

In [None]:
# label best models (dif_names ordered from 0 - 79)
roman_nums_arr = [ef.int_to_Roman(i) for i in range(1,80+1)]
best_models_rom = ['X', 'XV', 'XLVIII', 'XLIX', 'LIII']
model_names = ['O'] + best_models_rom
best_models = [roman_nums_arr.index(r) + 1 for r in best_models_rom]

In [None]:
nrows = 6 ; ncols = 4
vmin = -0.7 ; vmax = 1.75

for i in range(6): 
    fig, axes = plt.subplots(figsize=(20, 4), dpi= 120, nrows = 1, ncols = ncols)
    # plot violin plot
    # load data
    all_data = [ll_samples_dict['ll_total'][i], ll_samples_dict['ll_inner'][i], ll_samples_dict['ll_outer'][i]]
    colors = ['red', 'blue', 'green']
    eplt.violin_plot(all_data, colors, ax = axes[0])
    axes[0].set_xlim([-150,100])

    plt.axes(axes[1])
    temp_sample_dict = temp_sample_dict_list[i]
    #ef.tot_log_counts_hist(temp_sample_dict, None, None, mask = mask_p, bins = np.linspace(3.,5.,150), gp_model_nfw=True, gp_model_iso = False, ax = axes[i,1])

    axes[1].text(0.6,0.85, 'Model {}'.format(model_names[i]), 
                 horizontalalignment = 'center', verticalalignment = 'center', 
                 fontsize = 30,transform = axes[1].transAxes)

    names = list(temp_sample_dict.keys())

    temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

    ordered_names = [name for name in all_temp_names if name in names]
    for k in range(len(ordered_names)):
        name = ordered_names[k]
        idx = all_temp_names.index(name)
        ccode = ccodes[idx]
        # if name == 'gp':
        #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
        #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
        # else:
        temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
        axes[1].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)

    handles, labels = axes[1].get_legend_handles_labels()
    new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

    axes[1].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
    axes[1].set_xscale('log')
    axes[1].set_xlabel(r'$\mathrm{Counts}$')
    axes[1].set_ylabel(r'$\mathrm{Density}$')

    # plot diagnostic plots
    slice_val = 3.2  # y-value of slice

    # 1d slice of total rate map
    tot_samples_cart = tot_samples_cart_list[i]
    data_residuals_cart = model_residuals_cart_list[i]

    # 1d slice of total rate map
    q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)

    plt.axes(axes[2])
    eplt.cart_plot_1d(q, sim_cart = None, raw_cart = raw_cart, 
        slice_val = slice_val, 
        ylim = [20., 90.],
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,)

    # 1d slice of total rate residuals
    q = np.percentile(data_residuals_cart, [2.5,16,50,84,97.5], axis = 0)
    sim_cart = np.zeros((n_pixels,n_pixels))

    plt.axes(axes[3])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = None, 
        slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = '$(y - y_\mathrm{pred})$', q_color = 'purple', line_color = 'k', ls = '--')

    axes[2].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[2].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)
    axes[3].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    
    plt.tight_layout()
    savestring = 'fig_fit_to_data_summary_'+str(i)
    plt.savefig('figures/'+savestring+'.pdf',format='pdf',bbox_inches='tight')
    plt.show()

# Figure 8

In [None]:
vmin = -0.7 ; vmax = 1.75

idxs = jax.random.randint(jax.random.PRNGKey(53), (5, 6 - 1), 0, temp_sample_dict_cmask['gp'].shape[0])
for i in range(6): 
    fig, axes = plt.subplots(figsize=(20, 20/5.2), dpi= 120, nrows = 1, ncols = 5)
    exp_gp_samples_cart = exp_gp_samples_cart_list[i]
    temp_sample_dict_cmask = temp_sample_dict_cmask_list[i]

    # 1d slice of GCE
    q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles

    plt.axes(axes[0], aspect = 'equal')
    eplt.cart_plot_1d(q, sim_cart = None, raw_cart = None, 
        slice_dir = 'horizontal', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{\rm gce}$', q_color = 'darkorange', line_color = 'green')
    
    axes[0].text(0.05,0.1, 'Model {}'.format(model_names[i]), 
                 horizontalalignment = 'left', verticalalignment = 'center', 
                 fontsize = 30,transform = axes[0].transAxes)

    
    plt.axes(axes[1], aspect = 'equal')
    eplt.cart_plot_1d(q, sim_cart = None, raw_cart = None, 
        slice_dir = 'vertical', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{gce}$', q_color = 'darkorange', line_color = 'green')
    
    
    t = np.linspace(0, 2 * np.pi, 100)
    l_list = np.cos(t)
    b_list = np.sin(t)

    for j in range(2,5):
        idx = idxs[i,j-1]
        grid_pix = hp.ang2pix(nside, THETA, PHI)
        temp_map_0= np.zeros(hp.nside2npix(nside))
        temp_map_0[~mask] = temp_sample_dict_cmask['gp'][idx]
        grid_0 = temp_map_0[grid_pix]

        im = axes[j].pcolormesh(longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                      vmin=-0.7,vmax=1.75)

#        axes[j].set_facecolor('grey')
        #axes[0].set_title(r'\texttt{nfw}')
        axes[j].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.,zorder=11)
        axes[j].set_xlabel('$\ell~(^\circ)$')
        axes[j].set_ylabel('$b~(^\circ)$')
    
    
        axes[j].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
        
        dx = 5/72.; dy = 0/72. 
        offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
        for label in axes[j].xaxis.get_majorticklabels()[0:1]:
            label.set_transform(label.get_transform() + offset)
        
        _, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

        annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
        annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

        axes[j].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
        axes[j].set_frame_on(True)
        axes[j].set_xlim(-20,20) # added since annuluses extend beyond 40 deg
        axes[j].set_ylim(-20,20)
        axes[j].set_xticks([-20,-10,0,10,20])
#        axes[j].xaxis.set_inverted (True)
            
    axes[0].set_xticklabels(['$20$','$0$','$-20$'])
    
    for label in axes[0].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)
    
    plt.tight_layout()
    fig.subplots_adjust(right=0.9)
    cbar_ax = fig.add_axes([0.91, 0.2, 0.01, 0.72])
    fig.colorbar(im, cax=cbar_ax,label=r'$\log_{10}(\lambda)$')
    savestring = 'fig_fit_to_data_gp_samples'+str(i)
    plt.savefig('figures/'+savestring+'.png',format='png',bbox_inches='tight',dpi=300)
    plt.show()

# Figure 9

In [None]:
# load data
temp_samples_list, temp_svi_results_list = pickle.load(open(ed_data_location+'data/figdata_9.p', 'rb'))

In [None]:

# blg id and blg name
blg_ids = np.arange(1,6)
num_blgs = len(blg_ids)
blg_names = [r'\texttt{'+ef.gen_blg_name_(int(blg_ids[j]) - 1)[0]+'}' for j in range(len(blg_ids))]

# highlight best models (dif_names ordered from 0 - 79)
roman_nums_arr = [ef.int_to_Roman(i) for i in range(1,80+1)]
best_models_rom = ['X', 'XV', 'XLVIII', 'XLIX', 'LIII']
model_names = ['O'] + best_models_rom
best_models = [roman_nums_arr.index(r) + 1 for r in best_models_rom]
from scipy.special import softmax

In [None]:
blg_names

In [None]:
fig = plt.figure(figsize=(12,10), dpi= 120)

ax11 = plt.subplot2grid((4, 3), (0, 0), colspan = 2)
num_trials = 6
width = 0.14

rects = []
colors = ['C' + str(i) for i in range(6)]
for n in range(num_trials):
    t = np.linspace(0,1,2)
    ind = np.arange(len(t))

    temp_samples = temp_samples_list[n]
    q_blg = np.quantile(temp_samples['S_blg'], [0.16, 0.5, 0.84])
    q_nfw = np.quantile(temp_samples['S_nfw'], [0.16, 0.5, 0.84])

    rel_norms = np.array([q_blg[1], q_nfw[1]])
    yerr = np.array([[q_blg[1] - q_blg[0], q_nfw[1] - q_nfw[0]], [q_blg[2] - q_blg[1], q_nfw[2] - q_nfw[1]]])

    rects.append(ax11.bar(ind - (num_trials - 1) / 2 * width + n * width , rel_norms, width, yerr = yerr,capsize=5,
                        color = colors[n], label = 'Dif Model {}'.format(n)))

ax11.set_ylim(0, 5.5)
ax11.set_xticks(ind)
ax11.set_xticklabels([r'$S_{\rm blg}^{10^{\circ}}$', r'$S_{\rm nfw}^{10^{\circ}}$'])
# ax11.text(0.9, 2., 'Normalizations', fontsize=20, ha='center')

ax12 = plt.subplot2grid((4, 3), (0, 2), colspan = 1)
num_trials = 6
width = 0.14

rects = []
colors = ['C' + str(i) for i in range(6)]
for n in range(num_trials):
    t = np.linspace(0,1,1)
    ind = np.arange(len(t))

    temp_samples = temp_samples_list[n]
    if 'gamma' not in list(temp_samples.keys()):
        q_gamma = np.quantile(np.zeros_like(temp_samples['S_nfw']), [0.16, 0.5, 0.84])
    else:
        q_gamma = np.quantile(temp_samples['gamma'], [0.16, 0.5, 0.84])

    rel_norms = np.array([q_gamma[1]])
    yerr = np.array([[q_gamma[1] - q_gamma[0]], [q_gamma[2] - q_gamma[1]]])

    rects.append(ax12.bar(ind - (num_trials - 1) / 2 * width + n * width , rel_norms, width, yerr = yerr,capsize=5, 
                        color = colors[n], label = 'Dif Model {}'.format(n)))

ax12.set_ylim(0, 2.)
ax12.set_xticks(ind)
ax12.set_xticklabels([r'$\gamma$'])
ax12.axhline(y = 0.2, color = 'k', linestyle = '--')
# ax12.text(0, 1.75, 'NFW Gamma', fontsize=20, ha='center')

ax2 = plt.subplot2grid((4, 3), (1, 0), colspan = 3)
num_trials = 6
width = 0.14

rects = []
colors = ['C' + str(i) for i in range(num_trials)]
for n in range(num_trials):
    t = np.linspace(0,1,num_blgs)
    ind = np.arange(len(t))

    temp_samples = temp_samples_list[n]
    q_rel = np.quantile(temp_samples['theta_blg'], [0.16, 0.5, 0.84], axis = 0)

    rel_norms = [q_rel[1,i] for i in range(num_blgs)]
    yerr = [[q_rel[1,i] - q_rel[0,i] for i in range(num_blgs)], [q_rel[2,i] - q_rel[1,i] for i in range(num_blgs)]]

    rects.append(ax2.bar(ind - (num_trials - 1) / 2 * width + n * width , rel_norms, width, yerr = yerr,capsize=5, 
                        color = colors[n], label = 'Model {}'.format(model_names[n])))

ax2.set_xlabel(r'Relative Normalizations ($\theta_{\rm blg}$)')
ax2.set_xticks(ind)
ax2.set_xticklabels(blg_names)
ax2.legend(frameon=False,loc=1,ncol=2)
# ax2.text(0.55, 0.75, 'Bulge Comparisons', fontsize=24, ha='center')

plt.tight_layout()
fig.savefig('figures/fig_fit_to_data_bulge_comparisons.pdf', bbox_inches='tight')

# Figure 10

In [None]:
temp_samples_list, temp_svi_results_list = pickle.load(open(ed_data_location+'data/figdata_10.p', 'rb'))

In [None]:
# histogram plots of bulges

# blg id and blg name
blg_ids = np.arange(1,6)
num_blgs = len(blg_ids)
blg_names = [r'\texttt{'+ef.gen_blg_name_(int(blg_ids[j]) - 1)[0]+'}' for j in range(len(blg_ids))]

fig = plt.figure(figsize=(12,10), dpi= 120)

ax11 = plt.subplot2grid((4, 2), (0, 0), colspan = 2)
num_trials = 6
width = 0.14

rects = []
colors = ['C' + str(i) for i in range(6)]
for n in range(num_trials):
    t = np.linspace(0,1,2)
    ind = np.arange(len(t))

    temp_samples = temp_samples_list[n]
    q_blg = np.quantile(temp_samples['S_blg'], [0.16, 0.5, 0.84])
    q_nfw = np.quantile(temp_samples['S_nfw'], [0.16, 0.5, 0.84])

    rel_norms = np.array([q_blg[1], q_nfw[1]])
    yerr = np.array([[q_blg[1] - q_blg[0], q_nfw[1] - q_nfw[0]], [q_blg[2] - q_blg[1], q_nfw[2] - q_nfw[1]]])

    rects.append(ax11.bar(ind - (num_trials - 1) / 2 * width + n * width , rel_norms, width, yerr = yerr, capsize=5, 
                        color = colors[n], label = 'Dif Model {}'.format(n)))

ax11.set_ylim(0, 5.5)
ax11.set_xticks(ind)
ax11.set_xticklabels([r'$S_{\rm blg}^{10^{\circ}}$', r'$S_{\rm nfw}^{10^{\circ}}$'])
# ax11.text(0.9, 2., 'Normalizations', fontsize=20, ha='center')

ax2 = plt.subplot2grid((4, 2), (1, 0), colspan = 2)
num_trials = 6
width = 0.14

rects = []
colors = ['C' + str(i) for i in range(num_trials)]
for n in range(num_trials):
    t = np.linspace(0,1,num_blgs)
    ind = np.arange(len(t))

    temp_samples = temp_samples_list[n]
    q_rel = np.quantile(temp_samples['theta_blg'], [0.16, 0.5, 0.84], axis = 0)

    rel_norms = [q_rel[1,i] for i in range(num_blgs)]
    yerr = [[q_rel[1,i] - q_rel[0,i] for i in range(num_blgs)], [q_rel[2,i] - q_rel[1,i] for i in range(num_blgs)]]

    rects.append(ax2.bar(ind - (num_trials - 1) / 2 * width + n * width , rel_norms, width, yerr = yerr, capsize=5, 
                        color = colors[n], label = 'Model {}'.format(model_names[n])))

ax2.set_xlabel(r'Relative Normalizations ($\theta_{blg}$)')
ax2.set_xticks(ind)
ax2.set_xticklabels(blg_names)
ax2.legend(frameon=False,loc=1,ncol=2)
# ax2.text(0.55, 0.75, 'Bulge Comparisons', fontsize=24, ha='center')

plt.tight_layout()
fig.savefig('figures/fig_fit_to_data_bulge_comparisons_gamma_1p2.pdf', bbox_inches='tight')

# Figure 11

In [None]:
# load data
samples_dict, temp_sample_dict, temp_sample_dict_cmask = pickle.load(open(ed_data_location+'data/figdata_10p5.p', 'rb'))

In [None]:
fig, axes = plt.subplots(figsize=(12, 6*0.95), dpi= 120, nrows = 1, ncols = 2)
bins = np.logspace(3.,5.,150)

all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
names = list(temp_sample_dict.keys())

ordered_names = [name for name in all_temp_names if name in names]
for k in range(len(ordered_names)):
    name = ordered_names[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    # if name == 'gp':
    #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
    #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
    # else:
    temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
    axes[0].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)
    

handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

axes[0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
axes[0].set_xscale('log')
axes[0].set_xlabel(r'$\mathrm{Counts}$')
axes[0].set_ylabel(r'$\mathrm{Density}$')

q = np.percentile(temp_sample_dict_cmask['gp'], 50, axis = 0)

grid_pix = hp.ang2pix(nside, THETA, PHI)
temp_map_0= np.zeros(hp.nside2npix(nside))
temp_map_0[~mask] = q
grid_0 = temp_map_0[grid_pix]

im = axes[1].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
              vmin=-0.7,vmax=1.26)

#axes[1].set_facecolor('grey')
#axes[0].set_title(r'\texttt{nfw}')
axes[1].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[1].set_xlabel('$\ell~(^\circ)$')
axes[1].set_ylabel('$b~(^\circ)$')
axes[1].set_ylim(-20,20)
axes[1].set_xlim(-20,20)
axes[1].set_xticks([-20,-10,0,10,20])
axes[1].xaxis.set_inverted(True)


# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

axes[1].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)


dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[1].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)

plt.tight_layout()
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.91, 0.14, 0.015, 0.805])
fig.colorbar(im, cax=cbar_ax,label=r'$\log_{10}(\lambda)$')
plt.savefig('figures/fig_fit_to_data_all_models.png',format='png',bbox_inches='tight',dpi=300)

In [None]:
fig, axes = plt.subplots(figsize=(12, 6), dpi= 120, nrows = 2, ncols = 1)


q = np.quantile(samples_dict['theta_pib'], [0.16, 0.5, 0.84], axis = 0)
roman_nums_arr = [ef.int_to_Roman(i) for i in range(1,80+1)]
best_models_rom = ['X', 'XV', 'XLVIII', 'XLIX', 'LIII']
best_models_idx = [roman_nums_arr.index(r) for r in best_models_rom]

x = np.arange(0,80+1)
c = ['blue' for i in range(80)]
best_models = []
for i in best_models_idx:
    c[i] = 'red'
c = ['green'] + c # add green for gceNNo
for i in range(80+1):
    axes[0].errorbar(x[i], q[1][i], yerr= np.array([[q[1][i] - q[0][i]], [q[2][i] - q[1][i]]]), fmt='o', color=c[i],capsize=3)
    
axes[0].set_ylabel(r'$\theta_{\texttt{pib}}$')
axes[0].set_xlim(-1, 81)

q = np.quantile(samples_dict['theta_ics'], [0.16, 0.5, 0.84], axis = 0)

x = np.arange(0,80+1)
c = ['blue' for i in range(80)]
best_models = []
for i in best_models_idx:
    c[i] = 'red'
c = ['green'] + c # add green for gceNNo
for i in range(80+1):
    axes[1].errorbar(x[i], q[1][i], yerr= np.array([[q[1][i] - q[0][i]], [q[2][i] - q[1][i]]]), fmt='o', color=c[i],capsize=3)

axes[1].set_xlabel('Model Number')
axes[1].set_ylabel(r'$\theta_{\texttt{ics}}$')
axes[1].set_xlim(-1, 81)

plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('figures/fig_fit_to_data_all_models_hist.pdf',format='pdf',bbox_inches='tight')

# Figure 12

In [None]:
# load data
temp_samples, temp_sample_dict_cmask, blg_list = pickle.load(open(ed_data_location+'data/figdata_11.p', 'rb'))

# load cartesian data
mask_map_cart, exp_gp_samples_cart, nfw_cart, blg_cart, blg_pieces_cart = pickle.load(open(ed_data_location+'data/figdata_11_cart.p', 'rb'))

In [None]:
keys = list(temp_samples.keys())

temp_samples_tot_list = []
for k in keys:
    if k not in ['S_nfw', 'S_blg', 'gamma']:
        continue
    else:
        temp_samples_tot_list.append(temp_samples[k])
temp_samples_tot_list = np.array(temp_samples_tot_list)

#labels = ef.gen_labels(keys, num_theta = len(blg_list))
#labels[0] = '$S_{blg}^{10^{\circ}}$' ; labels[1] = '$S_{nfw}^{10^{\circ}}$' ; labels[2] = '$\\gamma$'
fig = corner.corner(temp_samples_tot_list.T, labels = [r'$S_{\texttt{blg}}^{10^{\circ}}$',
                                                      r'$S_{\texttt{nfw}}^{10^{\circ}}$',r'$\gamma$'], quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})
fig.savefig('figures/fig_alltemps_corner.pdf',format='pdf', bbox_inches='tight')

In [None]:
def cart_plot_1d_multi(q_list,
                 sim_cart=None, raw_cart=None,
                 mask_map_cart = None,
                 n_pixels = 160, res_scale = 1, map_size = 40,
                 slice_dir = 'horizontal', slice_val = 2.,
                 yscale = 'linear', ylim = None,
                 q_colors = None, line_color = 'blue', scatter_color = 'k', ylabel = 'Counts', ls = '-', q_labels = None,
                 samples = None):
    # generate cartesian grid                                                                                                               
    Nx1, Nx2, x1_plt, x2_plt, x1_c, x2_c, x = ef.cart_coords(n_pixels, res_scale, map_size)
    pix_scale = map_size / n_pixels


    if slice_dir == 'horizontal':
        y_slice = slice_val
        ny = np.where(np.abs(x2_c - y_slice) < pix_scale * res_scale)[0][1]
        print('Slice at y = {:.5f} deg'.format(x[ny,0,1]))

        for i, q in enumerate(q_list):
            q_label = q_labels[i]
            q_color = q_colors[i]
            plt.plot(x[ny,:,0], q[1][ny,:], c = q_color, label = q_label)
            plt.fill_between(x[ny,:,0], q[0][ny,:], q[2][ny,:], color = q_color, alpha = 0.2)
        if sim_cart is not None:
            plt.plot(x[ny,:,0], sim_cart[ny,:], c = line_color, label = 'True', ls = ls)
        if raw_cart is not None:
            plt.errorbar(x[ny,:,0], raw_cart[ny,:], fmt = 'o', c = scatter_color, alpha = 0.5, label = 'Data')
        plt.xlabel('$\ell$ $(^\circ)$')
        plt.ylabel(ylabel)
        plt.legend(frameon = False, fontsize = 10,loc=2)
        plt.axvline(0, color='k', ls = '--', lw = 0.5)
        plt.yscale(yscale)

        if samples is not None:
            for i in range(4):
                plt.plot(x[ny,:,0], samples[i,ny,:], c = 'gray', alpha = 0.3)

        if mask_map_cart is not None:
            mask_map_cart_slice = mask_map_cart[ny,:]

            # fill points where mask_map_cart is nan                                                                                        
            nan_mask = np.isnan(mask_map_cart_slice)
            x_nan = x[ny,nan_mask,0]
            y_nan = np.zeros_like(x_nan)
            # ax.scatter(x_nan, y_nan, c = 'k', s = 1)                                                                                      

            # fill between points in x_nan that are separated by at most 0.5 deg                                                            
            x_nan_diff = np.diff(x_nan)
            
            # find the indices of the x_nan_diff that are greater than 0.25 in order to                                                     
            # partition x_nan into separate arrays with members that are less than 0.25 separated                                           
            # added 0.01 to account for floating point error                                                                                

            split_indices = np.where(x_nan_diff > pix_scale + 0.01)[0] + 1 # 1 is added to account for the diff shift in indices            
            split_indices = np.insert(split_indices, 0, 0)
            split_indices = np.append(split_indices, len(x_nan))

            # find max q from q_list                                                                                                        
            max_q = np.max([np.max(q[-1][ny,:]) for q in q_list])
            for i in range(len(split_indices) - 1):
                x_fill = x_nan[split_indices[i]:split_indices[i+1]]
                y_fill = np.zeros_like(x_fill) + np.min([0, np.min(q[0][ny,:])]) - 0.25 - 1 # complicated min expression in case q negative or positive                                                                                                                                
                plt.fill_between(x_fill, 100 * y_fill, 100 * max_q + 1, color = 'gray', alpha = 0.175, edgecolor = None)

            plt.xlim(-20,20)
            if ylim is None:
                plt.ylim(np.min([-0.25, np.min(q[0][ny,:]) - 0.25]), max_q + 0.25) # complicated min expression in case q negative or positive
            else:
                plt.ylim(ylim)
    elif slice_dir == 'vertical':
        x_slice = slice_val
        nx = np.where(np.abs(x1_c - x_slice) < pix_scale * res_scale)[0][1]
        print('Slice at x = {:.5f} deg'.format(x[0,nx,0]))

        for i, q in enumerate(q_list):
            q_label = q_labels[i]
            q_color = q_colors[i]
            plt.plot(x[:,nx,1], q[1][:,nx], c = q_color, label = q_label)
            plt.fill_between(x[:,nx,1], q[0][:,nx], q[2][:,nx], color = q_color, alpha = 0.2)
        if sim_cart is not None:
            plt.plot(x[:,nx,1], sim_cart[:,nx], c = line_color, label = 'True', ls = ls)
        if raw_cart is not None:
            plt.errorbar(x[:,nx,1], raw_cart[:,nx], fmt = 'o', c = scatter_color, alpha = 0.5, label = 'Data')
        plt.xlabel('$b$ $(^\circ)$')
        plt.ylabel(ylabel)
        plt.legend(frameon = False, fontsize = 10)
        plt.axvline(0, color='k', ls = '--', lw = 0.5)
        plt.yscale(yscale)
        
        if samples is not None:
            for i in range(4):
                plt.plot(x[:,nx,1], samples[i,:,nx], c = 'gray', alpha = 0.3)

        if mask_map_cart is not None:
            mask_map_cart_slice = mask_map_cart[:,nx]

            # fill points where mask_map_cart is nan                                                                                        
            nan_mask = np.isnan(mask_map_cart_slice)
            x_nan = x[nan_mask,nx,1]
            y_nan = np.zeros_like(x_nan)
            # ax.scatter(x_nan, y_nan, c = 'k', s = 1)                                                                                      

            # fill between points in x_nan that are separated by at most 0.5 deg                                                            
            x_nan_diff = np.diff(x_nan)

            # find the indices of the x_nan_diff that are greater than 0.25 in order to                                                     
            # partition x_nan into separate arrays with members that are less than 0.25 separated                                           
            # added 0.01 to account for floating point error
            split_indices = np.where(x_nan_diff > pix_scale + 0.01)[0] + 1 # 1 is added to account for the diff shift in indices            
            split_indices = np.insert(split_indices, 0, 0)
            split_indices = np.append(split_indices, len(x_nan))

            # find max q from q_list                                                                                                        
            max_q = np.max([np.max(q[-1][:,nx]) for q in q_list])
            for i in range(len(split_indices) - 1):
                x_fill = x_nan[split_indices[i]:split_indices[i+1]]
                y_fill = np.zeros_like(x_fill) + np.min([0, np.min(q[0][:,nx])]) - 0.25 - 1 # complicated min expression in case q negative or positive                                                                                                                                
                plt.fill_between(x_fill, 100 * y_fill, 100 * max_q + 1, color = 'gray', alpha = 0.175, edgecolor = None)

            plt.xlim(-20,20)
            if ylim is None:
                plt.ylim(np.min([-0.25, np.min(q[0][:,nx]) - 0.25]), max_q + 0.25) # complicated min expression in case q negative or positive                                                                                                                                         
            else:
                plt.ylim(ylim)
        else:
            print('Only horizontal slices are supported at the moment.')

In [None]:
fig, axes = plt.subplots(figsize=(12, 4), dpi= 120, nrows = 1, ncols = 3)

for n in range(temp_samples['theta_blg'].shape[1]):
    axes[0].hist(temp_samples['theta_blg'][:,n], bins = 50, histtype = 'step', color = 'C' + str(n + 4), density = True, label = r'$\texttt{'+ef.gen_blg_name_(n)[0]+'}$')
axes[0].set_xlabel(r'$\theta_{\texttt{blg}}$')
axes[0].set_ylabel('Density')
handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

axes[0].legend(handles=new_handles, labels=labels,frameon=False,fontsize=12)#,loc=2)
axes[0].set_ylim(0, 15)


q_list = []
q_list.append(np.percentile(exp_gp_samples_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(blg_cart + nfw_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(nfw_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(blg_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
idx = np.argmax(temp_samples['theta_blg'].mean(axis = 0))
q_list.append(np.percentile(blg_pieces_cart[idx], [16,50,84], axis = 0)) # cartesian sample map quantiles
# [q_list.append(np.percentile(blg_pieces_cart[i], [16,50,84], axis = 0)) for i in range(len(blg_pieces_cart))] # cartesian sample map quantiles

q_labels = ['GP', 'NFW + Blg', 'NFW', 'Blg', r'$\texttt{'+ef.gen_blg_name_(idx)[0]+'}$']
q_colors = ['C' + str(i) for i in range(len(q_list))]

plt.axes(axes[1])
cart_plot_1d_multi(q_list,
    slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = '$\lambda$', q_colors = q_colors, line_color = 'green', q_labels = q_labels,
    )
axes[1].axvline(np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[1].axvline(-np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[1].set_ylim([0,15])


slice_val = 3.2

q_labels = [None, None, None, None, None]
q_colors = ['C' + str(i) for i in range(len(q_list))]

plt.axes(axes[2])
cart_plot_1d_multi(q_list, 
    slice_dir = 'vertical', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = '$\lambda$', q_colors = q_colors, line_color = 'green', q_labels = q_labels,
    )
axes[2].axvline(np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[2].axvline(-np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[2].set_ylim([0,15])

axes[1].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[1].xaxis.get_majorticklabels()[0:1]:
    label.set_transform(label.get_transform() + offset)
plt.tight_layout()

fig.savefig('figures/fig_alltemps_hist.pdf',format='pdf', bbox_inches='tight')

# Figure 13

In [None]:
# load data
temp_samples, temp_sample_dict_cmask, blg_list = pickle.load(open(ed_data_location+'data/figdata_12.p', 'rb'))

# load cartesian data
mask_map_cart, exp_gp_samples_cart, nfw_cart, blg_cart, blg_pieces_cart = pickle.load(open(ed_data_location+'data/figdata_12_cart.p', 'rb'))

In [None]:
keys = list(temp_samples.keys())

temp_samples_tot_list = []
for k in keys:
    if k not in ['S_nfw', 'S_blg', 'gamma']:
        continue
    else:
        temp_samples_tot_list.append(temp_samples[k])
temp_samples_tot_list = np.array(temp_samples_tot_list)

#labels = ef.gen_labels(keys, num_theta = len(blg_list))
#labels[0] = '$S_{blg}^{10^{\circ}}$' ; labels[1] = '$S_{nfw}^{10^{\circ}}$'
fig = corner.corner(temp_samples_tot_list.T, labels = [r'$S_{\texttt{blg}}^{10^{\circ}}$',
                                                      r'$S_{\texttt{nfw}}^{10^{\circ}}$'], quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})
fig.savefig('figures/fig_alltemps_corner_gamma_1p2.pdf',format='pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(figsize=(12, 4), dpi= 120, nrows = 1, ncols = 3)

for n in range(temp_samples['theta_blg'].shape[1]):
    axes[0].hist(temp_samples['theta_blg'][:,n], bins = 50, histtype = 'step', color = 'C' + str(n + 4), density = True, label = r'$\texttt{'+ef.gen_blg_name_(n)[0]+'}$')
axes[0].set_xlabel(r'$\theta_{\texttt{blg}}$')
axes[0].set_ylabel('Density')
handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

axes[0].legend(handles=new_handles, labels=labels,frameon=False,fontsize=12)#,loc=2)
axes[0].set_ylim(0, 15)


q_list = []
q_list.append(np.percentile(exp_gp_samples_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(blg_cart + nfw_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(nfw_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
q_list.append(np.percentile(blg_cart, [16,50,84], axis = 0)) # cartesian sample map quantiles
idx = np.argmax(temp_samples['theta_blg'].mean(axis = 0))
q_list.append(np.percentile(blg_pieces_cart[idx], [16,50,84], axis = 0)) # cartesian sample map quantiles
# [q_list.append(np.percentile(blg_pieces_cart[i], [16,50,84], axis = 0)) for i in range(len(blg_pieces_cart))] # cartesian sample map quantiles

q_labels = ['GP', 'NFW + Blg', 'NFW', 'Blg', r'$\texttt{'+ef.gen_blg_name_(idx)[0]+'}$']
q_colors = ['C' + str(i) for i in range(len(q_list))]

plt.axes(axes[1])
cart_plot_1d_multi(q_list,
    slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = '$\lambda$', q_colors = q_colors, line_color = 'green', q_labels = q_labels,
    )
axes[1].axvline(np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[1].axvline(-np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[1].set_ylim([0,15])


slice_val = 3.2

q_labels = [None, None, None, None, None]
q_colors = ['C' + str(i) for i in range(len(q_list))]

plt.axes(axes[2])
cart_plot_1d_multi(q_list, 
    slice_dir = 'vertical', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = '$\lambda$', q_colors = q_colors, line_color = 'green', q_labels = q_labels,
    )
axes[2].axvline(np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[2].axvline(-np.sqrt(10.**2. - 3.375**2.), c = 'k', ls = '--', lw = 1)
axes[2].set_ylim([0,15])

axes[1].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[1].xaxis.get_majorticklabels()[0:1]:
    label.set_transform(label.get_transform() + offset)

plt.tight_layout()

fig.savefig('figures/fig_theta_blg_hist_gamma_1p2.pdf',format='pdf', bbox_inches='tight')

# Appendix A

# Figure A1

In [None]:
# load data
temp_dict, temp_sample_dict_1, temp_sample_dict_2, losses_1, ll_list_1, losses_2, ll_list_2, rx1, rx2  = pickle.load(open(ed_data_location+'data/appdata_degeneracy.p', 'rb'))

In [None]:
fig = plt.figure(figsize=(12, 4), dpi= 120) ; ax = fig.add_subplot(111)

# note requires predictive distribution, which was loaded two cells up
Nstep1 = len(losses_1)
Nstep2 = len(losses_2)

x1 = np.arange(1, Nstep1 + 1) / Nstep1
x2 = np.arange(1, Nstep2 + 1) / Nstep2

ax.plot(x1, losses_1, label = '$-$ELBO Fit 1 ', c = 'b', alpha = 0.3)
ax.plot(x2, losses_2, label = '$-$ELBO Fit 2', c = 'r', alpha = 0.3)
ax.plot(rx1 , -ll_list_1, label = r'$-{\cal L}$ Fit 1', c = 'b')  
ax.plot(rx2, -ll_list_2, label = r'$-{\cal L}$ Fit 2', c = 'r')
ax.set_ylim(13000,14000)
ax.legend(frameon=False)
ax.set_xlabel('Normalized Step')
ax.set_ylabel('Loss Metric')

plt.tight_layout()
fig.savefig('figures/fig_degeneracy_loss.pdf',format='pdf', bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(figsize=(12, 8), dpi= 120, nrows = 2, ncols = 1)
bins = np.logspace(3.,5.,150)

all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
names = list(temp_sample_dict.keys())

temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

ordered_names = [name for name in all_temp_names if name in names]
for k in range(len(ordered_names)):
    name = ordered_names[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    # if name == 'gp':
    #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
    #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
    # else:
    temp_sum = temp_sample_dict_1[name].sum(axis = 1) # sum over spatial bins
    axes[0].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)
    
names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
ordered_names_sim = [name for name in all_temp_names if name in names_sim]
      
for k in range(len(ordered_names_sim)):
    name = ordered_names_sim[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    if ordered_names_sim[k] == 'gp':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    elif ordered_names_sim[k] == 'nfw':
        continue
    elif ordered_names_sim[k] == 'blg':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    else:
        temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)  

handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

axes[0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
axes[0].set_xscale('log')
axes[0].set_xlabel(r'$\mathrm{Counts}$')
axes[0].set_ylabel(r'$\mathrm{Density}$')
axes[0].text(0.9, 0.9, 'Fit 1', horizontalalignment = 'center', verticalalignment = 'center', transform = axes[0].transAxes, fontsize = 24)

for k in range(len(ordered_names)):
    name = ordered_names[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    # if name == 'gp':
    #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
    #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
    # else:
    temp_sum = temp_sample_dict_2[name].sum(axis = 1) # sum over spatial bins
    axes[1].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)
    
names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
ordered_names_sim = [name for name in all_temp_names if name in names_sim]
      
for k in range(len(ordered_names_sim)):
    name = ordered_names_sim[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    if ordered_names_sim[k] == 'gp':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    elif ordered_names_sim[k] == 'nfw':
        continue
    elif ordered_names_sim[k] == 'blg':
        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        axes[1].axvline(temp_sum_sim, linestyle='--', c = ccode)
    else:
        temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
        axes[1].axvline(temp_sum_sim, linestyle='--', c = ccode)  

handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

#axes[1].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
axes[1].set_xscale('log')
axes[1].set_xlabel(r'$\mathrm{Counts}$')
axes[1].set_ylabel(r'$\mathrm{Density}$')
axes[1].text(0.9, 0.9, 'Fit 2', horizontalalignment = 'center', verticalalignment = 'center', transform = axes[1].transAxes, fontsize = 24)

#plt.tight_layout()
plt.savefig('figures/fig_degeneracy_hist.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

# Figure A2

In [None]:
# load data
outer_radius_list, q_dict = pickle.load(open(ed_data_location+'data/appdata_roi_outer_radius_scan.p', 'rb'))

In [None]:
fig, ax = plt.subplots(figsize=(12,5), dpi= 120)
#plt.axes(ax)
for name in list(q_dict.keys()):
    idx = all_temp_names.index(name)
    ccode = ccodes[idx] # colors chosen to match colored hist plots
    ax.plot(outer_radius_list, q_dict[name][1], label=r'\texttt{'+name+'}', color = ccode)
    ax.fill_between(outer_radius_list, q_dict[name][0], q_dict[name][2], alpha=0.3, color = ccode)
ax.legend(frameon=False)
ax.set_xlabel('Outer Radius Boundary $(^\circ)$')
ax.set_ylabel(r'$(\lambda - \lambda_{\rm true}) / \lambda_{\rm true}$')
ax.axvline(x = 40, color = 'black', linestyle = '--')
ax.axhline(y = 0, color = 'black', linestyle = '--')
ax.set_xlim(30, 70)

plt.tight_layout()
plt.savefig('figures/fig_outer_roi_scan.pdf',format='pdf', bbox_inches='tight')

# Figure A3

In [None]:
# load data
inner_radius_list_1, q_dict_1 = pickle.load(open(ed_data_location+'../figures/data/appdata_roi_inner_radius_scan_upto40.p', 'rb'))
inner_radius_list_2, q_dict_2 = pickle.load(open(ed_data_location+'../figures/data/appdata_roi_inner_radius_scan_upto70.p', 'rb'))

In [None]:
fig, axes = plt.subplots(figsize=(12, 8), dpi= 120, nrows = 2, ncols = 1)

inner_radius_lists = [inner_radius_list_1, inner_radius_list_2]
q_dicts = [q_dict_1, q_dict_2]
inner_radius_bdry = [40, 70]

for i in range(2):
    inner_radius_list = inner_radius_lists[i]
    q_dict = q_dicts[i]
    for name in list(q_dict.keys()):
        idx = all_temp_names.index(name)
        ccode = ccodes[idx] # colors chosen to match colored hist plots
        axes[i].plot(inner_radius_list, q_dict[name][1], label=r'\texttt{'+name+'}', color = ccode)
        axes[i].fill_between(inner_radius_list, q_dict[name][0], q_dict[name][2], alpha=0.3, color = ccode)

    axes[i].set_xlabel('Outer Radius of Buffer $(^\circ)$')
    if i==1:
        axes[i].legend(frameon=False,loc='upper center')
    axes[i].set_ylabel(r'$(\lambda - \lambda_{\rm true}) / \lambda_{\rm true}$')
    axes[i].axvline(x = 30, color = 'black', linestyle = '--')
    axes[i].axhline(y = 0, color = 'black', linestyle = '--')
    axes[i].set_xlim(20.25, inner_radius_bdry[i] - 10)
    

axes[0].text(0.7, 0.9, r'Inner Radius of Buffer $20^\circ$ \\ \\ Outer Radius of Outer ROI $40^\circ$', horizontalalignment = 'left', verticalalignment = 'center', transform = axes[0].transAxes, fontsize = 16)
axes[1].text(0.7, 0.9, r'Inner Radius of Buffer $20^\circ$ \\ \\ Outer Radius of Outer ROI $70^\circ$', horizontalalignment = 'left', verticalalignment = 'center', transform = axes[1].transAxes, fontsize = 16)

fig.tight_layout()
plt.savefig('figures/fig_buffer_scan.pdf',format='pdf', bbox_inches='tight')

# Figure A4

In [None]:
# load data
bubbles_1, bubbles_2, bubbles_3 = pickle.load(open(ed_data_location+'../figures/data/app_bubbles.p', 'rb'))

In [None]:
from healpy.newvisufunc import projview

In [None]:
def bub_map(m, title, subplot = 111, display_x_info = True, display_y_info = True):
    # set custom tick labels
    pre_xtick_labels = ['dummy', '$40$', '$20$', 0., '$-20$', '$-40$']
    xtick_labels = [str(i) + '$^\circ$' for i in pre_xtick_labels]
    pre_ytick_labels = ['dummy', '$-40$', '$-20$', 0., '$20$', '$40$']
    ytick_labels = [str(i) + '$^\circ$' for i in pre_ytick_labels]

    # generate map
    projview(
        m,
        coord=["G"], 
        flip = "astro", 
        projection_type="cart", 
        title = title,
        xlabel = '$\ell$' if display_x_info else None, 
        ylabel = '$b$' if display_y_info else None, 
        xsize = 2000,
        latitude_grid_spacing = 20,
        longitude_grid_spacing = 20, 
        custom_xtick_labels=xtick_labels if display_x_info else ['' for i in pre_xtick_labels],
        custom_ytick_labels=ytick_labels if display_y_info else ['' for i in pre_ytick_labels],
        graticule = True, 
        graticule_labels = True, 
        unit='$\\lambda$',
        cb_orientation = 'horizontal', 
        override_plot_properties = {'cbar_pad': 0.1},
        hold = True,
        sub = subplot,
        cbar = False,
        fontsize={'xlabel': 18, 'ylabel': 18},
        )
    
    
    plt.xlim(np.deg2rad(-40),np.deg2rad(40))
    plt.ylim(np.deg2rad(-40),np.deg2rad(40))

    plt.grid(False)

In [None]:
fig, ax_array = plt.subplots(figsize = (12,6), nrows = 1, ncols = 3)
ax1 = ax_array[0] ; ax2 = ax_array[1] ; ax3 = ax_array[2]

plt.axes(ax1)
m = bubbles_2
bub_map(m, title = 'A17', subplot = 131)
plt.xlim(np.deg2rad(-40),np.deg2rad(40))
plt.ylim(np.deg2rad(-40),np.deg2rad(40))

plt.axes(ax2)
m = bubbles_1
bub_map(m, title = 'M19', subplot = 132, display_y_info = False)
plt.xlim(np.deg2rad(-40),np.deg2rad(40))
plt.ylim(np.deg2rad(-40),np.deg2rad(40))

# add the boundary of template
plt.plot([-np.deg2rad(20), np.deg2rad(20)], [-np.deg2rad(20), -np.deg2rad(20)] , 'w--', lw = 0.5)
plt.plot([-np.deg2rad(20), np.deg2rad(20)], [np.deg2rad(20), np.deg2rad(20)] , 'w--', lw = 0.5)
plt.plot([-np.deg2rad(20), -np.deg2rad(20)], [-np.deg2rad(20), np.deg2rad(20)] , 'w--', lw = 0.5)
plt.plot([np.deg2rad(20), np.deg2rad(20)], [-np.deg2rad(20), np.deg2rad(20)] , 'w--', lw = 0.5)

plt.axes(ax3)
m = bubbles_3
bub_map(m, title = 'Concatenation', subplot = 133, display_y_info = False)
plt.xlim(np.deg2rad(-40),np.deg2rad(40))
plt.ylim(np.deg2rad(-40),np.deg2rad(40))

# add the concatenation boundary
plt.plot([-np.deg2rad(18), np.deg2rad(18)], [-np.deg2rad(18), -np.deg2rad(18)] , 'r--', lw = 0.5)
plt.plot([-np.deg2rad(18), np.deg2rad(18)], [np.deg2rad(18), np.deg2rad(18)] , 'r--', lw = 0.5)
plt.plot([-np.deg2rad(18), -np.deg2rad(18)], [-np.deg2rad(18), np.deg2rad(18)] , 'r--', lw = 0.5)
plt.plot([np.deg2rad(18), np.deg2rad(18)], [-np.deg2rad(18), np.deg2rad(18)] , 'r--', lw = 0.5)

plt.tight_layout()
plt.savefig('figures/fig_bubbles.pdf', format='pdf', bbox_inches='tight',dpi=300)

# Figure A5

In [None]:
# load data
xu_f, xu_f_r, u_sample, f_sample = pickle.load(open(ed_data_location+'data/appdata_inducing.p', 'rb'))

In [None]:
xsize = 400
ysize = int(xsize/2)
theta = np.linspace(np.pi*(0.5+20/180), np.pi*(0.5-20/180), ysize)
phi   = np.linspace(-np.pi/180*20, np.pi/180*20, xsize)
longitude = np.radians(np.linspace(-20, 20, xsize))
longitude = longitude[::-1]
latitude = np.radians(np.linspace(-20, 20, ysize))
# project the map to a rectangular matrix xsize x ysize
PHI, THETA = np.meshgrid(phi, theta)

In [None]:
fig, axes = plt.subplots(figsize=(8/0.95, 4), dpi= 120, nrows = 1, ncols = 2) 


circle1 = plt.Circle((0, 0), 20, color='w',fill=True)
        
axes[0].set_facecolor('grey')
axes[0].add_patch(circle1)
axes[0].scatter(180/np.pi*xu_f_r[:,0], 180/np.pi*xu_f_r[:,1], c = u_sample, edgecolors = 'k', s = 10, lw = 0.25, cmap = 'viridis')
axes[0].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2.)
axes[0].set_xlabel('$\ell~(^\circ)$')
axes[0].set_ylabel('$b~(^\circ)$')
axes[0].set_xticks([-20,-10,0,10,20])

axes[0].set_ylim(-20,20)
axes[0].set_xlim(-20,20)

axes[0].xaxis.set_inverted(True)

dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[0].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)

grid_pix = hp.ang2pix(nside, THETA, PHI)
temp_map_0= np.zeros(hp.nside2npix(nside))
temp_map_0[~mask] = f_sample
grid_0 = temp_map_0[grid_pix]

im = axes[1].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,grid_0,cmap='viridis')

#axes[1].set_facecolor('grey')
axes[1].scatter(180/np.pi*xu_f_r[:,0], 180/np.pi*xu_f_r[:,1], facecolors = 'none', edgecolors = 'k', s = 10, lw = 0.25, cmap = 'viridis')
axes[1].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
axes[1].set_xlabel('$\ell~(^\circ)$')
axes[1].set_ylabel('$b~(^\circ)$')
axes[1].set_xticks([-20,-10,0,10,20])

# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

axes[1].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)
axes[1].set_ylim(-20,20)
axes[1].set_xlim(-20,20)

axes[1].xaxis.set_inverted(True)


dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[1].xaxis.get_majorticklabels()[4:5]:
    label.set_transform(label.get_transform() + offset)
    
plt.tight_layout()
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.91, 0.20, 0.015, 0.72])
fig.colorbar(im, cax=cbar_ax,label=r'$\log_{10}({\bf f})$')

fig.savefig('figures/fig_inducing.png', format='png', bbox_inches='tight',dpi=300)

# Figure B1

In [None]:
# load data
samples, _ = pickle.load(open(ed_data_location+'data/appdata_syndata_samples.p', 'rb'))
temp_dict = np.load(ed_data_location+'data/appdata_syndata_temp_dict.npy', allow_pickle = True).item()

In [None]:
names = list(samples.keys())
print(names)

In [None]:
names_dressed = [r'$S_{\texttt{blg}}$',
                r'$S_{\texttt{bub}}$',
                r'$S_{\texttt{ics}}$',
                r'$S_{\texttt{iso}}$',
                r'$S_{\texttt{nfw}}$',
                r'$S_{\texttt{pib}}$',
                r'$S_{\texttt{psc}}$']

In [None]:
template_sample_array = np.zeros((len(names), len(samples[names[0]])))
for i in range(len(names)):
    name = names[i]
    template_sample_array[i] = samples[name]

In [None]:
fig = corner.corner(template_sample_array.T, labels=names_dressed, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})

N_var = len(names_dressed)
axes = np.array(fig.axes).reshape((N_var, N_var))
for i in range(len(names)):
    ax = axes[i,i]
    name = names[i]
#    ax.axvline(np.mean(corner_samples[name]), color='red', linestyle='--')
    
N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline( temp_dict[name], color='red', linestyle='--')
#        continue

for yi in range(len(names)):
    for xi in range(yi):
        ax = axes[yi,xi]
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]

#        print(name,value_x)
        ax.axvline(value_x, color='red', linestyle='--')
        ax.axhline(value_y, color='red', linestyle='--')
        ax.plot(value_x, value_y, "sr")

#plt.tight_layout()
plt.savefig('figures/fig_syndata.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

# Figure B2

In [None]:
# load data
temp_dict, temp_samples = pickle.load(open(ed_data_location+'data/app_temp_extract.p', 'rb'))

In [None]:
names = list(temp_samples.keys())
print(names)

In [None]:
names_dressed = [r'$S_{\texttt{blg}}$',
                r'$S_{\texttt{nfw}}$',
                 r'$\gamma$']

In [None]:
template_sample_array = np.zeros((len(names), len(temp_samples[names[0]])))
for i in range(len(names)):
    name = names[i]
    template_sample_array[i] = temp_samples[name]

In [None]:
fig = corner.corner(template_sample_array.T, labels=names_dressed, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})

N_var = len(names_dressed)
axes = np.array(fig.axes).reshape((N_var, N_var))
for i in range(len(names)):
    ax = axes[i,i]
    name = names[i]
#    ax.axvline(np.mean(corner_samples[name]), color='red', linestyle='--')
    
N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline( temp_dict[name], color='red', linestyle='--')
#        continue

for yi in range(len(names)):
    for xi in range(yi):
        ax = axes[yi,xi]
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]

#        print(name,value_x)
        ax.axvline(value_x, color='red', linestyle='--')
        ax.axhline(value_y, color='red', linestyle='--')
        ax.plot(value_x, value_y, "sr")

#plt.tight_layout()
plt.savefig('figures/fig_temp_extract.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

# Figure B3

In [None]:
# load data
temp_dict, corner_samples, S_roc_runs, gp_roc_runs = pickle.load(open(ed_data_location+'data/app_poiss_summary.p', 'rb'))
n_actual_run = 100

In [None]:
names = list(corner_samples.keys())
print(names)

In [None]:
names_dressed = [r'$S_{\texttt{bub}}$',
                r'$S_{\texttt{ics}}$',
                r'$S_{\texttt{iso}}$',
                r'$S_{\texttt{pib}}$',
                r'$S_{\texttt{psc}}$',
                r'$S_{\texttt{gp}}$']

In [None]:
template_sample_array = np.zeros((len(names), len(corner_samples[names[0]])))
for i in range(len(names)):
    name = names[i]
    template_sample_array[i] = corner_samples[name]

In [None]:
fig = corner.corner(template_sample_array.T, labels=names_dressed, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})

N_var = len(names_dressed)
axes = np.array(fig.axes).reshape((N_var, N_var))
for i in range(len(names)):
    ax = axes[i,i]
    name = names[i]
#    ax.axvline(np.mean(corner_samples[name]), color='red', linestyle='--')
    
N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline( temp_dict[name], color='red', linestyle='--')
#        continue

for yi in range(len(names)):
    for xi in range(yi):
        ax = axes[yi,xi]
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]

#        print(name,value_x)
        ax.axvline(value_x, color='red', linestyle='--')
        ax.axhline(value_y, color='red', linestyle='--')
        ax.plot(value_x, value_y, "sr")

#plt.tight_layout()
plt.savefig('figures/fig_poiss_summary_corner.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), dpi= 120, nrows = 1, ncols = 2) 

view_keys = list(corner_samples.keys())
linestyles = ['-', '--', ':', '-.']

axes[0].fill_between([0,1], [0,1], color='lightgray')
for i, (k, roc) in enumerate(S_roc_runs.items()):
    if k in names:
        label = names_dressed[i] #ef.gen_labels([k])[0]
        axes[0].plot(roc, np.linspace(0, 1, n_actual_run), label=label, color=f'C{i%10}', linestyle=linestyles[i//10])

axes[0].set(aspect=1)
axes[0].set_xlabel('Coverage of HDI Needed To Include Truth',fontsize=13)
axes[0].set(ylabel='Fraction of Posteriors')
axes[0].set( xlim = (0,1), ylim = (0,1))
axes[0].text(0.95, 0.05, 'overconfident', ha='right', va='center')
axes[0].text(0.05, 0.925, 'underconfident', ha='left', va='center')
axes[0].legend(frameon=False, loc='center left',
           ncol=1, fontsize = 10)

q_roc = np.percentile(gp_roc_runs, [2.5,16,50,84,97.5], axis = 0)
axes[1].fill_between([0,1], [0,1], color='lightgray')
axes[1].fill_betweenx(np.linspace(0, 1, n_actual_run), q_roc[1], q_roc[3], color='b', alpha = 0.2)
axes[1].fill_betweenx(np.linspace(0, 1, n_actual_run), q_roc[0], q_roc[4], color='b', alpha = 0.2)
axes[1].plot(q_roc[2], np.linspace(0, 1, n_actual_run), color='b', ls = '-', alpha = 1,)
axes[1].plot(np.linspace(0, 1, n_actual_run), np.linspace(0, 1, n_actual_run), color='k', ls = '--', alpha = 1)

axes[1].set(aspect=1)
axes[1].set_xlabel('Coverage of HDI Needed To Include Truth',fontsize=13)
axes[1].set_ylabel('Fraction of Posteriors')
axes[1].set(xlim = (0,1), ylim = (0,1))
axes[1].text(0.95, 0.05, 'overconfident', ha='right', va='center')
axes[1].text(0.05, 0.925, 'underconfident', ha='left', va='center')

axes[0].set_xticklabels(['$0$','$0.25$','$0.5$','$0.75$','$1$'])
axes[1].set_xticklabels(['$0$','$0.25$','$0.5$','$0.75$','$1$'])

plt.tight_layout()
fig.savefig('figures/fig_poiss_summary_data.pdf', format='pdf',bbox_inches='tight')

# Figure B4

In [None]:
# load data
p_tr, temp_samples_tot_list_r, keys, roc_runs = pickle.load(open(ed_data_location+'data/app_poiss_summary_gp.p', 'rb'))

In [None]:
corner_samples = {k : temp_samples_tot_list_r[:,i] for i,k in enumerate(keys)}
temp_dict = {k : p_tr[i] for i,k in enumerate(keys)}

In [None]:
names = list(temp_dict.keys())
print(names)

In [None]:
names_dressed = [r'$S_{\texttt{blg}}$',
                r'$S_{\texttt{nfw}}$',
                r'$\gamma$']

In [None]:
template_sample_array = np.zeros((len(names), len(corner_samples[names[0]])))
for i in range(len(names)):
    name = names[i]
    template_sample_array[i] = corner_samples[name]

In [None]:
fig = corner.corner(template_sample_array.T, labels=names_dressed, quantiles=[0.16, 0.5, 0.84], show_titles=True, title_kwargs={"fontsize": 12})

N_var = len(names_dressed)
axes = np.array(fig.axes).reshape((N_var, N_var))
for i in range(len(names)):
    ax = axes[i,i]
    name = names[i]
#    ax.axvline(np.mean(corner_samples[name]), color='red', linestyle='--')
    
N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline( temp_dict[name], color='red', linestyle='--')
#        continue

for yi in range(len(names)):
    for xi in range(yi):
        ax = axes[yi,xi]
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]

#        print(name,value_x)
        ax.axvline(value_x, color='red', linestyle='--')
        ax.axhline(value_y, color='red', linestyle='--')
        ax.plot(value_x, value_y, "sr")

#plt.tight_layout()
plt.savefig('figures/fig_poiss_summary_gp_corner.pdf',format='pdf',bbox_inches='tight')#, dpi=300)

In [None]:
fig, axes = plt.subplots(figsize=(4, 4), dpi= 120) 


view_keys = keys
linestyles = ['-', '--', ':', '-.']

axes.fill_between([0,1], [0,1], color='lightgray')
for i, (k, roc) in enumerate(roc_runs.items()):
    if k in view_keys:
        label = names_dressed[i]
        axes.plot(roc, np.linspace(0, 1, n_actual_run), label=label, color=f'C{i%10}', linestyle=linestyles[i//10])

axes.set(aspect=1)
axes.set_xlabel('Coverage of HDI Needed To Include Truth', fontsize=13)
axes.set_ylabel('Fraction of Posteriors',fontsize=13)
axes.set(xlim = (0,1), ylim = (0,1))
axes.text(0.95, 0.05, 'overconfident', ha='right', va='center')
axes.text(0.05, 0.95, 'underconfident', ha='left', va='center')
axes.legend(bbox_to_anchor=(0, 0.75), loc='upper left', bbox_transform=axes.transAxes, frameon=False,)
axes.set_xticklabels(['$0$','$0.25$','$0.5$','$0.75$','$1$'])

plt.tight_layout()
fig.savefig('figures/fig_poiss_summary_gp_data.pdf', format='pdf',bbox_inches='tight')

# Figure B5

In [None]:
# load data
mask_map_cart, gp_true, exp_gp_samples_cart_list = pickle.load(open(ed_data_location+'data/app_kernels.p', 'rb'))

In [None]:
fig, axes = plt.subplots(figsize=(16, 4), dpi= 120, nrows = 1, ncols = 4)

slice_val = 3.2

sim_cart = ef.healpix_to_cart(gp_true, mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
raw_cart = None

titles = ['Matern32', 'Matern52', 'ExpSquared', 'RationalQuadratic']
q_list = [np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) for exp_gp_samples_cart in exp_gp_samples_cart_list]

for i, q in enumerate(q_list):
    plt.axes(axes[i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{f}$', q_color = 'C' + str(i), line_color = 'black')

    axes[i].set_title(titles[i])
    axes[i].set_ylim(-0.25, 27.5)
    
#    axes[i].xaxis.set_inverted(True)

    axes[i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)

plt.tight_layout()
fig.savefig('figures/fig_kernels.pdf', format='pdf',bbox_inches='tight')

# Figure B6

In [None]:
# load data
Nu_list, q_list, ordered_names = pickle.load(open(ed_data_location+'data/app_nu_scan.p', 'rb'))

In [None]:
fig, ax = plt.subplots(figsize=(12,5), dpi= 120)
#plt.axes(ax)
Ns = len(Nu_list)
for k in range(len(ordered_names)):
    name = ordered_names[k]
    temp_name = r'$\texttt{'+name+'}$'
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    low_list = [q_list[n][name][0] for n in range(Ns)]
    mean_list = [q_list[n][name][1] for n in range(Ns)]
    high_list = [q_list[n][name][2] for n in range(Ns)]
    ax.plot(Nu_list, mean_list, label = temp_name, color = ccode)
    ax.fill_between(Nu_list, low_list, high_list, alpha = 0.2, color = ccode)
ax.legend(frameon=False,ncol=2,loc='upper left')
ax.set_xlabel('$N_u$')
ax.set_ylabel(r'$(\lambda - \lambda_{\rm true}) / \lambda_{\rm true}$')
ax.axhline(0, linestyle = '--', color = 'k')
ax.set_xlim(min(Nu_list), max(Nu_list))
ax.set_ylim(-0.2,0.25)

plt.tight_layout()
plt.savefig('figures/fig_nu_scan.pdf',format='pdf', bbox_inches='tight')

# Figure B7

In [None]:
# load data
syn_gamma_list, q_gamma_list, ordered_names = pickle.load(open(ed_data_location+'data/app_gamma_scan.p', 'rb'))
syn_blg_id_list, q_blg_list, ordered_names = pickle.load(open(ed_data_location+'data/app_blg_scan.p', 'rb'))

In [None]:
fig, axes = plt.subplots(figsize=(12,8), dpi= 120, nrows=2,ncols=1)
#plt.axes(ax)
Ns = len(syn_gamma_list)
for k in range(len(ordered_names)):
    name = ordered_names[k]
    temp_name = r'$\texttt{'+name+'}$'
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    low_list = [q_gamma_list[n][name][0] for n in range(Ns)]
    mean_list = [q_gamma_list[n][name][1] for n in range(Ns)]
    high_list = [q_gamma_list[n][name][2] for n in range(Ns)]
    axes[0].plot(syn_gamma_list, mean_list, label = temp_name, color = ccode)
    axes[0].fill_between(syn_gamma_list, low_list, high_list, alpha = 0.2, color = ccode)
axes[0].legend(frameon=False,ncol=3,loc='upper left')
axes[0].set_xlabel('$\gamma$')
axes[0].set_ylabel(r'$(\lambda - \lambda_{\rm true}) / \lambda_{\rm true}$')
axes[0].axhline(0, linestyle = '--', color = 'k')
axes[0].set_xlim(min(syn_gamma_list), max(syn_gamma_list))
axes[0].set_ylim(-0.22,0.25)

Ns = len(syn_blg_id_list)
for k in range(len(ordered_names)):
    name = ordered_names[k]
    temp_name = r'$\texttt{'+name+'}$'
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    low_list = [q_blg_list[n][name][0] for n in range(Ns)]
    mean_list = [q_blg_list[n][name][1] for n in range(Ns)]
    high_list = [q_blg_list[n][name][2] for n in range(Ns)]
    axes[1].plot(syn_blg_id_list, mean_list, label = temp_name, color = ccode)
    axes[1].fill_between(syn_blg_id_list, low_list, high_list, alpha = 0.2, color = ccode)
#axes[1].legend(frameon=False,ncol=3,loc='upper left')
axes[1].set_ylabel(r'$(\lambda - \lambda_{\rm true}) / \lambda_{\rm true}$')
axes[1].axhline(0, linestyle = '--', color = 'k')
axes[1].set_xlim(min(syn_blg_id_list), max(syn_blg_id_list))
#axes[1].set_ylim(-0.2,0.25)
blg_names = [ r'$\texttt{'+ef.gen_blg_name_(syn_blg_id_list[i])[0]+'}$' for i in range(len(syn_blg_id_list))]
axes[1].set_xticks(syn_blg_id_list)
axes[1].set_xticklabels(blg_names, rotation = 0)

plt.tight_layout()
plt.savefig('figures/fig_gce_scan.pdf',format='pdf', bbox_inches='tight')

# Figure B8

In [None]:
# load data
temp_dict, temp_sample_dict, temp_sample_dict_cmask, exp_gp_samples_cart, gp_true, sim_cart, mask_map_cart = pickle.load(open(ed_data_location+'data/app_zero_gp.p', 'rb'))

In [None]:
fig, axes = plt.subplots(figsize=(12, 4), dpi= 120, nrows = 1, ncols = 3)

all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
names = list(temp_sample_dict.keys())

temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

bins = np.logspace(2.,5.,150)

ordered_names = [name for name in all_temp_names if name in names]
for k in range(len(ordered_names)):
    name = ordered_names[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
    # if name == 'gp':
    #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
    #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
    # else:
    temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
    axes[0].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)
    
names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
ordered_names_sim = [name for name in all_temp_names if name in names_sim]
      
for k in range(len(ordered_names_sim)):
    name = ordered_names_sim[k]
    idx = all_temp_names.index(name)
    ccode = ccodes[idx]
#    if ordered_names_sim[k] == 'gp':
#        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
#        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
#    elif ordered_names_sim[k] == 'nfw':
#        continue
#    elif ordered_names_sim[k] == 'blg':
#        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
#        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    if ordered_names_sim[k] != 'blg':
        temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)  

handles, labels = axes[0].get_legend_handles_labels()
new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

axes[0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
axes[0].set_xscale('log')
axes[0].set_xlabel(r'$\mathrm{Counts}$')
axes[0].set_ylabel(r'$\mathrm{Density}$')

# 1d slice of GCE
q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
raw_cart = None

plt.axes(axes[1], aspect = 'equal')
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_dir = 'horizontal', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

axes[1].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])

dx = 5/72.; dy = 0/72. 
offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
for label in axes[1].xaxis.get_majorticklabels()[0:1]:
    label.set_transform(label.get_transform() + offset)

plt.axes(axes[2], aspect = 'equal')
eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
    slice_dir = 'vertical', slice_val = slice_val, 
    mask_map_cart = mask_map_cart,
    n_pixels = n_pixels,
    ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')


plt.tight_layout()
fig.savefig('figures/fig_zero_gp.pdf', format='pdf',bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(figsize=(8/0.9, 4), dpi= 120, nrows = 1, ncols = 2)

# load boundary and fill shapes
_, _, inner_roi_x, inner_roi_y, outer_roi_low_x, outer_roi_low_y, outer_roi_high_x, outer_roi_high_y, outer_roi_lim_x, outer_roi_lim_y = eplt.preprocess_map_shapes(rad = False)

annulus_roi_x = [inner_roi_x, outer_roi_low_x[::-1]]
annulus_roi_y = [inner_roi_y, outer_roi_low_y[::-1]]

for i in range(2):
    
    q = temp_sample_dict_cmask['gp'][i]

    grid_pix = hp.ang2pix(nside, THETA, PHI)
    temp_map_0= np.zeros(hp.nside2npix(nside))
    temp_map_0[~mask] = q
    grid_0 = temp_map_0[grid_pix]

    im = axes[i].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                  vmin=-2.0,vmax=1.)
    axes[i].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
    axes[i].set_xlabel('$\ell~(^\circ)$')
    axes[i].set_ylabel('$b~(^\circ)$')
    axes[i].set_ylim(-20,20)
    axes[i].set_xlim(-20,20)
    axes[i].set_xticks([-20,-10,0,10,20])
    axes[i].xaxis.set_inverted(True)

    axes[i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)


    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[i].xaxis.get_majorticklabels()[4:5]:
        label.set_transform(label.get_transform() + offset)

plt.tight_layout()
fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.91, 0.2, 0.017, 0.72])
fig.colorbar(im, cax=cbar_ax,label=r'$\log_{10}(\lambda)$')
fig.savefig('figures/fig_zero_gp_map.png',format='png',bbox_inches='tight',dpi=300)

# Figure B9

In [None]:
# load data
temp_dict, corner_samples_list = pickle.load(open(ed_data_location+'data/app_mask_res.p', 'rb'))

In [None]:
names_dressed = [r'$S_{\texttt{bub}}$', 
               r'$S_{\texttt{ics}}$', 
               r'$S_{\texttt{iso}}$', 
               r'$S_{\texttt{pib}}$', 
               r'$S_{\texttt{psc}}$', 
               r'$S_{\texttt{gp}}$']

In [None]:
mask_fractions = ['68', '90', '95', '99']

for n in range(len(mask_fractions)):
    m = mask_fractions[n]
    corner_samples = corner_samples_list[n]
    names = list(corner_samples.keys())
    
    template_sample_array = np.zeros((len(names), len(corner_samples['S_gp'])))
    for i in range(len(names)):
        name = names[i]
        template_sample_array[i] = corner_samples[name]
    
    color = 'C' + str(n)
    labels = names_dressed
    if n == 0:
        fig = corner.corner(template_sample_array.T, labels=labels, show_titles=False, density = True, 
                        color = color)
    else:
        corner.corner(template_sample_array.T, labels=labels, show_titles=False, fig=fig, density = True, 
                      color = color)
    corner_samples_list.append(corner_samples)

N_var = len(names)
axes = np.array(fig.axes).reshape((N_var, N_var))

for i in range(len(names)):
    name = names[i]
    ax = axes[i,i]
    if name == 'S_gp':
        ax.axvline(temp_dict['S_nfw'] + temp_dict['S_blg'], color='red', linestyle='--')
    if name not in list(temp_dict.keys()):
        continue
    else:
        ax.axvline(temp_dict[name], color='red', linestyle='--')

for yi in range(len(names)):
    for xi in range(yi):
        name_x = names[xi]
        name_y = names[yi]
        if (name_x == 'S_gp') & (name_y != 'S_gp'):
            value_x = temp_dict['S_nfw'] + temp_dict['S_blg']
            value_y = temp_dict[name_y]
        elif (name_x != 'S_gp') & (name_y == 'S_gp'):
            value_x = temp_dict[name_x]
            value_y = temp_dict['S_nfw'] + temp_dict['S_blg']
        elif name_x not in list(temp_dict.keys()):
            print(name_x)
            continue
        elif name_y not in list(temp_dict.keys()):
            continue
        else:
            value_x = temp_dict[name_x]
            value_y = temp_dict[name_y]
        
        print(name_x,name_y,value_x,value_y)
        ax = axes[yi,xi].axvline(value_x, color='red', linestyle='--')
        ax = axes[yi,xi].axhline(value_y, color='red', linestyle='--')
        ax = axes[yi,xi].plot(value_x, value_y, "sr")
        
lines = [mpl.lines.Line2D([], [], color='C' + str(n), label=mask_fractions[n] + '\%') for n in range(len(mask_fractions))]
plt.legend(handles=lines, bbox_to_anchor=(0., 1.0, 1., .0), loc=4,frameon=False)

fig.savefig('figures/fig_mask_res.png',format='png',bbox_inches='tight',dpi=300)

# Figure B10

In [None]:
# load data (~7 GB ; ~6.5 min)
seq = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/all_mismodelling_data.p', 'rb'))
temp_sample_dict_list, temp_sample_dict_cmask_list, exp_gp_samples_cart_list, gp_true_cart_list, tot_samples_cart_list, model_residuals_cart_list, mask, mask_p, mask_map_cart = seq

# load synthetic data
temp_dict_list, sim_cart_list = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/temp_dict_list.p', 'rb'))

# cartesian map of masks to keep track of masking for plots
mask_map = np.zeros((~mask_p).sum())
mask_map_cart = ef.healpix_to_cart(mask_map, mask_p, n_pixels = n_pixels, nside = 128, nan_fill = True) # doesn't matter what mask used

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        mod_id = 10 * (i + 1) + (j + 1)
        temp_sample_dict = temp_sample_dict_list[mod_id]
        temp_dict = temp_dict_list[mod_id]

        
        all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
        ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
        names = list(temp_sample_dict.keys())

        temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

        bins = np.logspace(2.,5.,150)

        ordered_names = [name for name in all_temp_names if name in names]
        for k in range(len(ordered_names)):
            name = ordered_names[k]
            idx = all_temp_names.index(name)
            ccode = ccodes[idx]
            # if name == 'gp':
            #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
            #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
            # else:
            temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
            axes[i,j].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)

        names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
        ordered_names_sim = [name for name in all_temp_names if name in names_sim]

        
        
        for k in range(len(ordered_names_sim)):
            name = ordered_names_sim[k]
            idx = all_temp_names.index(name)
            ccode = ccodes[idx]
        #    if ordered_names_sim[k] == 'gp':
        #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
        #    elif ordered_names_sim[k] == 'nfw':
        #        continue
        #    elif ordered_names_sim[k] == 'blg':
        #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
        #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
            if ordered_names_sim[k] != 'blg':
                temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
                axes[i,j].axvline(temp_sum_sim, linestyle='--', c = ccode)  

        handles, labels = axes[i,j].get_legend_handles_labels()
        new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]

        
        axes[i,j].set_xscale('log')
        axes[i,j].set_xlabel(r'$\mathrm{Counts}$')
        axes[i,j].set_ylabel(r'$\mathrm{Density}$')
axes[0,0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)

cols = [r'{\bf O}', r'{\bf A}', r'{\bf F}']
rows = [r'{\bf O}', r'{\bf A}', r'{\bf F}']

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
        
#        ef.tot_log_counts_hist(temp_sample_dict, temp_dict, temp_names_sim, mask = mask_p, bins = np.linspace(3.,5.,150), gp_model_nfw=True, gp_model_iso = False, ax = axes[i,j])

fig.savefig('figures/fig_dif_norms_1.pdf',format='pdf', bbox_inches='tight')

In [None]:
slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        mod_id = 10 * (i + 1) + (j + 1)

        exp_gp_samples_cart = exp_gp_samples_cart_list[mod_id]
        gp_true = gp_true_cart_list[mod_id][:5938] # contains 10 copies due to carelessness when creating this

        # 1d slice of GCE
        q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
        sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
        raw_cart = None

        plt.axes(axes[i,j])
        eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
            slice_dir = 'horizontal', slice_val = slice_val, 
            mask_map_cart = mask_map_cart,
            n_pixels = n_pixels,
            ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')
        
        axes[i,j].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
        dx = 5/72.; dy = 0/72. 
        offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
        for label in axes[i,j].xaxis.get_majorticklabels()[0:1]:
            label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
    
fig.savefig('figures/fig_dif_hgp_1.pdf', format='pdf',bbox_inches='tight')

In [None]:
slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        mod_id = 10 * (i + 1) + (j + 1)

        exp_gp_samples_cart = exp_gp_samples_cart_list[mod_id]
        gp_true = gp_true_cart_list[mod_id][:5938] # contains 10 copies due to carelessness when creating this

        # 1d slice of GCE
        q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
        sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
        raw_cart = None

        plt.axes(axes[i,j])
        eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
            slice_dir = 'vertical', slice_val = slice_val, 
            mask_map_cart = mask_map_cart,
            n_pixels = n_pixels,
            ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
    
fig.savefig('figures/fig_dif_vgp_1.pdf', format='pdf',bbox_inches='tight')

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

c = 0
for i in range(nrows):
    for j in range(ncols):
        c += 1

        mod_id = 10 * (i + 1) + (j + 1)
        temp_sample_dict_cmask = temp_sample_dict_cmask_list[mod_id]

        plt.axes(axes[i,j])
        q = np.percentile(temp_sample_dict_cmask['gp'], 50, axis = 0)

        grid_pix = hp.ang2pix(nside, THETA, PHI)
        temp_map_0= np.zeros(hp.nside2npix(nside))
        temp_map_0[~mask] = q
        grid_0 = temp_map_0[grid_pix]

        im = axes[i,j].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                      vmin=None,vmax=None)
        axes[i,j].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
        axes[i,j].set_xlabel('$\ell~(^\circ)$')
        axes[i,j].set_ylabel('$b~(^\circ)$')
        axes[i,j].set_ylim(-20,20)
        axes[i,j].set_xlim(-20,20)
        axes[i,j].set_xticks([-20,-10,0,10,20])
        axes[i,j].xaxis.set_inverted(True)

        axes[i,j].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)


        dx = 5/72.; dy = 0/72. 
        offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
        for label in axes[i,j].xaxis.get_majorticklabels()[4:5]:
            label.set_transform(label.get_transform() + offset)

#plt.tight_layout()

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
    
fig.savefig('figures/fig_dif_gp_1.png', format='png',bbox_inches='tight',dpi=300)

In [None]:
slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        mod_id = 10 * (i + 1) + (j + 1)

        tot_samples_cart = tot_samples_cart_list[mod_id]
        sim_cart = sim_cart_list[mod_id]

        # 1d slice of GCE
        q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)
        raw_cart = None

        plt.axes(axes[i,j])
        eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
            slice_dir = 'horizontal', slice_val = slice_val, 
            ylim = [20., 90.],
            mask_map_cart = mask_map_cart,
            n_pixels = n_pixels,)

        axes[i,j].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    
        offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
        for label in axes[i,j].xaxis.get_majorticklabels()[0:1]:
            label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
    
fig.savefig('figures/fig_dif_htot_1.pdf',format='pdf', bbox_inches='tight')

In [None]:
slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(nrows):
    for j in range(ncols):
        mod_id = 10 * (i + 1) + (j + 1)

        tot_samples_cart = tot_samples_cart_list[mod_id]
        sim_cart = sim_cart_list[mod_id]

        # 1d slice of GCE
        q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)
        raw_cart = None

        plt.axes(axes[i,j])
        eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
            slice_dir = 'vertical', slice_val = slice_val,
            mask_map_cart = mask_map_cart,
            n_pixels = n_pixels,)
        

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)

for r, ax in zip(rows, axes[:, 0]):
    ax2 = ax.twinx()
    # move extra axis to the left, with offset
    ax2.yaxis.set_label_position('left')
    ax2.spines['left'].set_position(('axes', -0.2))
    # hide spine and ticks, set group label
    ax2.spines['left'].set_visible(False)
    ax2.set_yticks([])
    ax2.set_ylabel(r, rotation=0,  fontsize=20,
                   ha='right', va='center')
    
fig.savefig('figures/fig_dif_vtot_1.pdf',format='pdf', bbox_inches='tight')

In [None]:
# load data (~7 GB ; ~6.5 min)
seq = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/all_mismodelling_data_2.p', 'rb'))
temp_sample_dict_list, temp_sample_dict_cmask_list, exp_gp_samples_cart_list, gp_true_cart_list, tot_samples_cart_list, model_residuals_cart_list, mask, mask_p, mask_map_cart = seq

# load real data
corner_samples_dict = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/corner_samples_dict_2.p', 'rb'))
temp_dict_list, sim_cart_list = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/temp_dict_list_2.p', 'rb'))

# correction to temp_sample_dict described in src
temp_sample_dict_list, temp_sample_dict_cmask_list = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/temp_dicts_theta_2.p', 'rb'))

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

mod_ids = list(temp_dict_list.keys())
data_models = ['O', 'A', 'F']

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 9), dpi= 120, nrows = 3, ncols = ncols)

axes_0 = [plt.subplot2grid((4, 3), (0, i), rowspan = 2, colspan = 1) for i in range(ncols)]
axes_1 = [plt.subplot2grid((4, 3), (2, i), rowspan = 1, colspan = 1) for i in range(ncols)]
axes_2 = [plt.subplot2grid((4, 3), (3, i), rowspan = 1, colspan = 1) for i in range(ncols)]
axes = [axes_0, axes_1, axes_2]

for i in range(ncols):
    mod_id = mod_ids[i]
    temp_sample_dict = temp_sample_dict_list[mod_id]
    corner_samples = corner_samples_dict[mod_id]
    temp_dict = temp_dict_list[mod_id]

#    plt.axes(axes[0][i])

#    ef.tot_log_counts_hist(temp_sample_dict, temp_dict, temp_names_sim, mask = mask_p, bins = np.linspace(3.,5.,150), gp_model_nfw=True, gp_model_iso = False, ax = axes[0][i], 
#                        display_y_info = True if i == 0 else False, legend = True if i == 0 else False)

    temp_sample_dict = temp_sample_dict_list[mod_id]
    temp_dict = temp_dict_list[mod_id]


    all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
    ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
    names = list(temp_sample_dict.keys())

    temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

    bins = np.logspace(2.,5.,150)

    ordered_names = [name for name in all_temp_names if name in names]
    for k in range(len(ordered_names)):
        name = ordered_names[k]
        idx = all_temp_names.index(name)
        ccode = ccodes[idx]
        # if name == 'gp':
        #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
        #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
        # else:
        temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
        axes[0][i].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)

    names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
    ordered_names_sim = [name for name in all_temp_names if name in names_sim]



    for k in range(len(ordered_names_sim)):
        name = ordered_names_sim[k]
        idx = all_temp_names.index(name)
        ccode = ccodes[idx]
    #    if ordered_names_sim[k] == 'gp':
    #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
    #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    #    elif ordered_names_sim[k] == 'nfw':
    #        continue
    #    elif ordered_names_sim[k] == 'blg':
    #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
    #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
        if ordered_names_sim[k] != 'blg':
            temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
            axes[0][i].axvline(temp_sum_sim, linestyle='--', c = ccode)  

    handles, labels = axes[0][i].get_legend_handles_labels()
    new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
    
    if i ==0:
        axes[0][0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)


    axes[0][i].set_xscale('log')
    axes[0][i].set_xlabel(r'$\mathrm{Counts}$')
    axes[0][i].set_ylabel(r'$\mathrm{Density}$')
    
    data_mods_idx = [int(str(mod_id)[i]) - 1 for i in range(2)]
    data_mods = [data_models[int(str(mod_id)[i]) - 1] for i in range(2)]
    print(mod_id, data_mods)
    for k in range(len(data_mods)):
        color = 'C{}'.format(data_mods_idx[k] + 6)
#        plt.axes(axes[1][i])
        axes[1][i].hist(corner_samples['theta_ics'][:,k], bins = np.linspace(-0.05,1.05,100), 
                        histtype='step', color = color, label = r'$k$ = {}'.format(data_mods[k]), density = True,lw=2)
        axes[1][i].set_xlabel(r'$\theta_{\texttt{ics}}^{k}$')
        axes[1][i].set_ylabel('Density')
        
        handles, labels = axes[1][i].get_legend_handles_labels()
        new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
    
        axes[1][i].legend(handles=new_handles, labels=labels,frameon=False,loc='upper center')

#        plt.axes(axes[2][i])
        axes[2][i].hist(corner_samples['theta_pib'][:,k], bins = np.linspace(-0.05,1.05,100),
                        histtype='step', color = color, label = '$k$ = {}'.format(data_mods[k]), density = True,lw=2)
        axes[2][i].set_xlabel(r'$\theta_{\texttt{pib}}^{k}$')
        axes[2][i].set_ylabel('Density')
        
        handles, labels = axes[2][i].get_legend_handles_labels()
        new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
    
        axes[2][i].legend(handles=new_handles, labels=labels,frameon=False,loc=2)
        
for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_norms_2.pdf',format='pdf', bbox_inches='tight')

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

c = 6
for i in range(ncols):
    mod_id = mod_ids[i]

    exp_gp_samples_cart = exp_gp_samples_cart_list[mod_id]
    gp_true = gp_true_cart_list[mod_id][:5938] # contains 10 copies due to carelessness when creating this

    # 1d slice of GCE
    q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
    sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
    raw_cart = None

    plt.axes(axes[0,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'horizontal', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{\rm gce}$', q_color = 'darkorange', line_color = 'green')

    axes[0,i].xaxis.set_inverted(True)

    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[0,i].xaxis.get_majorticklabels()[4:5]:
        label.set_transform(label.get_transform() + offset)
    
    # 1d slice of GCE
    q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
    sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
    raw_cart = None

    plt.axes(axes[1,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'vertical', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

    c += 1
    temp_sample_dict_cmask = temp_sample_dict_cmask_list[mod_id]

    q = np.percentile(temp_sample_dict_cmask['gp'], 50, axis = 0)
    
    grid_pix = hp.ang2pix(nside, THETA, PHI)
    temp_map_0= np.zeros(hp.nside2npix(nside))
    temp_map_0[~mask] = q
    grid_0 = temp_map_0[grid_pix]

    im = axes[2,i].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                  vmin=None,vmax=None)
    axes[2,i].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
    axes[2,i].set_xlabel('$\ell~(^\circ)$')
    axes[2,i].set_ylabel('$b~(^\circ)$')
    axes[2,i].set_ylim(-20,20)
    axes[2,i].set_xlim(-20,20)

    axes[2,i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)


    axes[2,i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[2,i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_gp_2.png',format='png', bbox_inches='tight',dpi=300)

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

slice_val = 3.2  # y-value of slice

nrows = 2 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*2), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(ncols):
    mod_id = mod_ids[i]

    tot_samples_cart = tot_samples_cart_list[mod_id]
    sim_cart = sim_cart_list[mod_id]

    # 1d slice of GCE
    q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)
    raw_cart = None

    plt.axes(axes[0,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'horizontal', slice_val = slice_val, 
        ylim = [20., 90.],
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,)

    plt.axes(axes[1,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'vertical', slice_val = slice_val,
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,)


    axes[0,i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[0,i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_tot_2.pdf',format='pdf', bbox_inches='tight')

In [None]:
# load data (~7 GB ; ~6.5 min)
seq = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/all_mismodelling_data_3.p', 'rb'))
temp_sample_dict_list, temp_sample_dict_cmask_list, exp_gp_samples_cart_list, gp_true_cart_list, tot_samples_cart_list, model_residuals_cart_list, mask, mask_p, mask_map_cart = seq

# load real data
corner_samples_dict = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/corner_samples_dict_3.p', 'rb'))
temp_dict_list, sim_cart_list = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/temp_dict_list_3.p', 'rb'))

# correction to temp_sample_dict described in src
temp_sample_dict_list, temp_sample_dict_cmask_list = pickle.load(open(ed_data_location+'../diffuse_model_tests/plotting_data/multi/temp_dicts_theta_3.p', 'rb'))

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

mod_ids = list(temp_dict_list.keys())
data_mods = ['O', 'A', 'F']

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 9), dpi= 120, nrows = 3, ncols = ncols)

axes_0 = [plt.subplot2grid((4, 3), (0, i), rowspan = 2, colspan = 1) for i in range(ncols)]
axes_1 = [plt.subplot2grid((4, 3), (2, i), rowspan = 1, colspan = 1) for i in range(ncols)]
axes_2 = [plt.subplot2grid((4, 3), (3, i), rowspan = 1, colspan = 1) for i in range(ncols)]
axes = [axes_0, axes_1, axes_2]

for i in range(ncols):
    mod_id = mod_ids[i]
    temp_sample_dict = temp_sample_dict_list[mod_id]
    corner_samples = corner_samples_dict[mod_id]
    temp_dict = temp_dict_list[mod_id]

#    plt.axes(axes[0][i])

#    ef.tot_log_counts_hist(temp_sample_dict, temp_dict, temp_names_sim, mask = mask_p, bins = np.linspace(3.,5.,150), gp_model_nfw=True, gp_model_iso = False, ax = axes[0][i], 
#                        display_y_info = True if i == 0 else False, legend = True if i == 0 else False)

    temp_sample_dict = temp_sample_dict_list[mod_id]
    temp_dict = temp_dict_list[mod_id]


    all_temp_names = ['iso', 'psc', 'bub', 'pib', 'ics', 'blg', 'gp', 'nfw', 'dsk']
    ccodes = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C5', 'C6', 'C7']
    names = list(temp_sample_dict.keys())

    temp_sim_dict = {name: temp_dict[name][~mask_p] for name in temp_dict.keys() if name in all_temp_names}

    bins = np.logspace(2.,5.,150)

    ordered_names = [name for name in all_temp_names if name in names]
    for k in range(len(ordered_names)):
        name = ordered_names[k]
        idx = all_temp_names.index(name)
        ccode = ccodes[idx]
        # if name == 'gp':
        #     temp_sum = jnp.exp(temp_sample_dict[name]).sum(axis = 1) # sum over spatial bins
        #     ax.hist(np.log10(temp_sum), bins = bins, alpha = 0.75, label = name, density = True, histtype = 'step', color = ccode)
        # else:
        temp_sum = temp_sample_dict[name].sum(axis = 1) # sum over spatial bins
        axes[0][i].hist(temp_sum, bins = bins, alpha = 1, label = r'\texttt{'+name+'}', density = False, histtype = 'step', color = ccode)

    names_sim = temp_names_sim # this piece is provided by the "settings" file since we only save a dictionary with all the fit parameters
    ordered_names_sim = [name for name in all_temp_names if name in names_sim]



    for k in range(len(ordered_names_sim)):
        name = ordered_names_sim[k]
        idx = all_temp_names.index(name)
        ccode = ccodes[idx]
    #    if ordered_names_sim[k] == 'gp':
    #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
    #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
    #    elif ordered_names_sim[k] == 'nfw':
    #        continue
    #    elif ordered_names_sim[k] == 'blg':
    #        temp_sum_sim = temp_sim_dict['blg'].sum(axis = 0) + temp_sim_dict['nfw'].sum(axis = 0)
    #        axes[0].axvline(temp_sum_sim, linestyle='--', c = ccode)
        if ordered_names_sim[k] != 'blg':
            temp_sum_sim = temp_sim_dict[name].sum(axis = 0)
            axes[0][i].axvline(temp_sum_sim, linestyle='--', c = ccode)  

    handles, labels = axes[0][i].get_legend_handles_labels()
    new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
    
    if i ==0:
        axes[0][0].legend(handles=new_handles, labels=labels,frameon=False,loc=2)


    axes[0][i].set_xscale('log')
    axes[0][i].set_xlabel(r'$\mathrm{Counts}$')
    axes[0][i].set_ylabel(r'$\mathrm{Density}$')
   
    for k in range(len(data_mods)):
        color = 'C{}'.format(k + 6)
#        plt.axes(axes[1][i])
        axes[1][i].hist(corner_samples['theta_ics'][:,k], bins = np.linspace(-0.05,1.05,100), 
                        histtype='step', color = color, label = r'$k$ = {}'.format(data_mods[k]), density = True,lw=2)
        axes[1][i].set_xlabel(r'$\theta_{\texttt{ics}}^{k}$')
        axes[1][i].set_ylabel('Density')
        
        handles, labels = axes[1][i].get_legend_handles_labels()
        new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
        if i==0:
            axes[1][i].legend(handles=new_handles, labels=labels,frameon=False,loc='upper right')
        else:
            axes[1][i].legend(handles=new_handles, labels=labels,frameon=False,loc='upper center')

#        plt.axes(axes[2][i])
        axes[2][i].hist(corner_samples['theta_pib'][:,k], bins = np.linspace(-0.05,1.05,100),
                        histtype='step', color = color, label = '$k$ = {}'.format(data_mods[k]), density = True,lw=2)
        axes[2][i].set_xlabel(r'$\theta_{\texttt{pib}}^{k}$')
        axes[2][i].set_ylabel('Density')
        
        handles, labels = axes[2][i].get_legend_handles_labels()
        new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
    
        axes[2][i].legend(handles=new_handles, labels=labels,frameon=False,loc='upper center')
        
for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_norms_3.pdf',format='pdf', bbox_inches='tight')

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

slice_val = 3.2  # y-value of slice

nrows = 3 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*3), dpi= 120, nrows = nrows, ncols = ncols)

c = 6
for i in range(ncols):
    mod_id = mod_ids[i]

    exp_gp_samples_cart = exp_gp_samples_cart_list[mod_id]
    gp_true = gp_true_cart_list[mod_id][:5938] # contains 10 copies due to carelessness when creating this

    # 1d slice of GCE
    q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
    sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
    raw_cart = None

    plt.axes(axes[0,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'horizontal', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

    axes[0,i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])

    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[0,i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)
    
    # 1d slice of GCE
    q = np.percentile(exp_gp_samples_cart, [2.5,16,50,84,97.5], axis = 0) # cartesian sample map quantiles
    sim_cart = ef.healpix_to_cart(gp_true[:5938], mask, n_pixels = n_pixels, nside = 128) # simulated rate cartesian map
    raw_cart = None

    plt.axes(axes[1,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'vertical', slice_val = slice_val, 
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,
        ylabel = r'$\lambda_{f}$', q_color = 'darkorange', line_color = 'green')

    c += 1
    temp_sample_dict_cmask = temp_sample_dict_cmask_list[mod_id]

    q = np.percentile(temp_sample_dict_cmask['gp'], 50, axis = 0)
    
    grid_pix = hp.ang2pix(nside, THETA, PHI)
    temp_map_0= np.zeros(hp.nside2npix(nside))
    temp_map_0[~mask] = q
    grid_0 = temp_map_0[grid_pix]

    im = axes[2,i].pcolormesh(-longitude*180/np.pi, latitude*180/np.pi,np.log10(grid_0),cmap='viridis',
                  vmin=None,vmax=None)
    axes[2,i].plot(20*l_list,20*b_list, color="k", ls = "-", lw = 2., zorder = 11)
    axes[2,i].set_xlabel('$\ell~(^\circ)$')
    axes[2,i].set_ylabel('$b~(^\circ)$')
    axes[2,i].set_ylim(-20,20)
    axes[2,i].set_xlim(-20,20)

    axes[2,i].fill(np.ravel(annulus_roi_x), np.ravel(annulus_roi_y), color = 'gray', zorder = 10)


    axes[2,i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[2,i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_gp_3.png',format='png', bbox_inches='tight',dpi=300)

In [None]:
# plot total counts histogram for each fit
# row: different diffuse model ; column: different synthetic dataset

slice_val = 3.2  # y-value of slice

nrows = 2 ; ncols = 3
fig, axes = plt.subplots(figsize=(6*3, 6*2), dpi= 120, nrows = nrows, ncols = ncols)

for i in range(ncols):
    mod_id = mod_ids[i]

    tot_samples_cart = tot_samples_cart_list[mod_id]
    sim_cart = sim_cart_list[mod_id]

    # 1d slice of GCE
    q = np.percentile(tot_samples_cart, [2.5,16,50,84,97.5], axis = 0)
    raw_cart = None

    plt.axes(axes[0,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'horizontal', slice_val = slice_val, 
        ylim = [20., 90.],
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,)

    plt.axes(axes[1,i])
    eplt.cart_plot_1d(q, sim_cart = sim_cart, raw_cart = raw_cart, 
        slice_dir = 'vertical', slice_val = slice_val,
        mask_map_cart = mask_map_cart,
        n_pixels = n_pixels,)


    axes[0,i].set_xticklabels(['$20$','$10$','$0$','$-10$','$-20$'])
    dx = 5/72.; dy = 0/72. 
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    for label in axes[0,i].xaxis.get_majorticklabels()[0:1]:
        label.set_transform(label.get_transform() + offset)

for c, ax in zip(cols, axes[0]):
    ax.set_title(c, fontsize=20)
plt.tight_layout()
fig.savefig('figures/fig_dif_tot_3.pdf',format='pdf', bbox_inches='tight')