In [1]:
import opencosmo as oc
import numpy as np
import pandas as pd
from scipy import interpolate
from astropy import table
import h5py
from pathlib import Path
from astropy.cosmology import LambdaCDM
from scipy.interpolate import interp1d
from diffsky.experimental import lc_utils
from diffsky.data_loaders.hacc_utils import lightcone_utils
import jax.random as jran


In [21]:
class Desi_Selector:

    def __init__(self, 
                 desi_tracer,
                 path_desi_tracer,
                 path_sim,
                 sfh_model,
                 path_threshold,
                 sim_area=1121,
                 z_range =[0,2]
                 z_grid_points=481,
                 ):

        self.desi_tracer = desi_tracer
        self.path_desi_tracer = path_desi_tracer
        self.path_sim = path_sim
        self.sfh_model = sfh_model
        self.path_threshold = path_threshold
        self.sim_area = sim_area
        self.z_range = z_range
        self.z_grid_points = z_grid_points

        # column to use for threshold calculation
        if desi_tracer == 'bgs':
            self.threshold_col = 'lsst_g'

        elif desi_tracer == 'lrg':
            self.threshold_col = 'log_halo_mass'

        elif desi_tracer == 'elg':
            self.threshold_col = 'log_sfr'

        elif desi_tracer == 'qso':
            self.threshold_col = 'black_hole_mass'
    
    
    def load_sim_cat(self):
        
        dict_sfh_models = {'tng': 'tng_2025_11_06', 
                           'um': 'smdpl_dr1_2025_11_07',
                           'gal': 'galacticus_in_plus_ex_situ_2025_11_10'}
        
        dataset = oc.open(self.path_sim+dict_sfh_models[self.sfh_model])

        if desi_tracer == 'bgs':
            columns = ['ra', 'dec', 'redshift_true', 'lsst_g']

        if desi_tracer == 'lrg':
            columns = ['ra', 'dec', 'redshift_true', 'logmp_obs']

        if desi_tracer == 'elg':
            columns = ['ra', 'dec', 'redshift_true', 'logsm_obs', 'logssfr_obs']

        if desi_tracer == 'qso':
            columns = ['ra', 'dec', 'redshift_true', 'black_hole_mass', 'black_hole_eddington_ratio']
        
        
        dataset = dataset.select(columns)
        dataset = dataset.with_redshift_range(self.z_range[0], self.z_range[1])
        sim_cat = dataset.data.to_pandas()

        if desi_tracer == 'lrg':
            sim_cat.rename(columns={'logmp_obs': 'log_halo_mass'}, inplace=True))

        if desi_tracer == 'elg':
            sim_cat['log_sfr'] = sim_cat['logsm_obs'] + sim_cat['logssfr_obs']

        return sim_cat

    
    def rebin_desi_tracer(self, sim_cat):

        if desi_tracer == 'qso':
            
            tracer_data = table.Table.read(self.path_desi_tracer,  format='ascii.ecsv')
            z_bin_center = tracer_data['z']
            z_bin_min = z_bin_center - 0.050/2
            z_bin_max = z_bin_center + 0.050/2
            
            nz_north = qso_data['n_z_north']
            nz_south = qso_data['n_z_south']
            nz_avg = (nz_north + nz_south) / 2 
            zgrid = np.linspace(np.min(z_bin_min), np.max(z_bin_max), self.z_grid_points)

            repeat_n = int((self.z_grid_points-1)/len(z_bin_center))
            new_z_bin_min = np.linspace(np.min(z_bin_min), np.max(z_bin_max),  self.z_grid_points)[:-1]
            new_z_bin_max = np.linspace(np.min(z_bin_min), np.max(z_bin_max),  self.z_grid_points)[1:]
            new_z_center = (new_z_bin_max + new_z_bin_min) / 2

            interp_func = interp1d(z_bin_center, nz_avg, fill_value=0, bounds_error=False)
            interp_nz_avg = interp_func(new_z_center) / repeat_n

            values, edges = np.histogram(sim_cat['redshift_true'], bins=zgrid)
            values_sim = values / self.sim_area
            z_frac = interp_nz_avg / values_sim

            return (new_z_bin_min, new_z_center, new_z_bin_max, z_frac)
        
    
    def generate_threshold(self, sim_cat, new_z_bin_min, new_z_center, new_z_bin_max, z_frac):
        thres_list = []

        if desi_tracer == 'qso':
            
            for i in range(len(new_z_center)):
                print(i)
                this_zmin = new_z_bin_min[i]
                this_zmax = new_z_bin_max[i]
            
                this_cat = sim_cat[np.logical_and(sim_cat['redshift_true']>this_zmin, sim_cat['redshift_true']<this_zmax)]
            
                if len(this_cat) == 0:
                    print(f"Empty bin: zmin={this_zmin}, zmax={this_zmax}")
                
                this_thres = np.percentile(a = this_cat[self.threshold_col], q = 100-z_frac[i]*100)
                thres_list.append(this_thres)
            
            thres_arr = np.array(thres_list)

        return thres_list
        
        
    def generate_threshold(self):

        catalog = h5py.read .... 
        
        threshold_all = catalog[self.threshold_col]


    def produce_mock(self):



    # def measure_auto_correlation(self):
        


    def run(self):

        self.prepare_threshold()

        self.generate_threshold()

        self.produce_mock()
        

In [25]:
elg_selector = desi_selector(desi_tracer = 'elg', sim_path = '.', Z_GRID_POINTS = 401)

In [26]:
elg_selector.run()

'sfr'

In [None]:
lrg_selector = desi_selector(desi_tracer = 'lrg', sim_path = '.', Z_GRID_POINTS = 401)

In [None]:
def get_desi_proxy_thresholds():



def get_desi_interp(path_desi_data=path_desi_data, thres_list=sfr_thres):
    
    desi_data = table.Table.read(path_desi_data, format='ascii.ecsv')
    zmin = desi_data['ZMIN']
    zmax = desi_data['ZMAX']
    z_bin_centers = (zmin + zmax) / 2
    thres_of_z = interpolate.interp1d(z_bin_centers, thres_list,  fill_value=9E11, bounds_error=False)

    return thres_of_z


def get_sim_cat(healpix_id=None,
                z_range=[[0,1], [1,2]]):
    """
    Load the galaxy data from the HDF5 file for a specific Healpix ID and redshift range.
    """
    cat_list = []

    for z in z_range:

        filepath = f'/global/cfs/cdirs/lsst/shared/xgal/skysim/skysim5000_v1.1.1/z_{z[0]}_{z[1]}.step_all.healpix_{healpix_id}.hdf5'

        with h5py.File(filepath, 'r') as file:
            properties = file['galaxyProperties']
            redshift = np.array(properties['redshift'])
            redshift_hubble = np.array(properties['redshiftHubble'])
            distance = cosmo.comoving_distance(redshift).value
            distance_hubble = cosmo.comoving_distance(redshift_hubble).value
            sfr = np.array(properties['baseDC2']['sfr'])
            sfr_tot = np.array(properties['totalStarFormationRate'])
            stellar_mass = np.array(properties['totalMassStellar'])
            blackhole_mass = np.array(properties['blackHoleMass'])
            gal_id = np.array(properties['galaxyID'])
            mag_u = np.array(properties['LSST_filters']['magnitude:LSST_u:observed:dustAtlas']) # mags with no MW extinction corrections
            mag_g = np.array(properties['LSST_filters']['magnitude:LSST_g:observed:dustAtlas'])
            mag_r = np.array(properties['LSST_filters']['magnitude:LSST_r:observed:dustAtlas'])
            mag_i = np.array(properties['LSST_filters']['magnitude:LSST_i:observed:dustAtlas'])
            mag_z = np.array(properties['LSST_filters']['magnitude:LSST_z:observed:dustAtlas'])
            mag_y = np.array(properties['LSST_filters']['magnitude:LSST_y:observed:dustAtlas'])
            ra = np.array(properties['ra'])
            dec = np.array(properties['dec'])
            ra_true = np.array(properties['ra_true'])
            dec_true = np.array(properties['dec_true'])


            array_list = np.column_stack([redshift, redshift_hubble, distance, distance_hubble, sfr, sfr_tot, stellar_mass, blackhole_mass, gal_id, mag_u,
                                        mag_g, mag_r, mag_i, mag_z, mag_y
                                          , ra, dec, ra_true, dec_true])
            NEAR_0 = 3
            NEAR_360 = 357
            sim_cat_in = pd.DataFrame(array_list, columns=['redshift', 'redshift_hubble', 'distance', 'distance_hubble', 'sfr','sfr_tot','stellar_mass',
                                                           'blackhole_mass','gal_id','mag_u', 'mag_g', 'mag_r', 'mag_i', 'mag_z', 'mag_y', 'ra', 'dec', 'ra_true', 'dec_true'])
            edge_mask = np.logical_and(sim_cat_in['ra'] < NEAR_360, sim_cat_in['ra'] > NEAR_0)
            cat_list.append(sim_cat_in[edge_mask])
    
            

        sim_cat = pd.concat(cat_list)
        

    return sim_cat



def get_intermediate_random_cat(sim_cat, healpix_id=None, NUM_POINTS=NUM_POINTS):
    """
    Generate a random catalog of RA and DEC and match them with the simulated galaxy catalog.
    """
        ra_min = np.min(sim_cat['ra')
        ra_max = np.max(sim_cat['ra'])
        
        rand_ra = ra_min + (ra_max - ra_min)*np.random.random(size=NUM_POINTS)
        cth_min = np.min(np.sin(np.radians(sim_cat['dec'])))
        cth_max = np.max(np.sin(np.radians(sim_cat['dec'])))
        cth_rand = cth_min + (cth_max - cth_min)*np.random.random(size=NUM_POINTS)
        rand_dec = np.degrees(np.arcsin(cth_rand))
        
        # Convert to HEALPix θ (colatitude) and φ (longitude)
        rand_theta = np.radians(90.0 - rand_dec)  
        rand_phi = np.radians(rand_ra)            
        
        # Get HEALPix pixel id
        rand_pix_id = hp.ang2pix(NSIDE, rand_theta, rand_phi, nest=False) 
        rand_hpix_mask = rand_pix_id == healpix_id
        
        
        rand_cols_list = np.column_stack([rand_ra[rand_hpix_mask], rand_dec[rand_hpix_mask]])
        rand_cat = pd.DataFrame(rand_cols_list, columns=['ra', 'dec'])

        return rand_cat


def get_mock_cat(sim_cat, mag, healpix_id=None, mag_lim_faint=26, mag_lim_bright=None,
                 desi_proxy_column='sfr_tot', NSIDE=32)

        theta = np.radians(90.0 - sim_cat['dec'])  
        phi = np.radians(sim_cat['ra'])            
        
        # Get HEALPix pixel id
        pix_id = hp.ang2pix(NSIDE, theta, phi, nest=False) 
        hpix_mask = pix_id == healpix_id


        conditions = [np.ones_like(mag, dtype=bool)]
        if mag_lim_bright is not None:
            bright_cut = mag > mag_lim_bright
            conditions.append(bright_cut)

        if mag_lim_faint is not None:
            faint_cut = mags < mag_lim_faint
            conditions.append(faint_cut)
        
        tracer_masks = np.logical_and(conditions, hpix_mask)
        sim_cat_masked = sim_cat[tracer_masks]
        
        
        threshold_all = thres_of_z(sim_cat_masked['redshift'])
        desi_proxy_mask = sim_cat_masked[f'{desi_proxy_column}'] > threshold_all
        mock_cat = sim_cat_masked[desi_proxy_mask]

        return mock_cat


def get_final_rand_cat(rand_cat_intermediate, mock_cat, RAND_TO_DATA_RATIO=10):
        
        # add distances and redshift to random catalog
        rand_cat = rand_cat_intermediate.reset_index(drop=True) 
        mock_cat_temp = mock_cat.reset_index(drop=True).sample(len(rand_cat), replace=True)

        rand_cat['distance'] = mock_cat_temp['distance'].to_numpy()
        rand_cat['distance_hubble'] = mock_cat_temp['distance_hubble'].to_numpy()
        rand_cat['redshift'] = mock_cat_temp['redshift'].to_numpy()
        
        # mask the random cat to have around ten times as many objects as the mock catalog
        num_to_keep = len(mock_cat_temp) * RAND_TO_DATA_RATIO # number of objects to keep for randoms
        rand_cat_final = rand_cat.iloc[:NUM_TO_KEEP]
    
    return rand_cat_final


In [None]:

class DesiEmulator:
    healpix_ids_path = '/global/homes/y/yoki/roman/desi_like_samples/skysim_5000/data/healpix_ids/id_nums_exclude_edges.npy'
    healpix_ids = np.load(healpix_ids_path)
    TOTAL_DESI_AREA = 14000
    NSIDE=32
    AREA_PER_HEALPIX = 57.071968/17
    NEAR_O = 3
    NEAR_360 = 357
    
    def __init__(self,
                 desi_data_path = '/global/homes/y/yoki/roman/desi_like_samples/skysim_5000/data/desi_sv_data/desi_elg_ts_zenodo/main-800coaddefftime1200-nz-zenodo.ecsv',
                 galaxy_type='ELG',
                 total_target_density=1930 + 1950 + 1900,
                 target_density_regions=3, 
                 RAND_TO_DATA_RATIO = 10
                ):
        
        self.total_target_density = total_target_density
        self.target_density_regions = target_density_regions
        self.desi_data_path = desi_data_path
        self.galaxy_type = galaxy_type
        self.RAND_TO_DATA_RATIO = RAND_TO_DATA_RATIO
        self.desi_data = table.Table.read(desi_data_path, format='ascii.ecsv')

        self.zmin = self.desi_data['ZMIN']
        self.zmax = self.desi_data['ZMAX']
        self.z_bin_centers = (self.zmin + self.zmax) / 2

        self.TARG_DENS_AVG = (total_target_density) / 3
        self.NUM_POINTS = int(self.TARG_DENS_AVG * AREA_PER_HEALPIX * 40) # Factor of 40 is there for overkill 
    
        if self.galaxy_type == 'ELG':
            self.selection_column = 'sfr_tot'
            self.selection_column_dc2 = 'totalStarFormationRate'

        def save_sim_cat_(healpix_id=None):
    # print(f'{worker_index} is processing {healpix_id}')
    

        cat_list = []
        z_range_skysim = [[0,1], [1,2]]

        for z in z_range_skysim:
        
        filepath = '/global/cfs/cdirs/lsst/shared/xgal/skysim/skysim5000_v1.1.1'
        h5_filename = f'/z_{z[0]}_{z[1]}.step_all.healpix_{healpix_id}.hdf5' # assuming all healpix files have same root file 
        h5f = filepath + h5_filename
    
    
        with h5py.File(h5f, 'r') as file:
            
         
            
            properties = file['galaxyProperties']
            redshift = np.array(properties['redshift'])
            # r = cosmo.comoving_distance(redshift).value # units of Mpc
            sfr = np.array(properties['baseDC2']['sfr'])
            sfr_tot = np.array(properties['totalStarFormationRate'])
            # stellar_mass = np.array(properties['totalMassStellar'])
            # blackhole_mass = np.array(properties['blackHoleMass'])
            # gal_id = np.array(properties['galaxyID'])
            # mag_u = np.array(properties['LSST_filters']['magnitude:LSST_u:observed:dustAtlas']) # mags with no MW extinction corrections
            mag_g = np.array(properties['LSST_filters']['magnitude:LSST_g:observed:dustAtlas'])
            # mag_r = np.array(properties['LSST_filters']['magnitude:LSST_r:observed:dustAtlas'])
            # mag_i = np.array(properties['LSST_filters']['magnitude:LSST_i:observed:dustAtlas'])
            # mag_z = np.array(properties['LSST_filters']['magnitude:LSST_z:observed:dustAtlas'])
            # mag_y = np.array(properties['LSST_filters']['magnitude:LSST_y:observed:dustAtlas'])
            ra = np.array(properties['ra'])
            # dec = np.array(properties['dec'])
            # ra_true = np.array(properties['ra_true'])
            # dec_true = np.array(properties['dec_true'])


            array_list = np.column_stack([redshift, sfr, sfr_tot, ra, mag_g])

            sim_cat_in = pd.DataFrame(array_list, columns=['redshift', 'sfr','sfr_tot', 'ra', 'mag_g'])
            elg_mag_cut = sim_cat_in['mag_g'] < G_MAG_CUT
            edge_mask = np.logical_and(sim_cat_in['ra'] < NEAR_360, sim_cat_in['ra'] > NEAR_0)
            masks = np.logical_and(elg_mag_cut, edge_mask)
            cat_list.append(sim_cat_in[masks])

        sim_cat = pd.concat(cat_list)
    
    def get_selection_threshold(self, bins):
        cat_list = []
        
        
        for hpix in self.healpix_ids:
        
            input_file_name = f'sim_cat_hpix_{hpix}.parquet'
            sim_cat_path = f'/global/homes/y/yoki/roman/desi_like_samples/skysim_5000/data/sim_data/{input_file_name}'
            sim_cat_loop = pd.read_parquet(path=sim_cat_path)
            cat_list.append(sim_cat_loop)
        
    
        sim_cat = pd.concat(cat_list)

        if galaxy_type == 'ELG'

        
        elif galaxy_type == 'LRG'

    

    

    def generate(self, healpix_id=None, z_range=[[0, 1], [1, 2]]):
        z_range = 
        cat_list = []

        for z in z_range:
            filepath = f'/global/cfs/cdirs/lsst/shared/xgal/skysim/skysim5000_v1.1.1/z_{z[0]}_{z[1]}.step_all.healpix_{healpix_id}.hdf5'
            with h5py.File(filepath, 'r') as file:
                properties = file['galaxyProperties']
                sim_cat_in = pd.DataFrame({
                    'redshift': np.array(properties['redshift']),
                    'redshift_hubble': np.array(properties['redshiftHubble']),
                    'distance': cosmo.comoving_distance(np.array(properties['redshift'])).value,
                    'distance_hubble': cosmo.comoving_distance(np.array(properties['redshiftHubble'])).value,
                    'sfr': np.array(properties['baseDC2']['sfr']),
                    'sfr_tot': np.array(properties['totalStarFormationRate']),
                    'stellar_mass': np.array(properties['totalMassStellar']),
                    'blackhole_mass': np.array(properties['blackHoleMass']),
                    'gal_id': np.array(properties['galaxyID']),
                    'mag_u': np.array(properties['LSST_filters']['magnitude:LSST_u:observed:dustAtlas']),
                    'mag_g': np.array(properties['LSST_filters']['magnitude:LSST_g:observed:dustAtlas']),
                    'mag_r': np.array(properties['LSST_filters']['magnitude:LSST_r:observed:dustAtlas']),
                    'mag_i': np.array(properties['LSST_filters']['magnitude:LSST_i:observed:dustAtlas']),
                    'mag_z': np.array(properties['LSST_filters']['magnitude:LSST_z:observed:dustAtlas']),
                    'mag_y': np.array(properties['LSST_filters']['magnitude:LSST_y:observed:dustAtlas']),
                    'ra': np.array(properties['ra']),
                    'dec': np.array(properties['dec']),
                    'ra_true': np.array(properties['ra_true']),
                    'dec_true': np.array(properties['dec_true']),
                })

                edge_mask = np.logical_and(sim_cat_in['ra'] < self.NEAR_360, sim_cat_in['ra'] > self.NEAR_0)
                cat_list.append(sim_cat_in[edge_mask])

        sim_cat = pd.concat(cat_list)

        # Random catalog generation
        ra_min, ra_max = sim_cat['ra'].min(), sim_cat['ra'].max()
        rand_ra = np.random.uniform(ra_min, ra_max, size=self.NUM_POINTS)
        cth_min, cth_max = np.sin(np.radians(sim_cat['dec'].min())), np.sin(np.radians(sim_cat['dec'].max()))
        cth_rand = np.random.uniform(cth_min, cth_max, size=self.NUM_POINTS)
        rand_dec = np.degrees(np.arcsin(cth_rand))

        rand_theta = np.radians(90.0 - rand_dec)
        rand_phi = np.radians(rand_ra)
        rand_pix_id = hp.ang2pix(self.NSIDE, rand_theta, rand_phi, nest=False)
        rand_hpix_mask = rand_pix_id == healpix_id
        rand_cat = pd.DataFrame({'ra': rand_ra[rand_hpix_mask], 'dec': rand_dec[rand_hpix_mask]})

        elg_theta = np.radians(90.0 - sim_cat['dec'])
        elg_phi = np.radians(sim_cat['ra'])
        elg_pix_id = hp.ang2pix(self.NSIDE, elg_theta, elg_phi, nest=False)
        elg_hpix_mask = elg_pix_id == healpix_id

        sim_cat_gmag_mask = sim_cat['mag_g'] < self.LOP_ELG_MAG_CUTOFF
        elg_masks = np.logical_and(sim_cat_gmag_mask, elg_hpix_mask)
        sim_cat_masked = sim_cat[elg_masks]

        threshold_all = self.thres_of_z(sim_cat_masked['redshift'])
        sfr_mask = sim_cat_masked['sfr_tot'] > threshold_all
        mock_elg_cat = sim_cat_masked[sfr_mask]

        rand_cat = rand_cat.reset_index(drop=True)
        mock_elg_sampled = mock_elg_cat.reset_index(drop=True).sample(len(rand_cat), replace=True)

        rand_cat['distance'] = mock_elg_sampled['distance'].values
        rand_cat['distance_hubble'] = mock_elg_sampled['distance_hubble'].values
        rand_cat['redshift'] = mock_elg_sampled['redshift'].values

        num_to_keep = len(mock_elg_cat) * self.RAND_TO_DATA_RATIO
        rand_cat_final = rand_cat.iloc[:num_to_keep]

        # rand_cat_final.to_parquet(f"{rand_output_dir}/rand_elg_cat_hpix_{healpix_id}.parquet")
        # mock_elg_cat.to_parquet(f"{mock_output_dir}/mock_elg_cat_hpix_{healpix_id}.parquet")
