# Pickle AbacusSummit Pk mocks for Barry 

In [None]:
# Import the necessary packages, set up the fiducial cosmology and save the DESI template
import os
import pickle
import numpy as np
import scipy as sp
import pandas as pd
from astropy.io import ascii
import matplotlib.pyplot as plt
from scipy.interpolate import splrep, splev
from cosmoprimo import PowerSpectrumBAOFilter
from cosmoprimo.fiducial import DESI
from pypower import BaseMatrix, CatalogFFTPower, CatalogFFTCorr, PowerSpectrumMultipoles, PowerSpectrumSmoothWindow, PowerSpectrumSmoothWindowMatrix, PowerSpectrumOddWideAngleMatrix, setup_logging
from pycorr import TwoPointCorrelationFunction, project_to_multipoles

cosmo = DESI()
print(cosmo["Omega_b"]*cosmo["h"]**2, cosmo["Omega_cdm"]*cosmo["h"]**2, cosmo["Omega_m"]*cosmo["h"]**2 - cosmo["Omega_b"]*cosmo["h"]**2)
print(cosmo["A_s"], cosmo["n_s"], cosmo["tau_reio"])
print(np.sum(cosmo["m_ncdm"]))

# Save the default DESI template to a file
k_min = 1e-4
k_max = 5
k_num = 2000
kl = np.logspace(np.log(k_min), np.log(k_max), k_num, base=np.e)
pkz = cosmo.get_fourier().pk_interpolator()
pk = pkz.to_1d(z=0)
pkv = pk(kl)
pknow = PowerSpectrumBAOFilter(pk, engine='wallish2018').smooth_pk_interpolator()
pksmv = pknow(kl)

In [None]:
# Useful utility function to collate some Xi data
def collect_pk_data(pre_files, post_files, pre_cov_files, post_cov_files, pre_files_name, post_files_name, pre_cov_name, post_cov_name, zs, reconsmooth, mocks, rpcut, imaging):

    pre_data, post_data = None, None
    
    pre_mocks = get_pk(pre_files, pre_name, mocks, rpcut, imaging) if pre_files_name is not None else None
    post_mocks = get_pk(post_files, post_name, mocks, rpcut, imaging) if post_files_name is not None else None
    
    pre_cov = get_pk_cov(pre_cov_files, pre_cov_name, rpcut, imaging) if pre_cov_files is not None else None
    post_cov = get_pk_cov(post_cov_files, post_cov_name, rpcut, imaging) if post_cov_files is not None else None
       
    if post_files is not None:
        winmat, wam_reshape = getwin(post_mocks[0]["k"].to_numpy(), post_cov_files, post_name, rpcut, imaging)
    else:
        winmat, wam_reshape = getwin_dummy(post_mocks[0]["k"].to_numpy())
        
    rp = f" {imaging} rpcut2.5" if rpcut else f" {imaging}" 
        
    split = {
        "n_data": 1,
        "pre-recon data": pre_data,
        "pre-recon cov": pre_cov,
        "post-recon data": post_data,
        "post-recon cov": post_cov,
        "pre-recon mocks": pre_mocks,
        "post-recon mocks": post_mocks,
        "cosmology": {
            "om": cosmo["Omega_m"],
            "h0": cosmo["h"],
            "z": (zs[1]+zs[0])/2.0,
            "ob": cosmo["Omega_b"],
            "ns": cosmo["n_s"],
            "mnu": np.sum(cosmo["m_ncdm"]),
            "reconsmoothscale": reconsmooth,
        },
        "name": "DESI SecondGen " + f"sm{reconsmooth} " +  ("_").join(post_files_name.split("_")[2:]) + rp,
        "winfit": winmat,
        "winpk": None,  # We can set this to None; Barry will set it to zeroes given the length of the data vector.
        "m_mat": wam_reshape,
    }
    
    with open(f"/global/cfs/cdirs/desi/users/chowlett/barry_inputs/DESI_SecondGen_grid003_pickledbyAW_sm{reconsmooth}_" + ("_").join(post_files_name.split("_")[2:]).lower() + ("_").join(rp.split(" ")) + "_pk.pkl", "wb") as f:
        pickle.dump(split, f)
        
    return split

# Power Spectrum
def get_pk(loc, name, mocks, rpcut, imaging):
    
    rp = "_rpcut2.5" if rpcut else "" 
    
    # Overwrite the <k> with the bin centres as we now use a binning matrix to correct to <P(k)>
    ks = np.linspace(0.0, 0.4, 80, endpoint=False) + 0.0025
    #ks = None
    
    pks = []
    for mock in mocks:
        # if 'BGS_BRIGHT-21.5' in name and mock == 13:
        #     continue

        infile = loc + name + f'{mock:02}' + ".npy"
        
        data = PowerSpectrumMultipoles.load(infile)
        data.slice(slice(0,400,5))
        df = pd.DataFrame(np.vstack(data(ell=[0,2,4], return_k=True)).T.real, columns=["k", "pk0", "pk2", "pk4"])
        df["pk1"] = np.zeros(len(df["k"]))
        df["pk3"] = np.zeros(len(df["k"]))
        df["nk"] = data.nmodes
        if ks is not None:
            df["k"] = ks
        pks.append(df[["k", "pk0", "pk1", "pk2", "pk3", "pk4"]])
        
    return pks

# Power Spectrum covariance matrix.
def get_pk_cov(loc, name, rpcut, imaging):

    rp = "_rpcut2.5" if rpcut else "" 
    infile = loc + name # + ".txt"
    
    cov_input = pd.read_csv(infile, comment="#", delim_whitespace=True, header=None).to_numpy()
    nks = int(np.shape(cov_input)[0]/3)
    nin = nks
    cov = np.eye(5 * nks)
    cov[:nks, :nks] = cov_input[:nks, :nks]
    cov[:nks, 2 * nks : 3 * nks] = cov_input[:nks, nin : nin + nks]
    cov[:nks, 4 * nks : 5 * nks] = cov_input[:nks, 2 * nin : 2 * nin + nks]
    cov[2 * nks : 3 * nks, :nks] = cov_input[nin : nin + nks, :nks]
    cov[2 * nks : 3 * nks, 2 * nks : 3 * nks] = cov_input[nin : nin + nks, nin : nin + nks]
    cov[2 * nks : 3 * nks, 4 * nks : 5 * nks] = cov_input[nin : nin + nks, 2 * nin : 2 * nin + nks]
    cov[4 * nks : 5 * nks, :nks] = cov_input[2 * nin : 2 * nin + nks, :nks]
    cov[4 * nks : 5 * nks, 2 * nks : 3 * nks] = cov_input[2 * nin : 2 * nin + nks, nin : nin + nks]
    cov[4 * nks : 5 * nks, 4 * nks : 5 * nks] = cov_input[2 * nin : 2 * nin + nks, 2 * nin : 2 * nin + nks]
    
    #plt.imshow(cov/np.sqrt(np.outer(np.diag(cov), np.diag(cov))))
    #plt.show()
    
    # Check the covariance matrix is invertible
    v = np.diag(cov @ np.linalg.inv(cov))
    if not np.all(np.isclose(v, 1)):
        print("ERROR, setting an inappropriate covariance matrix that is almost singular!!!!")

    return cov

# Read's in window and wideangle matrices
def getwin(ks, loc, name, rpcut, imaging):

    zs = name.split('_')
    zsfiltered = []
    for item in zs:
        if 'zm' in item:
            zsfiltered.append(item)
    zlow = zsfiltered[0]
    zhigh = zsfiltered[1]
    
    winname = f'wmatrix_smooth_GCcomb_{zlow}_{zhigh}_combined_Grid003.npy'
    
    infile = loc + 'windows/' + winname 
    wam = BaseMatrix.load(infile)
    wam = wam[:,:int(len(wam.xout[0])// 5 * 5)] 

    wam.rebin_x(factorout=5, factorin=4) 
    
    wam.select_x(xinlim = [np.min(wam.xin[0]), 0.4], xoutlim = [np.min(wam.xout[0]), 0.4]) 
    
    #print(old_wam.value/wam.value)
    kout = wam.xout[0] if ks is None else ks
    
    # This window function only has even multipoles as outputs and includes wide angle effects, so let's pad it with 
    # some zeros where the output odd multipoles would be so Barry is happy and then create a dummy wide angle matrix.
    w_transform = np.zeros((5 * len(kout), 6 * len(wam.xin[0])))
    wam_reshape = np.hsplit(wam.value, 3)
    for j in range(3):
        for i in range(3):
            w_transform[2*j*len(kout): (2*j+1)*len(kout) , 2*i*len(wam.xin[0]) : (2*i+1)*len(wam.xin[0])] = wam_reshape[j][i*len(wam.xin[0]) : (i+1)*len(wam.xin[0]), ].T
    
    matrix = np.zeros((6 * len(wam.xin[0]), 3 * len(wam.xin[0])))
    matrix[: len(wam.xin[0]), : len(wam.xin[0])] = np.diag(np.ones(len(wam.xin[0])))
    matrix[2 * len(wam.xin[0]) : 3 * len(wam.xin[0]), len(wam.xin[0]) : 2 * len(wam.xin[0])] = np.diag(np.ones(len(wam.xin[0])))
    matrix[4 * len(wam.xin[0]) : 5 * len(wam.xin[0]), 2 * len(wam.xin[0]) :] = np.diag(np.ones(len(wam.xin[0])))
            
    plt.imshow(np.log10(np.fabs(w_transform)), aspect='auto')
    plt.show()

    plt.imshow((w_transform @ matrix).T, aspect='auto')
    plt.show()
    # The conversion matrix M from Beutler 2019. Used to compute the odd multipole models given the even multipoles. In the absence of wide angle effects, or if we don't care about
    # the odd multipoles, we can set this to a block matrix with identity matrices in the appropriate places, as is done here.

    res = {"w_ks_input": wam.xin[0], "w_k0_scale": np.zeros(len(wam.xin[0])), "w_transform": w_transform, "w_ks_output": kout}
    winmat = {1: res}   # Step size is one, but we could modify this to contain other stepsizes too.
    
    # Wideangle matrix already included in window matrix, so pass None for wide-angle matrix so that Barry knows
    return winmat, matrix

# Window function matrix. The window functions are stored in a dictionary of 'step sizes' i.e., how many bins get stuck together relative to the 
# pk measurements so that we can rebin the P(k) at run time if required. Each step size is a dictionary with:
#    the input and output k binning (w_ks_input, w_ks_output), the window function matrix (w_transform) and integral constraint (w_k0_scale).
# The window function assumes 6 input and 5 output multipoles. For cubic sims, we can set the integral constraint to zero and window matrix to a binning matrix, as is done here.
def getwin_dummy(ks):
    
    dk = ks[1] - ks[0]
    ks_input = np.logspace(-3.0, np.log10(0.5), 500)

    binmat = np.zeros((len(ks), len(ks_input)))
    for ii in range(len(ks_input)):

        # Define basis vector
        pkvec = np.zeros_like(ks_input)
        pkvec[ii] = 1

        # Define the spline:
        pkvec_spline = splrep(ks_input, pkvec)

        # Now compute binned basis vector:
        tmp = np.zeros_like(ks)
        for i, kk in enumerate(ks):
            kl = kk - dk / 2
            kr = kk + dk / 2
            kin = np.linspace(kl, kr, 100)
            tmp[i] = np.trapz(kin**2 * splev(kin, pkvec_spline, ext=3), x=kin) * 3 / (kr**3 - kl**3)

        binmat[:, ii] = tmp

    w_transform = np.zeros((5 * ks.size, 6 * ks_input.size))
    for i in range(5):
        w_transform[i*ks.size: (i+1)*ks.size , i*ks_input.size : (i+1)*ks_input.size] = binmat
    
    # The conversion matrix M from Beutler 2019. Used to compute the odd multipole models given the even multipoles. In the absence of wide angle effects, or if we don't care about
    # the odd multipoles, we can set this to a block matrix with identity matrices in the appropriate places, as is done here.
    matrix = np.zeros((6 * ks_input.size, 3 * ks_input.size))
    matrix[: ks_input.size, : ks_input.size] = np.diag(np.ones(ks_input.size))
    matrix[2 * ks_input.size : 3 * ks_input.size, ks_input.size : 2 * ks_input.size] = np.diag(np.ones(ks_input.size))
    matrix[4 * ks_input.size : 5 * ks_input.size, 2 * ks_input.size :] = np.diag(np.ones(ks_input.size))
    
    res = {"w_ks_input": ks_input, "w_k0_scale": np.zeros(ks.size), "w_transform": w_transform, "w_ks_output": ks}
    return {1: res}, matrix  # Step size is one  

# Plot the correlation function, for sanity checking
def plot_pk(split, pre=True, post=True):
        
    color = ["r", "b", "g"]
    k = split["post-recon mocks"][0]["k"]
    nmocks = len(split["post-recon mocks"])
    label = [r"$P_{0}(k)$", r"$P_{2}(k)$", r"$P_{4}(k)$"]
        
    if pre:
        for m, pk in enumerate(["pk0", "pk2", "pk4"]):
            yerr = k * np.sqrt(np.diag(split["pre-recon cov"]))[m * len(k) : (m + 1) * len(k)]
            plt.errorbar(
                k,
                k * np.mean([split["pre-recon mocks"][i][pk] for i in range(nmocks)], axis=0),
                yerr=yerr,
                marker="o",
                ls="None",
                c=color[m],
                label=label[m],
            )
            for i in range(nmocks):
                plt.errorbar(k, k * split["pre-recon mocks"][i][pk], marker="None", ls="-", c='k', alpha=1.0 / nmocks**(3.0/4.0))
        plt.xlabel(r"$k$")
        plt.ylabel(r"$k\,\times pk(k)$")
        plt.title(split["name"] + " Prerecon")
        plt.legend(loc='upper right')
        plt.show()
        
    if post:
        for m, pk in enumerate(["pk0", "pk2", "pk4"]):
            yerr = k * np.sqrt(np.diag(split["post-recon cov"]))[m * len(k) : (m + 1) * len(k)]
            plt.errorbar(
                k,
                k * np.mean([split["post-recon mocks"][i][pk] for i in range(nmocks)], axis=0),
                yerr=yerr,
                marker="o",
                ls="None",
                c=color[m],
                label=label[m],
            )
            for i in range(nmocks):
                plt.errorbar(k, k * split["post-recon mocks"][i][pk], marker="None", ls="-", c='k', alpha=1.0 / nmocks**(3.0/4.0))
        plt.ylabel(r"$k\,\times pk(k)$")
        plt.title(split["name"] + " Postrecon")
        plt.legend(loc='upper right')
        plt.show()

In [None]:
# The catalogue version
import os 
version = 1.2
ffa = "altmtl"               # Flavour of fibre assignment. Can be "ffa" for fast fiber assign, or "complete"
rpcut = False             # Whether or not to include the rpcut
imaging = "default_FKP"   # What form of imaging systematics to use. Can be "default_FKP", "default_FKP_addSN", or "default_FKP_addRF"

# This is a dictionary of all the combinations of dataset that we have and their redshift bins.
tracers = {'BGS': [[0.1,0.4]],
            'LRG': [[0.4, 0.6], [0.6, 0.8], [0.8, 1.1]], 
            'ELG_LOP': [[0.8, 1.1], [1.1, 1.6]],
            'QSO': [[0.8, 2.1]]}

# How many complete mocks are available for each tracer? 
# While the mocks are still being processed, this allows us to skip over the missing entries
nmocks = {'BGS': [0,25], 'LRG': [0,25], 'ELG_LOP': [0,25], 'QSO': [0,25]}

# This dictionary specifies the particulars of how reconstruction was run on each tracer. First entry is smoothing scale, second is type of recon. 
# QSO has no recon, so set to None so it can be skipped over later.
recon = {'BGS': [15, "IFTrecsym"],
         'LRG': [10, "IFTrecsym"], 
         'ELG_LOP': [10, "IFTrecsym"],
         'QSO': [30, "IFTrecsym"]}

suffix = {
         'BGS': 'ifft_cellsize4.0_sm15_f0.665_b1.77_recsym_',
         'LRG': 'ifft_cellsize4.0_sm15_f0.823_b2.04_recsym_',
         'ELG_LOP': 'ifft_cellsize4.0_sm15_f0.893_b1.20_recsym_',
         'QSO': 'ifft_cellsize4.0_sm30_f0.944_b2.64_recsym_'
         }

basepath = f"/global/cfs/cdirs/desicollab/users/alexpzfz/KP4/fiducial_cosmo/CutSky/Pk/"

for t in tracers:
    for i, zs in enumerate(tracers[t]):  

        pre_files = basepath + t + "/AbacusSummit_base_c000_SecondGen_ffa/"
        post_files = pre_files 
        pre_cov_files = pre_files
        post_cov_files = pre_files

        pre_name = f"Pk_cutsky_{t}_GCcomb_zmin{zs[0]}_zmax{zs[1]}_Grid003_ph0" 

        post_name = f"Pk_cutsky_{t}_GCcomb_zmin{zs[0]}_zmax{zs[1]}_{suffix[t]}Grid003_ph0"

        pre_cov_name = f"covariance_{t}_GCcomb_zmin{zs[0]}_zmax{zs[1]}_Grid003.txt"
        post_cov_name = f"covariance_{t}_GCcomb_zmin{zs[0]}_zmax{zs[1]}_{suffix[t]}Grid003.txt"

        data = collect_pk_data(pre_files, post_files, pre_cov_files, post_cov_files, pre_name, post_name, pre_cov_name, post_cov_name, zs, recon[t][0], range(nmocks[t][0], nmocks[t][1]), rpcut, imaging)
        plot_pk(data, post=False if post_name is None else True, pre=False if pre_name is None else True) # Plot the data to check things
            
                    