In [1]:
# plotting.py
import h5py
import numpy as np
from synthesizer.conversions import lnu_to_absolute_mag
import pandas as pd
import unyt
from unyt import erg, Hz, s
import cmasher as cmr
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import sys
import glob

sys.path.append("/home/jovyan/camels/proj1/")
from variables_config import get_config

sys.path.append("/home/jovyan/camels/proj2/")
from setup_params_1P import  get_safe_name, get_colour_dir_name


def get_simulation_color(simulation):
    """Get standard color for each simulation"""
    color_map = {
        "IllustrisTNG": "blue",
        "SIMBA": "green",
        "Astrid": "red",
        "Swift-EAGLE": "orange"
    }
    return color_map.get(simulation, "gray")

def load_simulation_data(simulations, config):
    """Load data files organizing by simulation, redshift, category, and band."""
    all_uvlf_data = {}
    all_colour_data = {}
    
    for simulation in simulations:
        config = get_config(dataset="CV", simulation=simulation)
        sim_uvlf_data = {}
        sim_colour_data = {}
        
        # Load UVLF data
        for snap, redshift_info in config["redshift_values"].items():
            z_data = {}
            for category in ['intrinsic', 'attenuated']:
                z_data[category] = {}  # Initialize dictionary for each category
                
                for band in config["filters"][category]:
                    filter_system = get_safe_name(band, filter_system_only=True)
                    data_dir = os.path.join(config["lf_data_dir"][category][filter_system],
                                          get_safe_name(redshift_info['label']))
                    
                    pattern = f"UVLF_CV_*_{get_safe_name(band)}_{redshift_info['label']}_{category}.txt"
                    file_list = glob.glob(os.path.join(data_dir, pattern))
                    
                    if file_list:
                        # print(f"Found {len(file_list)} files for {simulation}, z={redshift_info['redshift']}, {category}, {band}")
                        all_data = []
                        for filename in file_list:
                            try:
                                data = pd.read_csv(filename, sep='\t')
                                all_data.append(data)
                            except Exception as e:
                                print(f"Error loading {filename}: {e}")
                        
                        if all_data:
                            # Store data by band name
                            band_key = get_safe_name(band)
                            z_data[category][band_key] = {
                                'magnitude': all_data[0]['magnitude'].values,
                                'phi': np.array([df['phi'].values for df in all_data])
                            }
            
            if z_data:
                sim_uvlf_data[redshift_info['redshift']] = z_data
        
        # Load colour data with correct naming
        for band1, band2 in config["colour_pairs"]:
            colour_dir = get_colour_dir_name(band1, band2)
            colour_key = f'{get_safe_name(band1)}-{get_safe_name(band2)}'
            
            if colour_key not in sim_colour_data:
                sim_colour_data[colour_key] = {}
                
            for category in ['intrinsic', 'attenuated']:
                for snap, redshift_info in config["redshift_values"].items():
                    data_dir = os.path.join(config["colour_data_dir"][category],
                                           colour_dir,
                                           get_safe_name(redshift_info['label']))
                    
                    # The pattern might need adjustment based on your actual file naming
                    pattern = f"Colour_CV_*_{colour_dir}_{redshift_info['label']}_{category}.txt"
                    file_list = glob.glob(os.path.join(data_dir, pattern))
                    
                    if file_list:
                        # print(f"Found {len(file_list)} colour files for {simulation}, z={redshift_info['redshift']}, {category}, {colour_key}")
                        all_data = []
                        for filename in file_list:
                            try:
                                data = pd.read_csv(filename, sep='\t')
                                all_data.append(data)
                            except Exception as e:
                                print(f"Error loading {filename}: {e}")
                        
                        if all_data:
                            if redshift_info['redshift'] not in sim_colour_data[colour_key]:
                                sim_colour_data[colour_key][redshift_info['redshift']] = {}
                            
                            if category not in sim_colour_data[colour_key][redshift_info['redshift']]:
                                sim_colour_data[colour_key][redshift_info['redshift']][category] = {}
                            
                            sim_colour_data[colour_key][redshift_info['redshift']][category] = {
                                'colour': all_data[0]['colour'].values,
                                'distribution': np.array([df['distribution'].values for df in all_data])
                            }
        
        all_uvlf_data[simulation] = sim_uvlf_data
        all_colour_data[simulation] = sim_colour_data
    
    return all_uvlf_data, all_colour_data



  from synthesizer.filters import Filter, FilterCollection


In [2]:

def plot_combined_colours_mean_only(all_sims_data, redshifts, output_dir, colour_pairs=None):
    """Create a multi-panel plot showing mean colour distributions for all simulations."""
    if colour_pairs is None:
        colour_pairs = [('GALEX FUV', 'GALEX NUV')]
    
    for band1, band2 in colour_pairs:
        colour_key = f'{get_safe_name(band1)}-{get_safe_name(band2)}'
        display_name = f'{band1} - {band2}'
        
        num_redshifts = len(redshifts)
        fig, axes = plt.subplots(1, num_redshifts, figsize=(20, 6))
        if num_redshifts == 1:
            axes = [axes]
        
        plotted_sims = []
        
        for ax_idx, (ax, redshift_info) in enumerate(zip(axes, redshifts)):
            z = redshift_info['redshift']
            
            for sim_name, sim_colour_data in all_sims_data.items():
                # print(f"Checking {sim_name} at z={z} for {colour_key}")
                if colour_key not in sim_colour_data:
                    print(f"  Color key {colour_key} not found. Available keys: {list(sim_colour_data.keys())}")
                    continue
                
                if z not in sim_colour_data[colour_key]:
                    print(f"  Redshift {z} not found for {colour_key}")
                    continue
                
                color = get_simulation_color(sim_name)
                z_data = sim_colour_data[colour_key][z]
                
                for category in ['attenuated']:
                    if category not in z_data:
                        print(f"  Category {category} not found")
                        continue
                    
                    print(f"  Found data for {sim_name}, z={z}, {category}")
                    
                    # Get the colour data
                    colours = z_data[category]['colour']
                    dist_arrays = z_data[category]['distribution']
                    
                    # Calculate mean distribution
                    mean_dist = np.mean(dist_arrays, axis=0)
                    
                    # Plot with appropriate label
                    label = sim_name if sim_name not in plotted_sims else None
                    ax.plot(colours, mean_dist, '-', color=color, linewidth=2, label=label)
                    
                    if label:
                        plotted_sims.append(sim_name)
            
            ax.set_xlabel(f'{display_name} [mag]', fontsize=10)
            if ax == axes[0]:
                ax.set_ylabel('Normalised Count', fontsize=10)
                ax.legend(loc='upper left', fontsize=9, frameon=True)
            
            ax.set_xlim(-0.5, 3.5)
            ax.set_ylim(0, 2.0)
            ax.grid(True, alpha=0.3)
            ax.text(0.95, 0.95, f'z = {z}', transform=ax.transAxes, 
                    fontsize=10, ha='right', va='top')
            ax.tick_params(axis='both', which='major', labelsize=8)
        
        # Add title
        # fig.suptitle(f'{display_name} Colour Distribution', fontsize=14)
        
        plt.subplots_adjust(wspace=0.1, top=0.9)
        safe_name = f'{get_safe_name(band1)}_{get_safe_name(band2)}'
        plt.savefig(os.path.join(output_dir, f'combined_colours_CV_mean_{safe_name}.pdf'), 
                    bbox_inches='tight', dpi=300)
        plt.savefig(os.path.join(output_dir, f'combined_colours_CV_mean_{safe_name}.png'), 
                    bbox_inches='tight', dpi=300)
        plt.close()

In [3]:


def plot_combined_uvlf_mean_only(all_sims_data, redshifts, output_dir, bands=None):
    """Create multi-panel plots showing mean UVLFs for all simulations by band."""
    if bands is None:
        bands = ['GALEX_FUV', 'GALEX_NUV']
    
    for band in bands:
        band_key = get_safe_name(band)
        num_redshifts = len(redshifts)
        fig, axes = plt.subplots(1, num_redshifts, figsize=(20, 6))
        if num_redshifts == 1:
            axes = [axes]
        
        plotted_sims = []
        
        for ax_idx, (ax, redshift_info) in enumerate(zip(axes, redshifts)):
            z = redshift_info['redshift']
            
            for sim_name, sim_data in all_sims_data.items():
                if z not in sim_data:
                    continue
                    
                color = get_simulation_color(sim_name)
                z_data = sim_data[z]
                
                for category in ['intrinsic', 'attenuated']:
                    if category not in z_data:
                        continue
                    
                    # Check if we have data for this specific band
                    if band_key not in z_data[category]:
                        print(f"Band {band_key} not found for {sim_name}, z={z}, {category}")
                        continue
                    
                    # Get band-specific data
                    band_data = z_data[category][band_key]
                    magnitudes = band_data['magnitude']
                    phi_arrays = band_data['phi']
                    
                    # Calculate mean phi
                    mean_phi = np.mean(phi_arrays, axis=0)
                    valid_points = mean_phi > -5.9
                    
                    if np.any(valid_points):
                        linestyle = '--' if category == 'intrinsic' else '-'
                        alpha_mean = 0.5 if category == 'intrinsic' else 1.0
                        
                        label = sim_name if category == 'attenuated' and sim_name not in plotted_sims else None
                        ax.plot(magnitudes[valid_points], mean_phi[valid_points], 
                               linestyle, color=color, alpha=alpha_mean, linewidth=2, label=label)
                        
                        if label:
                            plotted_sims.append(sim_name)
            
            ax.set_xlabel('M$_{UV}$ [AB mag]', fontsize=10)
            if ax == axes[0]:
                ax.set_ylabel('log$_{10}$ $\phi$ [Mpc$^{-3}$ mag$^{-1}$]', fontsize=10)
                ax.legend(loc='upper left', fontsize=9, frameon=True)
            
            ax.set_ylim(-5, -1.6)
            ax.set_xlim(-26.5, -14.5)
            ax.grid(True, alpha=0.3)
            ax.text(0.95, 0.05, f'z = {z}', transform=ax.transAxes, 
                    fontsize=10, ha='right', va='top')
            ax.tick_params(axis='both', which='major', labelsize=8)
        
        # Add title showing the band
        # fig.suptitle(f'{band} Luminosity Function', fontsize=14)
        
        plt.subplots_adjust(wspace=0.1, top=0.9)
        plt.savefig(os.path.join(output_dir, f'combined_uvlf_CV_mean_{band_key}.pdf'), 
                    bbox_inches='tight', dpi=300)
        plt.savefig(os.path.join(output_dir, f'combined_uvlf_CV_mean_{band_key}.png'), 
                    bbox_inches='tight', dpi=300)
        plt.close()



In [4]:

# def plot_redshift_evolution_lf(all_sims_data, redshifts, output_dir, bands=['GALEX_FUV', 'GALEX_NUV']):
#     """
#     Create a multi-panel plot showing redshift evolution of luminosity functions
#     for different bands and simulations, similar to Figure 4 in Lovell et al. 2024.
#     """
#     n_bands = len(bands)
#     n_sims = len(all_sims_data)
    
#     # Create a colormap for redshifts
#     cmap = plt.cm.viridis
#     z_values = [z_info['redshift'] for z_info in redshifts]
#     z_values.sort()  # Ensure they're in ascending order
#     z_norm = plt.Normalize(min(z_values), max(z_values))
    
#     # Create figure with a grid of panels
#     fig, axes = plt.subplots(n_bands, n_sims, figsize=(4*n_sims, 3*n_bands), 
#                            sharex='col', sharey='row')
    
#     # Ensure axes is 2D even with a single band or simulation
#     if n_bands == 1 and n_sims == 1:
#         axes = np.array([[axes]])
#     elif n_bands == 1:
#         axes = np.array([axes])
#     elif n_sims == 1:
#         axes = np.array([[ax] for ax in axes])
    
#     # Band labels for the right side of the plot
#     band_labels = {
#         'GALEX_FUV': 'GALEX FUV',
#         'GALEX_NUV': 'GALEX NUV',
#         'SDSS_g': 'SDSS g',
#         'SDSS_i': 'SDSS i',
#         'UKIRT_K': 'UKIRT K'
#     }
    
#     # Loop through each simulation and band
#     for sim_idx, (sim_name, sim_data) in enumerate(all_sims_data.items()):
#         for band_idx, band in enumerate(bands):
#             band_key = get_safe_name(band)
#             ax = axes[band_idx, sim_idx]
            
#             # Set title for top row only
#             if band_idx == 0:
#                 ax.set_title(sim_name, fontsize=14)


#             # Adjust the y-labels spacing
#             if sim_idx == 0:
#                 ax.set_ylabel(r'$\phi$ / Mpc$^{-3}$ dex$^{-1}$', fontsize=12, labelpad=10)  # Add labelpad
            
#             # Add band label with more spacing
#             if sim_idx == n_sims-1:
#                 ax.text(1.10, 0.5, band_labels.get(band, band),  # Increase from 1.02 to 1.10
#                        transform=ax.transAxes, rotation=270, 
#                        fontsize=12, va='center')
            
 
#             # Set y-label for leftmost column only
#             if sim_idx == 0:
#                 ax.set_ylabel(r'$\phi$ / Mpc$^{-3}$ dex$^{-1}$', fontsize=12)
            
#             # Set x-label for bottom row only
#             if band_idx == n_bands-1:
#                 ax.set_xlabel(r'$M_{\mathrm{AB}}$', fontsize=12)
            
#             # Loop through redshifts in ascending order
#             for z_info in sorted(redshifts, key=lambda x: x['redshift']):
#                 z = z_info['redshift']
                
#                 if z not in sim_data:
#                     continue
                
#                 z_data = sim_data[z]
                
#                 # Plot both intrinsic and attenuated data
#                 for category in ['intrinsic', 'attenuated']:
#                     if category not in z_data:
#                         continue
                    
#                     # Check if we have data for this specific band
#                     if band_key not in z_data[category]:
#                         continue
                    
#                     # Get band-specific data
#                     band_data = z_data[category][band_key]
#                     magnitudes = band_data['magnitude']
#                     phi_arrays = band_data['phi']
                    
#                     # Calculate mean phi
#                     mean_phi = np.mean(phi_arrays, axis=0)
                    
#                     # Plot with color based on redshift
#                     color = cmap(z_norm(z))
#                     linestyle = '--' if category == 'intrinsic' else '-'
#                     alpha = 0.3 if category == 'intrinsic' else 0.7
#                     ax.plot(magnitudes, mean_phi, linestyle, color=color, alpha=alpha, 
#                            linewidth=1.5 if category == 'attenuated' else 1)
            
#             # Set axis limits
#             ax.set_ylim(-5.0, -2.0)
#             ax.set_xlim(-27, -15)
            
#             # Add band label to right side of rightmost panels
#             if sim_idx == n_sims-1:
#                 ax.text(1.02, 0.5, band_labels.get(band, band), 
#                        transform=ax.transAxes, rotation=270, 
#                        fontsize=12, va='center')
            
#             # Add grid
#             ax.grid(True, alpha=0.3)
    
#     # Add colorbar for redshift
#     sm = plt.cm.ScalarMappable(cmap=cmap, norm=z_norm)
#     sm.set_array([])

#     # Use fig.colorbar instead of manually defining axes
#     cbar = fig.colorbar(sm, ax=axes, orientation='horizontal', fraction=0.05, pad=0.25)
#     cbar.set_label('Redshift', fontsize=12)

#     # Adjust layout for better spacing
#     plt.subplots_adjust(right=0.88, hspace=0.3, wspace=0.2, bottom=0.25)  # Increase bottom to 0.15 for spacing


#     # # Add colorbar at the bottom center with fixed width
#     # # This uses figure-relative coordinates
#     # fig_width = 0.5  # Width of colorbar as fraction of figure width
#     # fig_bottom = 0.07  # Distance from bottom of figure
#     # fig_height = 0.02  # Height of colorbar
#     # # Center the colorbar at the bottom of the figure
#     # cbar_width = 0.5 * fig_width  # Make it 50% of figure width
#     # cbar_height = 0.015
#     # cbar_left = (fig_width - cbar_width) / (2 * fig_width)  # Center it horizontally
    
#     # # Add thin colorbar at bottom
    
#     # # Center the colorbar
#     # cbar_ax = fig.add_axes([0.5 - fig_width/2, fig_bottom, fig_width, -1])
#     # cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
#     # cbar.set_label('Redshift', fontsize=12)

#    # # Reposition colorbar with more space
#    #  cbar_ax = fig.add_axes([0.15, 0.03, 0.3, 0.015])  # Increase bottom value from 0.00 to 0.03

#     # cbar_ax = fig.add_axes([cbar_centre, 0.03, cbar_width/fig_width, cbar_height])
#     # cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
#     # cbar.set_label('Redshift', fontsize=12)

#     # # Center the colorbar
#     # cbar_ax = fig.add_axes([0.5 - fig_width/2, fig_bottom, fig_width, fig_height])
#     # cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
#     # cbar.set_label('Redshift', fontsize=12)


#     # Adjust layout
#     # plt.subplots_adjust(right=0.92, hspace=0.3, wspace=0.2, bottom=0.1)
#     # Improve overall spacing
#     # plt.subplots_adjust(right=0.88, hspace=0.3, wspace=0.2, bottom=0.12)  # Reduce right from 0.92 to 0.88

#     # Save figure
#     pdf_path = os.path.join(output_dir, 'redshift_evolution_lf.pdf')
#     png_path = os.path.join(output_dir, 'redshift_evolution_lf.png')
    
#     # plt.savefig(pdf_path, dpi=300, bbox_inches='tight')
#     # plt.savefig(png_path, dpi=300, bbox_inches='tight')
#     plt.savefig(pdf_path, dpi=300, bbox_inches='tight', pad_inches=0.3)
#     plt.savefig(png_path, dpi=300, bbox_inches='tight', pad_inches=0.3)
#     plt.close()

In [21]:
import matplotlib.pyplot as plt
import numpy as np
import os

def plot_redshift_evolution_lf(all_sims_data, redshifts, output_dir, bands=['GALEX_FUV', 'GALEX_NUV']):
    """
    Create a multi-panel plot showing redshift evolution of luminosity functions
    for different bands and simulations, similar to Figure 4 in Lovell et al. 2024.
    """
    n_bands = len(bands)
    n_sims = len(all_sims_data)
    
    # Create a colormap for redshifts
    cmap = plt.cm.viridis
    z_values = [z_info['redshift'] for z_info in redshifts]
    z_values.sort()
    z_norm = plt.Normalize(min(z_values), max(z_values))
    
    # Create figure with a grid of panels
    fig, axes = plt.subplots(n_bands, n_sims, figsize=(4*n_sims, 3*n_bands), 
                             sharex='col', sharey='row')
    
    if n_bands == 1 and n_sims == 1:
        axes = np.array([[axes]])
    elif n_bands == 1:
        axes = np.array([axes])
    elif n_sims == 1:
        axes = np.array([[ax] for ax in axes])
    
    # Band labels for the right side of the plot
    band_labels = {
        'GALEX_FUV': 'GALEX FUV',
        'GALEX_NUV': 'GALEX NUV'
    }
    
    # Loop through each simulation and band
    for sim_idx, (sim_name, sim_data) in enumerate(all_sims_data.items()):
        for band_idx, band in enumerate(bands):
            ax = axes[band_idx, sim_idx]
            
            # Set title for top row only
            if band_idx == 0:
                ax.set_title(sim_name, fontsize=14)

            # Set y-label for leftmost column only
            if sim_idx == 0:
                ax.set_ylabel(r'$\phi$ / Mpc$^{-3}$ dex$^{-1}$', fontsize=12)
            
            # Set x-label for bottom row only
            if band_idx == n_bands-1:
                ax.set_xlabel(r'$M_{\mathrm{AB}}$', fontsize=12)
            
            # Loop through redshifts in ascending order
            for z_info in sorted(redshifts, key=lambda x: x['redshift']):
                z = z_info['redshift']
                
                if z not in sim_data:
                    continue
                
                z_data = sim_data[z]
                
                for category in ['intrinsic', 'attenuated']:
                    if category not in z_data:
                        continue
                    
                    if band not in z_data[category]:
                        continue
                    
                    # Get band-specific data
                    band_data = z_data[category][band]
                    magnitudes = band_data['magnitude']
                    phi_arrays = band_data['phi']
                    
                    # Calculate mean phi
                    mean_phi = np.mean(phi_arrays, axis=0)
                    
                    # Plot with color based on redshift
                    color = cmap(z_norm(z))
                    linestyle = '--' if category == 'intrinsic' else '-'
                    alpha = 0.3 if category == 'intrinsic' else 0.7
                    ax.plot(magnitudes, mean_phi, linestyle, color=color, alpha=alpha, 
                            linewidth=1.5 if category == 'attenuated' else 1)
            
            # Set axis limits
            ax.set_ylim(-5.0, -2.0)
            ax.set_xlim(-27, -15)
            
            # Add band label to rightmost column (corrected for FUV & NUV)
            if sim_idx == n_sims - 1:
                ax.text(1.05, 0.5, band_labels.get(band, band), 
                        transform=ax.transAxes, rotation=270, 
                        fontsize=12, va='center')
            
            ax.grid(True, alpha=0.3)
        
    # Add colorbar at the bottom (Moved Even Lower, Skinnier, and Less Wide)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=z_norm)
    sm.set_array([])
    
    cbar_ax = fig.add_axes([0.3, 0.025, 0.4, 0.015])  # Lower, Thinner, and Narrower
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
    cbar.set_label('Redshift', fontsize=12)
    
    # Adjust layout without affecting subplot spacing
    # plt.subplots_adjust(right=0.88, bottom=0.13)  # Increased bottom space even more
    plt.subplots_adjust(right=0.88, bottom=0.13, hspace=0.1, wspace=0.1)  

    # Save figure
    pdf_path = os.path.join(output_dir, 'redshift_evolution_lf.pdf')
    png_path = os.path.join(output_dir, 'redshift_evolution_lf.png')
    
    plt.savefig(pdf_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.savefig(png_path, dpi=300, bbox_inches='tight', pad_inches=0.1)
    plt.close()


In [22]:


def main():
    simulations = ["IllustrisTNG", "SIMBA", "Astrid", "Swift-EAGLE"]
    config = get_config(dataset="CV", simulation=simulations[0])
    base_plots_dir = os.path.join("/home/jovyan/camels/proj1/CV_set/CV_outputs/plots")
    
    # Choose what you want to do
    MAKE_INDIVIDUAL_PLOTS = False
    MAKE_COMBINED_PLOTS = False
    MAKE_REDSHIFT_EVOLUTION = True
    
    # Define bands here so they're available for all plotting functions
    bands = ['GALEX_FUV', 'GALEX_NUV']
    colour_pairs = [('GALEX FUV', 'GALEX NUV')]
    
    # Load the data - fix the typo here
    loaded_data = load_simulation_data(simulations, config)
    
    # Debug color data structure
    print("Inspecting color data structure:")
    for sim_name, sim_data in loaded_data[1].items():
        print(f"Simulation: {sim_name}")
        print(f"  Keys: {list(sim_data.keys())}")
        for color_key in sim_data:
            print(f"  Color: {color_key}")
            print(f"    Redshifts: {list(sim_data[color_key].keys())}")
            
    if MAKE_INDIVIDUAL_PLOTS:
        # Call with just the two arguments it expects
        create_individual_plots(simulations, config)
    
    if MAKE_COMBINED_PLOTS:
        combined_plot_dir = os.path.join(base_plots_dir, "combined")
        os.makedirs(combined_plot_dir, exist_ok=True)
        
        # Create separate plots for each band
        plot_combined_uvlf_mean_only(loaded_data[0],
                         sorted(config["redshift_values"].values(), key=lambda x: x['redshift']),
                         combined_plot_dir,
                         bands=bands)
        
        # Create separate plots for each colour pair
        plot_combined_colours_mean_only(loaded_data[1],
                           sorted(config["redshift_values"].values(), key=lambda x: x['redshift']),
                           combined_plot_dir,
                           colour_pairs=colour_pairs)
    
    if MAKE_REDSHIFT_EVOLUTION:
        evolution_plot_dir = os.path.join(base_plots_dir, "evolution")
        os.makedirs(evolution_plot_dir, exist_ok=True)
        
        # Create redshift evolution plot (Figure 4 style)
        plot_redshift_evolution_lf(loaded_data[0],
                              sorted(config["redshift_values"].values(), key=lambda x: x['redshift']),
                              evolution_plot_dir, 
                              bands=bands)
        
if __name__ == "__main__":
    main()

Inspecting color data structure:
Simulation: IllustrisTNG
  Keys: ['GALEX_FUV-GALEX_NUV']
  Color: GALEX_FUV-GALEX_NUV
    Redshifts: [2.0, 1.48, 1.05, 0.1]
Simulation: SIMBA
  Keys: ['GALEX_FUV-GALEX_NUV']
  Color: GALEX_FUV-GALEX_NUV
    Redshifts: [2.0, 1.48, 1.05, 0.1]
Simulation: Astrid
  Keys: ['GALEX_FUV-GALEX_NUV']
  Color: GALEX_FUV-GALEX_NUV
    Redshifts: [2.0, 1.48, 1.05, 0.1]
Simulation: Swift-EAGLE
  Keys: ['GALEX_FUV-GALEX_NUV']
  Color: GALEX_FUV-GALEX_NUV
    Redshifts: [2.0, 1.48, 1.05, 0.1]
