In [1]:
import time
import matplotlib.pyplot as plt
import seaborn as sns 
import pandas as pd
from scipy.integrate import solve_ivp
import numpy as np

from scipy.stats import norm
import warnings


df_data = pd.read_csv('../NLME/Rahil.csv')
df = pd.read_excel('../Data/Murine Data.xlsx', sheet_name='Dynamics')

In [13]:
class TimeManager:
    def __init__(self):
        self.start_time = time.time()

    def check_timeout(self, timeout: int) -> bool:
        return (time.time() - self.start_time) > timeout

    def reset_start_time(self):
        self.start_time = time.time()
    
    def get_elapsed_time(self) -> float:
        return time.time() - self.start_time

class Parameter:
    def __init__(self, name, val, min_bound, max_bound, distribution='uniform', sample=False, space='real'):
        self.name = name
        self.val = val
        self.min_bound = min_bound
        self.max_bound = max_bound
        self.distribution = distribution
        self.sample = sample
        self.space = space

    def __repr__(self):
        return f"Parameter(name={self.name}, val={self.val}, bounds=({self.min_bound}, {self.max_bound}), distribution={self.distribution}, sample={self.sample}, space={self.space})"

class Parameters:
    def __init__(self, **kwargs):
        self._parameters = kwargs

    def __getattr__(self, item):
        if item in self._parameters:
            return self._parameters[item]
        raise AttributeError(f"'Parameters' object has no attribute '{item}'")  # Base case to prevent recursion

    def __setattr__(self, key, value):
        if key == '_parameters':
            super().__setattr__(key, value)
        else:
            self._parameters[key] = value

    def __getstate__(self):
        # Return the state to be pickled
        return self._parameters

    def __setstate__(self, state):
        # Restore state from the unpickled state
        self._parameters = state

    def get_sampled_parameters(self):
        return {name: param for name, param in self._parameters.items() if param.sample}

    def items(self):
        return self._parameters.items()

    def __repr__(self):
        return f"Parameters({', '.join([f'{k}={v}' for k, v in self._parameters.items()])})"
    
class State:
    def __init__(self, label, initial_value=0.0):
        self.label = label
        self.initial_value = initial_value
        self.time_points = np.array([0.0])
        self.values = np.array([initial_value])

    def update_value(self, t, new_value):
        # Efficiently append values by storing them temporarily and updating in bulk, if necessary
        self.time_points = np.append(self.time_points, t)
        self.values = np.append(self.values, new_value)

    def get_latest_value(self):
        return self.values[-1]

    def get_value_at(self, t_delay):
        # Vectorized approach for interpolation
        return np.interp(t_delay, self.time_points, self.values, left=self.initial_value, right=self.values[-1])

    def reset(self):
        self.time_points = np.array([0.0])
        self.values = np.array([self.initial_value])

class States:
    def __init__(self, states_config):
        self.states = {config['label']: State(**config) for config in states_config}
        self.tau = {}
        self.state_labels = [state.label for state in self.states.values()]

    def get_delayed_state(self, state_label):
        return self.tau.get(state_label, 0)

    def get_current_values_as_array(self):
        return np.array([self.states[label].get_latest_value() for label in self.state_labels])

    def __getattr__(self, name):
        if name in self.states:
            return self.states[name].get_latest_value()
        elif name in self.tau:
            return self.tau[name]
        else:
            raise AttributeError(f"'States' object has no attribute '{name}'")

    def update_states(self, t, new_values):
        # Vectorized state update
        for label, value in zip(self.state_labels, new_values):
            self.states[label].update_value(t, value)

    def calculate_delayed_states(self, t, delays, p):
        # Batch process delayed state calculations
        for delay_info in delays:
            tau, dependent_state, affecting_state = delay_info
            delay_time = getattr(p, tau).val
            delayed_value = self.states[affecting_state].get_value_at(t - delay_time)
            delayed_state_label = f"{affecting_state}_{dependent_state}"
            self.tau[delayed_state_label] = delayed_value

    def reset_delayed_states(self):
        self.tau.clear()

In [15]:
def SmartSolve(task):
    funx, param_set, states_config, t_span = task
    states = States(states_config)
    y_initial = states.get_current_values_as_array()
    solvers_with_timeouts = [('RK23', 1),  ('BDF', 3)]
    #('RK45', 3), ('LSODA', 5), ('Radau', 5), ('DOP853', 5)

    time_manager = TimeManager()
    # Suppress warnings 
    warnings.filterwarnings("ignore", message="The following arguments have no effect for a chosen solver:*")
    #warnings.filterwarnings("ignore", message="RuntimeWarning: overflow encountered in cast*")

    def overflow_event(t, y):
        return 1E9 - max(abs(yi) for yi in y)  # overflow detection

    if np.shape(t_span)[0] > 2:
        t_eval = t_span[1:-1]  # Assuming this should include all but the first and last for evaluation
        t_span = [t_span[0], t_span[-1]]
    else:
        t_eval = None

    for method, timeout in solvers_with_timeouts:
        overflow_event.terminal = True
        overflow_event.direction = -1

        try:
            sol = solve_ivp(
                fun=lambda t, y: funx(t, y, param_set, states, time_manager, timeout=timeout),
                t_span=t_span, y0=y_initial, method=method, t_eval=t_eval, dense_output=False,
                events=overflow_event, vectorized=True, args=None, rtol=1e-3, atol=1e-4, jac=None,
                jac_sparsity=None, max_step=0.2, min_step=1e-2, first_step=None
            )
            if sol.success:
                elapsed_time = time_manager.get_elapsed_time()
                # print(f"solve_ivp Execution Time: {elapsed_time:.4f} seconds")
                return sol, elapsed_time
            else:
                # Handling based on sol.status for specific failure modes
                if sol.status == -1:
                    print(f"Integration step failed. Method: {method}")
                elif sol.t_events[0].size > 0:  # Handling overflow_event
                    print(f"Overflow event triggered. Method: {method}")
                else:
                    print(f"Unknown failure mode. Status: {sol.status}, Method: {method}")

                # For both known and unknown failure modes, try the next method
                if method != solvers_with_timeouts[-1][0]:
                    continue
                else:
                    return None, time_manager.get_elapsed_time()
        except TimeoutError:
            if method != solvers_with_timeouts[-1][0]:
                continue
            else:
                raise TimeoutError("Timeout occurred for all methods.")
        except Exception as e:
            print(f"Exception occurred in SmartSolve: {e}")
            continue
                
def generate_parameter_sets(n_sets, parameters, seed=None):
    if seed is not None: np.random.seed(seed)
    sampled_params = parameters.get_sampled_parameters()
    
    num_vars = len(sampled_params)
    param_names = list(sampled_params.keys())
    samples = np.zeros((n_sets, num_vars))
    
    # Generate samples for each parameter
    for i, param_name in enumerate(param_names):
        param = sampled_params[param_name]
        if param.space == 'log10':
            low, high = np.log10(param.min_bound), np.log10(param.max_bound)
        else:
            low, high = param.min_bound, param.max_bound

        if param.distribution == 'uniform':
            samples[:, i] = np.random.uniform(low, high, n_sets)
        elif param.distribution == 'normal':
            mean, stddev = (low + high) / 2, (high - low) / 6
            samples[:, i] = norm.rvs(loc=mean, scale=stddev, size=n_sets, random_state=seed)

        if param.space == 'log10':
            samples[:, i] = 10 ** samples[:, i]
            
    parameter_sets = []
    for set_row in samples:
        # Create a dictionary for the new parameter values
        new_param_dict = parameters._parameters.copy()  # Assuming _parameters is accessible; adjust as needed

        # Update the dictionary with the new sampled values
        for i, param_name in enumerate(param_names):
            sampled_value = set_row[i]
            new_param_dict[param_name] = Parameter(name=param_name,
                                                  val=sampled_value,
                                                  min_bound=sampled_params[param_name].min_bound,
                                                  max_bound=sampled_params[param_name].max_bound,
                                                  distribution=sampled_params[param_name].distribution,
                                                  sample=sampled_params[param_name].sample,
                                                  space=sampled_params[param_name].space)

        # Create a new Parameters instance with the updated dictionary
        new_parameters = Parameters(**new_param_dict)
        parameter_sets.append(new_parameters)

    return parameter_sets

In [25]:
def DDModel(t: float, y: np.ndarray, p: Parameters, s: States, time_manager: TimeManager, timeout: int = None) -> np.ndarray:
    if timeout is not None and time_manager.check_timeout(timeout): 
        raise TimeoutError('ODE Solver timeout.')

    # Clip unreal and negative values from all states
    y[:] = y.clip(min=0.0)
    for label, value in zip(s.state_labels, y):
        s.states[label].update_value(t, value)

    # Calculate delayed states 
    s.calculate_delayed_states(t, [('tau_E', 'E', 'I2')], p)
    I2_tauE = s.get_delayed_state('I2_E')

    return np.array([
        -p.beta.val * s.T * s.V,  # T
        p.beta.val * s.T * s.V - p.k.val * s.I1,  # I1
        p.k.val * s.I1 - p.delta.val * s.I2 - p.delta_E.val * s.E * s.I2 / (p.Kd.val + s.I2),  # I2
        p.p.val * s.I2 - p.c.val * s.V,  # V
        p.xi.val * s.I2 / (p.K_Ef.val + s.E) + p.eta.val * s.E * I2_tauE - p.d_E.val * s.E,  # E
    ])
    
states_config = [
    {'label': 'T', 'initial_value': 1E7},
    {'label': 'I1', 'initial_value': 75.0},
    {'label': 'I2', 'initial_value': 0.0},
    {'label': 'V', 'initial_value': 0.0},
    {'label': 'E', 'initial_value': 0.0}
]

parameters = Parameters(
    beta=Parameter(name='beta', val=6.2E-5, min_bound=1E-6, max_bound=1E-4, distribution='uniform', sample=True, space='log10'),
    k=Parameter(name='k', val=4.0, min_bound=3.5, max_bound=6.0, distribution='uniform', sample=True, space='nat'),
    p=Parameter(name='p', val=1.0, min_bound=0.25, max_bound=1E2, distribution='uniform', sample=True, space='log10'),
    c=Parameter(name='c', val=9.4, min_bound=2.5, max_bound=1E2, distribution='uniform', sample=True, space='log10'),
    delta=Parameter(name='delta', val=0.24, min_bound=0.05, max_bound=0.5, distribution='uniform', sample=True, space='nat'),
    delta_E=Parameter(name='delta_E', val=1.9, min_bound=0.1, max_bound=2.0, distribution='uniform', sample=True, space='nat'),
    Kd=Parameter(name='Kd', val=4.34E2, min_bound=1E2, max_bound=1E5, distribution='uniform', sample=True, space='log10'),
    xi=Parameter(name='xi', val=2.6E4, min_bound=1E2, max_bound=1E5, distribution='uniform', sample=True, space='log10'),
    K_Ef=Parameter(name='K_Ef', val=8.1E5, min_bound=1E3, max_bound=1E6, distribution='uniform', sample=True, space='log10'),
    eta=Parameter(name='eta', val=2.5E-7, min_bound=1E-7, max_bound=5E-7, distribution='uniform', sample=True, space='log10'),
    tau_E=Parameter(name='tau_E', val=3.6, min_bound=2.5, max_bound=4.0, distribution='uniform', sample=True, space='nat'),
    d_E=Parameter(name='d_E', val=1.0, min_bound=0.5, max_bound=2.0, distribution='uniform', sample=True, space='nat'),
)

parameter_sets = generate_parameter_sets(10, parameters, seed=42)
t_span = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
t_fill = np.linspace(t_span[0], t_span[-1], 50)
t_span = np.unique(np.concatenate([t_span, t_fill])) 
sols, integration_times = [], []
for param_set in parameter_sets:  
    task = (DDModel, param_set, states_config, t_span)
    try:
        sol, integration_time = SmartSolve(task)
        sols.append(sol)
        integration_times.append(integration_time)
    except:
        print('Integration failed at', param_set)
print(np.mean(integration_times))

0.06729986667633056
