In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from skimpy.utils.tabdict import TabDict
import os
import configparser
from pytfa.io.json import load_json_model
from skimpy.io.yaml import load_yaml_model
from skimpy.analysis.oracle.minimum_fluxes import MinFLuxVariable

# Read configuration from config.ini (single instance, robust path)
config = configparser.ConfigParser()
config_path = '../src/config.ini'
config.read(config_path)

# Path to data and model from config, using base_dir
base_dir = config['paths']['base_dir']
path_to_kmodel = os.path.abspath(os.path.join(base_dir, config['paths']['path_to_kmodel_WT']))
path_to_tmodel = os.path.abspath(os.path.join(base_dir, config['paths']['path_to_tmodel_WT']))
path_to_ccc_df_WT = os.path.abspath(os.path.join(base_dir, config['paths']['path_to_cc_df_WT']))
path_to_ccc_df_MUT = os.path.abspath(os.path.join(base_dir, config['paths']['path_to_cc_df_MUT']))

In [None]:
# Load pytfa model
print('Loading TFA model from:', path_to_tmodel)
tmodel = load_json_model(path_to_tmodel)

# Load the kinetic model and prepare for compilinig of jacobian
kmodel = load_yaml_model(path_to_kmodel)
kmodel.prepare()
ind_reactants = [kmodel.reactants.iloc(i)[0] for i in kmodel.independent_variables_ix]
kmodel.independent_variables_names = ind_reactants

In [None]:
PARAMETER_FOR_MCA = 'vmax_forward'
parameter_list = TabDict([(k, p.symbol) for k, p in kmodel.parameters.items()
                          if p.name.startswith(PARAMETER_FOR_MCA)])

# List of relevant metabolite IDs
met_ids = [
    met.id for met in tmodel.reactions.biomass.reactants
    if met.id in kmodel.independent_variables_names
]

In [None]:
path_to_ccc_df_WT = '../../results/MCA_workflow/WT_results/WT_CCC_{}.csv'

In [None]:
def get_average_cccs(met_id, physiology='MUT'):
    try:
        if physiology == 'MUT':
            global_ccc = pd.read_csv(path_to_ccc_df_MUT.format(met_id), index_col=0)
        else:
            global_ccc = pd.read_csv(path_to_ccc_df_WT.format(met_id), index_col=0)
        average_df = global_ccc.mean(axis=1)
        # Name it after the metabolite
        average_df.name = met_id
        return average_df
    except Exception as e:
        print(f"Error processing metabolite {met_id}: {e}")
        return []

# Run in parallel
parameter_index = pd.Index(parameter_list.keys())
bbb_df = pd.DataFrame(index=parameter_index)
with ProcessPoolExecutor(max_workers=40) as executor:
    futures = {executor.submit(get_average_cccs, met_id): met_id for met_id in met_ids}
    for future in tqdm(as_completed(futures), total=len(futures), desc='Getting key enzymes...'):
        result = future.result()
        if result is not None:
            bbb_df = pd.concat([bbb_df, result], axis=1)

In [None]:
# Define a function for WT physiology
def get_average_cccs_wt(met_id):
    return get_average_cccs(met_id, physiology='WT')

# Prepare a new DataFrame for WT results
bbb_df_wt = pd.DataFrame(index=parameter_index)

# Run in parallel for WT
with ProcessPoolExecutor(max_workers=40) as executor:
    futures = {executor.submit(get_average_cccs_wt, met_id): met_id for met_id in met_ids}
    for future in tqdm(as_completed(futures), total=len(futures), desc='Getting key enzymes for WT...'):
        result = future.result()
        if result is not None:
            bbb_df_wt = pd.concat([bbb_df_wt, result], axis=1)

In [None]:
sorted_names = ['TRIOK', 'HMR_7748', 'r0354', 'PGI', 'ICDHyrm', 'ICDHxm',
                'r0426', 'PPAP', 'r0301', '3DSPHR', 'SERPT', 'METAT', 'PSP_L',
                'r0178', 'r0179', 'TMDS', 'NTD1', 'DGK1', 'ADSS', 'IMPD',
                'r0474', 'URIK1', 'GMPS2', 'HMR_4343', 'CYTK1', 'UMPK2',
                'MI1PP', 'NADH2_u10mi']
data_to_plot_MUT = bbb_df.loc[:,set(met_ids)]
data_to_plot_MUT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_MUT.index]
data_to_plot_MUT = data_to_plot_MUT.loc[(sorted_names)]

data_to_plot_WT = bbb_df_wt.loc[:, set(met_ids)]
data_to_plot_WT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_WT.index]
data_to_plot_WT = data_to_plot_WT.loc[(sorted_names)]


# Create the mapping from metabolite to subcategory
group_df = pd.DataFrame([
    ("his_L_c", "Amino Acid", "Essential"),
    ("ile_L_c", "Amino Acid", "Essential"),
    ("leu_L_c", "Amino Acid", "Essential"),
    ("lys_L_c", "Amino Acid", "Essential"),
    ("met_L_c", "Amino Acid", "Essential"),
    ("phe_L_c", "Amino Acid", "Essential"),
    ("thr_L_c", "Amino Acid", "Essential"),
    ("trp_L_c", "Amino Acid", "Essential"),
    ("val_L_c", "Amino Acid", "Essential"),
    ("ala_L_c", "Amino Acid", "Non-Essential"),
    ("arg_L_c", "Amino Acid", "Non-Essential"),
    ("asn_L_c", "Amino Acid", "Non-Essential"),
    ("asp_L_c", "Amino Acid", "Non-Essential"),
    ("cys_L_c", "Amino Acid", "Non-Essential"),
    ("gln_L_c", "Amino Acid", "Non-Essential"),
    ("glu_L_c", "Amino Acid", "Non-Essential"),
    ("gly_c", "Amino Acid", "Non-Essential"),
    ("pro_L_c", "Amino Acid", "Non-Essential"),
    ("ser_L_c", "Amino Acid", "Non-Essential"),
    ("tyr_L_c", "Amino Acid", "Non-Essential"),
    ("atp_c", "Ribonucleotide", "Nucleotides"),
    ("ctp_c", "Ribonucleotide", "Nucleotides"),
    ("gtp_c", "Ribonucleotide", "Nucleotides"),
    ("utp_c", "Ribonucleotide", "Nucleotides"),
    ("datp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dctp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dgtp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dttp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("chsterol_c", "Lipid", "Lipids"),
    ("clpn_hs_c", "Lipid", "Lipids"),
    ("pail_hs_c", "Lipid", "Lipids"),
    ("pchol_hs_c", "Lipid", "Lipids"),
    ("pe_hs_c", "Lipid", "Lipids"),
    ("pglyc_hs_c", "Lipid", "Lipids"),
    ("ps_hs_c", "Lipid", "Lipids"),
    ("sphmyln_hs_c", "Lipid", "Lipids"),
    ("g6p_c", "Carbohydrate", "")
], columns=["Metabolite", "Category", "Subcategory"])

# Reorder columns by subcategory
met_to_subcat = group_df.set_index("Metabolite")["Subcategory"].to_dict()
existing_columns = [col for col in data_to_plot_WT.columns if col in met_to_subcat and met_to_subcat[col] not in ["Essential", "Non-Essential"]]

# Define custom order for subcategories
subcategory_order = [
    'Essential',
    'Deoxynucleotides',
    'Nucleotides',
    "Non-Essential",
    "Lipids",
    ""  # For g6p_c and any uncategorized items
]

# Sort columns based on this order
sorted_columns = sorted(
    existing_columns,
    key=lambda x: (subcategory_order.index(met_to_subcat[x]), x)
)
data_to_plot_WT = data_to_plot_WT[sorted_columns]
data_to_plot_MUT = data_to_plot_MUT[sorted_columns]

# Prepare subcategory boundaries for bracket plotting
from collections import OrderedDict
from matplotlib.patches import ConnectionPatch

subcat_to_indices = OrderedDict()
for i, col in enumerate(data_to_plot_WT.columns):
    subcat = met_to_subcat[col]
    if subcat not in subcat_to_indices:
        subcat_to_indices[subcat] = [i, i]
    else:
        subcat_to_indices[subcat][1] = i

# Function to add brackets to an axis
def add_brackets(ax, subcat_to_indices, met_to_subcat):
    # Draw brackets above the heatmap using lines
    bracket_y = -2  # y position above the heatmap
    text_y = -2.5    # y position for text labels
    for label, (start, end) in subcat_to_indices.items():
        if label == '':
            continue
        if label == 'Essential':
            label = 'Essential AAs'
        elif label == 'Non-Essential':
            label = 'Non-Essential AAs'
        # Draw horizontal line
        ax.hlines(bracket_y, start+0.1, end + 1-0.1, color='black', linewidth=1.5, clip_on=False)
        # Draw vertical end lines (like bracket tips)
        ax.vlines([start+0.1, end + 1-0.1], bracket_y, bracket_y + 1, color='black', linewidth=1.5, clip_on=False)
        # Place subcategory label
        ax.text((start + end + 1) / 2, text_y, label,
                ha='center', va='bottom', fontsize=20, rotation=45, clip_on=False)

# Create figure with subplots - reserve space for colorbar
fig = plt.figure(figsize=(35, 20))
plt.rcParams.update({'font.size': 22})  # globally increase font size

# Create three axes: two for heatmaps (equal size) and one for colorbar with spacing
ax1 = plt.subplot2grid((1, 25), (0, 0), colspan=10)  # Left heatmap
ax2 = plt.subplot2grid((1, 25), (0, 11), colspan=10)  # Right heatmap (2 column gap)
# cbar_ax = plt.subplot2grid((1, 25), (0, 24), colspan=1)  # Colorbar (2 column gap)

# Left heatmap
sns.heatmap(
    data_to_plot_WT,
    cmap='seismic',
    center=0,
    vmin=-2,  # Set minimum value for color mapping
    vmax=2,   # Set maximum value for color mapping
    cbar=False,  # No colorbar for left plot
    ax=ax1,
    square=True
)

ax1.set_xlabel('')  # Remove individual x-axis label
ax1.set_ylabel('Enzymes', fontsize=26)
ax1.tick_params(axis='both', which='major', labelsize=18)

# Add brackets to left heatmap
add_brackets(ax1, subcat_to_indices, met_to_subcat)

# Right heatmap (identical copy)
# data_to_plot_WT =  pd.read_csv('../../results/MCA_workflow/WT_results/average_cccs_biomass_precursors.csv', index_col=0)
im = sns.heatmap(
    data_to_plot_MUT,
    cmap='seismic',
    center=0,
    vmin=-2,  # Set minimum value for color mapping
    vmax=2,   # Set maximum value for color mapping
    cbar=False,  # No colorbar here either
    ax=ax2,
    square=True
)
ax2.set_xlabel('')  # Remove individual x-axis label
ax2.set_ylabel('')  # Remove y-axis label from second heatmap
ax2.tick_params(axis='both', which='major', labelsize=18)


# Add brackets to right heatmap
add_brackets(ax2, subcat_to_indices, met_to_subcat)

group_ends = [4, 8, 17, 16]  # For example: val_L_c is at row 12, datp_n is at row 25

for x in group_ends:
    ax1.axvline(x=x, color='black', linewidth=1, linestyle='--')
    ax2.axvline(x=x, color='black', linewidth=1, linestyle='--')

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# Create a smaller, thicker inset colorbar to the right of ax2
cax = inset_axes(
    ax2,
    width="6%",              # Thicker
    height="30%",            # Shorter
    loc='center left',
    bbox_to_anchor=(1.05, 0.35, 1, 1),  # Centered vertically
    bbox_transform=ax2.transAxes,
    borderpad=0
)

# Add colorbar
cbar = plt.colorbar(im.collections[0], cax=cax)
cbar.set_label('Effect on each biomass precursor', fontsize=20, labelpad=6)
cbar.set_ticks([-2, 0, 2])
cbar.set_ticklabels(['< -2', '0', '> 2'])
cbar.ax.tick_params(labelsize=20, length=3, width=1)


plt.tight_layout()

# Save figure
plt.rcParams['pdf.fonttype'] = 42  # For PDF
plt.rcParams['svg.fonttype'] = 'none'  # For SVG
# plt.savefig('../../results/MCA_workflow/MUT_CCC_bio_precursors_heatmap_double.pdf',
#             dpi=300,
#             transparent=True,
#             bbox_inches='tight',
#             pad_inches=0.1)
plt.show()

In [None]:
bbb_df = bbb_df.drop(columns=group_df[group_df['Subcategory'] == 'Essential']['Metabolite'].tolist())
bbb_df = bbb_df.drop(columns=group_df[group_df['Subcategory'] == 'Non-Essential']['Metabolite'].tolist())

In [None]:
bbb_df_wt = bbb_df_wt.drop(columns=group_df[group_df['Subcategory'] == 'Essential']['Metabolite'].tolist())
bbb_df_wt = bbb_df_wt.drop(columns=group_df[group_df['Subcategory'] == 'Non-Essential']['Metabolite'].tolist())

In [None]:
bbb_df.to_csv('./ccc_remove_me.csv')
bbb_df_wt.to_csv('./ccc_wt_remove_me.csv')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import OrderedDict

sorted_names = ['TRIOK', 'HMR_7748', 'r0354', 'PGI', 'ICDHyrm', 'ICDHxm',
                'r0426', 'PPAP', 'r0301', '3DSPHR', 'SERPT', 'METAT', 'PSP_L',
                'r0178', 'r0179', 'TMDS', 'NTD1', 'DGK1', 'ADSS', 'IMPD',
                'r0474', 'URIK1', 'GMPS2', 'HMR_4343', 'CYTK1', 'UMPK2',
                'MI1PP', 'NADH2_u10mi']

data_to_plot_MUT = bbb_df.loc[:, set(met_ids)]
data_to_plot_MUT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_MUT.index]
data_to_plot_MUT = data_to_plot_MUT.loc[(sorted_names)]
data_to_plot_MUT = data_to_plot_MUT.rename(index={"NADH2_u10mi": "MComplex1"})

data_to_plot_WT = bbb_df_wt.loc[:, set(met_ids)]
data_to_plot_WT.index = [i.split('vmax_forward_',)[1] for i in data_to_plot_WT.index]
data_to_plot_WT = data_to_plot_WT.loc[(sorted_names)]
data_to_plot_WT = data_to_plot_WT.rename(index={"NADH2_u10mi": "MComplex1"})

# Create the mapping from metabolite to subcategory
group_df = pd.DataFrame([
    ("his_L_c", "Amino Acid", "Essential"),
    ("ile_L_c", "Amino Acid", "Essential"),
    ("leu_L_c", "Amino Acid", "Essential"),
    ("lys_L_c", "Amino Acid", "Essential"),
    ("met_L_c", "Amino Acid", "Essential"),
    ("phe_L_c", "Amino Acid", "Essential"),
    ("thr_L_c", "Amino Acid", "Essential"),
    ("trp_L_c", "Amino Acid", "Essential"),
    ("val_L_c", "Amino Acid", "Essential"),
    ("ala_L_c", "Amino Acid", "Non-Essential"),
    ("arg_L_c", "Amino Acid", "Non-Essential"),
    ("asn_L_c", "Amino Acid", "Non-Essential"),
    ("asp_L_c", "Amino Acid", "Non-Essential"),
    ("cys_L_c", "Amino Acid", "Non-Essential"),
    ("gln_L_c", "Amino Acid", "Non-Essential"),
    ("glu_L_c", "Amino Acid", "Non-Essential"),
    ("gly_c", "Amino Acid", "Non-Essential"),
    ("pro_L_c", "Amino Acid", "Non-Essential"),
    ("ser_L_c", "Amino Acid", "Non-Essential"),
    ("tyr_L_c", "Amino Acid", "Non-Essential"),
    ("atp_c", "Ribonucleotide", "Nucleotides"),
    ("ctp_c", "Ribonucleotide", "Nucleotides"),
    ("gtp_c", "Ribonucleotide", "Nucleotides"),
    ("utp_c", "Ribonucleotide", "Nucleotides"),
    ("datp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dctp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dgtp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("dttp_n", "Deoxyribonucleotide", "Deoxynucleotides"),
    ("chsterol_c", "Lipid", "Lipids"),
    ("clpn_hs_c", "Lipid", "Lipids"),
    ("pail_hs_c", "Lipid", "Lipids"),
    ("pchol_hs_c", "Lipid", "Lipids"),
    ("pe_hs_c", "Lipid", "Lipids"),
    ("pglyc_hs_c", "Lipid", "Lipids"),
    ("ps_hs_c", "Lipid", "Lipids"),
    ("sphmyln_hs_c", "Lipid", "Lipids"),
    ("g6p_c", "Carbohydrate", "")
], columns=["Metabolite", "Category", "Subcategory"])

# Reorder columns by subcategory
met_to_subcat = group_df.set_index("Metabolite")["Subcategory"].to_dict()
existing_columns = [col for col in data_to_plot_WT.columns if col in met_to_subcat and met_to_subcat[col] not in ["Essential", "Non-Essential"]]

subcategory_order = [
    'Essential',
    'Deoxynucleotides',
    'Nucleotides',
    "Non-Essential",
    "Lipids",
    ""
]

sorted_columns = sorted(
    existing_columns,
    key=lambda x: (subcategory_order.index(met_to_subcat[x]), x)
)
data_to_plot_WT = data_to_plot_WT[sorted_columns]
data_to_plot_MUT = data_to_plot_MUT[sorted_columns]

# --- Brackets data for Y-axis (since we'll plot transposed) ---
subcat_to_indices_y = OrderedDict()
for i, row in enumerate(data_to_plot_WT.T.index):
    subcat = met_to_subcat[row]
    if subcat not in subcat_to_indices_y:
        subcat_to_indices_y[subcat] = [i, i]
    else:
        subcat_to_indices_y[subcat][1] = i

def add_brackets_y(ax, subcat_to_indices, met_to_subcat):
    bracket_x = -3.7
    text_x = -4
    for label, (start, end) in subcat_to_indices.items():
        if label == '':
            continue
        if label == 'Essential':
            label = 'Essential AAs'
        elif label == 'Non-Essential':
            label = 'Non-Essential AAs'
        ax.vlines(bracket_x, start + 0.1, end + 1 - 0.1, color='black', linewidth=1.5, clip_on=False)
        ax.hlines([start + 0.1, end + 1 - 0.1], bracket_x, bracket_x + 0.5, color='black', linewidth=1.5, clip_on=False)
        ax.text(text_x, (start + end + 1) / 2, label, ha='center', va='center', fontsize=15, rotation=90, clip_on=False)

# --- Figure: stack vertically, transpose for plotting ---
fig = plt.figure(figsize=(35, 20))
plt.rcParams.update({'font.size': 22})

ax1 = plt.subplot2grid((25, 1), (0, 0), rowspan=12)   # WT top
ax2 = plt.subplot2grid((25, 1), (13, 0), rowspan=12)  # MUT bottom

# Remove the last two characters from each of the column names
data_to_plot_WT.columns = [col[:-2] for col in data_to_plot_WT.columns]

# Remove the last two characters from each of the column names
data_to_plot_MUT.columns = [col[:-2] for col in data_to_plot_MUT.columns]


# WT heatmap (transposed)
im_wt = sns.heatmap(
    data_to_plot_WT.T,
    cmap='seismic',
    center=0,
    vmin=-2,
    vmax=2,
    cbar=False,
    ax=ax1,
    square=True
)
ax1.set_ylabel('Precursors', fontsize=26, labelpad=50)
ax1.tick_params(axis='both', which='major', labelsize=18)
# remove x ticks and labels on the first heatmap
ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
add_brackets_y(ax1, subcat_to_indices_y, met_to_subcat)

# MUT heatmap (transposed)
im_mut = sns.heatmap(
    data_to_plot_MUT.T,
    cmap='seismic',
    center=0,
    vmin=-2,
    vmax=2,
    cbar=False,
    ax=ax2,
    square=True
)
ax2.set_xlabel('Enzymes', fontsize=26)
ax2.set_ylabel('Precursors', fontsize=26, labelpad=50)
ax2.tick_params(axis='both', which='major', labelsize=18)
add_brackets_y(ax2, subcat_to_indices_y, met_to_subcat)

# Optional separators were for x-groups; after transpose they'd be y-lines—omit to keep minimal
# for y in group_ends: ax1.axhline(y=y, ...); ax2.axhline(y=y, ...)

# Colorbar to the right of the bottom heatmap, slimmer
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cax = inset_axes(
    ax2,
    width="3%",
    height="50%",
    loc='center left',
    bbox_to_anchor=(1.05, 0.35, 1, 1),
    bbox_transform=ax2.transAxes,
    borderpad=0
)
cbar = plt.colorbar(im_mut.collections[0], cax=cax)
cbar.set_label('Effect on each biomass precursor', fontsize=20, labelpad=6)
cbar.set_ticks([-2, 0, 2])
cbar.set_ticklabels(['< -2', '0', '> 2'])
cbar.ax.tick_params(labelsize=20, length=3, width=1)

plt.tight_layout()
plt.subplots_adjust(hspace=0.05)  # reduce vertical gap
plt.tight_layout()

# Save figure
plt.rcParams['pdf.fonttype'] = 42  # For PDF
plt.rcParams['svg.fonttype'] = 'none'  # For SVG
plt.savefig('../../results/MCA_workflow/MUT_CCC_bio_precursors_heatmap_double_flipped.pdf',
            dpi=300,
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)
plt.show()
