In [None]:
import numpy as np
import pandas as pd
from skimpy.analysis.oracle import *
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.append("../../../../NRAplus/NRAplus") # Adds higher directory to python modules path.
sys.path.append('../')
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import pandas as pd
from collections import OrderedDict
import plotly.graph_objects as go
import kaleido
import sys
import os
sys.path.append(os.path.abspath('..'))


from utils.nra_save_custom_json import load_json_nra_model
from pytfa.io.json import load_json_model

import configparser
import os

config = configparser.ConfigParser()
config.read(os.path.abspath('../src/config.ini'))
base_dir = config['paths']['base_dir']
path_to_nra_model = os.path.join(base_dir, config['paths']['path_to_nra_model'])
path_to_solutions = os.path.join(base_dir, config['paths']['path_to_solutions'])
path_to_essential_enzymes = os.path.join(base_dir, config['paths']['path_to_essential_enzymes'])
path_to_tflink_database = os.path.join(base_dir, config['paths']['path_to_tflink_database'])
path_to_recon_model = os.path.join(base_dir, config['paths']['path_to_recon_model'])
path_to_gene_to_uniprot_mapping = os.path.join(base_dir, config['paths']['path_to_gene_to_uniprot_mapping'])
path_to_essential_enzymes_to_TFs_mapping = os.path.join(base_dir, config['paths']['path_to_essential_enzymes_to_TFs_mapping'])  

In [None]:
nra_model = load_json_nra_model(path_to_nra_model)

In [None]:
# Read the essential enzymes txt file
with open(path_to_essential_enzymes, 'r') as f:
    essential_enzymes_list = [fline.strip() for fline in f.readlines()]

# Load the reference solution
sol = pd.read_csv(path_to_solutions.format('250'), index_col=0)

In [None]:
# Save in a dataframe the enzyme name if it is up or down regulated and the subsystem it belongs to
essential_enzymes = pd.DataFrame(columns=['Enzyme', 'Subsystem', 'Regulation'])
for enz in essential_enzymes_list:
    var_up = sol.loc['EUU_' + enz].values[0]
    var_down = sol.loc['EDU_' + enz].values[0]
    if np.isclose(var_up, 1.0):
        regulation = 'up'
    elif np.isclose(var_down, 1.0):
        regulation = 'down'
    else:
        raise ValueError(f"Enzyme {enz} is not up or down regulated")

    rxn = nra_model.reactions.get_by_id(enz)
    sub = rxn.subsystem
    essential_enzymes = essential_enzymes.append({'Enzyme': enz, 'Subsystem': sub, 'Regulation': regulation}, ignore_index=True)

# Set the index to the enzyme name
essential_enzymes.set_index('Enzyme', inplace=True)

# Print how many enzymes are up or down regulated 
print(f"Number of down regulated enzymes: {len(essential_enzymes[essential_enzymes['Regulation'] == 'down'])}")
print(f"Number of up regulated enzymes: {len(essential_enzymes[essential_enzymes['Regulation'] == 'up'])}")

In [None]:
# Set the style for high-impact journals
plt.style.use('default')
sns.set_palette("Set2")

# Count number of up- and down-regulated enzymes per subsystem
df = essential_enzymes.copy()

# Rename the long subsystem
df['Subsystem'] = df['Subsystem'].replace(
    'Glycine, serine, alanine, and threonine metabolism',
    'Gly/Se/rAla/Theo metabolism'
)

# Lump subsystems with fewer than 8 enzymes into "Other"
subsystem_counts = df['Subsystem'].value_counts()
rare_subsystems = subsystem_counts[subsystem_counts < 8].index
df['Subsystem'] = df['Subsystem'].replace(rare_subsystems, 'Other')

# Group and count
count_df = df.groupby(['Subsystem', 'Regulation']).size().reset_index(name='Count')

# Ensure 'Other' is at the end
subsystem_order = [s for s in count_df['Subsystem'].unique() if s != 'Other']
subsystem_order.sort()
subsystem_order.append('Other')

colors = ['#2E86AB', '#A23B72']  # Blue and burgundy

# Calculate total essential enzymes per subsystem and sort
totals = [
    "Gly/Se/rAla/Theo metabolism",
    "Arginine and proline metabolism",
    "Glutamate metabolism",
    "Glycolysis/gluconeogenesis",
    "Pentose phosphate pathway",
    "Citric acid cycle",
    "Urea cycle",
    "Pyrimidine synthesis",
    "NAD metabolism",
    "Other"
]


fig, ax = plt.subplots(figsize=(4, 6))  # taller than it is wide
sns.barplot(
    data=count_df,
    y='Subsystem',
    x='Count',
    hue='Regulation',
    order=totals,
    palette=colors,
    ax=ax,
    edgecolor='black',
    linewidth=1
)

# swap labels
ax.set_ylabel('Metabolic Pathway', fontsize=12)
ax.set_xlabel('Number of Essential Enzymes', fontsize=12)

# keep pathway names horizontal
plt.yticks(rotation=0)
ax.tick_params(axis='y', which='both', length=0)  # ensure no little stubs


# move legend to the right, outside the axes
legend = ax.legend(
    loc='center',
    bbox_to_anchor=(1.0, 1.0),
    frameon=False
)

# rename legend entries
legend.texts[0].set_text('Downregulated')
legend.texts[1].set_text('Upregulated')

# remove top and right spines
sns.despine(ax=ax, top=True, right=True)

# give a bit of extra right margin so the legend isn’t cut off
plt.tight_layout(rect=[0, 0, 0.85, 1])

# plt.savefig(
#     '../../results/physiology_comparison/essential_enzymes_subystems_horizontal_legend.pdf',
#     dpi=300,
#     transparent=True,
#     bbox_inches='tight',
#     pad_inches=0.1
# )
plt.show()

# Figure 5E

In [None]:
# Load the TFLink database
tf_link = pd.read_csv(path_to_tflink_database, sep='\t')

from cobra.io.json import load_json_model # FBA type import
recon = load_json_model(path_to_recon_model)

# Load gene to UNIPROT mapping
gene_uniprot_mapping = pd.read_csv(path_to_gene_to_uniprot_mapping, sep='\t')

# The first column should be a string
gene_uniprot_mapping.iloc[:, 0] = gene_uniprot_mapping.iloc[:, 0].astype(str)

In [None]:
brca1_total = tf_link[tf_link['Name.TF']=='BRCA1']

# Find which genes from recon3d are targeted by BRCA1
genes_identified = []
for ll in gene_uniprot_mapping.index:
    ncbi = gene_uniprot_mapping.loc[ll,'gene_number'].split('.')[0]
    if str(ncbi) in tf_link.loc[:,'NCBI.GeneID.Target'].values:
        row = tf_link[tf_link.loc[:,'NCBI.GeneID.Target'] == str(ncbi)]
        print(row)
        genes_identified.append(row)

# Make into a datframe
genes_identified = pd.concat(genes_identified, ignore_index=True)
genes_identified = genes_identified.drop_duplicates(ignore_index=True)

# Now we go from those TFs as targets to BRCA1 as a TF
tfs_identified = []
for ll in genes_identified.index:
    tf_name = genes_identified.loc[ll,'Name.TF']
    if tf_name in brca1_total.loc[:,'Name.TF'].values:
        row = brca1_total[brca1_total.loc[:,'Name.TF'] == tf_name]
        print(row)
        tfs_identified.append(row)

# Make into a datframe
tfs_identified = pd.concat(tfs_identified, ignore_index=True)
tfs_identified = tfs_identified.drop_duplicates(ignore_index=True)

# Go back to genes through the TFS identified in tfs_identified
brca1_specific_genes = []
for ll in tfs_identified.index:
    tf_name = tfs_identified.loc[ll,'Name.Target']
    if tf_name in genes_identified.loc[:,'Name.TF'].values:
        row = genes_identified[genes_identified.loc[:,'Name.TF'] == tf_name]
        print(row)
        brca1_specific_genes.append(row)

# Make into a datframe
brca1_specific_genes = pd.concat(brca1_specific_genes, ignore_index=True)
brca1_specific_genes = brca1_specific_genes.drop_duplicates(ignore_index=True)

In [None]:
enz = []
for i in recon.genes:
    name = i.id.split('_')[0]
    if name in brca1_specific_genes.loc[:,'NCBI.GeneID.Target'].values:
        print(name)

        for rxn in i.reactions:
            if rxn.id not in enz:
                enz.append(rxn.id)
                print(rxn.id)
print(f'Total number of reactions associated with BRCA1: {len(enz)}')

# Do the same for the small model
enz_small = []
for i in nra_model.genes:
    name = i.id.split('.',1)[0]
    if name in brca1_specific_genes.loc[:,'NCBI.GeneID.Target'].values:
        print(name)

        for rxn in i.reactions:
            if rxn.id not in enz_small:
                enz_small.append(rxn.id)
                print(rxn.id)
print(f'Total number of reactions associated with BRCA1 in the small model: {len(enz_small)}')

In [None]:
# For each essential enzyme find the genes ascosiated with it
essential_enzymes['genes'] = pd.Series(dtype=object) # Initialize with object dtype
essential_enzymes['uniprot_genes'] = pd.Series(dtype=object) # Initialize with object dtype
essential_enzymes['hugo_genes'] = pd.Series(dtype=object) # Initialize with object dtype

# Iterate through each essential enzyme (reaction ID)
for enz in essential_enzymes.index:
    rxn = nra_model.reactions.get_by_id(enz)
    genes = [g.id for g in rxn.genes]

    if genes:
        # Join multiple gene IDs with ';'
        essential_enzymes.at[enz, 'genes'] = ';'.join(genes)
        # Find the corresponding uniprot IDs
        uniprot_ids = []
        for gene in genes:
            if gene in gene_uniprot_mapping['gene_number'].values:
                uniprot_id = gene_uniprot_mapping.loc[gene_uniprot_mapping['gene_number'] == gene, 'uniprot_gname'].values[0]
                if uniprot_id not in uniprot_ids and not pd.isna(uniprot_id):
                    uniprot_ids.append(uniprot_id)
        # Join multiple uniprot IDs with ';'
        if uniprot_ids:
            essential_enzymes.at[enz, 'uniprot_genes'] = ';'.join(uniprot_ids)
        else:
            # Keep as NaN if no uniprot IDs are found
            essential_enzymes.at[enz, 'uniprot_genes'] = np.nan

        # Find the corresponding HUGO gene names
        hugo_genes = []
        for gene in genes:
            if gene in gene_uniprot_mapping['gene_number'].values:
                hugo_id = gene_uniprot_mapping.loc[gene_uniprot_mapping['gene_number'] == gene, 'symbol'].values[0]
                if hugo_id not in hugo_genes and not pd.isna(hugo_id):
                    hugo_genes.append(hugo_id)
        # Join multiple HUGO gene names with ';'
        if hugo_genes:
            essential_enzymes.at[enz, 'hugo_genes'] = ';'.join(hugo_genes)
        else:
            # Keep as NaN if no HUGO gene names are found
            essential_enzymes.at[enz, 'hugo_genes'] = np.nan
    else:
        # Keep as NaN if no genes are associated
        essential_enzymes.at[enz, 'genes'] = np.nan

In [None]:
# Check if TFLink returns more TF hits for the essential enzymes
essential_enzymes['TF_link'] = pd.Series(dtype=object) # Initialize with object dtype

for enz in essential_enzymes.index:
    uniprot_ids = essential_enzymes.loc[enz, 'uniprot_genes']
    if not pd.isna(uniprot_ids):
        uniprot_ids_split = uniprot_ids.split(';')
        tf_hits = []
        for uniprot_id in uniprot_ids_split:
            tf_matches = tf_link.loc[tf_link['UniprotID.Target'] == uniprot_id, 'Name.TF']
            if not tf_matches.empty:
                tf_hits.extend(tf_matches.unique())
        if tf_hits:
            essential_enzymes.at[enz, 'TF_link'] = ';'.join(tf_hits)
        else:
            essential_enzymes.at[enz, 'TF_link'] = np.nan

In [None]:
# Find how many unique genes are there
gene_list = []
for i in essential_enzymes['genes'].dropna():
    genes = i.split(';')
    for gene in genes:
        gene_list.append(gene)
print(f"{essential_enzymes.genes.notna().sum()}/{len(essential_enzymes)} essential enzymes have genes associated with them")
print(f"Number of genes: {len(gene_list)}")
print(f"Number of unique genes: {len(set(gene_list))}\n")

# Find how many unique TFs are there for TF_link
tfs_list_tf_link = []
for i in essential_enzymes['TF_link'].dropna():
    tfs = i.split(';')
    for tf in tfs:
        tfs_list_tf_link.append(tf)
print(f"{essential_enzymes.TF_link.notna().sum()}/{len(essential_enzymes)} essential enzymes have TF_link TFs associated with them")
print(f"Number of TFs: {len(tfs_list_tf_link)}")
print(f"Number of unique TFs: {len(set(tfs_list_tf_link))}\n")

# Find the unique genes that are associated with the TF_link TFs
genes_tfs_tf_link = []
for i in essential_enzymes[essential_enzymes.TF_link.notna()].genes:
    genes = i.split(';')
    for gene in genes:
        genes_tfs_tf_link.append(gene)
print(f"Number of genes associated with TF_link TFs: {len(genes_tfs_tf_link)}")
print(f"Number of unique genes associated with TF_link TFs: {len(set(genes_tfs_tf_link))}\n")
print('------------------------------------------------------')

print(f'In total {essential_enzymes.TF_link.notna().sum()}/{len(essential_enzymes)} essential enzymes are connected to {len(set(genes_tfs_tf_link))} unique genes and are regulated by {len(set(tfs_list_tf_link))} unique TFs (TF_link)')

In [None]:
essential_enzymes.to_csv(path_to_essential_enzymes_to_TFs_mapping)

In [None]:
# Load the CSV file
df = pd.read_csv('../../results/physiology_comparison/essential_enzymes_with_genes_and_TFs.csv')

# Drop rows missing TF or gene
df_filtered = df.dropna(subset=['TF_link', 'hugo_genes'])

# Split semicolon-separated values
df_filtered['hugo_genes'] = df_filtered['hugo_genes'].astype(str).str.split(';')
df_filtered['TF_link'] = df_filtered['TF_link'].astype(str).str.split(';')

# Explode all combinations
df_exploded = df_filtered.explode('hugo_genes').explode('TF_link')
df_exploded['hugo_genes'] = df_exploded['hugo_genes'].str.strip()
df_exploded['TF_link'] = df_exploded['TF_link'].str.strip()

# !!! IMPORTANT: Apply NADH2_u10mi merge before filtering to top TFs if we want this aggregated enzyme
# to be considered in the overall flow calculation for top TFs or relevant enzymes.
# If the merge is done later, the 'Enzyme' column for NADH2_u10mi might still be the original for earlier steps.
# Let's keep it here as it was.
df_exploded['hugo_genes'] = df_exploded.apply(
    lambda row: 'NADH2_u10mi genes' if row['Enzyme'] == 'NADH2_u10mi' else row['hugo_genes'],
    axis=1
)

# Filter to top 5 most common TFs based on all original connections
top_5_tfs_all_connections = df_exploded['TF_link'].value_counts().nlargest(5).index.tolist()
top_5_tfs_all_connections = ['SP1', 'MYC', 'CREB1', 'NFKB1', 'AR', 'HIF1A', 'STAT1', 'EGR1', 'TP53']  # Use the same top TFs as before
df_top = df_exploded[df_exploded['TF_link'].isin(top_5_tfs_all_connections)].copy() # Use .copy() to avoid SettingWithCopyWarning

# Filter relevant columns and remove rows without Subsystem info
df_top.dropna(subset=['Subsystem'], inplace=True)


# --- REVISED LOGIC FOR ORDERING NODES BY FLOW ---

# 1. Prepare link dataframes with calculated 'value' for sorting
#    These are the actual links that will be drawn, so their 'value' is what we sort by.

# TF -> Enzyme links
tf_enzyme_flows = df_top.groupby(['TF_link', 'Enzyme']).size().reset_index(name='value')

# Enzyme -> Subsystem links
enzyme_subsystem_flows = df_top.groupby(['Enzyme', 'Subsystem']).size().reset_index(name='value')

# 2. Calculate total 'flow' for each node type (what's coming into it or leaving it)

# TFs: Sort by their total outgoing flow to enzymes (which is their primary contribution)
tf_outgoing_flow = tf_enzyme_flows.groupby('TF_link')['value'].sum().sort_values(ascending=False)
sorted_tfs = tf_outgoing_flow.index.tolist()

# Enzymes: Sort by their total incoming flow from TFs.
# This aligns with "inputs in each layer ... biggest inflow to smallest".
# NADH2_u10mi is already handled in df_exploded before this step.
enzyme_inflow_sum = tf_enzyme_flows.groupby('Enzyme')['value'].sum().sort_values(ascending=False)

# Ensure only enzymes present in the filtered df_top are considered
relevant_enzymes_in_df_top = df_top['Enzyme'].unique().tolist()
# Filter enzyme_inflow_sum to include only relevant enzymes
sorted_enzymes = [e for e in enzyme_inflow_sum.index.tolist() if e in relevant_enzymes_in_df_top]


# Subsystems: Sort by their total incoming flow from Enzymes.
subsystem_inflow_sum = enzyme_subsystem_flows.groupby('Subsystem')['value'].sum().sort_values(ascending=False)

# Ensure only subsystems present in the filtered df_top are considered
relevant_subsystems_in_df_top = df_top['Subsystem'].unique().tolist()
# Filter subsystem_inflow_sum to include only relevant subsystems
sorted_subsystems = [s for s in subsystem_inflow_sum.index.tolist() if s in relevant_subsystems_in_df_top]


# 3. Combine all nodes in the desired order for Sankey
# This order is CRITICAL for Plotly's rendering.
nodes = sorted_tfs + sorted_enzymes + sorted_subsystems
node_indices = {name: i for i, name in enumerate(nodes)}

# --- END REVISED LOGIC ---

# Build Sankey links using the prepared flow dataframes
source, target, value = [], [], []

# TF → Enzyme links
for _, row in tf_enzyme_flows.iterrows():
    # Only add links if both source and target nodes exist in our final `nodes` list
    if row['TF_link'] in node_indices and row['Enzyme'] in node_indices:
        source.append(node_indices[row['TF_link']])
        target.append(node_indices[row['Enzyme']])
        value.append(row['value'])

# Enzyme → Subsystem links
for _, row in enzyme_subsystem_flows.iterrows():
    # Only add links if both source and target nodes exist in our final `nodes` list
    if row['Enzyme'] in node_indices and row['Subsystem'] in node_indices:
        source.append(node_indices[row['Enzyme']])
        target.append(node_indices[row['Subsystem']])
        value.append(row['value'])


# Node colors (using the new sorted lists for checks)
node_colors = []
for node in nodes:
    if node in sorted_tfs:
        node_colors.append('rgba(255, 127, 14, 0.8)')  # TF = orange
    elif node in sorted_enzymes:
        node_colors.append('rgba(44, 160, 44, 0.8)') # Enzyme = green
    else: # Subsystem
        node_colors.append('rgba(31, 119, 180, 0.8)') # Subsystem = blue

# Link colors by source type (using the new sorted lists for checks)
link_colors = []
for s_idx in source: # s_idx is the index in the 'nodes' list
    if nodes[s_idx] in sorted_tfs:
        link_colors.append('rgba(255, 127, 14, 0.3)')  # TF→Enzyme (orange-ish)
    elif nodes[s_idx] in sorted_enzymes:
        link_colors.append('rgba(44, 160, 44, 0.3)')   # Enzyme→Subsystem (green-ish)

# Create Sankey plot
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=20,
        thickness=25,
        line=dict(color="black", width=0.8),
        label=nodes, # This is the crucial part for ordering
        color=node_colors,
        hovertemplate='<b>%{label}</b><br>Total flow: %{value}<extra></extra>'
    ),
    link=dict(
        source=source,
        target=target,
        value=value,
        color=link_colors,
        hovertemplate='<b>%{source.label}</b> → <b>%{target.label}</b><br>Flow: %{value}<extra></extra>'
    )
)])

fig.update_layout(
    title={
        'text': "",
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16, 'family': 'Arial, sans-serif'}
    },
    font=dict(size=16, family='Arial, sans-serif'),
    plot_bgcolor='rgba(0,0,0,0)',    # transparent plotting area
    paper_bgcolor='rgba(0,0,0,0)',   # transparent paper/background
    width=1800,
    height=1200,
    margin=dict(l=50, r=50, t=80, b=50)
)

# TFs title in orange
fig.add_annotation(dict(
    x=0,
    y=0.95,
    xref='paper',
    yref='paper',
    text='TFs',
    showarrow=False,
    font=dict(size=24, family='Arial, sans-serif',
              color='rgba(255,127,14,1)')  # orange
))

# Enzymes title in green
fig.add_annotation(dict(
    x=0.5,
    y=1.03,
    xref='paper',
    yref='paper',
    text='Enzymes',
    showarrow=False,
    font=dict(size=24, family='Arial, sans-serif',
              color='rgba(44,160,44,1)')   # green
))

# Pathways title in blue
fig.add_annotation(dict(
    x=1.015,
    y=1.03,
    xref='paper',
    yref='paper',
    text='Pathways',
    showarrow=False,
    font=dict(size=24, family='Arial, sans-serif',
              color='rgba(31,119,180,1)')  # blue
))

# # Save the figure to pdf
# fig.write_image("sankey_plot.png", format="png", scale=3)

# # Save the figure as pdf
# fig.write_image("../../results/physiology_comparison/sankey_plot.pdf", format="pdf", scale=3)




fig.show()