In [None]:
import pandas as pd
import numpy as np

PHYSIOLOGIES = ['WT', 'MUT']
TARGET_NAMES = ['TMDS', 'HEX', 'TRIOK']
path_to_metabolite_enrichment_analysis = 'data/drug_target_simulation/{}_{}/conc_fold_changes/metabolite_enrichment_analysis.csv'

In [None]:
reordered_index = [
    # Central carbon metabolism
    'Glycolysis/gluconeogenesis',
    'Pyruvate metabolism',
    'Citric acid cycle',
    'Pentose phosphate pathway',
    'Glyoxylate and dicarboxylate metabolism',
    'Urea cycle',

    # Amino acid metabolism
    'Alanine and aspartate metabolism',
    'Arginine and proline metabolism',
    'Beta-Alanine metabolism',
    'Glutamate metabolism',
    'Histidine metabolism',
    'Lysine metabolism',
    'Glycine, serine, alanine, and threonine metabolism',
    'Methionine and cysteine metabolism',
    'Phenylalanine metabolism',
    'Tryptophan metabolism',
    'Tyrosine metabolism',
    'Valine, leucine, and isoleucine metabolism',

    # Nucleotide metabolism
    'Purine synthesis',
    'Pyrimidine synthesis',
    'Purine catabolism',
    'Pyrimidine catabolism',
    'Nucleotide interconversion',

    # Lipid metabolism
    'Sphingolipid metabolism',
    'Fatty acid synthesis',
    'Fatty acid oxidation',
    'Cholesterol metabolism',
    'Squalene and cholesterol synthesis',
    'Bile acid synthesis',
    'Glycerophospholipid metabolism',

    # Cofactor and vitamin metabolism
    'Folate metabolism',
    'Tetrahydrobiopterin metabolism',
    'NAD metabolism',
    'Inositol phosphate metabolism',

    # Energy metabolism
    'Oxidative Phosphorylation',

    # Detoxification
    'ROS detoxification',
    'Glutathione metabolism',
]

In [None]:
path_to_metabolite_enrichment_analysis = '../../results/drug_target_simulation/{}_stratified/{}/pathway_enrichment_with_metabolites.csv'

In [None]:
subsystem_changes = []
for ph in PHYSIOLOGIES:
    for target in TARGET_NAMES:
        # Load the metabolite pathway enrichment analysis
        data = pd.read_csv(path_to_metabolite_enrichment_analysis.format(ph, target), index_col=0)
        data['significant_changes'] = data['significant_changes'].apply(lambda x: 0 if np.isclose(x, 0) else x)
        data.index = data.index.str.replace('ETC_Rxns', 'Oxidative Phosphorylation')
        data.drop(data[data.index.str.startswith('Transport')].index, inplace=True)
        data.drop(data[data.index.str.startswith('Miscellaneous')].index, inplace=True)
        data.drop(data[data.index.str.startswith('Exchange/demand reaction')].index, inplace=True)
        data = data.loc[reordered_index]
        subsystem_changes.append(data)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.patches import Rectangle


dataframes = subsystem_changes
titles = ["TMDS WT", "HEX WT", "TRIOK WT", "TMDS MUT", "HEX MUT", "TRIOK MUT"]

fraction_0 = subsystem_changes[0]["significant_changes"] / subsystem_changes[0]["total_metabolites"]
# sorted_index = fraction_0.sort_values(ascending=False).index
sorted_index = subsystem_changes[0].index
dataframes = [df.reindex(sorted_index) for df in dataframes]
subsystem_names = sorted_index
y_positions = np.arange(len(subsystem_names))

# Collect all p-values for global normalization
all_p_values = np.concatenate([
    df["p_value"].replace(0, np.nextafter(0, 1)).values
    for df in dataframes
])
norm = plt.Normalize(0, 1)          # p-values always between 0 and 1
cmap = plt.cm.viridis_r             # reversed: small p ⇒ darker colour

fig, ax = plt.subplots(figsize=(3.5 * len(dataframes), 20))
box_size  = 1.3
x_spacing = box_size + 0.1
x_offsets = [i * x_spacing for i in range(len(dataframes))]

# light grey panels behind each column
for x_offset in x_offsets:
    rect = Rectangle(
        (x_offset - box_size / 2, -0.5),
        width=box_size,
        height=len(subsystem_names),
        edgecolor='gray',
        facecolor='whitesmoke',
        linewidth=1,
        zorder=1
    )
    ax.add_patch(rect)

# main scatter
for df, x_offset in zip(dataframes, x_offsets):
    fraction_changed = df["significant_changes"].values / df["total_metabolites"].values
    p_values = df["p_value"].replace(0, np.nextafter(0, 1)).values
    x_positions = np.full(len(subsystem_names), x_offset)

    scatter = ax.scatter(
        x=x_positions,
        y=y_positions,
        s=fraction_changed * 1000,  # <<< size mapping
        c=p_values,
        cmap=cmap,
        norm=norm,
        edgecolor='black',
        zorder=3
    )

ax.set_yticks(y_positions)
ax.set_yticklabels(subsystem_names, fontsize=10)
ax.set_xticks(x_offsets)
ax.set_xticklabels(titles, fontsize=12, rotation=45, ha='right')
ax.set_aspect('equal')

# ─────────────────────────────────────────────────────
# colour-bar (horizontal, above plot)
pos                = ax.get_position()
cbar_width_ratio   = 0.5          # 50 % of the data-axes width
cbar_height        = 0.01
cbar_x_start       = pos.x0 + (pos.width * (1 - cbar_width_ratio)) / 2
cbar_y_start       = pos.y1 + 0.005

cbar_ax = fig.add_axes([cbar_x_start, cbar_y_start,
                        pos.width * cbar_width_ratio, cbar_height])
cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal')
cbar.set_label('p-value', fontsize=12)
cbar.ax.xaxis.set_ticks_position('bottom')
cbar.ax.xaxis.set_label_position('top')

# ─────────────────────────────────────────────────────
# **REFERENCE-SIZE CIRCLES**  ← added block
# choose the example fractions you want to illustrate
ref_fractions = [0.1, 0.3, 0.5]      # 10 %, 30 %, 50 %
ref_labels    = ['10%', '30%', '50%']

# tiny axes, parked just right of the colour-bar
ref_circle_x_start = cbar_x_start + pos.width * cbar_width_ratio + 0.02
ref_circle_y_center = cbar_y_start + cbar_height / 2
ref_ax = fig.add_axes([ref_circle_x_start, ref_circle_y_center - 0.015,
                       0.12, 0.03])
ref_ax.set_xlim(0, 1)
ref_ax.set_ylim(-1, 1)
ref_ax.axis('off')

for i, (frac, label) in enumerate(zip(ref_fractions, ref_labels)):
    x_pos = i * 0.32 + 0.16          # space circles evenly
    ref_ax.scatter(
        x_pos, 0,
        s=frac * 1000,               # ***same size mapping as main plot***
        c='lightgray',
        edgecolor='black',
        linewidth=1,
        zorder=3
    )
    ref_ax.text(x_pos, -0.7, label,
                ha='center', va='top', fontsize=10)

ref_ax.text(0.5, 0.8, 'Fraction of significant\nchanges',
            ha='center', va='bottom',
            fontsize=10, weight='bold',
            transform=ref_ax.transAxes)
# ─────────────────────────────────────────────────────

sns.despine(left=True, bottom=True)
plt.subplots_adjust(right=0.85)     # keep room on right for brackets

# ═════════════════════════════════════════════════════
# category brackets (unchanged)
category_definitions = {
    "Central carbon metabolism": [
        'Glycolysis/gluconeogenesis',
        'Pyruvate metabolism',
        'Citric acid cycle',
        'Pentose phosphate pathway',
        'Glyoxylate and dicarboxylate metabolism',
        'Urea cycle'
    ],
    "Amino acid metabolism": [
        'Alanine and aspartate metabolism',
        'Arginine and proline metabolism',
        'Beta-Alanine metabolism',
        'Glutamate metabolism',
        'Histidine metabolism',
        'Lysine metabolism',
        'Glycine, serine, alanine, and threonine metabolism',
        'Methionine and cysteine metabolism',
        'Phenylalanine metabolism',
        'Tryptophan metabolism',
        'Tyrosine metabolism',
        'Valine, leucine, and isoleucine metabolism'
    ],
    "Nucleotide metabolism": [
        'Purine synthesis',
        'Pyrimidine synthesis',
        'Purine catabolism',
        'Pyrimidine catabolism',
        'Nucleotide interconversion'
    ],
    "Lipid metabolism": [
        'Sphingolipid metabolism',
        'Fatty acid synthesis',
        'Fatty acid oxidation',
        'Cholesterol metabolism',
        'Squalene and cholesterol synthesis',
        'Bile acid synthesis',
        'Glycerophospholipid metabolism'
    ],
    "Cofactor and vitamin metabolism": [
        'Folate metabolism',
        'Tetrahydrobiopterin metabolism',
        'NAD metabolism',
        'Inositol phosphate metabolism'
    ],
    "Energy metabolism": ['Oxidative Phosphorylation'],
    "Detoxification":    ['ROS detoxification', 'Glutathione metabolism']
}

def calculate_category_positions(subsystem_names, category_definitions):
    category_positions = {}
    for category, pathways in category_definitions.items():
        indices = [list(subsystem_names).index(p) for p in pathways
                   if p in subsystem_names]
        if indices:
            category_positions[category] = (min(indices), max(indices))
    return category_positions

category_positions = calculate_category_positions(subsystem_names,
                                                  category_definitions)

x_bracket = x_offsets[-1] + box_size / 2 + 0.3
x_text    = x_bracket + 0.1

for category, (start, end) in category_positions.items():
    mid = (start + end) / 2
    ax.annotate(category, xy=(x_bracket, mid), xytext=(x_text, mid),
                va='center', ha='left', fontsize=10, weight='bold',
                annotation_clip=False)
    ax.plot([x_bracket, x_bracket], [start - 0.4, end + 0.4],
            color='black', lw=1.2, clip_on=False)
    ax.plot([x_bracket, x_bracket - 0.1],
            [start - 0.4, start - 0.4], color='black', lw=1.2,
            clip_on=False)
    ax.plot([x_bracket, x_bracket - 0.1],
            [end + 0.4, end + 0.4], color='black', lw=1.2,
            clip_on=False)

plt.tight_layout(rect=[0, 0, 1, 0.88])   # make space for top colour-bar

# save & show
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'
# plt.savefig('../../results/drug_target_simulation/pathway_enrichment_analysis.pdf',
#             transparent=True, bbox_inches='tight', pad_inches=0.1)
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.patches import Rectangle


dataframes = subsystem_changes
titles = ["TMDS WT", "HEX WT", "TRIOK WT", "TMDS MUT", "HEX MUT", "TRIOK MUT"]

fraction_0 = subsystem_changes[0]["significant_changes"] / subsystem_changes[0]["total_metabolites"]
# sorted_index = fraction_0.sort_values(ascending=False).index
sorted_index = subsystem_changes[0].index
dataframes = [df.reindex(sorted_index) for df in dataframes]
subsystem_names = sorted_index
y_positions = np.arange(len(subsystem_names))

# Collect all p-values for global normalization
all_p_values = np.concatenate([
    df["p_value"].replace(0, np.nextafter(0, 1)).values
    for df in dataframes
])
norm = plt.Normalize(0, 1)          # p-values always between 0 and 1
cmap = plt.cm.viridis_r             # reversed: small p ⇒ darker colour

fig, ax = plt.subplots(figsize=(3.5 * len(dataframes), 20))
box_size  = 1.3
x_spacing = box_size + 0.1
x_offsets = [i * x_spacing for i in range(len(dataframes))]

# light grey panels behind each column
for x_offset in x_offsets:
    rect = Rectangle(
        (x_offset - box_size / 2, -0.5),
        width=box_size,
        height=len(subsystem_names),
        edgecolor='gray',
        facecolor='whitesmoke',
        linewidth=1,
        zorder=1
    )
    ax.add_patch(rect)

# main scatter
for df, x_offset in zip(dataframes, x_offsets):
    fraction_changed = df["significant_changes"].values / df["total_metabolites"].values
    p_values = df["p_value"].replace(0, np.nextafter(0, 1)).values
    x_positions = np.full(len(subsystem_names), x_offset)

    scatter = ax.scatter(
        x=x_positions,
        y=y_positions,
        s=fraction_changed * 1000,  # <<< size mapping
        c=p_values,
        cmap=cmap,
        norm=norm,
        edgecolor='black',
        zorder=3
    )

ax.set_yticks(y_positions)
ax.set_yticklabels(subsystem_names, fontsize=10)
ax.set_xticks(x_offsets)
ax.set_xticklabels(titles, fontsize=12, rotation=45, ha='right')
ax.set_aspect('equal')

# ─────────────────────────────────────────────────────
# colour-bar (horizontal, above plot)
pos                = ax.get_position()
cbar_width_ratio   = 0.5          # 50 % of the data-axes width
cbar_height        = 0.01
cbar_x_start       = pos.x0 + (pos.width * (1 - cbar_width_ratio)) / 2
cbar_y_start       = pos.y1 + 0.005

cbar_ax = fig.add_axes([cbar_x_start, cbar_y_start,
                        pos.width * cbar_width_ratio, cbar_height])
cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal')
cbar.set_label('p-value', fontsize=12)
cbar.ax.xaxis.set_ticks_position('bottom')
cbar.ax.xaxis.set_label_position('top')

# ─────────────────────────────────────────────────────
# **REFERENCE-SIZE CIRCLES**  ← added block
# choose the example fractions you want to illustrate
ref_fractions = [0.1, 0.3, 0.5]      # 10 %, 30 %, 50 %
ref_labels    = ['10%', '30%', '50%']

# tiny axes, parked just right of the colour-bar
ref_circle_x_start = cbar_x_start + pos.width * cbar_width_ratio + 0.02
ref_circle_y_center = cbar_y_start + cbar_height / 2
ref_ax = fig.add_axes([ref_circle_x_start, ref_circle_y_center - 0.015,
                       0.12, 0.03])
ref_ax.set_xlim(0, 1)
ref_ax.set_ylim(-1, 1)
ref_ax.axis('off')

for i, (frac, label) in enumerate(zip(ref_fractions, ref_labels)):
    x_pos = i * 0.32 + 0.16          # space circles evenly
    ref_ax.scatter(
        x_pos, 0,
        s=frac * 1000,               # ***same size mapping as main plot***
        c='lightgray',
        edgecolor='black',
        linewidth=1,
        zorder=3
    )
    ref_ax.text(x_pos, -0.7, label,
                ha='center', va='top', fontsize=10)

ref_ax.text(0.5, 0.8, 'Fraction of significant\nchanges',
            ha='center', va='bottom',
            fontsize=10, weight='bold',
            transform=ref_ax.transAxes)
# ─────────────────────────────────────────────────────

sns.despine(left=True, bottom=True)
plt.subplots_adjust(right=0.85)     # keep room on right for brackets

# ═════════════════════════════════════════════════════
# category brackets (unchanged)
category_definitions = {
    "Central carbon metabolism": [
        'Glycolysis/gluconeogenesis',
        'Pyruvate metabolism',
        'Citric acid cycle',
        'Pentose phosphate pathway',
        'Glyoxylate and dicarboxylate metabolism',
        'Urea cycle'
    ],
    "Amino acid metabolism": [
        'Alanine and aspartate metabolism',
        'Arginine and proline metabolism',
        'Beta-Alanine metabolism',
        'Glutamate metabolism',
        'Histidine metabolism',
        'Lysine metabolism',
        'Glycine, serine, alanine, and threonine metabolism',
        'Methionine and cysteine metabolism',
        'Phenylalanine metabolism',
        'Tryptophan metabolism',
        'Tyrosine metabolism',
        'Valine, leucine, and isoleucine metabolism'
    ],
    "Nucleotide metabolism": [
        'Purine synthesis',
        'Pyrimidine synthesis',
        'Purine catabolism',
        'Pyrimidine catabolism',
        'Nucleotide interconversion'
    ],
    "Lipid metabolism": [
        'Sphingolipid metabolism',
        'Fatty acid synthesis',
        'Fatty acid oxidation',
        'Cholesterol metabolism',
        'Squalene and cholesterol synthesis',
        'Bile acid synthesis',
        'Glycerophospholipid metabolism'
    ],
    "Cofactor and vitamin metabolism": [
        'Folate metabolism',
        'Tetrahydrobiopterin metabolism',
        'NAD metabolism',
        'Inositol phosphate metabolism'
    ],
    "Energy metabolism": ['Oxidative Phosphorylation'],
    "Detoxification":    ['ROS detoxification', 'Glutathione metabolism']
}

def calculate_category_positions(subsystem_names, category_definitions):
    category_positions = {}
    for category, pathways in category_definitions.items():
        indices = [list(subsystem_names).index(p) for p in pathways
                   if p in subsystem_names]
        if indices:
            category_positions[category] = (min(indices), max(indices))
    return category_positions

category_positions = calculate_category_positions(subsystem_names,
                                                  category_definitions)

# ═════════════════════════════════════════════════════
# category brackets - moved to the LEFT side (minimal changes)

# position just left of the first column
x_bracket = x_offsets[0] - box_size / 2 - 0.3
x_text    = x_bracket - 0.1

for category, (start, end) in category_positions.items():
    mid = (start + end) / 2
    ax.annotate(category, xy=(x_bracket, mid), xytext=(x_text, mid),
                va='center', ha='right', fontsize=10, weight='bold',
                annotation_clip=False)
    # main vertical bracket
    ax.plot([x_bracket, x_bracket], [start - 0.4, end + 0.4],
            color='black', lw=1.2, clip_on=False)
    # little horizontal ticks pointing toward the plot (to the right now)
    ax.plot([x_bracket, x_bracket + 0.1],
            [start - 0.4, start - 0.4], color='black', lw=1.2,
            clip_on=False)
    ax.plot([x_bracket, x_bracket + 0.1],
            [end + 0.4, end + 0.4], color='black', lw=1.2,
            clip_on=False)

# leave extra room on the left for the brackets and labels
plt.tight_layout(rect=[0.08, 0, 1, 0.88])   # was [0, 0, 1, 0.88]


plt.tight_layout(rect=[0, 0, 1, 0.88])   # make space for top colour-bar

# save & show
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'
# plt.savefig('../../results/drug_target_simulation/pathway_enrichment_analysis.pdf',
#             transparent=True, bbox_inches='tight', pad_inches=0.1)
plt.show()
