In [None]:
from pytfa.io.json import load_json_model
from skimpy.io.yaml import load_yaml_model
from skimpy.core.solution import ODESolutionPopulation
from skimpy.core.parameters import ParameterValuePopulation, \
    load_parameter_population
from skimpy.utils.namespace import QSSA
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import multiprocessing as mp
from skimpy.core.reactions import Reaction
import numpy as np
import configparser
import os
import sys
sys.path.append('../')
from utils.drug_ode_simulation import run_simulation_ic50, simulate_sample, ODESolution_ic50, produce_biomass_df, CellViabilitySolution
from utils.make_flux_fun_parallel import make_flux_fun_parallel
from utils.enzyme_degradation_class import make_enzymedegradation

In [None]:
TIME = np.linspace(0, 600, 5) # 20-30 times the doubling time of the cell
PHYSIOLOGY = 'WT'
TARGETS = ['HEX1', 'r0354', 'r0355']
TARGET_NAME = 'HEX'

# Scaling parameters
config = configparser.ConfigParser()
config_path = os.path.join('../src/config.ini')
config.read(config_path)
CONCENTRATION_SCALING = float(config['scaling']['CONCENTRATION_SCALING'])  # 1 mol to 1 mumol
TIME_SCALING = float(config['scaling']['TIME_SCALING'])  # 1 hour to 1 min
DENSITY = float(config['scaling']['DENSITY'])  # g/L
GDW_GWW_RATIO = float(config['scaling']['GDW_GWW_RATIO'])  # Assumes 75% Water
flux_scaling_factor = 1e-3 * (GDW_GWW_RATIO * DENSITY) * CONCENTRATION_SCALING / TIME_SCALING

# NCPU
ncpu = int(config['global']['ncpu'])

# ODE parameters
time_limit = float(config['drug_metabolism']['time_limit'])
rtol = float(config['drug_metabolism']['rtol'])
atol = float(config['drug_metabolism']['atol'])

# Paths from config.ini using PHYSIOLOGY variable
base_dir = config['paths']['base_dir']
path_to_kmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_kmodel_{PHYSIOLOGY}']))
path_to_tmodel = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_tmodel_{PHYSIOLOGY}']))
path_to_samples = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_samples_{PHYSIOLOGY}']))
path_to_params = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_param_output_{PHYSIOLOGY}']))
path_to_max_eig = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_lambda_values_{PHYSIOLOGY}']))
path_to_stratified_samples = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_stratified_samples_{PHYSIOLOGY}']))
path_to_stratified_params = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_stratified_params_{PHYSIOLOGY}']))
path_to_cell_viability_solutions = os.path.abspath(os.path.join(base_dir, config['paths'][f'path_to_cell_viability_solutions_{PHYSIOLOGY}']))

print('Loading kinetic model from:', path_to_kmodel)
kmodel = load_yaml_model(path_to_kmodel)
kmodel.prepare(mca=False)
from utils.precompile_nonlinear_funcs import precompile_ode, make_flux_fun_parallel
precompile_ode(kmodel, QSSA, ncpu=110, expressions_file=f'../../data/kin_logs/tmp_kmodel_expressions_{PHYSIOLOGY}.pkl', path_to_so_file=f'../../data/kin_logs/tmp_kmodel_ode_function_{PHYSIOLOGY}.so')


# Add the 

# List of parameter values
parameter_population = load_parameter_population(path_to_stratified_params)
samples_to_simulate = list(parameter_population._index.keys())
a_vals = [0.999, 0.99, 0.9, 0.6, 0.5, 0.4, 0.3, 0.2, 0.15, 0.1, 0.05, 0.01, 0.001, 1e-5, 1e-9]
ix_values = [i for i in samples_to_simulate]

# Create a list of (ix, a) tuples
args = [(ix, a) for ix in ix_values for a in a_vals]

In [None]:
# Load all the solutions and make them into ODESolution_ic50 objects
import os
from tqdm import tqdm
import pickle
ic50_solutions = []
for arg in tqdm(args):
    model_ix, a_degradation = arg
    filename = path_to_cell_viability_solutions.format(TARGET_NAME, model_ix, a_degradation)
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            res = pickle.load(f)
        ic50_solutions.append(ODESolution_ic50(res[0][1], res[0][0], model_ix, a_degradation))


# Function to initialize the pool
def init_pool(kmodel):
    global kmodel_
    kmodel_ = kmodel

fluxes_dict_ic50 = {}
results = []
# Run in parallel using multiprocessing.Pool and imap

if __name__ == '__main__':
    def biomass_df_wrapper(res):
        return produce_biomass_df(res, kmodel, TARGETS, parameter_population)

    with mp.Pool(processes=int(ncpu//5), initializer=init_pool, initargs=(kmodel,)) as pool:
        results = list(tqdm(pool.imap(biomass_df_wrapper, ic50_solutions), total=len(ic50_solutions)))

# Save the fluxes
flux_ic50_solutions = []
for flux_sol in results:
    ic50_sol = [sol for sol in ic50_solutions if sol.model_ix == flux_sol[0] and sol.a_degradation == flux_sol[1]][0] # just to take the time points
    flux_ic50_solutions.append(CellViabilitySolution(flux_sol[2], ic50_sol.time, flux_sol[0], flux_sol[1]))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from tqdm.auto import tqdm
import seaborn as sns

# Define the 3-parameter logistic function
def three_param_logistic(c, E_inf, EC50, HS):
    return E_inf + (1 - E_inf) / (1 + np.exp(HS * (c - EC50)))

# Set lower bounds (0) for E_inf, EC50, and HS
lower_bounds = (0, 0, 0)
upper_bounds = (1, np.inf, 10)  # E_inf should be in [0,1]

# Data for fitting
xdata = [-np.log(x) for x in a_vals]

# Initialize lists to store the fitted parameters
E_inf_values = []
EC50_values = []
HS_values = []

# Iterate over each model and perform the parameter estimation
for i in ix_values:
    # plt.figure(figsize=(10, 5))
    same_model_flux_sols = [element for element in flux_ic50_solutions if element.model_ix == i]
    # if empty continue
    if not same_model_flux_sols:
        continue
    cell_viability_values = []
    stop_flag = False
    for sol in same_model_flux_sols:
        # If there are not 10 rows then skip this model
        if len(sol.fluxes) < 10:
            print(f'Skipping model {i} with a = {sol.a_degradation} due to insufficient data points.')
            stop_flag = True
        else:
            cell_viability_values.append(sol.fluxes['biomass'].iloc[-1] / sol.fluxes['biomass'].iloc[0])
            print(f'ix = {i}, a = {sol.a_degradation}: {sol.fluxes["biomass"].iloc[-1] / sol.fluxes["biomass"].iloc[0]:.2f}')
    if stop_flag:
        continue

    # Fit the logistic function
    popt, pcov = curve_fit(three_param_logistic, xdata, cell_viability_values, bounds=(lower_bounds, upper_bounds))


    # Extract fitted parameters
    E_inf_fit, EC50_fit, HS_fit = popt
    E_inf_values.append(E_inf_fit)
    EC50_values.append(np.exp(-EC50_fit))  # Convert back to linear space
    HS_values.append(HS_fit)
    print(f'Model ix = {i}')
    print(f'E_inf = {E_inf_fit:.2f}')
    print(f'EC50 in linear space = {np.exp(-EC50_fit):.2f}')
    print(f'HS = {HS_fit:.2f}')

    # Find the R^2 value
    r2 = r2_score(cell_viability_values, three_param_logistic(xdata, *popt))
    print(f'R^2 = {r2:.2f}')
    print('---------------------------------')

# Calculate and print the range of the parameters
print(f'E_inf range: {min(E_inf_values):.2f} - {max(E_inf_values):.2f}')
print(f'EC50 range in linear space: {min(EC50_values):.2f} - {max(EC50_values):.2f}')
print(f'HS range: {min(HS_values):.2f} - {max(HS_values):.2f}')    


In [None]:
# Remove outliers from the paramters produced
E_inf_values = pd.Series(E_inf_values)
EC50_values = pd.Series(EC50_values)
HS_values = pd.Series(HS_values)

from utils.remove_outliers import remove_outliers_row
E_inf_values = remove_outliers_row(E_inf_values)
EC50_values = remove_outliers_row(EC50_values)
HS_values = remove_outliers_row(HS_values)

In [None]:
# Calculate and print the range of the parameters and the median values
print(f'E_inf range: {min(E_inf_values):.2f} - {max(E_inf_values):.2f}' + f' (median: {E_inf_values.median():.2f})')
print(f'EC50 range in linear space: {min(EC50_values):.2f} - {max(EC50_values):.2f}' + f' (median: {EC50_values.median():.2f})')
print(f'HS range: {min(HS_values):.2f} - {max(HS_values):.2f}' + f' (median: {HS_values.median():.2f})')
print('')
# Print average value +- a standard deviation
print(f'E_inf average: {E_inf_values.mean():.2f} +- {E_inf_values.std():.2f}')
print(f'EC50 average in linear space: {EC50_values.mean():.2f} +- {EC50_values.std():.2f}')
print(f'HS average: {HS_values.mean():.2f} +- {HS_values.std():.2f}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from tqdm.auto import tqdm
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm as cm

# Define the 3-parameter logistic function
def three_param_logistic(c, E_inf, EC50, HS):
    return E_inf + (1 - E_inf) / (1 + np.exp(HS * (c - EC50)))

# Set lower bounds (0) for E_inf, EC50, and HS
lower_bounds = (0, 0, 0)
upper_bounds = (1, np.inf, 10)  # E_inf should be in [0,1]

# Data for fitting
xdata = [-np.log(x) for x in a_vals]

# Initialize a figure for the plot
plt.figure(figsize=(16, 9))
plt.xscale('log')

# Storage for all fitted curves and parameters
all_fitted_curves = []
all_parameters = []
all_cell_viability_data = []

# First pass: collect all valid model data
for i in ix_values:
    same_model_flux_sols = [element for element in flux_ic50_solutions if element.model_ix == i]
    cell_viability_values = []
    stop_flag = False
    
    for sol in same_model_flux_sols:
        if len(sol.fluxes) < 10:
            print(f'Skipping model {i} with a = {sol.a_degradation} due to insufficient data points.')
            stop_flag = True
            break
        cell_viability_values.append(sol.fluxes['biomass'].iloc[-1] / sol.fluxes['biomass'].iloc[0])
    
    if stop_flag:
        continue

    try:
        # Fit the logistic function
        popt, pcov = curve_fit(three_param_logistic, xdata, cell_viability_values, bounds=(lower_bounds, upper_bounds))
        E_inf_fit, EC50_fit, HS_fit = popt
        
        # Check parameter bounds
        if not (min(E_inf_values) <= E_inf_fit <= max(E_inf_values) and
                min(EC50_values) <= np.exp(-EC50_fit) <= max(EC50_values) and
                min(HS_values) <= HS_fit <= max(HS_values)):
            print(f'Skipping model {i} due to out of range parameters.')
            continue

        # Find the R^2 value
        r2 = r2_score(cell_viability_values, three_param_logistic(xdata, *popt))
        
        print(f'Model ix = {i}')
        print(f'E_inf = {E_inf_fit:.2f}')
        print(f'EC50 in linear space = {np.exp(-EC50_fit):.2f}')
        print(f'HS = {HS_fit:.2f}')
        print(f'R^2 = {r2:.2f}')
        print('---------------------------------')
        
        # Store the data
        x_smooth = np.linspace(min(xdata), max(xdata), 500)
        fitted_curve = three_param_logistic(x_smooth, *popt)
        all_fitted_curves.append(fitted_curve)
        all_parameters.append([E_inf_fit, EC50_fit, HS_fit])
        all_cell_viability_data.append(cell_viability_values)
        
    except Exception as e:
        print(f'Failed to fit model {i}: {e}')
        continue

# Convert to numpy arrays for easier manipulation
all_fitted_curves = np.array(all_fitted_curves)
all_parameters = np.array(all_parameters)

if len(all_fitted_curves) == 0:
    print("No valid curves found!")
    plt.show()
    exit()

# Find the envelope (min and max at each x point)
envelope_min = np.min(all_fitted_curves, axis=0)
envelope_max = np.max(all_fitted_curves, axis=0)

# Calculate average parameters
avg_E_inf = np.mean(all_parameters[:, 0])
avg_EC50 = np.mean(all_parameters[:, 1])
avg_HS = np.mean(all_parameters[:, 2])

# Generate the average curve
x_smooth = np.linspace(min(xdata), max(xdata), 500)
avg_curve = three_param_logistic(x_smooth, avg_E_inf, avg_EC50, avg_HS)

# Plot the envelope (shadow between furthest curves)
plt.fill_between(x_smooth, envelope_min, envelope_max, 
                 color='wheat', alpha=0.3, zorder=1)

# Plot all data points in grey
for cell_viability_values in all_cell_viability_data:
    plt.plot(xdata, cell_viability_values, 'o', 
             color='grey', alpha=0.3, markersize=3, zorder=2)

# Plot the average curve prominently
plt.plot(x_smooth, avg_curve, 
         color='saddlebrown', linewidth=3, zorder=3)

# Style improvements - using default colors

# Replace the x ticks with the original values
ticks_to_show = [0, 1, 2, 6, 12, 14]
plt.xticks([xdata[i] for i in ticks_to_show],
           [a_vals[i] for i in ticks_to_show])

# Add labels and title with improved styling
plt.xlabel('Normalized Enzyme Concentration', fontsize=25)
plt.xticks(fontsize=20)
plt.ylabel('Cell Viability', fontsize=25)
plt.yticks(fontsize=20)

# Add parameter statistics box (using median +- std from original data)
average_E_inf = E_inf_values.mean()
average_EC50 = EC50_values.mean()
average_HS = HS_values.mean()

plt.text(0.05, 0.05, f'$E_{{\\infty}}$: {average_E_inf:.2f} ± {E_inf_values.std():.2f}\n'
f'$EC_{{50}}$: {average_EC50:.2f} ± {EC50_values.std():.2f}\n'
f'$HS$: {average_HS:.2f} ± {HS_values.std():.2f}',
transform=plt.gca().transAxes, fontsize=20,
bbox=dict(facecolor='white', alpha=0.5, edgecolor='black'))

plt.ylim(0,1.25)
# Remove up and right spines
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)

plt.tight_layout()
plt.show()