In [1]:
import opencosmo as oc
import numpy as np
import pandas as pd
from scipy import interpolate
from astropy import table
import h5py
from pathlib import Path
from astropy.cosmology import LambdaCDM
from scipy.interpolate import interp1d
from diffsky.experimental import lc_utils
from diffsky.data_loaders.hacc_utils import lightcone_utils
import jax.random as jran


In [21]:
class Desi_Selector:

    def __init__(self, 
                 desi_tracer,
                 path_desi_tracer,
                 path_sim,
                 sfh_model,
                 path_threshold,
                 sim_area=1121,
                 z_range =[0,2]
                 z_grid_points=481,
                 ):

        self.desi_tracer = desi_tracer
        self.path_desi_tracer = path_desi_tracer
        self.path_sim = path_sim
        self.sfh_model = sfh_model
        self.path_threshold = path_threshold
        self.sim_area = sim_area
        self.z_range = z_range
        self.z_grid_points = z_grid_points

        # column to use for threshold calculation
        if desi_tracer == 'bgs':
            self.threshold_col = 'lsst_g'

        elif desi_tracer == 'lrg':
            self.threshold_col = 'log_halo_mass'

        elif desi_tracer == 'elg':
            self.threshold_col = 'log_sfr'

        elif desi_tracer == 'qso':
            self.threshold_col = 'black_hole_mass'
    
    
    def load_sim_cat(self):
        
        dict_sfh_models = {'tng': 'tng_2025_11_06', 
                           'um': 'smdpl_dr1_2025_11_07',
                           'gal': 'galacticus_in_plus_ex_situ_2025_11_10'}
        
        dataset = oc.open(self.path_sim+dict_sfh_models[self.sfh_model])

        if desi_tracer == 'bgs':
            columns = ['ra', 'dec', 'redshift_true', 'lsst_g']

        if desi_tracer == 'lrg':
            columns = ['ra', 'dec', 'redshift_true', 'logmp_obs']

        if desi_tracer == 'elg':
            columns = ['ra', 'dec', 'redshift_true', 'logsm_obs', 'logssfr_obs']

        if desi_tracer == 'qso':
            columns = ['ra', 'dec', 'redshift_true', 'black_hole_mass', 'black_hole_eddington_ratio']
        
        
        dataset = dataset.select(columns)
        dataset = dataset.with_redshift_range(self.z_range[0], self.z_range[1])
        sim_cat = dataset.data.to_pandas()

        if desi_tracer == 'lrg':
            sim_cat.rename(columns={'logmp_obs': 'log_halo_mass'}, inplace=True))

        if desi_tracer == 'elg':
            sim_cat['log_sfr'] = sim_cat['logsm_obs'] + sim_cat['logssfr_obs']

        return sim_cat

    
    def rebin_desi_tracer(self, sim_cat):

        if desi_tracer == 'qso':
            
            tracer_data = table.Table.read(self.path_desi_tracer,  format='ascii.ecsv')
            z_bin_center = tracer_data['z']
            z_bin_min = z_bin_center - 0.050/2
            z_bin_max = z_bin_center + 0.050/2
            
            nz_north = qso_data['n_z_north']
            nz_south = qso_data['n_z_south']
            nz_avg = (nz_north + nz_south) / 2 
            zgrid = np.linspace(np.min(z_bin_min), np.max(z_bin_max), self.z_grid_points)

            repeat_n = int((self.z_grid_points-1)/len(z_bin_center))
            new_z_bin_min = np.linspace(np.min(z_bin_min), np.max(z_bin_max),  self.z_grid_points)[:-1]
            new_z_bin_max = np.linspace(np.min(z_bin_min), np.max(z_bin_max),  self.z_grid_points)[1:]
            new_z_center = (new_z_bin_max + new_z_bin_min) / 2

            interp_func = interp1d(z_bin_center, nz_avg, fill_value=0, bounds_error=False)
            interp_nz_avg = interp_func(new_z_center) / repeat_n

            values, edges = np.histogram(sim_cat['redshift_true'], bins=zgrid)
            values_sim = values / self.sim_area
            z_frac = interp_nz_avg / values_sim

            return (new_z_bin_min, new_z_center, new_z_bin_max, z_frac)
        
    
    def generate_threshold(self, sim_cat, new_z_bin_min, new_z_center, new_z_bin_max, z_frac):
        thres_list = []

        if desi_tracer == 'qso':
            
            for i in range(len(new_z_center)):
                print(i)
                this_zmin = new_z_bin_min[i]
                this_zmax = new_z_bin_max[i]
            
                this_cat = sim_cat[np.logical_and(sim_cat['redshift_true']>this_zmin, sim_cat['redshift_true']<this_zmax)]
            
                if len(this_cat) == 0:
                    print(f"Empty bin: zmin={this_zmin}, zmax={this_zmax}")
                
                this_thres = np.percentile(a = this_cat[self.threshold_col], q = 100-z_frac[i]*100)
                thres_list.append(this_thres)
            
            thres_arr = np.array(thres_list)

        return thres_list
        
        
    def generate_threshold(self):

        catalog = h5py.read .... 
        
        threshold_all = catalog[self.threshold_col]


    def produce_mock(self):



    # def measure_auto_correlation(self):
        


    def run(self):

        self.prepare_threshold()

        self.generate_threshold()

        self.produce_mock()
        

In [25]:
elg_selector = desi_selector(desi_tracer = 'elg', sim_path = '.', Z_GRID_POINTS = 401)