In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from skimpy.utils.tabdict import TabDict
import os
import configparser
from pytfa.io.json import load_json_model
from skimpy.io.yaml import load_yaml_model
from skimpy.analysis.oracle.minimum_fluxes import MinFLuxVariable
import numpy as np
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../utils")))
from joblib import Parallel, delayed
import pickle

UPPER_IX = 1000
PHYSIOLOGY = 'WT'

# Read configuration from config.ini (single instance, robust path)
config = configparser.ConfigParser()
config_path = '../src/config.ini'
config.read(config_path)

# Path to data and model from config, using base_dir
base_dir = config['paths']['base_dir']

path_to_kmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_kmodel_{PHYSIOLOGY}']))
path_to_tmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_tmodel_{PHYSIOLOGY}']))
path_to_samples = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_samples_{PHYSIOLOGY}']))
path_to_fcc = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_fcc_{PHYSIOLOGY}']))
path_to_fcc_df_WT = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_cc_df_WT']))
path_to_fcc_df_MUT = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_cc_df_MUT']))

In [None]:
# Load pytfa model
print('Loading TFA model from:', path_to_tmodel)
tmodel = load_json_model(path_to_tmodel)
samples = pd.read_csv(path_to_samples, index_col=0)

# Find the producing reactions for each BBB
bbb_producing_reactions = {}
for bbb in tmodel.reactions.biomass.reactants:
    if bbb.id.endswith('_n'):
        print(f'Converting {bbb.id} to cytosolic metabolite.')
        bbb = tmodel.metabolites.get_by_id(bbb.id[:-2] + '_c')  # Convert to cytosolic metabolite if necessary
    # Find the producing reactions of the BBB
    if bbb.id in ['h2o_c']:
        continue
    reactions = [r.id for r in bbb.reactions if r.lower_bound*r.metabolites[bbb] > 0]
    bbb_producing_reactions[bbb.id] = reactions

In [None]:
def get_fccs_for_reactions(ss, rxns, path = path_to_fcc):
    '''Get FCCs for all provided reactions at one steady state'''
    from collections import defaultdict
    
    rxn_to_dfs = defaultdict(list)

    for i in range(0, 100, 10):
        try:
            with open(path.format(ss, i, i + 9), 'rb') as f:
                fcc = pickle.load(f)
            for rxn in rxns:
                try:
                    slice_df = fcc.slice_by('flux', rxn)
                    slice_df.columns = [f'{ss},' + str(col) for col in slice_df.columns]
                    rxn_to_dfs[rxn].append(slice_df)
                except KeyError:
                    continue  # This reaction may not be present
        except (FileNotFoundError, OSError):
            continue

    # Concatenate all slices per reaction
    return {
        rxn: pd.concat(dfs, axis=1) if dfs else pd.DataFrame()
        for rxn, dfs in rxn_to_dfs.items()
    }

def parallel_get_fccs(max_ss=600, rxns=['biomass'], n_jobs=100):
    from collections import defaultdict
    
    all_results = Parallel(n_jobs=n_jobs)(
        delayed(get_fccs_for_reactions)(i, rxns) for i in tqdm(range(max_ss), desc='Loading FCCs')
    )
    
    # Combine results for each reaction
    rxn_to_dfs = defaultdict(list)
    for result in all_results:
        for rxn, df in result.items():
            if not df.empty:
                rxn_to_dfs[rxn].append(df)

    return {
        rxn: pd.concat(dfs, axis=1) if dfs else pd.DataFrame()
        for rxn, dfs in rxn_to_dfs.items()
    }

def process_reactions(met, rxn_ids, max_ss=700, remove_outliers=True, path_to_save = path_to_fcc_df_WT):
    rxn_ids = list(rxn_ids) 
    name_for_saving = met+'_producing_reactions'
    if os.path.exists(path_to_save.format(name_for_saving)):
        print(f"Skipping {met} as it already exists.")
        return

    try:
        all_fccs = parallel_get_fccs(max_ss=max_ss, rxns=rxn_ids)
        print('Finished loading FCCs')
        weights = []
        fccs = []
        for rxn, df in all_fccs.items():
            print(rxn)
            if df.empty:
                print(f"No FCC data found for {rxn}")
                continue
            # Remove outliers
            if remove_outliers:
                df = remove_outliers_parallel(df, n_jobs=100)
            fccs.append(df.reindex(sorted(df.columns), axis=1))

            weights.append(samples.loc[:, tmodel.reactions.get_by_id(rxn).id] - samples.loc[:, tmodel.reactions.get_by_id(rxn).reverse_id])

        # Stack all fccs into a 3D numpy array: shape (n_samples, n_rows, n_cols)
        fcc_array = np.stack([fcc.values for fcc in fccs])  # shape: (n_samples, n_rows, n_cols)

        # Extract column names and index from the first dataframe
        columns = fccs[0].columns
        index = fccs[0].index
        n_rows, n_cols = len(index), len(columns)

        # Precompute weights for each column based on 'ss'
        final_data = []

        for col_idx, col_name in enumerate(columns):
            ss = int(col_name.split(',')[0])
            col_values = fcc_array[:, :, col_idx]  # shape: (n_samples, n_rows)
            
            # Extract weights for this ss
            w = np.array([weight[ss] for weight in weights])  # shape: (n_samples,)
            
            numerator = np.tensordot(w, col_values, axes=(0, 0))  # shape: (n_rows,)
            denominator = w.sum()
            
            if denominator != 0:
                final_data.append(numerator / denominator)
            else:
                final_data.append(np.zeros(n_rows))

        # Stack column-wise and convert to DataFrame
        final_fcc = pd.DataFrame(np.column_stack(final_data), columns=columns, index=index)

        # Save the final FCCs for this metabolite
        final_fcc.to_csv(path_to_save.format(name_for_saving))
        print(f"Finished processing {met} with {len(rxn_ids)} reactions.")

        # Delete the fccs list to free memory
        del fccs
        del fcc_array
        del final_data

    except Exception as e:
        print(f"Failed for reactions {rxn_ids}: {e}")


In [None]:
average_fccs_WT = pd.DataFrame(columns=bbb_producing_reactions.keys())
from tqdm import tqdm
for met, rxns in bbb_producing_reactions.items():
    name_for_saving = 'synthesis_rate_' + met
    filename = path_to_fcc_df_WT.format(name_for_saving)

    # First, get column names to define dtypes
    total_flux_df_cols = pd.read_csv(filename, index_col=0, nrows=0).columns
    dtype_dict = {col: 'float32' for col in total_flux_df_cols}
    dtype_dict['model_ix'] = 'string'

    # Count total number of lines for progress bar
    total_lines = sum(1 for _ in open(filename)) - 1  # Subtract header
    chunksize = 1000

    # Read in chunks using index_col=0 to preserve the original index
    chunk_iter = pd.read_csv(filename, chunksize=chunksize, index_col=0, dtype=dtype_dict)

    df_chunks = []
    for chunk in tqdm(chunk_iter, total=total_lines // chunksize + 1, desc="Loading CSV"):
        df_chunks.append(chunk)

    # Concatenate chunks WITHOUT ignore_index so original index is preserved
    total_flux_df = pd.concat(df_chunks)

    # Keep the average FCCs
    average_fccs_WT[met] = total_flux_df.mean(axis=1)

In [None]:
average_fccs_MUT = pd.DataFrame(columns=bbb_producing_reactions.keys())
from tqdm import tqdm
for met, rxns in bbb_producing_reactions.items():
    name_for_saving = 'synthesis_rate_' + met
    filename = path_to_fcc_df_MUT.format(name_for_saving)

    # First, get column names to define dtypes
    total_flux_df_cols = pd.read_csv(filename, index_col=0, nrows=0).columns
    dtype_dict = {col: 'float32' for col in total_flux_df_cols}
    dtype_dict['model_ix'] = 'string'

    # Count total number of lines for progress bar
    total_lines = sum(1 for _ in open(filename)) - 1  # Subtract header
    chunksize = 1000

    # Read in chunks using index_col=0 to preserve the original index
    chunk_iter = pd.read_csv(filename, chunksize=chunksize, index_col=0, dtype=dtype_dict)

    df_chunks = []
    for chunk in tqdm(chunk_iter, total=total_lines // chunksize + 1, desc="Loading CSV"):
        df_chunks.append(chunk)

    # Concatenate chunks WITHOUT ignore_index so original index is preserved
    total_flux_df = pd.concat(df_chunks)

    # Keep the average FCCs
    average_fccs_MUT[met] = total_flux_df.mean(axis=1)

In [None]:
# Remove the columns of essential amino acids and non essential amino acids
group_df = pd.DataFrame([
    ("his_L_c", "Amino Acid", "Essential"),
    ("ile_L_c", "Amino Acid", "Essential"),
    ("leu_L_c", "Amino Acid", "Essential"),
    ("lys_L_c", "Amino Acid", "Essential"),
    ("met_L_c", "Amino Acid", "Essential"),
    ("phe_L_c", "Amino Acid", "Essential"),
    ("thr_L_c", "Amino Acid", "Essential"),
    ("trp_L_c", "Amino Acid", "Essential"),
    ("val_L_c", "Amino Acid", "Essential"),
    ("ala_L_c", "Amino Acid", "Non-Essential"),
    ("arg_L_c", "Amino Acid", "Non-Essential"),
    ("asn_L_c", "Amino Acid", "Non-Essential"),
    ("asp_L_c", "Amino Acid", "Non-Essential"),
    ("cys_L_c", "Amino Acid", "Non-Essential"),
    ("gln_L_c", "Amino Acid", "Non-Essential"),
    ("glu_L_c", "Amino Acid", "Non-Essential"),
    ("gly_c", "Amino Acid", "Non-Essential"),
    ("pro_L_c", "Amino Acid", "Non-Essential"),
    ("ser_L_c", "Amino Acid", "Non-Essential"),
    ("tyr_L_c", "Amino Acid", "Non-Essential"),
    ("atp_c", "Ribonucleotide", "Nucleotides"),
    ("ctp_c", "Ribonucleotide", "Nucleotides"),
    ("gtp_c", "Ribonucleotide", "Nucleotides"),
    ("utp_c", "Ribonucleotide", "Nucleotides"),
    ("datp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dctp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dgtp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dttp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("chsterol_c", "Lipid", "Lipids"),
    ("clpn_hs_c", "Lipid", "Lipids"),
    ("pail_hs_c", "Lipid", "Lipids"),
    ("pchol_hs_c", "Lipid", "Lipids"),
    ("pe_hs_c", "Lipid", "Lipids"),
    ("pglyc_hs_c", "Lipid", "Lipids"),
    ("ps_hs_c", "Lipid", "Lipids"),
    ("sphmyln_hs_c", "Lipid", "Lipids"),
    ("g6p_c", "Carbohydrate", "")
], columns=["Metabolite", "Category", "Subcategory"])

average_fccs_MUT = average_fccs_MUT.drop(columns=group_df[group_df['Subcategory'] == 'Essential']['Metabolite'].tolist())
average_fccs_MUT = average_fccs_MUT.drop(columns=group_df[group_df['Subcategory'] == 'Non-Essential']['Metabolite'].tolist())

average_fccs_WT = average_fccs_WT.drop(columns=group_df[group_df['Subcategory'] == 'Essential']['Metabolite'].tolist())
average_fccs_WT = average_fccs_WT.drop(columns=group_df[group_df['Subcategory'] == 'Non-Essential']['Metabolite'].tolist())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import OrderedDict

# Subset and clean index for average_fccs
mca_targets = [
 'vmax_forward_TRIOK', 'vmax_forward_MI1PP', 'vmax_forward_PPAP', 'vmax_forward_r0301', 'vmax_forward_METAT',
 'vmax_forward_3DSPHR', 'vmax_forward_TMDS', 'vmax_forward_SERPT', 'vmax_forward_HMR_7748', 'vmax_forward_PSP_L',
 'vmax_forward_NADH2_u10mi', 'vmax_forward_DGK1', 'vmax_forward_NTD1', 'vmax_forward_r0354', 'vmax_forward_ADSS',
 'vmax_forward_r0178', 'vmax_forward_IMPD', 'vmax_forward_r0179', 'vmax_forward_PGI', 'vmax_forward_r0474',
 'vmax_forward_ICDHyrm', 'vmax_forward_GMPS2', 'vmax_forward_UMPK2', 'vmax_forward_ICDHxm', 'vmax_forward_URIK1',
 'vmax_forward_CYTK1', 'vmax_forward_HMR_4343', 'vmax_forward_r0426'
]


sorted_names = ['TRIOK', 'HMR_7748', 'r0354', 'PGI', 'ICDHyrm', 'ICDHxm', 'r0426', 'PPAP', 'r0301', '3DSPHR', 'SERPT',
                 'METAT', 'PSP_L', 'r0178', 'r0179', 'TMDS', 'NTD1', 'DGK1', 'ADSS', 'IMPD', 'r0474', 'URIK1', 'GMPS2',
                   'HMR_4343', 'CYTK1', 'UMPK2', 'MI1PP', 'NADH2_u10mi']

data_to_plot_MUT = average_fccs_MUT.loc[set(mca_targets)]
data_to_plot_MUT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_MUT.index]
data_to_plot_MUT = data_to_plot_MUT.loc[(sorted_names)]
# Replace NADH2_u10mi in the index with MComplex1
data_to_plot_MUT = data_to_plot_MUT.rename(index={"NADH2_u10mi": "MComplex1"})

data_to_plot_WT = average_fccs_WT.loc[set(mca_targets)]
data_to_plot_WT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_WT.index]
data_to_plot_WT = data_to_plot_WT.loc[(sorted_names)]
# Replace NADH2_u10mi in the index with MComplex1
data_to_plot_WT = data_to_plot_WT.rename(index={"NADH2_u10mi": "MComplex1"})

# Create the mapping from metabolite to subcategory
group_df = pd.DataFrame([
    ("his_L_c", "Amino Acid", "Essential"),
    ("ile_L_c", "Amino Acid", "Essential"),
    ("leu_L_c", "Amino Acid", "Essential"),
    ("lys_L_c", "Amino Acid", "Essential"),
    ("met_L_c", "Amino Acid", "Essential"),
    ("phe_L_c", "Amino Acid", "Essential"),
    ("thr_L_c", "Amino Acid", "Essential"),
    ("trp_L_c", "Amino Acid", "Essential"),
    ("val_L_c", "Amino Acid", "Essential"),
    ("ala_L_c", "Amino Acid", "Non-Essential"),
    ("arg_L_c", "Amino Acid", "Non-Essential"),
    ("asn_L_c", "Amino Acid", "Non-Essential"),
    ("asp_L_c", "Amino Acid", "Non-Essential"),
    ("cys_L_c", "Amino Acid", "Non-Essential"),
    ("gln_L_c", "Amino Acid", "Non-Essential"),
    ("glu_L_c", "Amino Acid", "Non-Essential"),
    ("gly_c", "Amino Acid", "Non-Essential"),
    ("pro_L_c", "Amino Acid", "Non-Essential"),
    ("ser_L_c", "Amino Acid", "Non-Essential"),
    ("tyr_L_c", "Amino Acid", "Non-Essential"),
    ("atp_c", "Ribonucleotide", "Nucleotides"),
    ("ctp_c", "Ribonucleotide", "Nucleotides"),
    ("gtp_c", "Ribonucleotide", "Nucleotides"),
    ("utp_c", "Ribonucleotide", "Nucleotides"),
    ("datp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dctp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dgtp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dttp_c", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("chsterol_c", "Lipid", "Lipids"),
    ("clpn_hs_c", "Lipid", "Lipids"),
    ("pail_hs_c", "Lipid", "Lipids"),
    ("pchol_hs_c", "Lipid", "Lipids"),
    ("pe_hs_c", "Lipid", "Lipids"),
    ("pglyc_hs_c", "Lipid", "Lipids"),
    ("ps_hs_c", "Lipid", "Lipids"),
    ("sphmyln_hs_c", "Lipid", "Lipids"),
    ("g6p_c", "Carbohydrate", "")
], columns=["Metabolite", "Category", "Subcategory"])

# Reorder columns by subcategory
met_to_subcat = group_df.set_index("Metabolite")["Subcategory"].to_dict()
existing_columns = [col for col in data_to_plot_WT.columns if col in met_to_subcat]

subcategory_order = [
    'Essential',
    'Deoxynucleotides',
    'Nucleotides',
    "Non-Essential",
    "Lipids",
    ""
]

sorted_columns = sorted(
    existing_columns,
    key=lambda x: (subcategory_order.index(met_to_subcat[x]), x)
)
data_to_plot_WT = data_to_plot_WT[sorted_columns]
data_to_plot_MUT = data_to_plot_MUT[sorted_columns]

# Prepare subcategory boundaries for bracket plotting
subcat_to_indices = OrderedDict()
for i, col in enumerate(data_to_plot_WT.columns):
    subcat = met_to_subcat[col]
    if subcat not in subcat_to_indices:
        subcat_to_indices[subcat] = [i, i]
    else:
        subcat_to_indices[subcat][1] = i
        
# Function to add brackets to an axis
def add_brackets(ax, subcat_to_indices, met_to_subcat):
    # Positions just above the heatmap axes area (negative y puts them above x-ticks)
    bracket_y = -2
    text_y = -2.5
    for label, (start, end) in subcat_to_indices.items():
        if label == '':
            continue
        if label == 'Essential':
            label = 'Essential AAs'
        elif label == 'Non-Essential':
            label = 'Non-Essential AAs'
        # Horizontal line spanning the subcategory
        ax.hlines(bracket_y, start + 0.1, end + 1 - 0.1, color='black', linewidth=1.5, clip_on=False)
        # Vertical tips
        ax.vlines([start + 0.1, end + 1 - 0.1], bracket_y, bracket_y + 1, color='black', linewidth=1.5, clip_on=False)
        # Label centered over the span
        ax.text((start + end + 1) / 2, text_y, label,
                ha='center', va='bottom', fontsize=20, rotation=45, clip_on=False)

# Compute vertical separator positions between adjacent subcategories (dashed lines)
# This avoids hardcoding indices and adapts to whatever columns are present
separator_positions = []
last_end = None
ordered_blocks = list(subcat_to_indices.items())
for idx, (label, (start, end)) in enumerate(ordered_blocks):
    if label == '':
        continue
    if idx > 0:
        # Separator goes at the left edge of this block
        separator_positions.append(start)

# Plot settings
VMIN, VMAX = -2, 2  # keep the same dynamic range as the previous figure for comparability

plt.rcParams.update({'font.size': 22})

fig = plt.figure(figsize=(35, 20))
ax1 = plt.subplot2grid((1, 25), (0, 0), colspan=10)   # WT
ax2 = plt.subplot2grid((1, 25), (0, 11), colspan=10)  # MUT

# Left heatmap (WT)
im_wt = sns.heatmap(
    data_to_plot_WT,
    cmap='seismic',
    center=0,
    vmin=VMIN,
    vmax=VMAX,
    cbar=False,
    ax=ax1,
    square=True
)
ax1.set_xlabel('')
ax1.set_ylabel('Enzymes', fontsize=26)
ax1.tick_params(axis='both', which='major', labelsize=18)
add_brackets(ax1, subcat_to_indices, met_to_subcat)

# Right heatmap (MUT)
im_mut = sns.heatmap(
    data_to_plot_MUT,
    cmap='seismic',
    center=0,
    vmin=VMIN,
    vmax=VMAX,
    cbar=False,
    ax=ax2,
    square=True
)
ax2.set_xlabel('')
ax2.set_ylabel('')
ax2.tick_params(axis='both', which='major', labelsize=18)
add_brackets(ax2, subcat_to_indices, met_to_subcat)


# Draw dashed separators at subcategory boundaries
for x in separator_positions:
    ax1.axvline(x=x, color='black', linewidth=1, linestyle='--')
    ax2.axvline(x=x, color='black', linewidth=1, linestyle='--')

# Inset colorbar to the right of the MUT heatmap
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cax = inset_axes(
    ax2,
    width="6%",
    height="30%",
    loc='center left',
    bbox_to_anchor=(1.05, 0.35, 1, 1),
    bbox_transform=ax2.transAxes,
    borderpad=0
)
# Use the right heatmap's mappable for the colorbar
cbar = plt.colorbar(im_mut.collections[0], cax=cax)
cbar.set_label('Effect on each biomass precursor', fontsize=20, labelpad=6)
cbar.set_ticks([VMIN, 0, VMAX])
cbar.set_ticklabels([f'< {VMIN}', '0', f'> {VMAX}'])
cbar.ax.tick_params(labelsize=20, length=3, width=1)

plt.tight_layout()

plt.show()