In [1]:
# Imports
import time
from math import ceil
from joblib import Parallel, delayed
import multiprocessing
import matplotlib.pyplot as plt
import seaborn as sns 
import pandas as pd
from scipy.integrate import solve_ivp
import numpy as np
from scipy.stats import norm
import warnings
from datetime import datetime

df_Rahil = pd.read_excel("../Data/References/Rahil_2020/Table8.xlsx", usecols=list(range(0, 143)))
df_Rahil = df_Rahil.loc[df_Rahil['DAY'] <= 8]
all_ids = df_Rahil['VOLUNTEER'].unique()

In [15]:
class TimeManager:
    def __init__(self):
        self.start_time = time.time()

    def check_timeout(self, timeout: int) -> bool:
        return (time.time() - self.start_time) > timeout

    def reset_start_time(self):
        self.start_time = time.time()
    
    def get_elapsed_time(self) -> float:
        return time.time() - self.start_time

class Parameter:
    def __init__(self, name, val, l_lim=None, u_lim=None, dist='uniform', mode='fixed', space='log10'):
        self.name = name
        self.val = val
        self.l_lim = l_lim
        self.u_lim = u_lim
        self.dist = dist
        self.mode = mode
        self.space = space

    def __repr__(self):
        return f"Parameter(name={self.name}, val={self.val}, bounds=({self.l_lim}, {self.u_lim}), dist={self.dist}, mode={self.mode}, space={self.space})"

    def sample_value(self):
        if self.dist == 'uniform':
            return np.random.uniform(self.l_lim, self.u_lim)
        elif self.dist == 'loguniform':
            return 10**np.random.uniform(np.log10(self.l_lim), np.log10(self.u_lim))
        else:
            raise ValueError(f"Unknown distribution type: {self.dist}")

class Parameters:
    def __init__(self, **kwargs):
        self._parameters = kwargs

    def __getattr__(self, item):
        if item in self._parameters:
            return self._parameters[item]
        raise AttributeError(f"'Parameters' object has no attribute '{item}'")  # Base case to prevent recursion

    def __setattr__(self, key, value):
        if key == '_parameters':
            super().__setattr__(key, value)
        else:
            self._parameters[key] = value

    def __getstate__(self):
        # Return the state to be pickled
        return self._parameters

    def __setstate__(self, state):
        # Restore state from the unpickled state
        self._parameters = state

    def get_sampled_parameters(self):
        return {name: param for name, param in self._parameters.items() if param.mode == 'sample'}

    def items(self):
        return self._parameters.items()

    def __repr__(self):
        return f"Parameters({', '.join([f'{k}={v}' for k, v in self._parameters.items()])})"

    def load_parameters_from_file(self, path, usecols=None):
        df = pd.read_excel(path, usecols=usecols)
        
        # Filter out rows where 'ID' contains non-numeric values like 'Mean', 'STD', etc.
        df = df[~df['id'].astype(str).str.contains('Mean|STD|Min|Max', regex=True)]
        
        file_parameters = []

        # Loop through each row in the DataFrame
        for index, row in df.iterrows():
            param_dict = {}
            
            # Loop through each parameter in the initial template
            for name, param in self._parameters.items():
                if param.mode == 'file':
                    if name in df.columns and not pd.isna(row[name]):
                        # Use the CSV value if the column exists and the value is not NaN
                        param_dict[name] = Parameter(name=name, val=row[name], l_lim=param.l_lim, u_lim=param.u_lim,
                                                     dist=param.dist, mode='fixed', space=param.space)
                    else:
                        # Use the default value from the initial template
                        print(f'Parameter {name} in file mode resorted to the default fixed value.')
                        param_dict[name] = param
                else:
                    param_dict[name] = param

            # Create a new Parameters object for this row
            file_parameters.append(Parameters(**param_dict))

        return file_parameters

class State:
    def __init__(self, label, initial_value=0.0):
        self.label = label
        self.initial_value = initial_value
        self.time_points = np.array([0.0])
        self.values = np.array([initial_value])

    def update_value(self, t, new_value):
        self.time_points = np.append(self.time_points, t)
        self.values = np.append(self.values, new_value)

    def get_latest_value(self):
        return self.values[-1]

    def get_value_at(self, t_delay):
        return np.interp(t_delay, self.time_points, self.values, left=self.initial_value, right=self.values[-1])

    def reset(self):
        self.time_points = np.array([0.0])
        self.values = np.array([self.initial_value])

class States:
    def __init__(self, states_config):
        self.states = {config['label']: State(**config) for config in states_config}
        self.tau = {}
        self.state_labels = [state.label for state in self.states.values()]

    def get_current_values_as_array(self):
        return np.array([self.states[label].get_latest_value() for label in self.state_labels])

    def get_delayed_state(self, state_label):
        return self.tau.get(state_label, 0)

    def __getattr__(self, name):
        if name in self.states:
            return self.states[name].get_latest_value()
        elif name in self.tau:
            return self.tau[name]
        else:
            raise AttributeError(f"'States' object has no attribute '{name}'")

    def update_states(self, t, new_values):
        # Vectorized state update
        for label, value in zip(self.state_labels, new_values):
            self.states[label].update_value(t, value)

    def calculate_delayed_states(self, t, delays, p):
        # Batch process delayed state calculations
        for delay_info in delays:
            tau, dependent_state, affecting_state = delay_info
            delay_time = getattr(p, tau).val
            delayed_value = self.states[affecting_state].get_value_at(t - delay_time)
            delayed_state_label = f"{affecting_state}_{dependent_state}"
            self.tau[delayed_state_label] = delayed_value

    def reset_delayed_states(self):
        self.tau.clear()

def SmartSolve(task):
    funx, param_set, states_config, t_span = task
    states = States(states_config)
    y_initial = states.get_current_values_as_array()
    solvers_with_timeouts = [('RK45', 1.5), ('BDF', 5.0)]

    time_manager = TimeManager() # Process hang monitoring init
    warnings.filterwarnings("ignore", message="The following arguments have no effect for a chosen solver:*") # Suppress warnings
    
    # Custom event for overflow detection
    def overflow_event(t, y):
        return 1E12 - max(abs(yi) for yi in y)  
    overflow_event.terminal = True
    overflow_event.direction = -1

    if np.shape(t_span)[0] > 2:
        t_eval = t_span[1:-1]  # Assuming this should include all but the first and last for evaluation
        t_span = [t_span[0], t_span[-1]]
    else:
        t_eval = None
    
    for method, timeout in solvers_with_timeouts:
        try:
            sol = solve_ivp(
                fun=lambda t, y: funx(t, y, param_set, states, time_manager, timeout=timeout),
                t_span=t_span, y0=y_initial, method=method, t_eval=t_eval, dense_output=False,
                events=overflow_event, vectorized=True, rtol=1e-5, atol=1e-6, max_step=0.1, min_step=1e-5
            )
            print('test')
            if sol.success:
                elapsed_time = time_manager.get_elapsed_time()
                return sol, elapsed_time, param_set
            else:
                continue
        except TimeoutError:
            continue
        except Exception:
            continue

    return None, time_manager.get_elapsed_time(), param_set

def batched_parallel_execution(funx, parameter_sets, states_config, t_span, min_tasks_per_core=1):
    total_cores = max(1, multiprocessing.cpu_count() - 2)
    total_tasks = len(parameter_sets)
    optimal_batches = max(1, min(total_cores, ceil(total_tasks / min_tasks_per_core)))
    batch_size = ceil(total_tasks / optimal_batches)

    batches = [parameter_sets[i:i + batch_size] for i in range(0, total_tasks, batch_size)]

    def process_batch(batch):
        results = []
        for param_set in batch:
            task = (funx, param_set, states_config, t_span)
            result = SmartSolve(task)
            results.append(result)
        return results

    cpu_count = optimal_batches
    aggregated_results = Parallel(n_jobs=cpu_count)(
        delayed(process_batch)(batch) for batch in batches
    )

    return [item for sublist in aggregated_results for item in sublist]

def execute(funx, parameter_sets, states_config, t_span, parallel=True):
    sol_list = []
    integration_times = []
    used_param_sets = []

    def handle_solution(sol, integration_time, param_set, sol_list, integration_times, used_param_sets):
            sol_list.append(sol)
            integration_times.append(integration_time)
            used_param_sets.append(param_set)

    if parallel:
        results = batched_parallel_execution(funx, parameter_sets, states_config, t_span)

        for result in results:
            sol, integration_time, param_set = result
            handle_solution(sol, integration_time, param_set, sol_list, integration_times, used_param_sets)
    else:
        for param_set in parameter_sets:
            result = SmartSolve((funx, param_set, states_config, t_span))
            sol, integration_time, param_set = result
            handle_solution(sol, integration_time, param_set, sol_list, integration_times, used_param_sets)

    results_dict = {
        "run_id": datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
        "parameter_sets": used_param_sets,
        "sol_list": sol_list,
        "integration_times": integration_times
    }

    return results_dict

def load_parameters_from_file(self, path, usecols=None):
        df = pd.read_excel(path, usecols=usecols)
        
        # Filter out rows where 'ID' contains non-numeric values like 'Mean', 'STD', etc.
        df = df[~df['id'].astype(str).str.contains('Mean|STD|Min|Max', regex=True)]
        
        file_parameters = []

        # Loop through each row in the DataFrame
        for index, row in df.iterrows():
            param_dict = {}
            
            # Loop through each parameter in the initial template
            for name, param in self._parameters.items():
                if param.mode == 'file':
                    if name in df.columns and not pd.isna(row[name]):
                        # Use the CSV value if the column exists and the value is not NaN
                        param_dict[name] = Parameter(name=name, val=row[name], l_lim=param.l_lim, u_lim=param.u_lim,
                                                     dist=param.dist, mode='fixed', space=param.space)
                    else:
                        # Use the default value from the initial template
                        print(f'Parameter {name} in file mode resorted to the default fixed value.')
                        param_dict[name] = param
                else:
                    param_dict[name] = param

            # Create a new Parameters object for this row
            file_parameters.append(Parameters(**param_dict))

        return file_parameters

def generate_parameter_sets(n_sets, parameters, include_initial=False):
    sampled_params = parameters.get_sampled_parameters()
    
    num_vars = len(sampled_params)
    param_names = list(sampled_params.keys())
    samples = np.zeros((n_sets, num_vars))
    
    # Generate samples for each parameter
    for i, param_name in enumerate(param_names):
        param = sampled_params[param_name]
        if param.space == 'log10':
            low, high = np.log10(param.l_lim), np.log10(param.u_lim)
        else:
            low, high = param.l_lim, param.u_lim

        if param.dist == 'uniform':
            samples[:, i] = np.random.uniform(low, high, n_sets)
        elif param.dist == 'normal':
            mean, stddev = (low + high) / 2, (high - low) / 6
            samples[:, i] = norm.rvs(loc=mean, scale=stddev, size=n_sets)

        if param.space == 'log10':
            samples[:, i] = 10 ** samples[:, i]
            
    parameter_sets = []
    
    for set_row in samples:
        # Create a dictionary for the new parameter values
        new_param_dict = parameters._parameters.copy()  # Assuming _parameters is accessible; adjust as needed

        # Update the dictionary with the new sampled values
        for i, param_name in enumerate(param_names):
            sampled_value = set_row[i]
            new_param_dict[param_name] = Parameter(name=param_name,
                                                  val=sampled_value,
                                                  l_lim=sampled_params[param_name].l_lim,
                                                  u_lim=sampled_params[param_name].u_lim,
                                                  dist=sampled_params[param_name].dist,
                                                  sample=sampled_params[param_name].sample,
                                                  space=sampled_params[param_name].space)

        # Create a new Parameters instance with the updated dictionary
        new_parameters = Parameters(**new_param_dict)
        parameter_sets.append(new_parameters)
        
    if include_initial:
        parameter_sets.append(parameters)
    return parameter_sets

def display_legend(color_dict):
    # Creating a figure just for the legend
    fig, ax = plt.subplots(figsize=(5, len(color_dict) * 0.3))  # Adjust size as needed
    ax.axis('off')  # Turn off axis
    patches = [plt.Line2D([0], [0], color=color, marker='o', linestyle='', markersize=10, label=f'{volunteer}') for volunteer, color in color_dict.items()]
    legend = ax.legend(handles=patches, loc='center', frameon=False)
    plt.show()
    
def unique_colors(ids):
    num_colors = len(ids)
    colors = sns.color_palette("husl", num_colors) 
    return dict(zip(ids, colors))

In [16]:
#best_ids = [103, 107, 108, 111, 112, 207, 301, 302, 307, 308, 311, 312]
best_ids = [111, 112, 207, 302, 308, 312]
best_Rahil = df_Rahil[df_Rahil['VOLUNTEER'].isin(best_ids)]
best_color_dict = unique_colors(best_ids)
t_span = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
t_fill = np.linspace(t_span[0], t_span[-1], 50)
t_span = np.unique(np.concatenate([t_span, t_fill])) 

def MonoModel(t: float, y: np.ndarray, p: Parameters, s: States, time_manager: TimeManager, timeout: int = None) -> np.ndarray:
    if timeout is not None and time_manager.check_timeout(timeout): 
        raise TimeoutError('ODE Solver timeout.')

    y[:] = y.clip(min=0.0)
    for label, value in zip(s.state_labels, y):
        s.states[label].update_value(t, value)

    s.calculate_delayed_states(t, [('tau_E', 'E', 'I2'), 
                                ('tau_EM', 'EM', 'E')
                                ], p)
    I2_tauE = s.get_delayed_state('I2_E')
    E_tauEM = s.get_delayed_state('E_EM') 

    return np.array([
        -p.beta.val * s.T * s.V,  # T
        p.beta.val * s.T * s.V - p.k.val * s.I1,  # I1
        p.k.val * s.I1 - p.delta.val * s.I2, #- p.delta_E.val * s.E * s.I2 / (p.Kd.val + s.I2),  # I2
        p.p.val * s.I2 - p.c.val * s.V,  # V
        p.xi.val * s.I2 / (p.K_Ef.val + s.E) + p.eta.val * s.E * I2_tauE - p.d_E.val * s.E,  # E
        p.zet.val * E_tauEM # EM
    ])
    
states_config = [
    {'label': 'T', 'initial_value': 4E8},
    {'label': 'I1', 'initial_value': 75.0},
    {'label': 'I2', 'initial_value': 0.0},
    {'label': 'V', 'initial_value': 0.0},
    {'label': 'E', 'initial_value': 0.0},
    {'label': 'EM', 'initial_value': 0.0}
]

parameters_config = Parameters(
    beta=Parameter(name='beta', val=8.14E-6, l_lim=3E-6, u_lim=4.8E-4, mode='file'),
    k=Parameter(name='k', val=4.0, l_lim=3.5, u_lim=6.0, mode='fixed'),
    p=Parameter(name='p', val=0.78, l_lim=0.1, u_lim=3, mode='file'),
    c=Parameter(name='c', val=4.5, l_lim=1, u_lim=16, mode='file'),
    delta=Parameter(name='delta', val=11.71, l_lim=2.4, u_lim=22, mode='file'),
    delta_E=Parameter(name='delta_E', val=40, l_lim=0.1, u_lim=3.0, mode='fixed'),
    Kd=Parameter(name='Kd', val=434, l_lim=1E2, u_lim=1E5, mode='fixed'),
    xi=Parameter(name='xi', val=5E4, l_lim=1E2, u_lim=1E5, mode='fit'),
    K_Ef=Parameter(name='K_Ef', val=1E5, l_lim=1E3, u_lim=1E6, mode='fixed'),
    eta=Parameter(name='eta', val=2E-7, l_lim=2E-7, u_lim=3E-7, mode='fit'),
    tau_E=Parameter(name='tau_E', val=2.0, l_lim=3.0, u_lim=4.0, mode='fixed'),
    d_E=Parameter(name='d_E', val=1.0, l_lim=0.5, u_lim=2.0, mode='fixed'),
    zet=Parameter(name='zet', val=0.22, l_lim=0.01, u_lim=0.5, mode='fixed'),
    tau_EM=Parameter(name='tau_EM', val=3.5, l_lim=3.0, u_lim=4.0, mode='fixed')
)

parameters_file_path = "../NLME/Baccam_individual_parameters.xlsx"
parameters_list = parameters_config.load_parameters_from_file(parameters_file_path, usecols="A:F") 
            
if __name__ == '__main__':
   results = execute(MonoModel, parameters_list, states_config, t_span)

In [20]:
def compute_means(data, config, scale_factor=1.0):
    if scale_factor is None:
        scale_factor = 1.0
    data_copy = data.copy()
    data_copy.loc[:, 'scaled_data'] = data_copy[config['data_key']] * scale_factor
    data_copy.loc[:, 'log_data'] = np.log10(data_copy['scaled_data'].replace(0, np.nan))  # Avoid log10(0)
    grouped_volunteer_data = data_copy.groupby('DAY')['log_data'].agg(['mean', 'std']).reset_index()
    grouped_volunteer_data['mean_exp'] = 10 ** grouped_volunteer_data['mean']
    grouped_volunteer_data['std_exp_lower'] = 10 ** (grouped_volunteer_data['mean'] - grouped_volunteer_data['std'])
    grouped_volunteer_data['std_exp_upper'] = 10 ** (grouped_volunteer_data['mean'] + grouped_volunteer_data['std'])
    return grouped_volunteer_data

def compute_ode_means(ode_results, config, states_config, t_span, volunteer_offsets, ids):
    sol_list = ode_results["sol_list"]
    states = States(states_config)
    sol_data = []

    interpolated_times = np.unique(np.concatenate((np.linspace(t_span[0], t_span[-1], num=200), t_span[1:-2])))

    for index, (sol_t, sol_y) in enumerate(sol_list):
        if index<len(ids):
            key_index = states.state_labels.index(config['sol_key'])
            sol_values = sol_y[key_index, :]

            volunteer_id = ids[index]
            if config['sol_key'] != 'V':
                offset = volunteer_offsets[volunteer_id][config['sol_key']]
                sol_values += offset

            interpolated_values = np.interp(interpolated_times, sol_t, sol_values)
            interpolated_log_values = np.log10(np.maximum(interpolated_values, 0.01))
            sol_data.append(interpolated_log_values)  # Store log-transformed values

    sol_data_array = np.stack(sol_data)
    mean_sol_log = np.mean(sol_data_array, axis=0)
    std_sol_log = np.std(sol_data_array, axis=0)

    mean_sol_nat = 10**mean_sol_log
    std_plus_std_nat = 10**(mean_sol_log + std_sol_log)
    std_minus_std_nat = 10**(mean_sol_log - std_sol_log)

    return interpolated_times, mean_sol_nat, std_minus_std_nat, std_plus_std_nat

def extract_offsets(data, subplot_config, ids):
    offsets = {volunteer: {} for volunteer in ids}
    for config in subplot_config:
        if config['sol_key'] != 'V':  # Skip virus
            for volunteer in ids:
                volunteer_data = data[data['VOLUNTEER'] == volunteer]
                scale_factor = config.get('data_scale', 1.0)
                y_limits = config.get('ylims', [0, np.inf])
                scaled_data = scale_factor * volunteer_data[config['data_key']]
                
                valid_values = scaled_data[(scaled_data > y_limits[0]) & (scaled_data < y_limits[1])]
                if not valid_values.empty:
                    first_valid_value = valid_values.iloc[0]
                    offsets[volunteer][config['sol_key']] = first_valid_value
    return offsets

def preprocess_data(data, subplot_config, ids):
    volunteer_offsets = extract_offsets(data, subplot_config, ids)
    return volunteer_offsets

def preprocess_ode_results(ode_results, subplot_config, t_span, volunteer_offsets):
    interpolated_results = {}
    for idx, (volunteer_id, ode_result) in enumerate(zip(best_ids, ode_results['sol_list'])):
        interp_data = {}
        for config in subplot_config:
            state_label = config['sol_key']
            interp_values = np.interp(t_span, ode_result.t, ode_result.y[states_config.index(state_label)])
            interp_data[state_label] = interp_values + volunteer_offsets[volunteer_id].get(state_label, 0)
        interpolated_results[volunteer_id] = interp_data
    return interpolated_results

def compute_sse(volunteer_data, ode_times, ode_values, y_limits):
    sse = 0
    count = 0  # To keep track of valid points for SSE calculation
    
    for time_point, data_point in zip(volunteer_data['DAY'], volunteer_data['scaled_data']):
        if time_point == 0:
            continue  # Skip initial condition points
        if y_limits[0] <= data_point <= y_limits[1]:
            ode_value = np.interp(time_point, ode_times, ode_values)
            
            if data_point <= 0:
                data_point = np.nan
            
            if ode_value <= 0:
                ode_value = np.nan
            
            if not np.isnan(data_point) and not np.isnan(ode_value):
                sse += (np.log10(data_point) - np.log10(ode_value)) ** 2
                count += 1
    
    if count == 0:
        return np.inf, 0  # No valid data/ODE
    else:
        return sse, count

def compute_all_sses(data, subplot_config, interpolated_results):
    ids = data['VOLUNTEER'].unique()
    sse_results = {volunteer: {} for volunteer in ids}
    
    for volunteer in ids:
        for config in subplot_config:
            scale_factor = config.get('data_scale', 1.0) if config['data_key'] else 1.0
            if scale_factor is None:
                scale_factor = 1.0
            if config['data_key']:
                volunteer_data = data[data['VOLUNTEER'] == volunteer].copy()
                volunteer_data['scaled_data'] = volunteer_data[config['data_key']] * scale_factor  # Apply scaling
                
                if config['sol_key']:
                    interpolated_values_for_sse = interpolated_results[volunteer][config['sol_key']]  # mean_sol_nat
                    y_limits = [1, np.inf] if 'ylims' not in config else config['ylims']
                    sse, count = compute_sse(volunteer_data, t_span, interpolated_values_for_sse, y_limits)
                    sse_results[volunteer][config['title']] = {'sse': round(sse, 2), 'count': count}
    
    return sse_results

def format_sse_results(sse_results):
    total_sse_by_id = {}
    total_sse_by_state = {}
    grand_total_sse = 0
    total_data_points = 0
    
    for volunteer, states in sse_results.items():
        total_sse_by_id[volunteer] = round(sum(state['sse'] for state in states.values()), 2)
        for state, result in states.items():
            if state not in total_sse_by_state:
                total_sse_by_state[state] = {'sse': 0, 'count': 0}
            total_sse_by_state[state]['sse'] += result['sse']
            total_sse_by_state[state]['count'] += result['count']
            grand_total_sse += result['sse']
            total_data_points += result['count']
    
    total_sse_by_state = {k: {'sse': round(v['sse'], 2), 'count': v['count']} for k, v in total_sse_by_state.items()}
    grand_total_sse = round(grand_total_sse, 2)
    
    return total_sse_by_id, total_sse_by_state, grand_total_sse, total_data_points, sse_results

subplot_configuration = [
    {'title': 'Virus', 'data_key': 'V', 'sol_key': 'V', 'data_scale': None, 'ylims': [1,1E9]},
    {'title': 'CD8 T Effectors', 'data_key': 'E', 'sol_key': 'E', 'data_scale': 5.6E7, 'ylims': [1E5, 1E7]}
]

volunteer_offsets = preprocess_data(best_Rahil, subplot_configuration, best_ids)
interpolated_results = preprocess_ode_results(results, subplot_configuration, t_span, volunteer_offsets)
sse_results = compute_all_sses(best_Rahil, subplot_configuration, interpolated_results)
total_sse_by_id, total_sse_by_state, grand_total_sse, total_data_points, formatted_sse_results = format_sse_results(sse_results)

# Display Results
sse_by_id_state_df = pd.DataFrame.from_dict({(i,j): formatted_sse_results[i][j] 
                            for i in formatted_sse_results.keys() 
                            for j in formatted_sse_results[i].keys()},
                        orient='index')

sse_by_id_df = pd.DataFrame.from_dict(total_sse_by_id, orient='index', columns=['Total SSE'])
sse_by_state_df = pd.DataFrame.from_dict(total_sse_by_state, orient='index').rename(columns={'sse': 'Total SSE', 'count': 'Data Points'})

print("\nSSE Results by State:")
print(sse_by_state_df.to_string())


ValueError: 'V' is not in list

In [None]:
def plot(data, ode_results, subplot_config, color_dict, data_means, ode_means, interpolated_times, means=True, individual_plots=False):   
    ids = data['VOLUNTEER'].unique()
    t_span = [0, 8]
    x_ticks = np.arange(t_span[0], t_span[1], 1)
    n_subplots = len(subplot_config)
    n_rows, n_cols = int(np.ceil(np.sqrt(n_subplots))), int(np.ceil(n_subplots / np.sqrt(n_subplots)))
    base_font_size = 12

    if individual_plots:
        for volunteer in ids:
            fig, axs = plt.subplots(n_rows, n_cols, figsize=(9, 9), squeeze=False)
            axs = axs.flatten()
            
            for i, config in enumerate(subplot_config):
                ax = axs[i]
                ax.set_title(config['title'], fontsize=base_font_size + 2)
                scale_factor = config.get('data_scale', 1.0) if config['data_key'] else 1.0
                if scale_factor is None: scale_factor = 1.0

                if means and config['data_key']:
                    mean_data = data_means[config['title']]
                    ax.errorbar(mean_data['DAY'], mean_data['mean_exp'], 
                                yerr=[mean_data['mean_exp'] - mean_data['std_exp_lower'],
                                      mean_data['std_exp_upper'] - mean_data['mean_exp']], 
                                fmt='o', color='black', capsize=3, label='Mean ± STD', alpha=0.4)

                if config['data_key']:
                    volunteer_data = data[data['VOLUNTEER'] == volunteer].copy()
                    volunteer_data['scaled_data'] = volunteer_data[config['data_key']] * scale_factor  # Apply scaling
                    ax.scatter(volunteer_data['DAY'], volunteer_data['scaled_data'], color=color_dict[volunteer], alpha=1.0, zorder=5)

                if means and config['sol_key']:
                    interpolated_times, mean_sol_nat, std_minus_std_nat, std_plus_std_nat = ode_means[config['title']]
                    ax.plot(interpolated_times, mean_sol_nat, 'k-', lw=1, alpha=0.5, zorder=5)  # Mean solution
                    ax.fill_between(interpolated_times, std_minus_std_nat, std_plus_std_nat, color='gray', alpha=0.2)  # Standard deviation

                if config['sol_key']:
                    sol_list = ode_results["sol_list"]
                    states = States(states_config)
                    
                    for index, (sol_t, sol_y) in enumerate(sol_list):
                        if index<len(ids):
                            if ids[index] == volunteer:
                                key_index = states.state_labels.index(config['sol_key'])
                                sol_values = sol_y[key_index, :]
                                color = color_dict[volunteer]
                                interpolated_values = np.interp(interpolated_times, sol_t, sol_values)
                                
                                ax.plot(interpolated_times, interpolated_values, color=color, alpha=0.5)

                                if 'interpolated_values' not in globals():
                                    globals()['interpolated_values'] = {}
                                if volunteer not in globals()['interpolated_values']:
                                    globals()['interpolated_values'][volunteer] = {}
                                globals()['interpolated_values'][volunteer][config['sol_key']] = interpolated_values

                y_limits = ax.get_ylim()
                if config['data_key']:
                    interpolated_values_for_sse = globals()['interpolated_values'][volunteer].get(config['sol_key'], np.zeros_like(interpolated_times))
                    sse, count = compute_sse(volunteer_data, interpolated_times, interpolated_values_for_sse, y_limits)
                    if count > 0:
                        legend_label = f"{volunteer}: SSE={sse:.2f} (n={count})"
                        ax.scatter(volunteer_data['DAY'], volunteer_data['scaled_data'], color=color_dict[volunteer], label=legend_label, alpha=1.0, zorder=5)

                if config['ylims']:
                    ax.set_ylim(config['ylims'])
                ax.set_yscale('log')
                y_low, y_high = ax.get_ylim()
                if y_low < 1:
                    ax.set_ylim(bottom=1)
                ax.set_xticks(x_ticks)
                ax.set_xlim(t_span)
                
                ax.margins(x=0.05)
                ax.set_xlabel('Days Post Infection')
                ax.set_ylabel('Level')
                ax.legend()

            for j in range(i + 1, len(axs)):
                axs[j].set_visible(False)
            plt.tight_layout()
            plt.show()
    else:
        fig, axs = plt.subplots(n_rows, n_cols, figsize=(9, 9), squeeze=False)
        axs = axs.flatten()

        for i, config in enumerate(subplot_config):
            ax = axs[i]
            ax.set_title(config['title'], fontsize=base_font_size + 2)

            if config['data_key'] and config['data_key'] in data.columns:
                scale_factor = config.get('data_scale', 1.0) if config['data_key'] else 1.0
                if scale_factor is None: scale_factor = 1.0
                for volunteer, color in color_dict.items():
                    volunteer_data = data[data['VOLUNTEER'] == volunteer].copy()
                    volunteer_data['scaled_data'] = volunteer_data[config['data_key']] * scale_factor  # Apply scaling
                    ax.scatter(volunteer_data['DAY'], volunteer_data['scaled_data'], color=color, alpha=1.0, zorder=5)

                if means:
                    mean_data = data_means[config['title']]
                    ax.errorbar(mean_data['DAY'], mean_data['mean_exp'], 
                                yerr=[mean_data['mean_exp'] - mean_data['std_exp_lower'],
                                      mean_data['std_exp_upper'] - mean_data['mean_exp']], 
                                fmt='o', color='black', capsize=5, label='Mean ± STD', alpha=0.4)

            if config['sol_key'] and ode_results:
                sol_list = ode_results["sol_list"]
                states = States(states_config)
                
                for index, (sol_t, sol_y) in enumerate(sol_list):
                    key_index = states.state_labels.index(config['sol_key'])
                    sol_values = sol_y[key_index, :]  
                    ode_id = ids[index]
                    color = color_dict[ode_id]
                    
                    interpolated_values = np.interp(interpolated_times, sol_t, sol_values)
                    
                    ax.plot(interpolated_times, interpolated_values, color=color, alpha=0.5)

                    if 'interpolated_values' not in globals():
                        globals()['interpolated_values'] = {}
                    if ode_id not in globals()['interpolated_values']:
                        globals()['interpolated_values'][ode_id] = {}
                    globals()['interpolated_values'][ode_id][config['sol_key']] = interpolated_values

                if means:
                    interpolated_times, mean_sol_nat, std_minus_std_nat, std_plus_std_nat = ode_means[config['title']]
                    ax.plot(interpolated_times, mean_sol_nat, 'k-', lw=1, alpha=0.5)  # Mean solution
                    ax.fill_between(interpolated_times, std_minus_std_nat, std_plus_std_nat, color='gray', alpha=0.25)  # Standard deviation

            y_limits = ax.get_ylim()
            if config['data_key']:
                for volunteer, color in color_dict.items():
                    volunteer_data = data[data['VOLUNTEER'] == volunteer].copy()
                    volunteer_data['scaled_data'] = volunteer_data[config['data_key']] * scale_factor  # Apply scaling
                    interpolated_values_for_sse = globals()['interpolated_values'][volunteer].get(config['sol_key'], np.zeros_like(interpolated_times))
                    sse, count = compute_sse(volunteer_data, interpolated_times, interpolated_values_for_sse, y_limits)
                    if count > 0:
                        legend_label = f"{volunteer}: SSE={sse:.2f} (n={count})"
                        ax.scatter(volunteer_data['DAY'], volunteer_data['scaled_data'], color=color, label=legend_label, alpha=1.0, zorder=5)

            if config['ylims']:
                ax.set_ylim(config['ylims']) 
            ax.set_yscale('log')
            y_low, y_high = ax.get_ylim()
            if y_low < 1:
                ax.set_ylim(bottom=1)
            ax.set_xticks(x_ticks)
            ax.set_xlim(t_span)
            
            ax.margins(x=0.05)
            ax.set_xlabel('Days Post Infection')
            ax.set_ylabel('Level')

        for j in range(i + 1, len(axs)):
            axs[j].set_visible(False)
        plt.tight_layout()
        plt.show()  
        
plot(best_Rahil, ode_results, subplot_configuration, best_color_dict, data_means, ode_means, interpolated_times, means=True, individual_plots=True)

In [None]:
# Monolith
def MonoModel(t: float, y: np.ndarray, p: Parameters, s: States, time_manager: TimeManager, timeout: int = None) -> np.ndarray:
    if timeout is not None and time_manager.check_timeout(timeout): 
        raise TimeoutError('ODE Solver timeout.')

    y[:] = y.clip(min=0.0)
    for label, value in zip(s.state_labels, y):
        s.states[label].update_value(t, value)

    s.calculate_delayed_states(t, [('tau_E', 'E', 'I2'), 
                                ('tau_EM', 'EM', 'E'),
                                #('tau_A', 'A', 'I2'),
                                #('tau_B', 'B', 'I1'),
                                #('tau_CA', 'CA', 'I2'),
                                #('tau_CB', 'CB', 'I1')
                                ], p)
    I2_tauE = s.get_delayed_state('I2_E')
    #I2_tauA = s.get_delayed_state('I2_A')
    #I1_tauB = s.get_delayed_state('I1_B')
    #I2_tauCA = s.get_delayed_state('I2_CA')
    #I1_tauCB = s.get_delayed_state('I1_CB')
    E_tauEM = s.get_delayed_state('E_EM') 

    return np.array([
        -p.beta.val * s.T * s.V,  # T
        p.beta.val * s.T * s.V - p.k.val * s.I1,  # I1
        p.k.val * s.I1 - p.delta.val * s.I2, #- p.delta_E.val * s.E * s.I2 / (p.Kd.val + s.I2),  # I2
        p.p.val * s.I2 - p.c.val * s.V,  # V
        p.xi.val * s.I2 / (p.K_Ef.val + s.E) + p.eta.val * s.E * I2_tauE - p.d_E.val * s.E,  # E
        p.zet.val * E_tauEM, # EM
        #p.alpha.val * I2_tauA + p.gamma.val * s.CA - p.dA.val * s.A, # A
        #p.iota.val * I1_tauB + p.kappa.val * s.CB - p.dB.val * s.B, # B
        #p.theta.val * I2_tauCA - p.dCA.val * s.CA, # CA
        #p.lamda.val * I1_tauCB - p.dCB.val * s.CB # CB
        
    ])
    
parameters = Parameters(
    beta=Parameter(name='beta', val=8.14E-6, l_lim=3E-6, u_lim=4.8E-4, mode='file'),
    k=Parameter(name='k', val=4.0, l_lim=3.5, u_lim=6.0, mode='file'),
    p=Parameter(name='p', val=0.78, l_lim=0.1, u_lim=3, mode='file'),
    c=Parameter(name='c', val=4.5, l_lim=1, u_lim=16, mode='file'),
    delta=Parameter(name='delta', val=11.71, l_lim=2.4, u_lim=22, mode='file'),
    delta_E=Parameter(name='delta_E', val=40, l_lim=0.1, u_lim=3.0, sample=False),
    Kd=Parameter(name='Kd', val=434, l_lim=1E2, u_lim=1E5, sample=False),
    xi=Parameter(name='xi', val=5E4, l_lim=1E2, u_lim=1E5, sample=False),
    K_Ef=Parameter(name='K_Ef', val=1E5, l_lim=1E3, u_lim=1E6, sample=False),
    eta=Parameter(name='eta', val=2E-7, l_lim=2E-7, u_lim=3E-7, sample=False),
    tau_E=Parameter(name='tau_E', val=2.0, l_lim=3.0, u_lim=4.0, sample=False),
    d_E=Parameter(name='d_E', val=1.0, l_lim=0.5, u_lim=2.0, sample=False),
    zet=Parameter(name='zet', val=0.22, l_lim=0.01, u_lim=0.5, sample=False),
    tau_EM=Parameter(name='tau_EM', val=3.5, l_lim=3.0, u_lim=4.0, sample=False),
    #alpha=Parameter(name='alpha', val=2.75E-3, l_lim=1E-4, u_lim=1E-2, sample=False),
    #gamma=Parameter(name='gamma', val=7.5, l_lim=1E-2, u_lim=1E1, sample=False),
    #tau_A=Parameter(name='tau_A', val=0.75, l_lim=3.0, u_lim=4.0, sample=False),
    #dA=Parameter(name='dA', val=75.0, l_lim=40.0, u_lim=100.0, sample=False),
    #iota=Parameter(name='iota', val=2.5E-5, l_lim=0.01, u_lim=0.5, sample=False),
    #kappa=Parameter(name='kappa', val=0.1, l_lim=0.01, u_lim=0.5, sample=False),
    #tau_B=Parameter(name='tau_B', val=0.85, l_lim=3.0, u_lim=5.0, sample=False),
    #dB=Parameter(name='dB', val=1.0, l_lim=0.5, u_lim=5.0, sample=False),
    #theta=Parameter(name='theta', val=1E-3, l_lim=1E-4, u_lim=1E-2, sample=False),
    #dCA=Parameter(name='dCA', val=10.0, l_lim=0.5, u_lim=2.0, sample=False),
    #lamda=Parameter(name='lamda', val=1E-3, l_lim=1E-4, u_lim=1E-2, sample=False),
    #tau_CA=Parameter(name='tau_CA', val=6.0, l_lim=3.0, u_lim=5.0, sample=False),
    #dCB=Parameter(name='dCB', val=9.4, l_lim=0.5, u_lim=2.0, sample=False),
    #tau_CB=Parameter(name='tau_CB', val=2.5, l_lim=3.0, u_lim=5.0, sample=False)
)