In [1]:
import numpy as np 
import pandas as pd
from numpyro.diagnostics import summary
from utils.helpers import pickle_load
import matplotlib.pylab as plt 

plt.rc('text', usetex=True)
plt.rc('font',**{'family':'sans-serif','serif':['Palatino']})
figSize  = (12, 8)
fontSize = 15

In [2]:
ANALYSIS = 'lsst'

if ANALYSIS != 'lsst':
    KEYS = ['sigma8', 'Omegac', 'Omegab', 'hubble', 'ns',
            'm1', 'm2', 'm3', 'm4',
            'dz_wl_1', 'dz_wl_2', 'dz_wl_3', 'dz_wl_4',
            'a_ia', 'eta',
            'b1', 'b2', 'b3', 'b4', 'b5', 
            'dz_gc_1', 'dz_gc_2', 'dz_gc_3', 'dz_gc_4', 'dz_gc_5']
else:
   
    KEYS = ['sigma8', 'Omegac', 'Omegab', 'hubble', 'ns', "m1", "m2", "m3", "m4", "m5",
    "dz_wl_1", "dz_wl_2", "dz_wl_3", "dz_wl_4", "dz_wl_5",
    "a_ia", "eta", "b1", "b2", "b3", "b4", "b5", "b6", "b7", "b8", "b9", "b10",
    "dz_gc_1", "dz_gc_2", "dz_gc_3", "dz_gc_4", "dz_gc_5",
    "dz_gc_6", "dz_gc_7", "dz_gc_8", "dz_gc_9", "dz_gc_10"]

# weight    
# minuslogpost          
# sigma8          
# omegac          
# omegab          
# hubble              
# ns              
# m1     
# m2              
# m3              
# m4              
# m5         
# dz_wl_1         
# dz_wl_2         
# dz_wl_3         
# dz_wl_4         
# dz_wl_5            
# a_ia             
# eta              
# b1              
# b2              
# b3              
# b4              
# b5             
# b6              
# b7              
# b8              
# b9             
# b10         
# dz_gc_1         
# dz_gc_2         
# dz_gc_3         
# dz_gc_4 
# dz_gc_5         
# dz_gc_6         
# dz_gc_7         
# dz_gc_8         
# dz_gc_9        
# dz_gc_10   
# minuslogprior 
# minuslogprior__0    
# chi2  
# chi2__LSSTlike


# ---------------------
# weight    
# minuslogpost          
# sigma8          
# omegac          
# omegab          
# hubble              
# ns              
# m1              
# m2              
# m3              
# m4         
# dz_wl_1         
# dz_wl_2         
# dz_wl_3         
# dz_wl_4            
# a_ia             
# eta              
# b1              
# b2              
# b3              
# b4              
# b5         
# dz_gc_1         
# dz_gc_2         
# dz_gc_3         
# dz_gc_4         
# dz_gc_5   
# minuslogprior 
# minuslogprior__0            
# chi2 
# chi2__my_likelihood

In [3]:
def summary_calculation(samples1: np.ndarray, samples2: np.ndarray, neval: int) -> pd.DataFrame:
    record = []
    for i, key in enumerate(KEYS):
        testsamples = np.vstack(([samples1[:,i], samples2[:,i]]))
        summary_stats = summary(testsamples)
        summary_stats[key] = summary_stats.pop('Param:0')
        record.append(summary_stats)

    record_df = []
    for i in range(len(record)):
        record_df.append(pd.DataFrame(record[i]).round(3).loc[['r_hat', 'n_eff', 'mean', 'std']])

    record_df = pd.concat(record_df, axis = 1).T
    record_df['scaled_n_eff'] = record_df['n_eff'] / neval
    return record_df

## Cobaya

In [4]:
def cobaya_statistics(engine = 'jaxcosmo'):
    
    record_samples = []
    nsamples = []
    nlike = 0
    for i in range(2):
        if ANALYSIS != 'lsst':
            file = np.loadtxt(f'outputcobaya/testing/{engine}_{i+1}/output_prefix.1.txt')
        else:
            file = np.loadtxt(f'CobayaLSST/{engine}_{i+1}/lsst.1.txt')
        weight = file[:,0]
        samples = file[:,2:-4]
        nlike += sum(weight)
        record_samples.append(samples)
        nsamples.append(samples.shape[0])

    min_nsamples = min(nsamples)

    stats = summary_calculation(record_samples[0][-min_nsamples:], record_samples[1][-min_nsamples:], nlike)

    return stats

In [5]:
## LSST
# 0.0002872206804902297 (Cobaya EH)
# 0.026257363008476475 (NUTS EH - Glamdring)

# 0.00012025138105330559 (Cobaya Emulator)
# 0.027563007453589822 (NUTS emulator - Local Desktop)

## DES 
# 0.04482185394258569 (NUTS EH)
# 0.004543626076053639 (Cobaya EH)

# 0.03415634511934331 (NUTS Emulator)
# 0.004761498346259732 (Cobaya Emulator)

In [6]:
# LSST - EH 
0.026257363008476475 / 0.0002872206804902297

91.41877584740861

In [7]:
# LSST - Emulator 
0.027563007453589822 / 0.00012025138105330559

229.21156673761246

In [8]:
# DES - EH 
0.04482185394258569 / 0.004543626076053639

9.864776104444681

In [9]:
# DES - Emulator
0.03415634511934331 / 0.004761498346259732

7.173444709095387

## EMCEE

In [10]:
def emcee_stats(engine = 'jaxcosmo'):

    test_1 = pickle_load('samples', f'{engine}_emcee_1')
    test_2 = pickle_load('samples', f'{engine}_emcee_2')

    nevals = test_1.flatchain.shape[0] + test_2.flatchain.shape[0]

    samples_1 = test_1.flatchain #test_1.get_chain(discard = discard, thin = thin, flat = True) 
    samples_2 = test_2.flatchain #test_2.get_chain(discard = discard, thin = thin, flat = True)
    
    
    stats = summary_calculation(samples_1, samples_2, nevals)
    return stats

## NUTS

In [11]:
def nuts_stats(engine = 'jaxcosmo'):
    
    if ANALYSIS == 'lsst':
        sampler = pickle_load('lsst', f'nuts_sampler_{engine}')
    else:
        sampler = pickle_load('samples', f'{engine}_nuts_small_ss_high_td')

    nsamples = sampler.num_chains * sampler.num_samples
    num_steps = sampler.get_extra_fields(group_by_chain=True)['num_steps'].sum(1).sum(0).item()
    samples = sampler.get_samples(group_by_chain=True)
    record = []
    for key in KEYS:
        parameter_samples = samples[key]
        summary_stats = summary(parameter_samples)
        summary_stats[key] = summary_stats.pop('Param:0')
        record.append(summary_stats)

    record_df = []
    for i in range(len(record)):
        record_df.append(pd.DataFrame(record[i]).round(3).loc[['r_hat', 'n_eff', 'mean', 'std']])

    record_df = pd.concat(record_df, axis = 1).T
    record_df['scaled_n_eff'] = record_df['n_eff'] / num_steps
    return record_df

In [12]:
def nuts_stats_jaxcosmo():
    info = pickle_load('lsst', 'nuts_jaxcosmo_info')
    samples, steps = info['samples'], info['steps']
    num_steps = sum(steps[0]) + sum(steps[1])
    
    samples_1 = []
    samples_2 = []
    for key in KEYS:
        samples_1.append(samples[key][0])
        samples_2.append(samples[key][1])
    samples_1 = np.asarray(samples_1)
    samples_2 = np.asarray(samples_2)
    df = summary_calculation(samples_1.T, samples_2.T, num_steps)
    return df 

In [13]:
def numpyro_model():
    pass

In [14]:
if ANALYSIS == 'lsst':
    df_cobaya_emu = cobaya_statistics('emulator')
    df_nuts_emu = nuts_stats('emulator')

    df_cobaya_jc = cobaya_statistics('jaxcosmo')
    
    # because we used JAX 0.4.25 on Glamdring - this is separate
    df_nuts_jc = nuts_stats_jaxcosmo()
    
else:

    df_cobaya_jc = cobaya_statistics(engine = 'jaxcosmo')
    df_emcee_jc = emcee_stats(engine = 'jaxcosmo')
    df_nuts_jc = nuts_stats(engine = 'jaxcosmo')

    df_cobaya_emu = cobaya_statistics(engine = 'emulator')
    df_emcee_emu = emcee_stats(engine = 'emulator')
    df_nuts_emu = nuts_stats(engine = 'emulator')

In [15]:
df_nuts_emu['scaled_n_eff'].mean()

0.02756300693233674

In [16]:
df_cobaya_emu['scaled_n_eff'].mean()

0.00012025138105330559

In [17]:
df_nuts_jc['scaled_n_eff'].mean()

0.027523081065323297

In [18]:
df_cobaya_jc['scaled_n_eff'].mean()

0.0002872206804902297

In [19]:
df_nuts_emu['scaled_n_eff'].mean() / df_cobaya_emu['scaled_n_eff'].mean()

229.21156240291728

In [20]:
df_nuts_jc['scaled_n_eff'].mean() / df_cobaya_jc['scaled_n_eff'].mean()

95.825554825463

In [21]:
df_nuts_emu.head()

Unnamed: 0,r_hat,n_eff,mean,std,scaled_n_eff
sigma8,1.0,20873.033,0.816,0.007,0.021188
Omegac,1.0,12373.363,0.262,0.004,0.01256
Omegab,1.0,9912.097,0.052,0.002,0.010061
hubble,1.001,7967.451,0.654,0.012,0.008088
ns,1.001,13306.551,1.043,0.009,0.013507


In [22]:
df_cobaya_emu.head()

Unnamed: 0,r_hat,n_eff,mean,std,scaled_n_eff
sigma8,1.003,473.472,0.817,0.007,6.6e-05
Omegac,1.0,709.851,0.262,0.005,9.9e-05
Omegab,1.003,970.206,0.052,0.002,0.000135
hubble,1.018,75.708,0.657,0.016,1.1e-05
ns,1.03,50.485,1.04,0.012,7e-06


In [23]:
df_nuts_jc.head()

Unnamed: 0,r_hat,n_eff,mean,std,scaled_n_eff
sigma8,1.0,17650.966,0.811,0.008,0.019791
Omegac,1.0,13293.174,0.274,0.006,0.014905
Omegab,1.0,9933.145,0.049,0.003,0.011137
hubble,1.0,8654.722,0.658,0.015,0.009704
ns,1.0,19155.291,1.027,0.007,0.021477


In [24]:
df_cobaya_jc.head()

Unnamed: 0,r_hat,n_eff,mean,std,scaled_n_eff
sigma8,1.003,515.508,0.812,0.009,9.3e-05
Omegac,1.004,302.648,0.274,0.007,5.4e-05
Omegab,1.001,2543.359,0.049,0.003,0.000457
hubble,1.007,188.103,0.66,0.018,3.4e-05
ns,1.011,128.522,1.025,0.008,2.3e-05
