In [1]:
import os 
os.environ["JAX_PLATFORMS"] = "cpu"
import jax
jax.config.update("jax_enable_x64", True)

import os, sys
import matplotlib.pyplot as plt
from timeit import timeit
from scipy.special import legendre
import numpy as np
import jax.numpy as jnp
from jax import grad, jit
from classy import Class
from pybird.correlator import Correlator
from jax import jacfwd, jacrev
import importlib 
import pybird 
from scipy.interpolate import interp1d
import jax.numpy as jnp
from train_pybird_emulators.emu_utils import emu_utils
from train_pybird_emulators.emu_utils.emu_utils import get_pgg_from_linps_and_f_and_A
from train_pybird_emulators.emu_utils.emu_utils import get_pgg_from_params

In [2]:
import h5py 

In [7]:
with h5py.File("/cluster/work/refregier/alexree/local_packages/train_pybird_emulators/src/train_pybird_emulators/data/training_data/validation_data/lhc_bank_Geff_z0p5to2p5/lhc_pk_lins.h5") as f:
    test = f['0'][:]

In [9]:
test.shape 

(200,)

In [2]:
km=1.0
kr=1.0
nd=3e-4

In [3]:
resum = False

In [4]:
#set up a baseline cosmology
k_r = 0.7
k_l = 1e-4
z = 0.57
# z = 2

kk = np.logspace(np.log10(k_l), np.log10(k_r), 1000)
M = Class()
cosmo = {'omega_b': 0.02235, 'omega_cdm': 0.120, 'h': 0.675, 'ln10^{10}A_s': 3.044, 'n_s': 0.965}
# cosmo = {'omega_b': 0.018, 'omega_cdm': 0.125, 'h': 0.9, 'ln10^{10}A_s': 2.044, 'n_s': 1.065}

M.set(cosmo)
M.set({'output': 'mPk', 'P_k_max_h/Mpc': 10, 'z_max_pk': 2})
M.compute()
pk_lin_0 = np.array([M.pk_lin(k*M.h(), z)*M.h()**3 for k in kk]) # k in Mpc/h, pk in (Mpc/h)^3


ipk_lin_0 = interp1d(kk, pk_lin_0, kind='cubic')
D1_0, f1_0 = M.scale_independent_growth_factor(z), M.scale_independent_growth_factor_f(z), 
A_s, Omega0_m_0 = 1e-10 * np.exp(cosmo['ln10^{10}A_s']), M.Omega0_m()

In [5]:
Omega0_m_0

0.3124279835390946

In [6]:
outdir = "/cluster/work/refregier/alexree/local_packages/pybird_emu/data/eftboss/out"
with open(os.path.join(outdir, 'fit_boss_onesky_pk_wc_cmass_ngc_l0.dat')) as f: data_file = f.read()
eft_params_str = data_file.split(', \n')[1].replace("# ", "")
eft_params = {key: float(value) for key, value in (pair.split(': ') for pair in eft_params_str.split(', '))}

In [7]:
N_bird = Correlator()
N_bird.set({'output': 'bPk', 'multipole': 3, 'kmax': 0.4,
       'fftaccboost': 2, # boosting the FFTLog precision (slower, but ~0.1% more precise -> let's emulate this)
       'with_resum':resum, 'with_exact_time': True,
       "with_time":False, # test without specifying this too 
       'km': km, 'kr': kr, 'nd': nd,
       'eft_basis': 'eftoflss', 'with_stoch': True})

N_bird.compute({'kk': kk, 'pk_lin': pk_lin_0, 'D': D1_0, 'f': f1_0, 'z': np.array(z), 'Omega0_m': Omega0_m_0},
          do_core=True, do_survey_specific=True)
bpk_benchmark_0 = N_bird.get(eft_params)

loading matrices!
setting EdS time approximation


In [9]:
# Define parameter ranges
omega_cdm_range = np.linspace(0.1, 0.13, 5)
omega_b_range = np.linspace(0.02, 0.03, 5)
h_range = np.linspace(0.4, 1.2, 5)
ln10A_s_range = np.linspace(1, 5, 5)
n_s_range = np.linspace(0.9, 1.1, 5)

# Generate parameter grid
param_grid = np.array(np.meshgrid(omega_cdm_range, omega_b_range, h_range, ln10A_s_range, n_s_range)).T.reshape(-1, 5)
print("param gird shape", param_grid.shape)
results = []

# Loop over each cosmological parameter set
inds = np.random.randint(0,100, 10)
for params in param_grid[inds]:
    omega_cdm, omega_b, h, ln10A_s, n_s = params
    cosmo = {'omega_b': omega_b, 'omega_cdm': omega_cdm, 'h': h, 'ln10^{10}A_s': ln10A_s, 'n_s': n_s}
    
    # Initialize CLASS
    M = Class()
    M.set(cosmo)
    M.set({'output': 'mPk', 'P_k_max_h/Mpc': 10, 'z_max_pk': 2})
    M.compute()
    
    z = np.random.rand()*2
    print("z", z)# Redshift

    # Compute pk_lin
    kk = np.logspace(np.log10(k_l), np.log10(k_r), 1000)
    pk_lin = np.array([M.pk_lin(k * M.h(), z) * M.h()**3 for k in kk])  # at z=0
    ipk_lin = interp1d(kk, pk_lin, kind='cubic')
    D1 = M.scale_independent_growth_factor(z)
    f1 = M.scale_independent_growth_factor_f(z)
    A_s = 1e-10 * np.exp(ln10A_s)
    Omega0_m = M.Omega0_m()
    
    # Initialize Correlator
    N_bird = Correlator()
    # Set up pybird in time unspecified mode for the computation of the pybird pieces training data
    N_bird.set(
        {
            "output": "bPk",
            "multipole": 3,
            "kmax": 0.4, #new kmax so dont have to filter out the large ks! 
            "fftaccboost": 2,
            "with_resum": resum, # add resum in by hand below 
            "with_exact_time": True,
            "with_time": False,  # time unspecified
            "with_emu":False,
            "km": km,
            "kr": kr,
            "nd": nd,
            "eft_basis": "eftoflss",
            "with_stoch": True,
        }
    )
    
    N_bird.compute({'kk': kk, 'pk_lin': pk_lin, 'D': D1, 'f': f1/1., 'z':z, 'Omega0_m': Omega0_m},
              do_core=True, do_survey_specific=True)

    bpk_benchmark = N_bird.get(eft_params)
    
    # Save results
    params = np.hstack((params, np.array([Omega0_m]), np.array([f1]), np.array([D1]), np.array([z])))
    results.append({'params': params, 'pk_lin': pk_lin, 'bpk_benchmark': bpk_benchmark})
    
    # Clean up CLASS object
    M.struct_cleanup()
    M.empty()

param gird shape (3125, 5)
z 0.5429550696658834
loading matrices!
setting EdS time approximation
z 1.7395611013101886
loading matrices!
setting EdS time approximation
z 0.9673748664101047
loading matrices!
setting EdS time approximation
z 0.420186368771186
loading matrices!
setting EdS time approximation
z 0.04421050001994842
loading matrices!
setting EdS time approximation
z 0.1739488572447807
loading matrices!
setting EdS time approximation
z 1.5930557427563385
loading matrices!
setting EdS time approximation
z 1.6922710820712636
loading matrices!
setting EdS time approximation
z 0.4692795677241508
loading matrices!
setting EdS time approximation
z 1.91957103622271
loading matrices!
setting EdS time approximation


In [None]:
from pybird import config
config.set_jax_enabled(True) # Enable JAX by setting the config Class
import jax.numpy as jnp 

In [None]:

N_emu = Correlator()
# Set up pybird in time unspecified mode for the computation of the pybird pieces training data
print(N_emu)
N_emu.set(
    {
        "output": "bPk",
        "multipole": 3,
        "kmax": 0.4, #new kmax so dont have to filter out the large ks! 
        "fftaccboost": 2,
        "with_resum": resum, # add resum in by hand below 
        "with_exact_time": True,
        "with_time": False,  # time unspecified
        "with_emu":True,
        "km": km,
        "kr": kr,
        "nd": nd,
        "eft_basis": "eftoflss",
        "with_stoch": True,
    }
)


N_emu.compute({'kk': jnp.array(kk), 'pk_lin': jnp.array(pk_lin_0), 'D': D1_0, 'f': f1_0, 'z':jnp.array(z), 'Omega0_m': Omega0_m_0},
          do_core=True, do_survey_specific=True)


In [None]:
bpk_recovered_emu = N_emu.get(eft_params)

In [None]:
plt.plot(bpk_recovered_emu[0]/bpk_benchmark_0[0])

### OK lets test over a range of cosmologies

In [None]:
results[0]["params"]

In [None]:

emu_out = []
for results_dict in results:
    pk_lin = results_dict["pk_lin"]
    Omega_m, f1, D1, z = results_dict["params"][-4:]
    print(Omega_m, f1, D1, z)
    
    N_emu = Correlator()
    # Set up pybird in time unspecified mode for the computation of the pybird pieces training data
    N_emu.set(
        {
            "output": "bPk",
            "multipole": 3,
            "kmax": 0.4, #new kmax so dont have to filter out the large ks! 
            "fftaccboost": 2,
            "with_resum": resum, # add resum in by hand below 
            "with_exact_time": True,
            "with_time": False,  # time unspecified
            "with_emu":True,
            "km": km,
            "kr": kr,
            "nd": nd,
            "eft_basis": "eftoflss",
            "with_stoch": True,
        }
    )
    
    N_emu.compute({'kk': jnp.array(kk), 'pk_lin': jnp.array(pk_lin), 'D': D1, 'f': f1/1., 'z':jnp.array(z), 'Omega0_m': Omega_m},
              do_core=True, do_survey_specific=True)

    emu_out.append(N_emu.get(eft_params))

In [None]:
for i in range(10): 
    plt.plot(results[i]['bpk_benchmark'][0]/emu_out[i][0])

In [None]:
for i in range(10):
    plt.loglog(kk, results[i]["pk_lin"]/np.amax(results[i]["pk_lin"]), alpha=0.1)
    plt.loglog(kk, pk_lin_0/np.amax(pk_lin_0), 'r--')