In [None]:
from pytfa.io.json import load_json_model
from skimpy.io.yaml import load_yaml_model
from skimpy.analysis.oracle.load_pytfa_solution import load_concentrations
from skimpy.core.solution import ODESolutionPopulation
from skimpy.core.parameters import ParameterValuePopulation, \
    load_parameter_population
from skimpy.utils.namespace import QSSA
from skimpy.utils.tabdict import TabDict
from scikits.odes import ode
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimpy.core.solution import ODESolution
from tqdm.auto import tqdm
from skimpy.analysis.oracle.load_pytfa_solution import load_fluxes, \
    load_concentrations
import seaborn as sns
import multiprocessing as mp
import sys
sys.path.append("..") # Adds higher directory to python modules path.
from utils.make_flux_fun_parallel import make_flux_fun_parallel
from utils.enzyme_degradation_class import make_enzymedegradation
from utils.remove_outliers import remove_outliers_parallel
from skimpy.core.reactions import Reaction
import time
import numpy as np
import sys
import os
import configparser
from skimpy.utils.general import sanitize_cobra_vars
from utils.remove_outliers import remove_outliers_row
from utils.drug_ode_simulation import simulate_sample, ODESolution_prior, ODESolution_post, FluxSolution, produce_flux_df
from scipy.stats import ttest_1samp
from statsmodels.stats.multitest import multipletests

sys.path.append('../')

TIME = np.linspace(0, 600, 1200) # 20-30 times the doubling time of the cell
PHYSIOLOGY = 'MUT'
# TARGETS = ['HEX1', 'r0354', 'r0355']
TARGETS = ['TMDS']
TARGET_NAME = 'TMDS'

config = configparser.ConfigParser()
config_path = '../src/config.ini'
config.read(config_path)

base_dir = config['paths']['base_dir']

NCPU = int(config['global']['ncpu'])

# Scaling parameters from config.ini
CONCENTRATION_SCALING = float(config['scaling']['CONCENTRATION_SCALING'])
TIME_SCALING = float(config['scaling']['TIME_SCALING'])
DENSITY = float(config['scaling']['DENSITY'])
GDW_GWW_RATIO = float(config['scaling']['GDW_GWW_RATIO'])
flux_scaling_factor = 1e-3 * (GDW_GWW_RATIO * DENSITY) * CONCENTRATION_SCALING / TIME_SCALING

# Paths from config.ini using PHYSIOLOGY variable
path_to_kmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_kmodel_{PHYSIOLOGY}']))
path_to_tmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_tmodel_{PHYSIOLOGY}']))
path_to_samples = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_samples_{PHYSIOLOGY}']))
path_to_params = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_param_output_{PHYSIOLOGY}']))
path_to_max_eig = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_lambda_values_{PHYSIOLOGY}']))

path_to_stratified_samples = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_stratified_samples_{PHYSIOLOGY}']))
path_to_stratified_params = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_stratified_params_{PHYSIOLOGY}']))
path_to_ode_conc_solutions = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_ode_conc_solutions_{PHYSIOLOGY}']))
path_to_ode_flux_solutions = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_ode_flux_solutions_{PHYSIOLOGY}']))
path_to_distances_metabolites = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_distances_metabolites_{PHYSIOLOGY}']))
path_to_distances_reactions = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_distances_reactions_{PHYSIOLOGY}']))

path_to_conc_fold_changes = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_conc_fold_changes_{PHYSIOLOGY}']))
path_to_metabolite_enrichment_analysis = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_metabolite_enrichment_analysis_{PHYSIOLOGY}']))
path_to_flux_fold_changes = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_flux_fold_changes_{PHYSIOLOGY}']))

In [None]:
cd kinetic_cancer_final/scripts/notebooks/

# Load all the required data

In [None]:
tmodel = load_json_model(path_to_tmodel)

# Load the samples
samples = pd.read_csv(path_to_samples, index_col=0, header=0)

# Load the kinetic model, prepare and compile ODEs
kmodel = load_yaml_model(path_to_kmodel)

samples_picked = pd.read_csv(path_to_stratified_samples, index_col=0)
parameter_population = load_parameter_population(path_to_stratified_params)
samples_to_simulate = list(parameter_population._index.keys())

In [None]:
final_res = pd.read_parquet(f'../../results/drug_target_simulation/MUT_stratified/{TARGET_NAME}/solutions_{TARGET_NAME}_600_hrs_combined.parquet')

In [None]:
print('Loading ODE concentration solutions for post-processing...')
# final_res = pd.read_csv(path_to_ode_conc_solutions.format(TARGET_NAME), index_col=0)

# In case the solutions need to be loaded again
solutions = []
if len(samples_to_simulate) != final_res.solution_id.nunique():
    raise ValueError('Number of samples to simulate does not match number of unique solution IDs in final_res.')
for ix, sol_id in zip(samples_to_simulate, range(final_res.solution_id.max()+1)):
    ids = np.where(final_res.solution_id == sol_id)[0]
    sol = ODESolution_post(final_res.iloc[ids,2:], final_res.iloc[ids,1], ix)
    solutions.append(sol)

solutions_raw = solutions
solutions = []
for sol, ix in zip(solutions_raw, samples_to_simulate):
    if sol.time[-1] == TIME[-1]:
        solutions.append(sol)
    else:
        print('Solution {} did not converge in time ({})'.format(ix, sol.time[-1]))

In [None]:
total_flux_df = pd.read_parquet(f'../../results/drug_target_simulation/MUT_stratified/{TARGET_NAME}/fluxes_{TARGET_NAME}_600_hrs_combined.parquet')

In [None]:
total_flux_df_cols = pd.read_csv(path_to_ode_flux_solutions.format(TARGET_NAME), index_col=0, nrows=0).columns

dtype_dict = {col: 'float32' for col in total_flux_df_cols}
dtype_dict['model_ix'] = 'string'

# Get total number of lines in the file (minus header)
filename = path_to_ode_flux_solutions.format(TARGET_NAME)
total_lines = sum(1 for _ in open(filename)) - 1 # subtract header
chunksize = 1000

# Create the chunk iterator
chunk_iter = pd.read_csv(filename, chunksize=chunksize, index_col=0, dtype=dtype_dict)

# Accumulate chunks with tqdm progress bar
df_chunks = []
for chunk in tqdm(chunk_iter, total=total_lines // chunksize + 1, desc="Loading CSV with flux solutions"):
    df_chunks.append(chunk)

# Combine all chunks into a single DataFrame
df = pd.concat(df_chunks, ignore_index=True)

total_flux_df = df

In [None]:
# Load the flux solutions
flux_solutions = []
for ix in tqdm(total_flux_df.model_ix.unique()):
    sol = FluxSolution(total_flux_df[total_flux_df.model_ix == ix].drop(['model_ix','time'], axis=1).reset_index(drop=True), total_flux_df[total_flux_df.model_ix == ix].time, ix)
    flux_solutions.append(sol)


malignant_models = []
for flux in flux_solutions:
    if flux.fluxes.biomass.iloc[-1]/flux.fluxes.biomass[0]> 1.001:
        malignant_models.append(flux.model_ix)

# Show the trajectories of growth rates (3B)

In [None]:
# Mapping dictionaries
glycolysis_map = {'0': 'low glycolysis', '1': 'average glycolysis', '2': 'high glycolysis'}
oxphos_map = {'0': 'low oxphos', '1': 'average oxphos', '2': 'high oxphos'}

# Function to transform groups
def transform_group(value):
    x, y = value.split('_')
    return f"{glycolysis_map[x]}, {oxphos_map[y]}"

# Apply transformation
samples_picked['group_levels'] = samples_picked['group'].apply(transform_group)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def plot_color_array(hex_colors):
    # Convert hex color codes to RGB
    color_codes_rgb = np.array([[mcolors.hex2color(color) for color in row] for row in hex_colors])

    # Plot the colors
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(color_codes_rgb, aspect='auto')

    # Remove axes
    ax.axis('off')

    # Display the plot
    plt.show()

# Example usage with a 3x3 numpy array of hex codes
colors =np.array([
    ["#5FA8A3", "#BC7C8F", "#AE3A4E"],
    ["#89A1C8", "#806A8A", "#77324C"],
    ["#4885C1", "#435786", "#3F2949"]
])

# Call the function with the provided hex color array
plot_color_array(colors)

# Assign bivariate color mapping 

samples_picked['color'] = samples_picked['group_levels'].map({
    'low glycolysis, low oxphos': '#4885C1',
    'low glycolysis, average oxphos': '#89A1C8',
    'low glycolysis, high oxphos': '#5FA8A3',
    'average glycolysis, low oxphos': '#435786',
    'average glycolysis, average oxphos': '#806A8A',
    'average glycolysis, high oxphos': '#BC7C8F',
    'high glycolysis, low oxphos': '#3F2949',
    'high glycolysis, average oxphos': '#77324C',
    'high glycolysis, high oxphos': '#AE3A4E'
})

In [None]:
def plot_trajectories(total_df, samples_picked, time_indices, t_span, save_path=None):

    # Normalize the data with respect to the first point
    normalized_df = total_df.div(total_df.iloc[0])

    # Initialize a figure
    plt.figure(figsize=(12, 8))
    ax = plt.gca()  # NEW: grab axis
    legend_handles = []

    # Subset the time points based on t_span
    time_indices = time_indices <= t_span

    # Predefine the marker timepoints (hours)
    marker_hours = np.array([0, 1, 2, 4, 8, 12, 20, 29.9], dtype=float)

    # Plot the trajectories for each cluster
    for i, group_id in enumerate(sorted(samples_picked.group.unique())):
        if group_id not in ['0_0', '0_2', '2_0', '2_2']:
            continue

        color = samples_picked.loc[samples_picked.group == group_id].color.iloc[0]

        steady_states = samples_picked.loc[samples_picked.group == group_id].index
        models = [flux_sol.model_ix for flux_sol in flux_solutions
                  if int(flux_sol.model_ix.split(',')[0]) in steady_states]

        if len(models) == 0:
            print('No models for group {}'.format(group_id))
            continue

        # Extract the data for the current cluster
        cluster_data = normalized_df.loc[:, models]

        # Calculate aggregates
        average_trajectory = cluster_data.mean(axis=1)
        std_deviation = cluster_data.std(axis=1)
        percentile_25 = cluster_data.quantile(0.25, axis=1)
        percentile_75 = cluster_data.quantile(0.75, axis=1)

        # Find the trajectory closest to the average
        distances = cluster_data.apply(lambda col: np.linalg.norm(col - average_trajectory), axis=0)
        closest_trajectory = cluster_data.loc[:, distances.idxmin()]

        # Time and data restricted to the visible span
        t_vis = flux_solutions[0].time[time_indices]
        y_vis = closest_trajectory[time_indices]

        # Plot the closest trajectory
        group_name = samples_picked.loc[samples_picked.group == group_id].group_levels.iloc[0]
        plt.plot(t_vis, y_vis, color=color, label=f'Cluster {group_name}', linewidth=3)

        # Plot the error bounds (25 to 75 percentile)
        plt.fill_between(t_vis,
                         percentile_25[time_indices],
                         percentile_75[time_indices],
                         color=color, alpha=0.08)

        # NEW: add markers only on the trajectory at specified hours (within range and span)
        mh = marker_hours[(marker_hours >= t_vis.min()) & (marker_hours <= min(t_span, t_vis.max()))]
        if mh.size > 0:
            # find nearest indices in t_vis for each marker hour
            t_vis_np = np.asarray(t_vis, dtype=float)
            idxs = np.abs(t_vis_np[:, None] - mh[None, :]).argmin(axis=0)
            plt.scatter(t_vis_np[idxs],
                        np.asarray(y_vis)[idxs],
                        s=40, marker='o', edgecolors=color, facecolors='white', linewidths=2, zorder=3)

        # Add legend handles
        legend_handles.append(plt.Line2D([0], [0], color=color, lw=2, label=f'Cluster {group_name}'))

    # Titles, labels
    plt.title('Growth rate', fontsize=30)
    plt.xlabel('Time (hours)', fontsize=25)
    plt.ylabel(r'Normalized growth', fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlim(0, t_span+0.1)
    plt.tight_layout()

    # NEW: add discrete dashed line at t=8 h (drug effect finishes)
    ax.axvline(x=8, color='gray', linestyle='--', linewidth=1, alpha=0.8, zorder=0)


    # NEW: remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Save the plot if save_path is provided
    plt.rcParams['pdf.fonttype'] = 42  # For PDF
    plt.rcParams['svg.fonttype'] = 'none'  # For SVG

    if save_path:
        plt.savefig(save_path, transparent=True, bbox_inches='tight', pad_inches=0.1)
    plt.show()

# Define your t_span (for example, 50 hours)
t_span = 30.1

values = {}

for flux_sol in flux_solutions:
    try:
        values[flux_sol.model_ix] = flux_sol.fluxes['biomass']
    except KeyError:
        pass
total_df = pd.DataFrame(values)


reaction = 'Growth rate'

# Call the function with the defined t_span
# plot_trajectories(total_df,samples_picked , flux_solutions[0].time, t_span, save_path=None)
plot_trajectories(total_df,samples_picked , flux_solutions[0].time, t_span, save_path=f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/biomass_growth_rate_trajectories.pdf')


# Show the final growth rate distribution (3G)

In [None]:
stunned_models = []
for flux in flux_solutions:
    if flux.fluxes.biomass.iloc[-1]/flux.fluxes.biomass[0] < 1.001:
        stunned_models.append(flux.fluxes.biomass.iloc[-1]/flux.fluxes.biomass[0])
        
plt.figure(figsize=(4, 3))
sns.histplot(stunned_models, color='grey', kde=True)

ax = plt.gca()

kde_line = ax.lines[0]

# Set the color of the KDE line to black
kde_line.set_color('black')
plt.xlabel('Final normalized growth rate', fontsize=18)
plt.ylabel('Number of models', fontsize=18)
plt.yticks(fontsize=15)
plt.xticks(fontsize=15)

# Have x axis ticks only at 0.0 0.5 1.0
plt.xticks([0.0, 0.5, 1.0], fontsize=15)


plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/final_growth_rate_distribution.pdf', 
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)

plt.show()

# Show the percentage change in concentration of the biomass precursors (3H)

In [None]:
# Find biomass building blocks that are also in the kinetic model
bbb_ids = []
for met in tmodel.reactions.biomass.reactants:
    # check if met_id starts with a number and add an underscore

    met_id = sanitize_cobra_vars(met.id)
    if met_id in solutions[0].concentrations.columns:
        bbb_ids.append(met_id)


# Make a dataframe with the changes in the biomass building blocks
bbb_changes = pd.DataFrame(index=bbb_ids, columns=[str(sol.model_ix) for sol in solutions])

for sol in solutions:
    for met in bbb_ids:
        # Find the % change in the metabolite level
        change = (sol.concentrations[met].iloc[-1] - sol.concentrations[met].iloc[0]) / sol.concentrations[met].iloc[0] * 100
        bbb_changes.loc[met, str(sol.model_ix)] = change

# Remove outliers
from utils.remove_outliers import remove_outliers_parallel
bbb_changes = remove_outliers_parallel(bbb_changes, n_jobs=100)

In [None]:
# Make a plot with the top xx metabolites by negative change
top_numbers = 5
top_bbbs = bbb_changes.mean(axis=1).sort_values(ascending=True)

# Make a list of the top xx metabolites names
top_names = []
for i in top_bbbs.index[:top_numbers]:
    met = tmodel.metabolites.get_by_id(i)
    top_names.append(met.id[:-2])

plt.figure(figsize=(4, 3))
sns.barplot(x=top_names, y=top_bbbs.values[:top_numbers], color='grey')

# add error bars with 25% and 75% percentiles
# Calculate 25th and 75th percentiles for the top metabolites
top_metabolites_data = bbb_changes.loc[top_bbbs.index[:top_numbers]]
percentile_25 = top_metabolites_data.quantile(0.25, axis=1)
percentile_75 = top_metabolites_data.quantile(0.75, axis=1)

# Calculate error bar values (distance from mean to percentiles)
lower_error = top_bbbs.values[:top_numbers] - percentile_25.values
upper_error = percentile_75.values - top_bbbs.values[:top_numbers]

# Add error bars with 25% and 75% percentiles
plt.errorbar(x=range(len(top_names)), y=top_bbbs.values[:top_numbers], 
             yerr=[lower_error, upper_error], 
             fmt='o', color='black', capsize=5)


plt.xlabel('Biomass precursor', fontsize=18)
plt.ylabel('% Change', fontsize=18)
# plt.title('Most consumed BBBs', fontsize=20)
plt.xticks(rotation=45, fontsize=15)
plt.yticks(fontsize=15)
plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/deregulated_BBBs.pdf',             
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)
plt.show()

# Make volcano plots for metabolites

In [None]:
path_to_distances_metabolites = f'../../results/drug_target_simulation/MUT_stratified/{TARGET_NAME}/MUT_distances.csv'

In [None]:
# Calculate the distance of each metabolite from the enzyme that was targeted
distances = pd.read_csv(path_to_distances_metabolites.format(TARGET_NAME), index_col=0)

# For the metabolites in the reaction of the targeted enzyme, set the distance to 0
for met_id in distances.index:
    met = tmodel.metabolites.get_by_id(met_id)
    for enz_name in TARGETS:
        if met in tmodel.reactions.get_by_id(enz_name).metabolites:
            distances.loc[met_id] = 0

# If distances are more than 4 then set the value to 5
distances[distances > 4] = 5

# Change the index to be the same as in the kmodel
distances.index = [sanitize_cobra_vars(i) for i in distances.index]

# Keep only the metabolites that are in the kmodel
for met in distances.index:
    if met not in kmodel.reactants:
        distances.drop(met, inplace=True)

# Sort the index of the distances to be the same order as the kmodel.reactants
distances = distances.reindex(kmodel.reactants)
distances.dropna(inplace=True) # Remove the targeted enzyme reactant

In [None]:
# Calculate the log2-fold changes for each metabolite in each solution
log2_fold_changes = pd.DataFrame(index=solutions[0].concentrations.columns)
for i, sol in enumerate(solutions):
    if sol.model_ix in malignant_models:
        continue
    final_conc = sol.concentrations.iloc[-1,:]
    # If any value is smaller than 1e-15 we set it to 1e-15
    final_conc[final_conc < 1e-15] = 1e-15
    log2_fold_changes[sol.model_ix] = np.log2(final_conc/sol.concentrations.iloc[0,:])


# Consider removing extreme outliers
# log2_fold_changes = remove_outliers_parallel(log2_fold_changes, multiplier=2.0, n_jobs=100)

# Drop the enzyme column
enzyme_names = [i for i in log2_fold_changes.index if i.startswith('E_')]
log2_fold_changes = log2_fold_changes.drop(index=enzyme_names)

# Calculate the mean and standard deviation of the log2-fold changes
log2_fold_changes['mean'] = log2_fold_changes.mean(axis=1)
log2_fold_changes['std'] = log2_fold_changes.std(axis=1)

# Calculate the p value of the log2-fold changes 
# We want to reject the hypothesis that the mean log2-fold change is 0 (fold change is 1)
log2_fold_changes['p_value'] = ttest_1samp(log2_fold_changes.iloc[:,:-2], 0, axis=1, nan_policy='omit')[1]

# Calculate the q value of the log2-fold changes
# We need to correct for multiple testing
log2_fold_changes['q_value'] = multipletests(log2_fold_changes['p_value'], method='fdr_bh')[1]

# This makes sure that the metabolites are sorted by distance
log2_fold_changes['distance'] = distances
log2_fold_changes = log2_fold_changes.sort_values(by='distance', ascending=False)

In [None]:
# Define custom colors for the categories
from matplotlib.colors import ListedColormap
colors = ["#991f17", "#b04238", "#c86558", "#df8879", "#a4a2a8", '#b3bfd1']

cmap = ListedColormap(colors)


# Find the smallest positive q_value
min_positive_q_value = log2_fold_changes[log2_fold_changes['q_value'] > 0]['q_value'].min()

# Replace zero or negative q_values with the smallest positive q_value
qzero_ids = np.where(log2_fold_changes['q_value'] <= 0)[0]
print('These concentrations had a q value of 0: {}'.format(log2_fold_changes.index[qzero_ids]))
log2_fold_changes['q_value'] = log2_fold_changes['q_value'].apply(lambda x: min_positive_q_value if x <= 0 else x)

# Make a volcano plot
fig, ax = plt.subplots(figsize=(12, 9))

# Scatter plot
sc = ax.scatter(log2_fold_changes['mean'], -np.log10(log2_fold_changes['q_value']), 
                c=[distances.loc[met][0] for met in log2_fold_changes.index], cmap=cmap, edgecolor='k', alpha=0.8, s=80)

# Add horizontal line for p-value threshold
q_value_threshold = 0.01
ax.axhline(y=-np.log10(q_value_threshold), color='grey', linestyle='--', linewidth=1, alpha=0.4)

# Set a threshold for the log2-fold change
fold_change_threshold = 1
ax.vlines(x=fold_change_threshold, ymin=0, ymax=-np.log10(min_positive_q_value)+5, color='grey', linestyle='--', linewidth=1, alpha=0.8)
ax.vlines(x=-fold_change_threshold, ymin=0, ymax=-np.log10(min_positive_q_value)+5, color='grey', linestyle='--', linewidth=1, alpha=0.8)

# Add labels for significant points
significant = log2_fold_changes[(log2_fold_changes['q_value'] < q_value_threshold) & 
                                (abs(log2_fold_changes['mean']) > fold_change_threshold)]

# for i, row in significant.iterrows():
#     ax.text(row['mean']-0.1, -np.log10(row['q_value']), i, fontsize=8, ha='right')

# Color bar
# Set ticks at the center of each color segment
cbar = plt.colorbar(sc, ax=ax, ticks=np.arange(6))
cbar.set_ticklabels(['0', '1', '2', '3', '4', '>4'])
cbar.set_label('Distance from enzymatic target', fontsize=16)
cbar.ax.tick_params(labelsize=14)


# Labels and title
ax.set_xlabel('mean concentration log2-fold change', fontsize=16)
ax.set_ylabel('-log10(q-value)', fontsize=16, labelpad=-60)
# ax.set_title('Volcano Plot of Metabolite Changes', fontsize=16)

# Add x-y that start at 0
ax.spines['left'].set_position('zero')
ax.spines['left'].set_bounds(0, ax.get_ylim()[1])
ax.spines['bottom'].set_position('zero')

# Remove 0 from y-axis ticks
ax.spines['left'].set_bounds(0, -np.log10(min_positive_q_value)+5)
yticks = ax.get_yticks()
ax.set_yticks([tick for tick in yticks if tick > 0 and tick < -np.log10(min_positive_q_value)+5])

# Remove surrounding box
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

# Increase the fontsize of the ticks
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)

# --- Robust labels: NW for positive, SE for negative ---

xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
xspan = xmax - xmin
yspan = ymax - ymin

# Top 5 pos / neg
top_pos = log2_fold_changes[log2_fold_changes['mean'] > 0]['mean'].nlargest(5).index
top_neg = log2_fold_changes[log2_fold_changes['mean'] < 0]['mean'].nsmallest(5).index

neglog10q = -np.log10(log2_fold_changes.loc[top_pos.union(top_neg), 'q_value'])

def adjust_vertical_positions(points, direction="up"):
    """Resolve vertical collisions with min spacing, direction can be 'up' or 'down'."""
    min_sep = 0.035 * yspan
    margin = 0.01 * yspan
    if direction == "up":
        points.sort(key=lambda d: d["ty"])
        for i in range(1, len(points)):
            if points[i]["ty"] < points[i-1]["ty"] + min_sep:
                points[i]["ty"] = points[i-1]["ty"] + min_sep
        max_ty = ymax - margin
        overflow = points[-1]["ty"] - max_ty
        if overflow > 0:
            for i in reversed(range(len(points))):
                shift = min(overflow, points[i]["ty"] - (ymin + margin + i * min_sep))
                points[i]["ty"] -= shift
                overflow -= shift
                if overflow <= 0:
                    break
    else:  # direction == "down"
        points.sort(key=lambda d: d["ty"], reverse=True)
        for i in range(1, len(points)):
            if points[i]["ty"] > points[i-1]["ty"] - min_sep:
                points[i]["ty"] = points[i-1]["ty"] - min_sep
        min_ty = ymin + margin
        underflow = min_ty - points[-1]["ty"]
        if underflow > 0:
            for i in range(len(points)):
                shift = min(underflow, (ymax - margin - (len(points)-1-i)*min_sep) - points[i]["ty"])
                points[i]["ty"] += shift
                underflow -= shift
                if underflow <= 0:
                    break
    return points

# Build label proposals
labels_pos, labels_neg = [], []

for met in top_pos:
    x = log2_fold_changes.loc[met, 'mean']
    y = neglog10q.loc[met]
    tx = x - 0.03 * xspan       # left of point
    ty = y + 0.02 * yspan       # above point
    labels_pos.append({"met": met, "x": x, "y": y, "tx": tx, "ty": ty})

for met in top_neg:
    x = log2_fold_changes.loc[met, 'mean']
    y = neglog10q.loc[met]
    tx = x + 0.03 * xspan       # right of point
    ty = y - 0.02 * yspan       # below point
    labels_neg.append({"met": met, "x": x, "y": y, "tx": tx, "ty": ty})

# Resolve overlaps separately
labels_pos = adjust_vertical_positions(labels_pos, direction="up")
labels_neg = adjust_vertical_positions(labels_neg, direction="down")

# Clip within axes
for group in [labels_pos, labels_neg]:
    for d in group:
        d["tx"] = np.clip(d["tx"], xmin + 0.01 * xspan, xmax - 0.01 * xspan)
        d["ty"] = np.clip(d["ty"], ymin + 0.01 * yspan, ymax - 0.01 * yspan)

# Annotate
for d in labels_pos:
    ax.annotate(
        d["met"],
        xy=(d["x"], d["y"]),
        xytext=(d["tx"], d["ty"]),
        ha='right', va='bottom',  # NW
        fontsize=10,
        bbox=dict(boxstyle='round,pad=0.2', fc='white', ec='none', alpha=0.8),
        arrowprops=dict(arrowstyle='-', lw=0.8, alpha=0.6)
    )

for d in labels_neg:
    ax.annotate(
        d["met"],
        xy=(d["x"], d["y"]),
        xytext=(d["tx"], d["ty"]),
        ha='left', va='top',  # SE
        fontsize=10,
        bbox=dict(boxstyle='round,pad=0.2', fc='white', ec='none', alpha=0.8),
        arrowprops=dict(arrowstyle='-', lw=0.8, alpha=0.6)
    )
# --- end labels ---



plt.tight_layout()
# Save the plot
plt.rcParams['pdf.fonttype'] = 42  # For PDF
plt.rcParams['svg.fonttype'] = 'none'  # For SVG
plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/volcano_plot_metabolites.pdf',
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)

# plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/volcano_plot_metabolites.svg', format='svg', bbox_inches='tight', transparent=True)


# Show plot
plt.show()


# Print the significant metabolites
significant.sort_values('mean', ascending=False).iloc[:,-5:]

# Make volcano plots for fluxes

In [None]:
path_to_distances_reactions = f'../../results/drug_target_simulation/MUT_stratified/{TARGET_NAME}/MUT_distances_fluxes.csv'

In [None]:
# Distance from the target enzyme
distances_reactions = pd.read_csv(path_to_distances_reactions.format(TARGET_NAME), index_col=0)

for enz_name in TARGETS:
    distances_reactions.loc[enz_name] = 0

# If distances are more than 4 then set the value to 4
distances_reactions[distances_reactions > 5] = 5

# Keep only the metabolites that are in the kmodel
for met in distances_reactions.index:
    if met not in kmodel.reactions:
        distances_reactions.drop(met, inplace=True)

# Sort the index of the distances to be the same order as the kmodel.reactants
distances_reactions = distances_reactions.reindex(kmodel.reactions)
distances_reactions.dropna(inplace=True) # Remove the targeted enzyme reactant

In [None]:
# Calculate the log2-fold changes for each metabolite in each solution
log10_flux_fold_changes = pd.DataFrame(index=flux_solutions[0].fluxes.columns[0:-3])

for i, sol in enumerate(flux_solutions):
    if sol.model_ix in malignant_models:
        continue
    final_flux = sol.fluxes.iloc[-1,:-3]
    # If any value is smaller than +-1e-14 we set it to +-1e-14
    pos = np.where(abs(final_flux) < 1e-10)
    pos_zero = np.where(final_flux == 0)
    if len(pos[0]) > 0:
        final_flux[pos[0]] = 1e-10*np.sign(final_flux[pos[0]])
    if len(pos_zero[0]) > 0:
        final_flux[pos_zero[0]] = 1e-10*np.sign(sol.fluxes.iloc[0,pos_zero[0]])

    log10_flux_fold_changes[sol.model_ix] = np.log10(abs(final_flux/sol.fluxes.iloc[0,:-3]).values.tolist())

# Consider removing extreme outliers
# log2_flux_fold_changes = remove_outliers_parallel(log2_flux_fold_changes, multiplier=2.0, n_jobs=100)

# Drop the degradation flux
# log10_flux_fold_changes = log10_flux_fold_changes.drop(index='exponentrial_degradation')

# Calculate the mean and standard deviation of the log2-fold changes
log10_flux_fold_changes['mean'] = log10_flux_fold_changes.mean(axis=1)
log10_flux_fold_changes['std'] = log10_flux_fold_changes.std(axis=1)

# Calculate the p value of the log2-fold changes
# We want to reject the hypothesis that the mean log2-fold change is 0 (fold change is 1)
log10_flux_fold_changes['p_value'] = ttest_1samp(log10_flux_fold_changes.iloc[:,:-2], 0, axis=1, nan_policy='omit')[1]

# Check if there are p_values that are nan and replace them with 1
# They are probably due to the fact that the standard deviation is 0
std_zero = log10_flux_fold_changes['std'] == 0
log10_flux_fold_changes.loc[std_zero, 'p_value'] = 1


# Calculate the q value of the log2-fold changes
# We need to correct for multiple testing
log10_flux_fold_changes['q_value'] = multipletests(log10_flux_fold_changes['p_value'], method='fdr_bh')[1]

In [None]:
# Define custom colors for the categories
from matplotlib.colors import ListedColormap
colors = ["#991f17", "#b04238", "#c86558", "#df8879", "#a4a2a8", '#b3bfd1']


cmap = ListedColormap(colors)

# This makes sure that the metabolites are sorted by distance
log10_flux_fold_changes['distance'] = distances_reactions
log10_flux_fold_changes = log10_flux_fold_changes.sort_values(by='distance', ascending=False)

# Find the smallest positive q_value
min_positive_q_value = log10_flux_fold_changes[log10_flux_fold_changes['q_value'] > 0]['q_value'].min()

# Replace zero or negative q_values with the smallest positive q_value
log10_flux_fold_changes['q_value'] = log10_flux_fold_changes['q_value'].apply(lambda x: min_positive_q_value if x <= 0 else x)

# Make a volcano plot
fig, ax = plt.subplots(figsize=(12, 9))

# Scatter plot
sc = ax.scatter(log10_flux_fold_changes['mean'], -np.log10(log10_flux_fold_changes['q_value']),
                c=log10_flux_fold_changes['distance'], cmap=cmap, edgecolor='k', alpha=0.8, s=80)

# Add horizontal line for p-value threshold
q_value_threshold = 0.01
ax.axhline(y=-np.log10(q_value_threshold), color='grey', linestyle='--', linewidth=1, alpha=0.8)

# Set a threshold for the log10-fold change, stop at y = 0
fold_change_threshold = 1
ax.vlines(x=fold_change_threshold, ymin=0, ymax=-np.log10(min_positive_q_value)+5, color='grey', linestyle='--', linewidth=1, alpha=0.8)
ax.vlines(x=-fold_change_threshold, ymin=0, ymax=-np.log10(min_positive_q_value)+5, color='grey', linestyle='--', linewidth=1, alpha=0.8)

# Add labels for significant points
significant = log10_flux_fold_changes[(log10_flux_fold_changes['q_value'] < q_value_threshold) &
                                        (abs(log10_flux_fold_changes['mean']) > fold_change_threshold)]
# for i, row in significant.iterrows():
#     ax.text(row['mean'], -np.log10(row['q_value']), i, fontsize=8, ha='right')

# Color bar
cbar = plt.colorbar(sc, ax=ax, ticks=range(5))
cbar.set_ticks(range(6))  # ensure ticks are set
cbar.set_label('Distance from enzymatic target', fontsize=16)
cbar.ax.tick_params(labelsize=14)  # set tick label size
cbar.set_ticklabels(['0', '1', '2', '3', '4', '>4'])

# Labels and title
ax.set_xlabel('mean flux log10-fold change', fontsize=16)
ax.set_ylabel('-log10(q-value)', fontsize=16, labelpad=-50)
# ax.set_title('Volcano Plot of Flux Changes', fontsize=16)

# Add x-y that start at 0
ax.spines['left'].set_position('zero')
ax.spines['bottom'].set_position('zero')

# Remove 0 from y-axis ticks
ax.spines['left'].set_bounds(0, -np.log10(min_positive_q_value)+5)
yticks = ax.get_yticks()
ax.set_yticks([tick for tick in yticks if tick > 0 and tick < -np.log10(min_positive_q_value)+5])

# Remove surrounding box
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

# # --- Labels: top 5 pos NW, top 5 neg SE, with collision avoidance ---

# xmin, xmax = ax.get_xlim()
# ymin, ymax = ax.get_ylim()
# xspan = xmax - xmin
# yspan = ymax - ymin

# # Top 5 positive and negative by mean
# top_pos = log10_flux_fold_changes[log10_flux_fold_changes['mean'] > 0]['mean'].nlargest(3).index
# top_pos = pd.Index(['XYLULte', 'GLCNte'])
# top_neg = log10_flux_fold_changes[log10_flux_fold_changes['mean'] < 0]['mean'].nsmallest(5).index

# # Append custom names if present in the index

# neglog10q = -np.log10(log10_flux_fold_changes.loc[top_pos.union(top_neg), 'q_value'])

# def _adjust_vertical(points, direction="up"):
#     # Enforce minimum vertical spacing to avoid overlaps
#     min_sep = 0.035 * yspan
#     margin = 0.01 * yspan
#     if not points:
#         return points
#     if direction == "up":
#         points.sort(key=lambda d: d["ty"])
#         for i in range(1, len(points)):
#             if points[i]["ty"] < points[i-1]["ty"] + min_sep:
#                 points[i]["ty"] = points[i-1]["ty"] + min_sep
#         max_ty = ymax - margin
#         overflow = points[-1]["ty"] - max_ty
#         if overflow > 0:
#             for i in reversed(range(len(points))):
#                 floor_i = ymin + margin + i * min_sep
#                 shift = min(overflow, points[i]["ty"] - floor_i)
#                 points[i]["ty"] -= max(0, shift)
#                 overflow -= max(0, shift)
#                 if overflow <= 0:
#                     break
#     else:  # direction == "down"
#         points.sort(key=lambda d: d["ty"], reverse=True)
#         for i in range(1, len(points)):
#             if points[i]["ty"] > points[i-1]["ty"] - min_sep:
#                 points[i]["ty"] = points[i-1]["ty"] - min_sep
#         min_ty = ymin + margin
#         underflow = min_ty - points[-1]["ty"]
#         if underflow > 0:
#             for i in range(len(points)):
#                 ceil_i = ymax - margin - (len(points)-1-i) * min_sep
#                 shift = min(underflow, ceil_i - points[i]["ty"])
#                 points[i]["ty"] += max(0, shift)
#                 underflow -= max(0, shift)
#                 if underflow <= 0:
#                     break
#     return points

# # Propose positions
# labels_pos, labels_neg = [], []

# for met in top_pos:
#     x = log10_flux_fold_changes.loc[met, 'mean']
#     y = neglog10q.loc[met]
#     labels_pos.append({
#         "met": met,
#         "x": x, "y": y,
#         "tx": x - 0.03 * xspan,   # NW: left
#         "ty": y + 0.02 * yspan    # NW: above
#     })

# for met in top_neg:
#     x = log10_flux_fold_changes.loc[met, 'mean']
#     y = neglog10q.loc[met]
#     labels_neg.append({
#         "met": met,
#         "x": x, "y": y,
#         "tx": x + 0.03 * xspan,   # SE: right
#         "ty": y - 0.02 * yspan    # SE: below
#     })

# # Resolve overlaps within each group
# labels_pos = _adjust_vertical(labels_pos, direction="up")
# labels_neg = _adjust_vertical(labels_neg, direction="down")

# # Keep labels inside axes
# for group in (labels_pos, labels_neg):
#     for d in group:
#         d["tx"] = np.clip(d["tx"], xmin + 0.01 * xspan, xmax - 0.01 * xspan)
#         d["ty"] = np.clip(d["ty"], ymin + 0.01 * yspan, ymax - 0.01 * yspan)

# # Annotate
# for d in labels_pos:
#     ax.annotate(
#         d["met"], xy=(d["x"], d["y"]), xytext=(d["tx"], d["ty"]),
#         ha='right', va='bottom', fontsize=10,
#         bbox=dict(boxstyle='round,pad=0.2', fc='white', ec='none', alpha=0.8),
#         arrowprops=dict(arrowstyle='-', lw=0.8, alpha=0.6)
#     )

# for d in labels_neg:
#     ax.annotate(
#         d["met"], xy=(d["x"], d["y"]), xytext=(d["tx"], d["ty"]),
#         ha='left', va='top', fontsize=10,
#         bbox=dict(boxstyle='round,pad=0.2', fc='white', ec='none', alpha=0.8),
#         arrowprops=dict(arrowstyle='-', lw=0.8, alpha=0.6)
#     )
# # --- end labels ---



plt.tight_layout()

# Save the plot
plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/volcano_plot_fluxes.pdf',
            transparent=True,
            bbox_inches='tight',
            pad_inches=0.1)

# plt.savefig(f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/volcano_plot_fluxes.svg', format='svg', bbox_inches='tight', transparent=True)

# Show plot
plt.show()

# Print the significant metabolites
significant.sort_values('mean', ascending=True).iloc[:,-5:]

# Create "macroscopic" effect figures (3C, D, E, F)

In [None]:
fluxes_dict = {}
for i, sol in enumerate(flux_solutions):
    fluxes_dict[sol.model_ix] = sol.fluxes.reset_index(drop=True, inplace=False)

In [None]:
def plot_trajectories(total_df, samples_picked, time_indices, t_span, save_path=None):

    # Normalize the data with respect to the first point
    normalized_df = total_df.div(total_df.iloc[0])

    # Initialize a figure
    plt.figure(figsize=(16, 9))
    ax = plt.gca()  # NEW: grab axis
    legend_handles = []

    # Subset the time points based on t_span
    time_indices = time_indices <= t_span

    # Predefine the marker timepoints (hours)
    marker_hours = np.array([0, 1, 2, 4, 8, 12, 20, 29.9], dtype=float)

    # Plot the trajectories for each cluster
    for i, group_id in enumerate(sorted(samples_picked.group.unique())):
        if group_id not in ['0_0', '0_2', '2_0', '2_2']:
            continue

        color = samples_picked.loc[samples_picked.group == group_id].color.iloc[0]

        steady_states = samples_picked.loc[samples_picked.group == group_id].index
        models = [flux_sol.model_ix for flux_sol in flux_solutions
                  if int(flux_sol.model_ix.split(',')[0]) in steady_states]

        if len(models) == 0:
            print('No models for group {}'.format(group_id))
            continue

        # Extract the data for the current cluster
        cluster_data = normalized_df.loc[:, models]

        # Calculate aggregates
        average_trajectory = cluster_data.mean(axis=1)
        std_deviation = cluster_data.std(axis=1)
        percentile_25 = cluster_data.quantile(0.25, axis=1)
        percentile_75 = cluster_data.quantile(0.75, axis=1)

        # Find the trajectory closest to the average
        distances = cluster_data.apply(lambda col: np.linalg.norm(col - average_trajectory), axis=0)
        closest_trajectory = cluster_data.loc[:, distances.idxmin()]

        # Time and data restricted to the visible span
        t_vis = flux_solutions[0].time[time_indices]
        y_vis = closest_trajectory[time_indices]

        # Plot the closest trajectory
        group_name = samples_picked.loc[samples_picked.group == group_id].group_levels.iloc[0]
        plt.plot(t_vis, y_vis, color=color, label=f'Cluster {group_name}', linewidth=3)

        # Plot the error bounds (25 to 75 percentile)
        plt.fill_between(t_vis,
                         percentile_25[time_indices],
                         percentile_75[time_indices],
                         color=color, alpha=0.08)

        # NEW: add markers only on the trajectory at specified hours (within range and span)
        mh = marker_hours[(marker_hours >= t_vis.min()) & (marker_hours <= min(t_span, t_vis.max()))]
        if mh.size > 0:
            # find nearest indices in t_vis for each marker hour
            t_vis_np = np.asarray(t_vis, dtype=float)
            idxs = np.abs(t_vis_np[:, None] - mh[None, :]).argmin(axis=0)
            plt.scatter(t_vis_np[idxs],
                        np.asarray(y_vis)[idxs],
                        s=40, marker='o', edgecolors=color, facecolors='white', linewidths=2, zorder=3)

        # Add legend handles
        legend_handles.append(plt.Line2D([0], [0], color=color, lw=2, label=f'Cluster {group_name}'))

    # Titles, labels
    from matplotlib.ticker import FormatStrFormatter  # ADD
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    plt.title(f'{reaction}', fontsize = 30)
    plt.xlabel('Time (hours)', fontsize=25)
    plt.ylabel(r'Normalized Flux', fontsize=25)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlim(0, t_span+0.1)
    # plt.ylim(0.99, 1.01)
    plt.tight_layout()

    # NEW: add discrete dashed line at t=8 h (drug effect finishes)
    ax.axvline(x=8, color='gray', linestyle='--', linewidth=1, alpha=0.8, zorder=0)


    # NEW: remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Save the plot if save_path is provided
    plt.rcParams['pdf.fonttype'] = 42  # For PDF
    plt.rcParams['svg.fonttype'] = 'none'  # For SVG

    if save_path:
        plt.savefig(save_path, transparent=True, bbox_inches='tight', pad_inches=0.1)
    plt.show()

In [None]:
t_span = 20.1

lac_rxns = []
for i in tmodel.metabolites.lac_L_e.reactions:
    lac_rxns.append(i.id)
lac_rxns.remove('EX_lac_L_e')

values = {}

for ix in samples_to_simulate:
    # We add up all the fluxes that are connected with lac_L_e.
    # We use abs values because every reaction secretes lac_L_e
    try:
        values[ix] = fluxes_dict[ix][lac_rxns].abs().sum(axis=1)
    except KeyError:
        pass
total_df = pd.DataFrame(values)

reaction = 'Lactate Secretion Rate'


plot_trajectories(total_df, samples_picked, flux_solutions[0].time, t_span, save_path=f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/lactate_secretion_trajectories.pdf')

In [None]:
t_span = 20.1

values = {}

for ix in samples_to_simulate:
    try:
        values[ix] = fluxes_dict[ix]['O2t']
    except KeyError:
        pass
total_df = pd.DataFrame(values)


reaction = 'Oxygen Uptake Rate'


plot_trajectories(total_df, samples_picked, flux_solutions[0].time, t_span, save_path=f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/oxygen_uptake_trajectories.pdf')

In [None]:
t_span = 20.1

values = {}

for ix in samples_to_simulate:
    try:
        values[ix] = fluxes_dict[ix]['ATPS4mi']
    except KeyError:
        pass
total_df = pd.DataFrame(values)


reaction = 'Mitochondrial ATP synthesis'


plot_trajectories(total_df, samples_picked, flux_solutions[0].time, t_span, save_path=f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/oxphos_ATP_synthesis.pdf')

In [None]:
t_span = 20.1

values = {}

for ix in samples_to_simulate:
    try:
        glc_atp_rxns = ['PYK', 'PGK']
        values[ix] = fluxes_dict[ix][glc_atp_rxns].abs().sum(axis=1)
    except KeyError:
        pass
total_df = pd.DataFrame(values)


reaction = 'Glycolytic ATP synthesis'


plot_trajectories(total_df, samples_picked, flux_solutions[0].time, t_span, save_path=f'../../results/drug_target_simulation/{PHYSIOLOGY}_stratified/{TARGET_NAME}/glycolytic_ATP_synthesis.pdf')

# Metabolite pathway enrichement analysis

In [None]:
# Calculate the log2-fold changes for each metabolite in each solution
log2_fold_changes = pd.DataFrame(index=solutions[0].concentrations.columns)
for i, sol in enumerate(solutions):
    if sol.model_ix in malignant_models:
        continue
    final_conc = sol.concentrations.iloc[-1,:]
    # If any value is smaller than 1e-15 we set it to 1e-15
    final_conc[final_conc < 1e-15] = 1e-15
    log2_fold_changes[sol.model_ix] = np.log2(final_conc/sol.concentrations.iloc[0,:])

# Drop the enzyme column
columns_from_E = [i for i in log2_fold_changes.index if i.startswith('E_')]
log2_fold_changes = log2_fold_changes.drop(index=columns_from_E)

# Calculate the mean and standard deviation of the log2-fold changes
log2_fold_changes['mean'] = log2_fold_changes.mean(axis=1)
log2_fold_changes['std'] = log2_fold_changes.std(axis=1)

# Calculate the p value of the log2-fold changes 
# We want to reject the hypothesis that the mean log2-fold change is 0 (fold change is 1)
log2_fold_changes['p_value'] = ttest_1samp(log2_fold_changes.iloc[:,:-2], 0, axis=1, nan_policy='omit')[1]

# Calculate the q value of the log2-fold changes
# We need to correct for multiple testing
log2_fold_changes['q_value'] = multipletests(log2_fold_changes['p_value'], method='fdr_bh')[1]
log2_fold_changes.to_csv(path_to_conc_fold_changes.format(TARGET_NAME))

In [None]:
# Find all the subsystems of the model that are connected with 5 or more reactions
subsystems = pd.Series(dtype='int')
for rxn in tmodel.reactions:
    if rxn.subsystem in subsystems:
        subsystems[rxn.subsystem] += 1
    else:
        subsystems[rxn.subsystem] = 1
    
subsystems = subsystems[subsystems >= 5].index

# For each subsystems find the metabolites that participate in it and find how many have a significant change
subsystem_changes = pd.DataFrame(0, index=subsystems, columns=['significant_changes', 'total_metabolites'])
q_value_threshold = 0.01
fold_change_threshold = 1


for met_id in log2_fold_changes.index: # For simulated metabolites
    mean_val = log2_fold_changes.loc[met_id, 'mean']
    q_val = log2_fold_changes.loc[met_id, 'q_value']
    if met_id.startswith('_'):
        met_id = met_id[1:]
    met = tmodel.metabolites.get_by_id(met_id)
    for rxn in met.reactions: # For each reaction that the metabolite participates in
        if rxn.subsystem in subsystem_changes.index:
            subsystem_changes.loc[rxn.subsystem, 'total_metabolites'] += 1
            if abs(mean_val) > fold_change_threshold and q_val < q_value_threshold:
                subsystem_changes.loc[rxn.subsystem, 'significant_changes'] += 1

# Perform hypergeometric test to find the significance of the changes
from scipy.stats import hypergeom

tot_significant_changes = len(log2_fold_changes[(log2_fold_changes['q_value'] < q_value_threshold) & 
                                                (abs(log2_fold_changes['mean']) > fold_change_threshold)])
for i, row in subsystem_changes.iterrows():
    subsystem_changes.loc[i, 'p_value'] = hypergeom.sf(row['significant_changes']-1, len(log2_fold_changes), row['total_metabolites'], tot_significant_changes)

subsystem_changes[subsystem_changes['p_value'] < 0.01]
subsystem_changes.to_csv(path_to_metabolite_enrichment_analysis.format(TARGET_NAME))


# Flux pathway enrichment analysis

In [None]:
# Calculate the log2-fold changes for each metabolite in each solution
log10_flux_fold_changes = pd.DataFrame(index=flux_solutions[0].fluxes.columns[0:-3])

for i, sol in enumerate(flux_solutions):
    if sol.model_ix in malignant_models:
        continue
    final_flux = sol.fluxes.iloc[-1,:-3]
    # If any value is smaller than +-1e-14 we set it to +-1e-14
    pos = np.where(abs(final_flux) < 1e-10)
    pos_zero = np.where(final_flux == 0)
    if len(pos[0]) > 0:
        final_flux[pos[0]] = 1e-10*np.sign(final_flux[pos[0]])
    if len(pos_zero[0]) > 0:
        final_flux[pos_zero[0]] = 1e-10*np.sign(sol.fluxes.iloc[0,pos_zero[0]])

    log10_flux_fold_changes[sol.model_ix] = np.log10(abs(final_flux/sol.fluxes.iloc[0,:-3]).values.tolist())

# Calculate the mean and standard deviation of the log2-fold changes
log10_flux_fold_changes['mean'] = log10_flux_fold_changes.mean(axis=1)
log10_flux_fold_changes['std'] = log10_flux_fold_changes.std(axis=1)

# Calculate the p value of the log2-fold changes
# We want to reject the hypothesis that the mean log2-fold change is 0 (fold change is 1)
from scipy.stats import ttest_1samp
log10_flux_fold_changes['p_value'] = ttest_1samp(log10_flux_fold_changes.iloc[:,:-2], 0, axis=1, nan_policy='omit')[1]

# Check if there are p_values that are nan and replace them with 1
# They are probably due to the fact that the standard deviation is 0
std_zero = log10_flux_fold_changes['std'] == 0
log10_flux_fold_changes.loc[std_zero, 'p_value'] = 1


# Calculate the q value of the log2-fold changes
# We need to correct for multiple testing
log10_flux_fold_changes['q_value'] = multipletests(log10_flux_fold_changes['p_value'], method='fdr_bh')[1]
log10_flux_fold_changes.to_csv(path_to_flux_fold_changes.format(TARGET_NAME))

In [None]:
# Find all the subsystems of the model that are connected with 5 or more reactions
subsystems = pd.Series(dtype='int')
for rxn in tmodel.reactions:
    if rxn.subsystem in subsystems:
        subsystems[rxn.subsystem] += 1
    else:
        subsystems[rxn.subsystem] = 1
    
subsystems = subsystems[subsystems >= 5].index

# For each subsystems find the metabolites that participate in it and find how many have a significant change
subsystem_changes_flux = pd.DataFrame(0, index=subsystems, columns=['significant_changes', 'total_reactions'])
q_value_threshold = 0.01
fold_change_threshold = np.log10(2)

for rxn_id in log10_flux_fold_changes.index: # For simulated reactions
    mean_val = log10_flux_fold_changes.loc[rxn_id, 'mean']
    q_val = log10_flux_fold_changes.loc[rxn_id, 'q_value']
    if rxn_id.startswith('_'):
        rxn_id = rxn_id[1:]
    rxn = tmodel.reactions.get_by_id(rxn_id)
    if rxn.subsystem in subsystem_changes_flux.index:
        subsystem_changes_flux.loc[rxn.subsystem, 'total_reactions'] += 1
        if abs(mean_val) > fold_change_threshold and q_val < q_value_threshold:
            subsystem_changes_flux.loc[rxn.subsystem, 'significant_changes'] += 1


# Perform hypergeometric test to find the significance of the changes``
from scipy.stats import hypergeom
tot_significant_changes_fluxes = len(log10_flux_fold_changes[(log10_flux_fold_changes['q_value'] < q_value_threshold) & 
                                                (abs(log10_flux_fold_changes['mean']) > fold_change_threshold)])

for i, row in subsystem_changes_flux.iterrows():
    subsystem_changes_flux.loc[i, 'p_value'] = hypergeom.sf(row['significant_changes']-1, len(log10_flux_fold_changes), row['total_reactions'], tot_significant_changes_fluxes)

subsystem_changes_flux[subsystem_changes_flux['p_value'] < 0.01]