In [1]:
import numpy as np 
import pandas as pd
from numpyro.diagnostics import summary
from utils.helpers import pickle_load
import matplotlib.pylab as plt 

plt.rc('text', usetex=True)
plt.rc('font',**{'family':'sans-serif','serif':['Palatino']})
figSize  = (12, 8)
fontSize = 15

In [2]:
ANALYSIS = 'lsst'

if ANALYSIS != 'lsst':
    KEYS = ['sigma8', 'Omegac', 'Omegab', 'hubble', 'ns',
            'm1', 'm2', 'm3', 'm4',
            'dz_wl_1', 'dz_wl_2', 'dz_wl_3', 'dz_wl_4',
            'a_ia', 'eta',
            'b1', 'b2', 'b3', 'b4', 'b5', 
            'dz_gc_1', 'dz_gc_2', 'dz_gc_3', 'dz_gc_4', 'dz_gc_5']
else:
    KEYS = ['sigma8', 'omegac', 'omegab', 'hubble', 'ns', "m1", "m2", "m3", "m4", "m5",
    "dz_wl_1", "dz_wl_2", "dz_wl_3", "dz_wl_4", "dz_wl_5",
    "a_ia", "eta", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8", "b9", "b10",
    "dz_gc_1", "dz_gc_2", "dz_gc_3", "dz_gc_4", "dz_gc_5",
    "dz_gc_6", "dz_gc_7", "dz_gc_8", "dz_gc_9", "dz_gc_10"]

# weight    
# minuslogpost          
# sigma8          
# omegac          
# omegab          
# hubble              
# ns              
# m1     
# m2              
# m3              
# m4              
# m5         
# dz_wl_1         
# dz_wl_2         
# dz_wl_3         
# dz_wl_4         
# dz_wl_5            
# a_ia             
# eta              
# b1              
# b2              
# b3              
# b4              
# b5             
# b6              
# b7              
# b8              
# b9             
# b10         
# dz_gc_1         
# dz_gc_2         
# dz_gc_3         
# dz_gc_4 
# dz_gc_5         
# dz_gc_6         
# dz_gc_7         
# dz_gc_8         
# dz_gc_9        
# dz_gc_10   
# minuslogprior 
# minuslogprior__0    
# chi2  
# chi2__LSSTlike


# ---------------------
# weight    
# minuslogpost          
# sigma8          
# omegac          
# omegab          
# hubble              
# ns              
# m1              
# m2              
# m3              
# m4         
# dz_wl_1         
# dz_wl_2         
# dz_wl_3         
# dz_wl_4            
# a_ia             
# eta              
# b1              
# b2              
# b3              
# b4              
# b5         
# dz_gc_1         
# dz_gc_2         
# dz_gc_3         
# dz_gc_4         
# dz_gc_5   
# minuslogprior 
# minuslogprior__0            
# chi2 
# chi2__my_likelihood

In [3]:
def summary_calculation(samples1: np.ndarray, samples2: np.ndarray, neval: int) -> pd.DataFrame:
    record = []
    for i, key in enumerate(KEYS):
        testsamples = np.vstack(([samples1[:,i], samples2[:,i]]))
        summary_stats = summary(testsamples)
        summary_stats[key] = summary_stats.pop('Param:0')
        record.append(summary_stats)

    record_df = []
    for i in range(len(record)):
        record_df.append(pd.DataFrame(record[i]).round(3).loc[['r_hat', 'n_eff', 'mean', 'std']])

    record_df = pd.concat(record_df, axis = 1).T
    record_df['scaled_n_eff'] = record_df['n_eff'] / neval
    return record_df

## Cobaya

In [4]:
def cobaya_statistics(engine = 'jaxcosmo'):
    
    record_samples = []
    nsamples = []
    nlike = 0
    for i in range(2):
        if ANALYSIS != 'lsst':
            file = np.loadtxt(f'outputcobaya/testing/{engine}_{i+1}/output_prefix.1.txt')
        else:
            file = np.loadtxt(f'CobayaLSST/{engine}_{i+1}/lsst.1.txt')
        weight = file[:,0]
        samples = file[:,2:-4]
        nlike += sum(weight)
        record_samples.append(samples)
        nsamples.append(samples.shape[0])

    min_nsamples = min(nsamples)

    stats = summary_calculation(record_samples[0][-min_nsamples:], record_samples[1][-min_nsamples:], nlike)

    return stats

In [None]:
## LSST
# 0.0002872206804902297 (Cobaya EH)
# 0.026257363008476475 (NUTS EH - Glamdring)

In [12]:
0.026257363008476475 / 0.0002872206804902297

91.41877584740861

## EMCEE

In [5]:
def emcee_stats(engine = 'jaxcosmo'):

    test_1 = pickle_load('samples', f'{engine}_emcee_1')
    test_2 = pickle_load('samples', f'{engine}_emcee_2')

    nevals = test_1.flatchain.shape[0] + test_2.flatchain.shape[0]

    samples_1 = test_1.flatchain #test_1.get_chain(discard = discard, thin = thin, flat = True) 
    samples_2 = test_2.flatchain #test_2.get_chain(discard = discard, thin = thin, flat = True)
    
    
    stats = summary_calculation(samples_1, samples_2, nevals)
    return stats

## NUTS

In [5]:
def nuts_stats(engine = 'jaxcosmo'):
    
    sampler = pickle_load('lsst', f'nuts_sampler_{engine}')

    nsamples = sampler.num_chains * sampler.num_samples
    num_steps = sampler.get_extra_fields(group_by_chain=True)['num_steps'].sum(1).sum(0).item()
    samples = sampler.get_samples(group_by_chain=True)
    record = []
    for key in KEYS:
        parameter_samples = samples[key]
        summary_stats = summary(parameter_samples)
        summary_stats[key] = summary_stats.pop('Param:0')
        record.append(summary_stats)

    record_df = []
    for i in range(len(record)):
        record_df.append(pd.DataFrame(record[i]).round(3).loc[['r_hat', 'n_eff', 'mean', 'std']])

    record_df = pd.concat(record_df, axis = 1).T
    record_df['scaled_n_eff'] = record_df['n_eff'] / num_steps
    return record_df

In [7]:
df_cobaya_jc = cobaya_statistics(engine = 'jaxcosmo')
df_emcee_jc = emcee_stats(engine = 'jaxcosmo')
df_nuts_jc = nuts_stats(engine = 'jaxcosmo')

df_cobaya_emu = cobaya_statistics(engine = 'emulator')
df_emcee_emu = emcee_stats(engine = 'emulator')
df_nuts_emu = nuts_stats(engine = 'emulator')

In [8]:
df_nuts_emu['scaled_n_eff'].mean()

0.03415634511934331

In [9]:
df_cobaya_emu['scaled_n_eff'].mean()

0.004761498346259732

In [10]:
df_nuts_jc['scaled_n_eff'].mean()

0.04482185394258569

In [11]:
df_cobaya_jc['scaled_n_eff'].mean()

0.004543626076053639

In [12]:
df_nuts_emu['scaled_n_eff'].mean() / df_cobaya_emu['scaled_n_eff'].mean()

7.173444709095387

In [13]:
df_nuts_jc['scaled_n_eff'].mean() / df_cobaya_jc['scaled_n_eff'].mean()

9.864776104444681