# RDME/CME_compare
This is the code used for extracting the data from RDME-ODE simulation results and comparing various trajectories of species with and without ER.

In [1]:
%run env.ipynb
import pickle
import os
import numpy as np
from jLM.RDME import File as RDMEFile
import jLM
import json
import matplotlib.pyplot as plt
import seaborn as sns
from traj_analysis_rdme import *
from tqdm import tqdm
import pandas as pd
import logging
from pyLM import *
from pyLM.units import *
from pySTDLM import *
from pySTDLM import PostProcessing
cme_traj_dir = "/data2/2024_Yeast_GS/my_current_code/my_cme_ode/output/02102025/"
rdme_traj_dir = "/data2/2024_Yeast_GS/my_current_code/rdme_ode_results/20250120_11_1_60_Normal_newtomo"
fig_dir = os.path.join(rdme_traj_dir, 'figures_rdmecme_comparison/')


if not os.path.exists(fig_dir):
    os.makedirs(fig_dir)
# Configure logging
log_file = os.path.join(fig_dir, 'run_log.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)


logging.info(f"This is the file to compare between RDME-ODE and CME-ODE data: {rdme_traj_dir} and {cme_traj_dir}")



2025-02-27 19:39:31,862 - INFO - This is the file to compare between RDME-ODE and CME-ODE data: /data2/2024_Yeast_GS/my_current_code/rdme_ode_results/20250120_11_1_60_Normal_newtomo and /data2/2024_Yeast_GS/my_current_code/my_cme_ode/output/02102025/


default, get data required

In [2]:


rdme_files = [f for f in os.listdir(rdme_traj_dir) if f.startswith('yeast') and f.endswith('.lm')]
cme_files = [f for f in os.listdir(cme_traj_dir) if f.startswith('gal') and f.endswith('.lm')]
traj_suff = "_ode.jsonl"

logging.info(f"RDME-ODE files: {rdme_files}")
logging.info(f"CME-ODE files: {cme_files}")

# Initialize dictionaries to store data for each species
rdme_species_data = {}
rdme_ode_data = {}
rdmeTs = None
odeTs = None
cmeTs = None

# Process RDME files
for traj_file in tqdm(rdme_files, desc="Processing RDME files", unit="file"):
    logging.info(f"Processing ER file: {traj_file}")
    traj, odeTraj, region_traj = get_traj(rdme_traj_dir, traj_file, traj_suff)
    curr_rdmeTs, rdmeYs, curr_odeTs, odeYs, _, _ = get_data_for_plot(traj, odeTraj, region_traj=None, sparse_factor=1)
    NAV = 6.022e23 * (traj.reg.cytoplasm.volume + traj.reg.nucleoplasm.volume + traj.reg.plasmaMembrane.volume)
    if rdmeTs is None:
        rdmeTs = curr_rdmeTs
        odeTs = curr_odeTs

    for species, data in rdmeYs.items():
        if species not in rdme_species_data:
            rdme_species_data[species] = []
        rdme_species_data[species].append(data)

    for species, data in odeYs.items():
        if species not in rdme_ode_data:
            rdme_ode_data[species] = []
        rdme_ode_data[species].append(data)

# Calculate and save RDME statistics
rdme_results = []
for species, trajectories in rdme_species_data.items():
    trajectories_array = np.array(trajectories)
    avg = np.mean(trajectories_array, axis=0)
    std = np.std(trajectories_array, axis=0)
    
    rdme_results.append({
        'Species': f"RDME_{species}",
        'Time': ','.join(map(str, rdmeTs)),
        'Average': ','.join(map(str, avg)),
        'Std': ','.join(map(str, std))
    })

for species, trajectories in rdme_ode_data.items():
    trajectories_array = np.array(trajectories)
    avg = np.mean(trajectories_array, axis=0)
    std = np.std(trajectories_array, axis=0)
    
    rdme_results.append({
        'Species': f"ODE_{species}",
        'Time': ','.join(map(str, odeTs)),
        'Average': ','.join(map(str, avg)),
        'Std': ','.join(map(str, std))
    })


2025-02-27 19:39:31,887 - INFO - RDME-ODE files: ['yeast1.13.1mt_multi_20250105_2_t60.0minGAE11.1mMdt_4gpu_gpu4.lm', 'yeast1.13.1mt_multi_20250115_3_t60.0minGAE11.1mMdt_4gpu_gpu4.lm', 'yeast1.13.1mt_multi_20241229_1_t60.0minGAE11.1mMdt_4gpu_gpu4.lm']
2025-02-27 19:39:31,887 - INFO - CME-ODE files: ['gal_cme_ode_gaecomparison_11.1_gia0_rep50_delta1_time60.lm']
Processing RDME files:   0%|          | 0/3 [00:00<?, ?file/s]2025-02-27 19:39:31,890 - INFO - Processing ER file: yeast1.13.1mt_multi_20250105_2_t60.0minGAE11.1mMdt_4gpu_gpu4.lm
Processing RDME files:  33%|███▎      | 1/3 [00:01<00:03,  1.72s/file]2025-02-27 19:39:33,609 - INFO - Processing ER file: yeast1.13.1mt_multi_20250115_3_t60.0minGAE11.1mMdt_4gpu_gpu4.lm
Processing RDME files:  67%|██████▋   | 2/3 [00:03<00:01,  1.71s/file]2025-02-27 19:39:35,316 - INFO - Processing ER file: yeast1.13.1mt_multi_20241229_1_t60.0minGAE11.1mMdt_4gpu_gpu4.lm
Processing RDME files: 100%|██████████| 3/3 [00:05<00:00,  1.71s/file]


In [3]:
# Calculate and save CME statistics
cme_results = []
cme_traj = PostProcessing.openLMFile(os.path.join(cme_traj_dir + cme_files[0]))
cme_species_list = PostProcessing.getSpecies(cme_traj)


# Reorganize the species list based on the given criteria
cme_species_list = sorted(cme_species_list, key=lambda x: (not x[0].startswith('DG'), not (x[0].startswith('R') or x[0] == 'reporter_rna'), not (x[0].startswith('G') and not x[0].startswith('GA')or x[0] == 'reporter'), x[0].startswith('GA')))
GA_species_list = ['GAI']
general_species_list = [species for species in cme_species_list if species[0] not in GA_species_list]

logging.info("CME species list:")
logging.info(general_species_list)
logging.info(GA_species_list)
logging.info(f"total number of species: {len(GA_species_list)} + {len(general_species_list)}")

avg_list_general = []
var_list_general = []
time_list_general = []

avg_list_GA = []
var_list_GA = []
time_list_GA = []

for species in general_species_list:
    avg, var, times = PostProcessing.getAvgVarTrace(cme_traj, species)
    avg_list_general.append(avg)
    var_list_general.append(np.sqrt(var))
    time_list_general.append(times)
# this unit conversion somehow not working for GAI
if len(GA_species_list) == 1:
    species = GA_species_list[0]
    avg, var, times = PostProcessing.getAvgVarTrace(cme_traj, species)
    count2concentration = 4.65e-8  #molecule/cell to mM
    avg_list_GA.append(avg*count2concentration)
    var_list_GA.append(np.sqrt(var)*count2concentration)
    time_list_GA.append(times)
else:
    for species in GA_species_list:
        avg, var, times = PostProcessing.getAvgVarTrace(cme_traj, species)
        count2concentration = 4.65e-8  #molecule/cell to mM
        avg_list_GA.append(avg*count2concentration)
        var_list_GA.append(np.sqrt(var)*count2concentration)
        time_list_GA.append(times)
        
        
for species, avg, std, times in zip(general_species_list + GA_species_list, 
                                    avg_list_general + avg_list_GA, 
                                    var_list_general + var_list_GA, 
                                    time_list_general + time_list_GA):
    species_name = species[0] if isinstance(species, list) else species
    cme_results.append({
                        'Species': species_name,
                        'Time': ','.join(map(str, times)),
                        'Average': ','.join(map(str, avg)),
                        'Std': ','.join(map(str, std))})

2025-02-27 19:39:37,220 - INFO - CME species list:
2025-02-27 19:39:37,221 - INFO - ['R1', 'R2', 'R3', 'R4', 'R80', 'G1', 'G2', 'G3', 'G3i', 'G4', 'G4d', 'G80', 'G80C', 'G80d', 'G80Cd', 'G80G3i', 'GAI', 'G2GAI', 'G2GAE', 'G1GAI', 'reporter_rna', 'reporter', 'DG1', 'DG1_G4d', 'DG1_G4d_G80d', 'DG2', 'DG2_G4d', 'DG2_G4d_G80d', 'DG3', 'DG3_G4d', 'DG3_G4d_G80d', 'DGrep', 'DGrep_G4d', 'DGrep_G4d_G80d', 'DG80', 'DG80_G4d', 'DG80_G4d_G80d']
2025-02-27 19:39:37,222 - INFO - ['GAI']
2025-02-27 19:39:37,222 - INFO - total number of species: 1 + 37
2025-02-27 19:39:37,224 - INFO - names: ['R1', 'R2', 'R3', 'R4', 'reporter_rna', 'R80', 'G1', 'G2', 'G3', 'G3i', 'G4', 'G4d', 'reporter', 'G80', 'G80C', 'G80d', 'G80Cd', 'G80G3i', 'GAI', 'DG1', 'DG1_G4d', 'DG1_G4d_G80d', 'DG2', 'DG2_G4d', 'DG2_G4d_G80d', 'DG3', 'DG3_G4d', 'DG3_G4d_G80d', 'DGrep', 'DGrep_G4d', 'DGrep_G4d_G80d', 'DG80', 'DG80_G4d', 'DG80_G4d_G80d', 'G2GAI', 'G2GAE', 'G1GAI']
2025-02-27 19:39:37,225 - INFO - names: ['R1', 'R2', 'R3', 'R4',

In [4]:

# Save to CSV files
rdme_df = pd.DataFrame(rdme_results)
cme_df = pd.DataFrame(cme_results)

rdme_csv_path = os.path.join(fig_dir, 'rdme_species_statistics.csv')
cme_csv_path = os.path.join(fig_dir, 'cme_species_statistics.csv')

rdme_df.to_csv(rdme_csv_path, index=False)
cme_df.to_csv(cme_csv_path, index=False)

logging.info(f"RDME statistics saved to: {rdme_csv_path}")
logging.info(f"CME statistics saved to: {cme_csv_path}")

2025-02-27 19:39:40,011 - INFO - RDME statistics saved to: /data2/2024_Yeast_GS/my_current_code/rdme_ode_results/20250120_11_1_60_Normal_newtomo/figures_rdmecme_comparison/rdme_species_statistics.csv
2025-02-27 19:39:40,012 - INFO - CME statistics saved to: /data2/2024_Yeast_GS/my_current_code/rdme_ode_results/20250120_11_1_60_Normal_newtomo/figures_rdmecme_comparison/cme_species_statistics.csv


plot comparison graphs, this part can run separately

In [5]:
# Read the saved statistics
rdme_df = pd.read_csv(os.path.join(fig_dir, 'rdme_species_statistics.csv'))
cme_df = pd.read_csv(os.path.join(fig_dir, 'cme_species_statistics.csv'))

# Function to convert string of comma-separated values to numpy array
def str_to_array(s):
    return np.array([float(x) for x in s.split(',')])

# Debug: Print available species
print("Available species in RDME:", rdme_df['Species'].tolist())
print("Available species in CME:", cme_df['Species'].tolist())

# plot based on CME species
cme_species = set(cme_df['Species'].unique())


# Plot settings
plt.style.use('default')
plt.rcParams['figure.figsize'] = [10, 6]
plt.rcParams['figure.dpi'] = 600
plt.rcParams['font.size'] = 18  # Increase base font size
plt.rcParams['axes.titlesize'] = 28  # Increase title font size
plt.rcParams['axes.labelsize'] = 18  # Increase axis label font size
plt.rcParams['xtick.labelsize'] = 18  # Increase tick label font size
plt.rcParams['ytick.labelsize'] = 18  # Increase tick label font size
plt.rcParams['legend.fontsize'] = 18  # Increase legend font size



Available species in RDME: ['RDME_DGrep', 'RDME_DGrep_G4d', 'RDME_DGrep_G4d_G80d', 'RDME_Rrep', 'RDME_Grep', 'RDME_DG1', 'RDME_DG1_G4d', 'RDME_DG1_G4d_G80d', 'RDME_R1', 'RDME_G1', 'RDME_DG2', 'RDME_DG2_G4d', 'RDME_DG2_G4d_G80d', 'RDME_R2', 'RDME_G2', 'RDME_DG3', 'RDME_DG3_G4d', 'RDME_DG3_G4d_G80d', 'RDME_R3', 'RDME_G3', 'RDME_G3i', 'RDME_DG4', 'RDME_R4', 'RDME_G4', 'RDME_G4d', 'RDME_DG80', 'RDME_DG80_G4d', 'RDME_DG80_G4d_G80d', 'RDME_R80', 'RDME_G80', 'RDME_G80d', 'RDME_G80d_G3i', 'RDME_ribosome', 'RDME_ribosomeR1', 'RDME_ribosomeR2', 'RDME_ribosomeR3', 'RDME_ribosomeR4', 'RDME_ribosomeR80', 'RDME_ribosomeGrep', 'ODE_GAI', 'ODE_G1', 'ODE_G1GAI', 'ODE_G2GAI', 'ODE_G2GAE', 'ODE_G2']
Available species in CME: ['R1', 'R2', 'R3', 'R4', 'R80', 'G1', 'G2', 'G3', 'G3i', 'G4', 'G4d', 'G80', 'G80C', 'G80d', 'G80Cd', 'G80G3i', 'GAI', 'G2GAI', 'G2GAE', 'G1GAI', 'reporter_rna', 'reporter', 'DG1', 'DG1_G4d', 'DG1_G4d_G80d', 'DG2', 'DG2_G4d', 'DG2_G4d_G80d', 'DG3', 'DG3_G4d', 'DG3_G4d_G80d', 'DGrep',

In [6]:
# Create plots for each species
for species_name in cme_species:
    fig, ax = plt.subplots()
    
    # Safely get data
    matching_rows = rdme_df[rdme_df['Species'].str.contains(species_name)]

    # If multiple rows are found, prioritize the one starting with 'RDME'
    if not matching_rows.empty:
        rdme_species_data = matching_rows[matching_rows['Species'].str.startswith('RDME')]
        # If no match starts with 'RDME', default to the full list of matches
        if rdme_species_data.empty:
            rdme_species_data = matching_rows
    
    cme_species_data = cme_df[cme_df['Species'] == species_name]
    
    if len(rdme_species_data) == 0 or len(cme_species_data) == 0:
        print(f"Skipping {species_name} - data not found")
        continue
        
    er_data = rdme_species_data.iloc[0]
    noer_data = cme_species_data.iloc[0]
    
    time = str_to_array(er_data['Time'])
    er_avg = str_to_array(er_data['Average'])
    er_std = str_to_array(er_data['Std'])
    noer_avg = str_to_array(noer_data['Average'])
    noer_std = str_to_array(noer_data['Std'])
    
    # Extract the part after underscore for display
    display_name = species_name.split('_', 1)[1] if '_' in species_name else species_name
    # Replace any subsequent underscores with colons
    display_name = display_name.replace('_', ':')
    
    # Plot ER
    ax.plot(time, er_avg, label=f'RDME-ODE', linestyle='-')
    ax.fill_between(time, er_avg - er_std, er_avg + er_std, alpha=0.2)
    
    # Plot NOER
    ax.plot(time, noer_avg, label=f'CME-ODE', linestyle='--')
    ax.fill_between(time, noer_avg - noer_std, noer_avg + noer_std, alpha=0.2)
    
    # Customize plot
    ax.set_xlabel('Time (min)')
    ax.set_ylabel('Counts')
    # ax.set_title(f'{display_name} Comparison')
    ax.legend(framealpha=0.3, loc='upper right')
    ax.grid(False)
    
    # Save figure
    # plt.tight_layout()
    fig_path = os.path.join(fig_dir, f'{species_name}_comparison.png')
    plt.savefig(fig_path, dpi=600, bbox_inches='tight')
    print(f"Saved plot for {display_name}")
    plt.close()

print(f"\nPlots saved in: {fig_dir}")

Saved plot for R1
Saved plot for G4
Saved plot for G4d
Saved plot for G4d
Saved plot for GAI
Saved plot for DG80
Saved plot for R4
Saved plot for DG3
Saved plot for G4d
Saved plot for DG1
Saved plot for R80
Saved plot for G1
Saved plot for reporter
Saved plot for G80C
Saved plot for DGrep
Saved plot for G4d
Saved plot for G4d
Saved plot for G4d:G80d
Saved plot for G80
Saved plot for rna
Saved plot for G4d:G80d
Saved plot for G2
Saved plot for G4d:G80d
Saved plot for G3i
Saved plot for G3
Saved plot for G4d:G80d
Saved plot for G4d
Saved plot for G2GAE
Saved plot for DG2
Saved plot for G80Cd
Saved plot for G4d:G80d
Saved plot for G1GAI
Saved plot for R2
Saved plot for G2GAI
Saved plot for R3
Saved plot for G80G3i
Saved plot for G80d

Plots saved in: /data2/2024_Yeast_GS/my_current_code/rdme_ode_results/20250120_11_1_60_Normal_newtomo/figures_rdmecme_comparison/


This is for G2 total

In [7]:
# Create combined G2 species plot
fig, ax = plt.subplots()

# List of species to combine
g2_species = ['G2', 'G2GAE', 'G2GAI']

# Initialize arrays for RDME and CME data
rdme_combined_avg = None
rdme_combined_var = None
cme_combined_avg = None
cme_combined_var = None
time = None

# For tracking which species are actually used
rdme_species_used = []
cme_species_used = []

# Combine RDME data
for species_name in g2_species:
    matching_rows = rdme_df[rdme_df['Species'].str.contains(species_name)]
    if not matching_rows.empty:
        # Modified to use ODE data instead of RDME data
        rdme_species_data = matching_rows[matching_rows['Species'].str.startswith('ODE')]
        
        if len(rdme_species_data) > 0:
            er_data = rdme_species_data.iloc[0]
            # Track which species are being used
            rdme_species_used.append(er_data['Species'])
            
            curr_avg = str_to_array(er_data['Average'])
            curr_std = str_to_array(er_data['Std'])
            curr_var = curr_std ** 2  # Convert std to variance
            
            if rdme_combined_avg is None:
                time = str_to_array(er_data['Time'])
                rdme_combined_avg = curr_avg
                rdme_combined_var = curr_var
            else:
                rdme_combined_avg += curr_avg
                rdme_combined_var += curr_var  # Variances add for independent variables

# Combine CME data
for species_name in g2_species:
    cme_species_data = cme_df[cme_df['Species'] == species_name]
    
    if len(cme_species_data) > 0:
        noer_data = cme_species_data.iloc[0]
        # Track which species are being used
        cme_species_used.append(noer_data['Species'])
        
        curr_avg = str_to_array(noer_data['Average'])
        curr_std = str_to_array(noer_data['Std'])
        curr_var = curr_std ** 2  # Convert std to variance
        
        if cme_combined_avg is None:
            cme_combined_avg = curr_avg
            cme_combined_var = curr_var
        else:
            cme_combined_avg += curr_avg
            cme_combined_var += curr_var  # Variances add for independent variables

# Print which species were actually used
print("RDME species used in G2 total:", rdme_species_used)
print("CME species used in G2 total:", cme_species_used)

# Convert combined variances back to standard deviations
rdme_combined_std = np.sqrt(rdme_combined_var)
cme_combined_std = np.sqrt(cme_combined_var)

# Plot RDME
ax.plot(time, rdme_combined_avg, label='RDME-ODE', linestyle='-')
ax.fill_between(time, rdme_combined_avg - rdme_combined_std, 
                rdme_combined_avg + rdme_combined_std, alpha=0.2)

# Plot CME
ax.plot(time, cme_combined_avg, label='CME-ODE', linestyle='--')
ax.fill_between(time, cme_combined_avg - cme_combined_std, 
                cme_combined_avg + cme_combined_std, alpha=0.2)

# Customize plot
ax.set_xlabel('Time (min)')
ax.set_ylabel('Counts')
# ax.set_title('Total G2 Species Comparison (G2 + G2GAE + G2GAI)')
ax.legend(framealpha=0.3, loc='upper right')
ax.grid(False)

# Save figure
plt.tight_layout()
fig_path = os.path.join(fig_dir, 'G2_total_comparison.png')
plt.savefig(fig_path, dpi=600, bbox_inches='tight')
print(f"Saved combined G2 total plot")
plt.close()

RDME species used in G2 total: ['ODE_G2GAI', 'ODE_G2GAE', 'ODE_G2GAI']
CME species used in G2 total: ['G2', 'G2GAE', 'G2GAI']
Saved combined G2 total plot


Save GAI total plot 

In [8]:
# Create combined GAI species plot
fig, ax = plt.subplots()

# List of species to combine
gai_species = ['GAI', 'G1GAI', 'G3i', 'G2GAI']

# Initialize arrays for RDME and CME data
rdme_combined_avg = None
rdme_combined_var = None
cme_combined_avg = None
cme_combined_var = None
time = None

# For tracking which species are actually used
rdme_species_used = []
cme_species_used = []

# Combine RDME data
for species_name in gai_species:
    matching_rows = rdme_df[rdme_df['Species'].str.contains(species_name)]
    
    if not matching_rows.empty:
        rdme_species_data = matching_rows[matching_rows['Species'].str.startswith('RDME')]
        if rdme_species_data.empty:
            rdme_species_data = matching_rows
            
        if len(rdme_species_data) > 0:
            er_data = rdme_species_data.iloc[0]
            # Track which species are being used
            rdme_species_used.append(er_data['Species'])
            
            curr_avg = str_to_array(er_data['Average'])
            curr_std = str_to_array(er_data['Std'])
            # Convert counts to mM
            curr_avg = curr_avg / NAV * 1e3  # NAV*1e3 for RDME conversion
            curr_std = curr_std / NAV * 1e3
            curr_var = curr_std ** 2  # Convert std to variance
            
            if rdme_combined_avg is None:
                time = str_to_array(er_data['Time'])
                rdme_combined_avg = curr_avg
                rdme_combined_var = curr_var
            else:
                rdme_combined_avg += curr_avg
                rdme_combined_var += curr_var  # Variances add for independent variables

# Combine CME data
for species_name in gai_species:
    cme_species_data = cme_df[cme_df['Species'] == species_name]
    
    if len(cme_species_data) > 0:
        noer_data = cme_species_data.iloc[0]
        # Track which species are being used
        cme_species_used.append(noer_data['Species'])
        
        curr_avg = str_to_array(noer_data['Average'])
        curr_std = str_to_array(noer_data['Std'])
        # Convert counts to mM if not already converted
        # if species_name != 'GAI':  # GAI is already converted in the CME processing
        count2concentration = 4.65e-8  # molecule/cell to mM
        curr_avg = curr_avg * count2concentration
        curr_std = curr_std * count2concentration
        curr_var = curr_std ** 2  # Convert std to variance
        
        if cme_combined_avg is None:
            cme_time = str_to_array(noer_data['Time'])
            cme_combined_avg = curr_avg
            cme_combined_var = curr_var
        else:
            cme_combined_avg += curr_avg
            cme_combined_var += curr_var  # Variances add for independent variables

# Print which species were actually used
print("RDME species used in GAI total:", rdme_species_used)
print("CME species used in GAI total:", cme_species_used)

# Convert combined variances back to standard deviations
if rdme_combined_var is not None:
    rdme_combined_std = np.sqrt(rdme_combined_var)
if cme_combined_var is not None:
    cme_combined_std = np.sqrt(cme_combined_var)

# Plot RDME if data exists
if rdme_combined_avg is not None and time is not None:
    ax.plot(time, rdme_combined_avg, label='RDME-ODE', linestyle='-')
    ax.fill_between(time, rdme_combined_avg - rdme_combined_std, 
                    rdme_combined_avg + rdme_combined_std, alpha=0.2)

# Plot CME if data exists
if cme_combined_avg is not None and 'cme_time' in locals():
    ax.plot(cme_time, cme_combined_avg, label='CME-ODE', linestyle='--')
    ax.fill_between(cme_time, cme_combined_avg - cme_combined_std, 
                    cme_combined_avg + cme_combined_std, alpha=0.2)

# Add horizontal line for GAE = 11.1mM with a more fitting color
ax.axhline(y=11.1, color='gray', linestyle='-.', linewidth=2, label='GAE')
ax.text(time[0]*1.05, 10.8, '11.1 mM', color='gray', fontsize=16, va='top', ha='left')
# Customize plot
ax.set_xlabel('Time (min)')
ax.set_ylabel('Concentration (mM)')
# ax.set_title('Total GAI Species Comparison (GAI + G1GAI + G3i + G2GAI)')
ax.legend(framealpha=0.3, loc='upper right')
ax.grid(False)

# Save figure
# plt.tight_layout()
fig_path = os.path.join(fig_dir, 'GAI_total_comparison.png')
plt.savefig(fig_path, dpi=600, bbox_inches='tight')
print(f"Saved combined GAI total plot")
plt.close()

RDME species used in GAI total: ['ODE_GAI', 'ODE_G1GAI', 'RDME_G3i', 'ODE_G2GAI']
CME species used in GAI total: ['GAI', 'G1GAI', 'G3i', 'G2GAI']
Saved combined GAI total plot
