In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pandas as pd
from matplotlib.patches import Wedge
from collections import OrderedDict
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns

In [None]:
# Load the FCCs for biomass for both physiologies
mut_biomass = pd.read_csv('../../results/MCA_workflow/MUT_results/MUT_FCC_biomass.csv', index_col=0)
wt_biomass = pd.read_csv('../../results/MCA_workflow/WT_results/WT_FCC_biomass.csv', index_col=0)

In [None]:
# 28 enzymes
combined_top_enzymes = [
 'vmax_forward_TRIOK', 'vmax_forward_MI1PP', 'vmax_forward_PPAP', 'vmax_forward_r0301', 'vmax_forward_METAT',
 'vmax_forward_3DSPHR', 'vmax_forward_TMDS', 'vmax_forward_SERPT', 'vmax_forward_HMR_7748', 'vmax_forward_PSP_L',
 'vmax_forward_NADH2_u10mi', 'vmax_forward_DGK1', 'vmax_forward_NTD1', 'vmax_forward_r0354', 'vmax_forward_ADSS',
 'vmax_forward_r0178', 'vmax_forward_IMPD', 'vmax_forward_r0179', 'vmax_forward_PGI', 'vmax_forward_r0474',
 'vmax_forward_ICDHyrm', 'vmax_forward_GMPS2', 'vmax_forward_UMPK2', 'vmax_forward_ICDHxm', 'vmax_forward_URIK1',
 'vmax_forward_CYTK1', 'vmax_forward_HMR_4343', 'vmax_forward_r0426'
]

# Create a new DataFrame with the valyes for the combined top enzymes
combined_df = pd.DataFrame(index=combined_top_enzymes, dtype=float)
combined_df['MUT_avg'] = mut_biomass.loc[combined_top_enzymes].mean(axis=1)
combined_df['WT_avg'] = wt_biomass.loc[combined_top_enzymes].mean(axis=1)
combined_df['MUT_lq'] = mut_biomass.loc[combined_top_enzymes].quantile(0.25, axis=1)
combined_df['WT_lq'] = wt_biomass.loc[combined_top_enzymes].quantile(0.25, axis=1)
combined_df['MUT_uq'] = mut_biomass.loc[combined_top_enzymes].quantile(0.75, axis=1)
combined_df['WT_uq'] = wt_biomass.loc[combined_top_enzymes].quantile(0.75, axis=1)

# Sort by absolute value of WT
combined_df = combined_df.reindex(combined_df['WT_avg'].abs().sort_values(ascending=False).index)

# Remove the vmax_forward from the index
combined_df.index = [i[13:] for i in combined_df.index.tolist()]

# Add subsystem column
enzyme_category_mapping = [
    ("TRIOK", "Central Metabolism", "#FFFACD"),
    ("r0301", "Lipid Metabolism", "#FFB6C1"),
    ("3DSPHR", "Lipid Metabolism", "#FFB6C1"),
    ("TMDS", "Nucleotide Metabolism", "#AEC6CF"),
    ("SERPT", "Lipid Metabolism", "#FFB6C1"),
    ("PSP_L", "Amino Acid Metabolism", "#DABFFF"),
    ("HMR_7748", "Central Metabolism", "#FFFACD"),
    ("NTD1", "Nucleotide Metabolism", "#AEC6CF"),
    ("DGK1", "Nucleotide Metabolism", "#AEC6CF"),
    ("PPAP", "Lipid Metabolism", "#FFB6C1"),
    ("ADSS", "Nucleotide Metabolism", "#AEC6CF"),
    ("r0178", "Amino Acid Metabolism", "#DABFFF"),
    ("UMPK2", "Nucleotide Metabolism", "#AEC6CF"),
    ("ICDHyrm", "Central Metabolism", "#FFFACD"),
    ("IMPD", "Nucleotide Metabolism", "#AEC6CF"),
    ("r0179", "Amino Acid Metabolism", "#DABFFF"),
    ("URIK1", "Nucleotide Metabolism", "#AEC6CF"),
    ("GMPS2", "Nucleotide Metabolism", "#AEC6CF"),
    ("r0474", "Nucleotide Metabolism", "#AEC6CF"),
    ("CYTK1", "Nucleotide Metabolism", "#AEC6CF"),
    ("PGI", "Central Metabolism", "#FFFACD"),
    ("r0354", "Central Metabolism", "#FFFACD"),
    ("METAT", "Amino Acid Metabolism", "#DABFFF"),
    ("MI1PP", "Inositol Phosphate Metabolism", "#77DD77"),
    ("ICDHxm", "Central Metabolism", "#FFFACD"),
    ("CHOLK", "Lipid Metabolism", "#FFB6C1"),
    ("HMR_4343", "Nucleotide Metabolism", "#AEC6CF"),
    ("NADH2_u10mi", "Oxidative Phosphorylation", "#FFDAB9"),
    ("PFK", "Central Metabolism", "#FFFACD"),
    ("HPYRRy", "Amino Acid Metabolism", "#DABFFF"),
    ("MDH", "Central Metabolism", "#FFFACD"),
    ("r0426", "Central Metabolism", "#FFFACD"),
]

# Add the information to the DataFrame
for enzyme, category, color in enzyme_category_mapping:
    if enzyme in combined_df.index:
        combined_df.loc[enzyme, 'Subsystem'] = category
        combined_df.loc[enzyme, 'Color'] = color

# Change NADH2_u10mi to  MComplex I
combined_df.rename(index={'NADH2_u10mi': 'MComplex1'}, inplace=True)

# Sort by Subsystem, then by abs(WT_avg)
combined_df = combined_df.sort_values(
    by=["Subsystem", "MUT_avg"], 
    key=lambda col: col.abs() if col.name == "MUT_avg" else col,
    ascending=[False, False]
)

# Rearrange the rows so that Central metabolism is first, then Lipid metabolism, etc. and within each category, sort by absolute value of WT_avg
desired_order = ["Central Metabolism", "Lipid Metabolism", 
         "Amino Acid Metabolism", "Nucleotide Metabolism", 'Inositol Phosphate Metabolism', "Oxidative Phosphorylation"]
combined_df['Subsystem'] = pd.Categorical(combined_df['Subsystem'], categories=desired_order, ordered=True)
combined_df = combined_df.sort_values(by=["Subsystem", "WT_avg"], ascending=[True, False], key=lambda col: col.abs() if col.name == "WT_avg" else col)

combined_df

In [None]:
approved_count = {'TRIOK': [0, 0],
 'HMR_7748': [2, 10],
 'r0354': [4, 18],
 'PGI': [2, 10],
 'ICDHyrm': [2, 3],
 'ICDHxm': [5, 6],
 'r0426': [1, 6],
 'PPAP': [0, 0],
 'r0301': [3, 3],
 '3DSPHR': [0, 0],
 'SERPT': [1, 2],
 'METAT': [2, 4],
 'PSP_L': [4, 5],
 'r0178': [4, 5],
 'r0179': [4, 5],
 'TMDS': [11, 23],
 'NTD1': [1, 1],
 'DGK1': [2, 3],
 'ADSS': [2, 9],
 'IMPD': [8, 17],
 'r0474': [9, 16],
 'URIK1': [1, 8],
 'GMPS2': [3, 3],
 'HMR_4343': [0, 0],
 'CYTK1': [4, 8],
 'UMPK2': [4, 8],
 'MI1PP': [2, 5],
 'MComplex1': [10, 13],
 'Total': [60, 147]}

In [None]:
# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(26, 20))

# ==================== FIRST SUBPLOT (WT vs MUT) ====================
bar_width = 0.7
gap = 0.75  # gap size between groups

# Assuming combined_df is available - you'll need to load this data
# combined_df = pd.read_csv('your_combined_df_file.csv')  # Uncomment and adjust path

# Get group labels and create x positions
group_labels = combined_df['Subsystem'].values
prev_label = None
x_positions = []
current_x = 0

for i, label in enumerate(group_labels):
    if label != prev_label and prev_label is not None:
        current_x += gap  # insert a gap between groups
    x_positions.append(current_x)
    current_x += 1.5
    prev_label = label

x = np.array(x_positions)

# Add vertical separator lines between groups
for i in range(1, len(group_labels)):
    if group_labels[i] != group_labels[i - 1]:
        midpoint = (x[i] + x[i - 1]) / 2
        ax1.axvline(midpoint, color='grey', linestyle='--', linewidth=1, alpha=0.5)

# Bars
ax1.bar(x - bar_width/2, combined_df['WT_avg'], width=bar_width, color='skyblue', edgecolor='black', linewidth=1, label='WT')
ax1.bar(x + bar_width/2, combined_df['MUT_avg'], width=bar_width, color='salmon', edgecolor='black', linewidth=1, label='MUT')

# Error bars
for i, (wt, mut) in enumerate(zip(combined_df['WT_avg'], combined_df['MUT_avg'])):
    ax1.errorbar(x[i] - bar_width/2, wt,
                yerr=[[wt - combined_df['WT_lq'].iloc[i]], [combined_df['WT_uq'].iloc[i] - wt]],
                fmt='o', color='black', capsize=4, alpha=0.5)
    ax1.errorbar(x[i] + bar_width/2, mut,
                yerr=[[mut - combined_df['MUT_lq'].iloc[i]], [combined_df['MUT_uq'].iloc[i] - mut]],
                fmt='o', color='black', capsize=4, alpha=0.5)

# Set x-axis ticks and labels
ax1.set_xticks(x)
ax1.set_xticklabels(combined_df.index, rotation=90, fontsize=16)
ax1.set_xlim(x[0] - 1, x[-1] + 1)  # Set explicit x-limits

ax1.axhline(y=0, color='black', linestyle='--', alpha=0.5, lw=0.5)

# Remove the box around the plot
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

ax1.tick_params(axis='y', labelsize=16)

# Axis formatting
ax1.set_ylabel('Effect on growth rate', fontsize=20, labelpad=30)
ax1.set_ylabel('Effect on growth rate', fontsize=20, labelpad=30, rotation=90)
ax1.yaxis.set_label_coords(-0.035, 0.5)    # align vertically centered and move left

# WT and MUT legend
base_handles = [
    patches.Patch(color='skyblue', edgecolor='black', linewidth=1, label='WT'),
    patches.Patch(color='salmon', edgecolor='black', linewidth=1, label='MUT')
]

ax1.legend(base_handles, [h.get_label() for h in base_handles],
          fontsize=20, frameon=False, loc='lower left', ncol=1, handlelength=1.5)

# Add group brackets for first subplot
subcat_to_indices = OrderedDict()
for i, label in enumerate(group_labels):
    if label not in subcat_to_indices:
        subcat_to_indices[label] = [i, i]
    else:
        subcat_to_indices[label][1] = i


bar_positions = ax1.get_xticks()
x_axis_y = ax1.get_ylim()[0]

bar_positions = ax1.get_xticks()
# x_axis_y = ax1.get_ylim()[0] # This is no longer needed as we're moving to the top

for label, (start_i, end_i) in subcat_to_indices.items():
    start_x = bar_positions[start_i]
    end_x = bar_positions[end_i]


    x0_ax = ax1.transData.transform((start_x - 0.6, ax1.get_ylim()[1]))[0] # Use the top of the y-axis for transformation
    x1_ax = ax1.transData.transform((end_x + 0.6, ax1.get_ylim()[1]))[0] # Use the top of the y-axis for transformation

    # Define the y-coordinate for the brackets in axes coordinates (e.g., 1.05 for slightly above the top)
    # You might need to adjust these 'val' and 'y_bracket_start' values based on your plot's specifics
    y_bracket_level = 1.05 # This places the bracket slightly above the top of the plot in axes coordinates
    y_tick_length = 0.03 # Length of the vertical tick marks

    # Invert transform for x-coordinates, keeping y at the desired bracket level
    x0, _ = ax1.transAxes.inverted().transform((x0_ax, ax1.transAxes.transform((0, y_bracket_level))[1]))
    x1, _ = ax1.transAxes.inverted().transform((x1_ax, ax1.transAxes.transform((0, y_bracket_level))[1]))

    # Plot the horizontal line of the bracket
    ax1.plot([x0, x1], [y_bracket_level, y_bracket_level], transform=ax1.transAxes, color='black', clip_on=False)
    # Plot the left vertical tick
    ax1.plot([x0, x0], [y_bracket_level, y_bracket_level - y_tick_length], transform=ax1.transAxes, color='black', clip_on=False)
    # Plot the right vertical tick
    ax1.plot([x1, x1], [y_bracket_level, y_bracket_level - y_tick_length], transform=ax1.transAxes, color='black', clip_on=False)

    # Place the label
    # ax1.text((x0 + x1) / 2, y_bracket_level + 0.02, label, # Adjust y-coordinate for the label
    #         transform=ax1.transAxes, ha='left', va='bottom', fontsize=16, rotation=45, clip_on=False) # Changed ha to 'center' and va to 'bottom'


# ==================== SECOND SUBPLOT (Drug data with pie charts) ====================

# Load data for second plot
pivot_df = pd.read_csv('../../results/MCA_workflow/mapping_matrix_simple.csv')
group_labels_df = pd.read_csv('../../results/MCA_workflow/mapping_matrix_group_labels.csv')

# Replace NADH2_u10m with MComplex1 in Enzyme column
pivot_df['Enzyme'] = pivot_df['Enzyme'].replace({'NADH2_u10mi': 'MComplex1'})

# Replace NADH2_u10m with 'MComplex1' in group labels
group_labels_df['Reaction'] = group_labels_df['Reaction'].replace({'NADH2_u10mi': 'MComplex1'})

# Handle NaN values
pivot_df = pivot_df.fillna(0)

# Get the original order of enzymes
original_enzyme_order = pivot_df['Enzyme'].tolist()

# Melt and merge data
atc_categories = pivot_df.columns.drop('Enzyme').tolist()
melted_df = pivot_df.melt(id_vars=['Enzyme'], value_vars=atc_categories,
                          var_name='ATC Category', value_name='Drug Count')

group_labels_df = group_labels_df.rename(columns={'Reaction': 'Enzyme'})
merged_df = pd.merge(melted_df, group_labels_df, on='Enzyme', how='left')

merged_df['Enzyme'] = pd.Categorical(merged_df['Enzyme'], categories=original_enzyme_order, ordered=True)

plot_df = merged_df.pivot_table(index='Enzyme', columns='ATC Category', values='Drug Count', aggfunc='sum')
plot_df = plot_df.loc[original_enzyme_order]

# Use the same x positions as first plot (assuming same enzyme order)
enzyme_subsystems = {row['Enzyme']: row['Subsystem'] for idx, row in merged_df[['Enzyme', 'Subsystem']].drop_duplicates().iterrows()}
group_labels_ordered = [enzyme_subsystems[enzyme] for enzyme in original_enzyme_order]

# Define colors for the stacked bars
# colors = sns.color_palette("pastel")[:len(atc_categories)-1]   # 10 pastel categorical colors
colors = plt.cm.tab10.colors[:len(atc_categories)-1]
# Replace the last color with grey
colors = list(colors) + ['#D3D3D3']  # RGB tuple for grey

# Plot the stacked bars using the same x positions
bottom_values = np.zeros(len(x))
for atc_cat_col, color in zip(atc_categories, colors):
    current_values = plot_df[atc_cat_col].values
    ax2.bar(x, current_values, width=bar_width, bottom=bottom_values, label=atc_cat_col, color=color, edgecolor='black', linewidth=1)
    bottom_values += current_values

# Set y-axis label
ax2.set_ylabel('Number of Drugs', fontsize=20, labelpad=30)
ax2.set_ylabel('Number of Drugs', fontsize=20, labelpad=30, rotation=90)
ax2.yaxis.set_label_coords(-0.035, 0.5)    # same alignment for second plot

# Force y-axis to go up to 25
ax2.set_ylim(0, 25)

# Set x-axis ticks and labels
ax2.set_xticks(x)
ax2.set_xticklabels(original_enzyme_order, rotation=90, fontsize=16, ha='center')
ax2.set_xlim(x[0] - 1, x[-1] + 1)  # Set same explicit x-limits as first plot

# Set y-ticks fontsize
ax2.tick_params(axis='y', labelsize=16)

# Remove top and right spines
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

# Move the legend box inside the actual plot and remove the frame
legend = ax2.legend(title='ATC Category', loc='upper left', bbox_to_anchor=(0.235, 1), fontsize=14)
legend.set_frame_on(False)
legend.get_title().set_fontsize(18)
legend._legend_title_box._text.set_ha('right')

# Add vertical separator lines between groups
for i in range(1, len(group_labels_ordered)):
    if group_labels_ordered[i] != group_labels_ordered[i - 1]:
        midpoint = (x[i] + x[i - 1]) / 2
        ax2.axvline(midpoint, color='grey', linestyle='--', linewidth=1, alpha=0.5)

# Add pie charts (assuming approved_count is available)
# You'll need to define approved_count based on your data
# approved_count = {...}  # Uncomment and define this

pie_colors = ["#2b3591", "#ea5e57"] 
explode = (0.05, 0)

for i, enzyme in enumerate(original_enzyme_order):
    # Skip if approved_count is not defined
    if 'approved_count' not in locals():
        continue
        
    pie_percentages = [approved_count[enzyme][0], approved_count[enzyme][1] - approved_count[enzyme][0]]
    total_height = plot_df.loc[enzyme].sum()
    if total_height == 0:
        continue
    center_x = x[i]

    y_pie_center = total_height + 0.08 * ax2.get_ylim()[1]
    pie_radius_data = bar_width / 2.0

    x_data_span = ax2.get_xlim()[1] - ax2.get_xlim()[0]
    x_axes_span = 1.0
    data_to_axes_x_ratio = x_axes_span / x_data_span

    inset_width_axes_fraction = (2 * pie_radius_data) * data_to_axes_x_ratio
    inset_height_axes_fraction = inset_width_axes_fraction

    trans_point = ax2.transData.transform((center_x, y_pie_center))
    center_x_axes, center_y_axes = ax2.transAxes.inverted().transform(trans_point)

    axins = inset_axes(ax2,
                       width=inset_width_axes_fraction * 30,
                       height=inset_height_axes_fraction * 30,
                       loc='center',
                       bbox_to_anchor=(center_x, y_pie_center),
                       bbox_transform=ax2.transData,
                       borderpad=0)

    axins.pie(pie_percentages,
              colors=pie_colors,
              radius=1.6,
              startangle=90,
              counterclock=False,
              autopct=lambda p: f'{int(round(p * sum(pie_percentages) / 100))}' if p * sum(pie_percentages) / 100 > 0 else '',
              wedgeprops={'edgecolor': 'black', 'linewidth': 0.5},
              textprops={'fontsize': 16, 'color': '#F5F5F5'})
    
    axins.set_xticks([])
    axins.set_yticks([])
    axins.set_frame_on(False)

# Add group brackets for second subplot (moved to top)
subcat_to_indices_2 = OrderedDict()
for i, label in enumerate(group_labels_ordered):
    if label not in subcat_to_indices_2:
        subcat_to_indices_2[label] = [i, i]
    else:
        subcat_to_indices_2[label][1] = i

bar_positions_2 = ax2.get_xticks()
y_top = ax2.get_ylim()[1]  # Get top of plot instead of bottom

# Add a summary pie chart to the top right of the second subplot
pie_ax = inset_axes(ax2, width="15%", height="15%", loc='upper right', 
                    bbox_to_anchor=(0.75, 0.75, 0.25, 0.25) , bbox_transform=ax2.transAxes)

# Pie chart data
pie_data = [approved_count['Total'][0], approved_count['Total'][1]-approved_count['Total'][0]]  # Blue: 47, Red: 78
pie_colors_summary = ["#2b3591", "#ea5e57"]  # Blue and Red
pie_labels = ['Approved', 'Total']

# Create the pie chart
wedges, texts, autotexts = pie_ax.pie(pie_data, 
                                      radius = 8,
                                      colors=pie_colors_summary,
                                      startangle=90,
                                      autopct=lambda p: f'{int(round(p * sum(pie_data) / 100))}' if p * sum(pie_data) / 100 > 0 else '',
                                      textprops={'fontsize': 18, 'color': 'white', 'weight': 'bold'},
                                      wedgeprops={'edgecolor': 'black', 'linewidth': 1})
                                      

# Remove axes ticks and frame
pie_ax.set_xticks([])
pie_ax.set_yticks([])
pie_ax.set_frame_on(False)

# Adjust layout
plt.tight_layout()
plt.subplots_adjust(bottom=0.2, hspace=0.6)

# Save figure
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'
plt.savefig('../../results/MCA_workflow/combined_figures_pastel.pdf',
            dpi=300,
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)

plt.show()