In [3]:
# data_processing.py

import h5py
import numpy as np
from synthesizer.conversions import lnu_to_absolute_mag
import pandas as pd
import unyt
from unyt import erg, Hz, s
import cmasher as cmr
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import sys
import glob

sys.path.append("/home/jovyan/camels/proj1/")
from setup_params import get_photometry, get_luminosity_function, get_colour_distribution, get_safe_name, get_colour_dir_name, get_magnitude_mask
from variables_config import get_config


def process_data(input_dir, redshift_values, uvlf_limits, n_bins_lf, lf_data_dir, 
                colour_limits, n_bins_colour, colour_data_dir, category, bands, 
                colour_pairs=None, mag_limits=None, simulation=None, dataset="CV"):
    """Process data for any combination of bands and color pairs"""
    sim_prefix = "Simba" if simulation == "SIMBA" else simulation
    
    photo_files = [f for f in os.listdir(input_dir) if f.endswith('_photometry.hdf5')]
    
    for filename in photo_files:
        sim_name = filename.replace(f'{sim_prefix}_', '').replace('_photometry.hdf5', '')
        
        for snap, redshift_info in redshift_values.items():
            try:
                spec_type = "intrinsic" if category == "intrinsic" else "attenuated"
                
                # Process filters
                if bands is not None:
                    filters_to_process = [bands] if isinstance(bands, str) else bands
                    
                    photo = get_photometry(
                        sim_name=sim_name,
                        spec_type=spec_type,
                        snap=snap,
                        sps="BC03",
                        model=sim_prefix,
                        filters=filters_to_process,
                        photo_dir=input_dir
                    )
                    
                    # Process UVLFs
                    for band in filters_to_process:
                        phi, phi_sigma, hist, bin_lims = get_luminosity_function(
                            photo, band, *uvlf_limits, n_bins=n_bins_lf
                        )
                        
                        bin_centers = 0.5 * (bin_lims[1:] + bin_lims[:-1])
                        uvlf_df = pd.DataFrame({
                            'magnitude': bin_centers,
                            'phi': phi,
                            'phi_sigma': phi_sigma,
                            'hist': hist
                        })
                        
                        filter_system = get_safe_name(band, filter_system_only=True)
                        output_dir = os.path.join(lf_data_dir[category][filter_system], 
                                                get_safe_name(redshift_info['label']))
                        os.makedirs(output_dir, exist_ok=True)
                        
                        uvlf_filename = f"UVLF_{sim_name}_{get_safe_name(band)}_{get_safe_name(redshift_info['label'])}_{spec_type}.txt"
                        output_path = os.path.join(output_dir, uvlf_filename)
                        uvlf_df.to_csv(output_path, index=False, sep='\t')
                
                # Process colours
                if colour_pairs:
                    for band1, band2 in colour_pairs:
                        if band1 in photo and band2 in photo:
                            mask = get_magnitude_mask(photo, [band1, band2], mag_limits)
                            colour_dist, bin_lims = get_colour_distribution(
                                photo, band1, band2, *colour_limits,
                                n_bins=n_bins_colour, mask=mask
                            )
                            
                            bin_centers = 0.5 * (bin_lims[1:] + bin_lims[:-1])
                            colour_df = pd.DataFrame({
                                'colour': bin_centers,
                                'distribution': colour_dist
                            })
                            
                            filter_system = get_colour_dir_name(band1, band2)
                            output_dir = os.path.join(colour_data_dir[category],
                                                    filter_system,
                                                    get_safe_name(redshift_info['label']))
                            os.makedirs(output_dir, exist_ok=True)
                            
                            colour_filename = f"Colour_{sim_name}_{filter_system}_{get_safe_name(redshift_info['label'])}_{spec_type}.txt"
                            output_path = os.path.join(output_dir, colour_filename)
                            colour_df.to_csv(output_path, index=False, sep='\t')
                
            except Exception as e:
                continue

def process_all_data(input_dir, redshift_values, uvlf_limits, n_bins_lf, lf_data_dir, 
                    colour_limits, n_bins_colour, colour_data_dir, mag_limits, 
                    simulation=None, dataset="CV"):
    """Process all combinations of data types and bands"""
    config = get_config(dataset=dataset, simulation=simulation)
    band_combinations = config["filters"]
    colour_combinations = config["colour_pairs"]
    
    for category in ["attenuated", "intrinsic"]:
        # Process bands for UVLF
        process_data(
            input_dir=input_dir,
            redshift_values=redshift_values,
            uvlf_limits=uvlf_limits,
            n_bins_lf=n_bins_lf,
            lf_data_dir=lf_data_dir,
            colour_limits=colour_limits,
            n_bins_colour=n_bins_colour,
            colour_data_dir=colour_data_dir,
            category=category,
            bands=band_combinations[category],
            colour_pairs=colour_combinations,  # Pass the color pairs here
            mag_limits=mag_limits,
            simulation=simulation,
            dataset=dataset
        )


In [None]:
if __name__ == "__main__":
    simulations = ["IllustrisTNG", "SIMBA", "Astrid", "Swift-EAGLE"]
    datasets = ["CV"]
    
    for simulation in simulations:
        print(f"\nProcessing {simulation}")
        for dataset in datasets:
            config = get_config(dataset=dataset, simulation=simulation)
            
            # Process all data including colors
            process_all_data(
                input_dir=config["input_dir"],
                redshift_values=config["redshift_values"],
                uvlf_limits=config["uvlf_limits"],
                n_bins_lf=config["n_bins_lf"],
                lf_data_dir=config["lf_data_dir"],
                colour_limits=config["colour_limits"],
                n_bins_colour=config["n_bins_colour"],
                colour_data_dir=config["colour_data_dir"],
                mag_limits=config["mag_limits"],
                simulation=simulation,
                dataset=dataset
            )