In [2]:
import os
os.environ["JULIA_NUM_THREADS"] = "1"
from julia.api import Julia
julia = Julia(sysimage="sysimage.so")
from julia import Main
Main.include("memory_model.jl")
import matplotlib.pyplot as plt
import numpy as np
import re
import matplotlib as mpl

def format_fit_params(fit_parameters):
    if isinstance(fit_parameters, list) and len(fit_parameters) > 1:
        fit_params_str = " ".join(fit_parameters)
    else:
        fit_params_str = fit_parameters[0] if isinstance(fit_parameters, list) else fit_parameters
    
    # Define the replacements
    replacements = {
        r"\bd_E\b": "dE",  # d_E -> d subscript E
        r"\bdelta\b": "δ",  # delta -> δ
        r"\bdelta_E\b": "δE",  # delta_E -> δ subscript E
        r"\bK_delta_E\b": "K(δE)",  # K_delta_E -> K subscript δ subscript E
        r"\btau\b": "τ",  # tau -> τ
        r"\bxi\b": "ξ",  # xi -> ξ
        r"\bbeta\b": "β",  # beta -> β
        r"\bdelta_E/K_delta_E\b": "Log10 [δE/KδE]",  # delta_E/K_delta_E -> δ subscript E/K subscript δ subscript E
        r"\bdelta/delta_E\b": "δ/δE",  # delta/delta_E -> δ divided by δ subscript E
        r"\beta\b": "η",  # eta -> η
        r"\bzeta\b": "ζ",  # zeta -> ζ
        r"\btau_memory\b": "τm",  # tau_memory -> τ subscript m
    }

    # Apply the replacements using regex
    for pattern, replacement in replacements.items():
        fit_params_str = re.sub(pattern, replacement, fit_params_str)

    return fit_params_str

def sweep(t_span, y0, Zero_conditions, base_params, sweep_param, n_steps, percent=None, log=None):
    # Convert inputs to Python-friendly types
    t_span = tuple(map(float, t_span))
    y0 = [float(y) for y in y0]

    # Calculate the base parameter value
    base_value = base_params[sweep_param]

    if percent is not None and log is None:
        # Calculate sweep range based on percent
        sweep_range = (base_value * (1 - percent / 100), base_value * (1 + percent / 100))
        # Generate parameter sweep values
        sweep_values = np.linspace(sweep_range[0], sweep_range[1], n_steps)
    elif log is not None and percent is None:
        # Sweep by factors of 10
        sweep_range = (base_value * (10 ** -log), base_value * (10 ** log))  # Sweep from base_value / 10^log to base_value * 10^log
        # Generate parameter sweep values at equal linear points in log10 space
        sweep_values = 10 ** np.linspace(np.log10(sweep_range[0]), np.log10(sweep_range[1]), n_steps)
    else:
        raise ValueError("Specify either 'percent' or 'log', but not both.")

    solutions = []

    for value in sweep_values:
        # Update parameters for the current sweep value
        params = base_params.copy()
        params[sweep_param] = value

        # Simulate using your model
        t_values, y_values = Main.tmap_LCTModel(t_span, y0, list(params.values()))

        # Store solutions
        solution = {
            "t": np.array(t_values, dtype=np.float64),
            "y": np.array(y_values, dtype=np.float64)
        }
        solutions.append(solution)

    # Plot results
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # Define color maps for decreasing and increasing values
    cmap_decrease = plt.get_cmap("cool")
    cmap_increase = plt.get_cmap("spring")

    # Split indices for decreasing and increasing values
    decreased_indices = np.where(sweep_values < base_value)[0]
    increased_indices = np.where(sweep_values >= base_value)[0]

    # Normalize color maps
    norm_decrease = mpl.colors.Normalize(vmin=sweep_values[decreased_indices].min(), vmax=base_value)
    norm_increase = mpl.colors.Normalize(vmin=base_value, vmax=sweep_values[increased_indices].max())

    base_font_size = 20

    for i, ax in enumerate(axs):
        for j, solution in enumerate(solutions):
            if j in decreased_indices:
                color = cmap_decrease(norm_decrease(sweep_values[j]))
            else:
                color = cmap_increase(norm_increase(sweep_values[j]))

            # Add Zero_conditions "E0" to plot 2 and "M0" to plot 3
            if i == 1:  # CD8T Effector cells
                ax.plot(solution["t"], np.log10(np.maximum(solution["y"][i + 3, :] + Zero_conditions["E0"], 1)), color=color)
            elif i == 2:  # CD8T Memory cells
                ax.plot(solution["t"], np.log10(np.maximum(solution["y"][i + 3, :] + Zero_conditions["M0"], 1)), color=color)
            else:
                ax.plot(solution["t"], np.log10(np.maximum(solution["y"][i + 3, :], 1)), color=color)

        # Apply Zero_conditions logic to baseline case
        baseline_t, baseline_y = Main.tmap_LCTModel(t_span, y0, list(base_params.values()))
        baseline = np.log10(np.maximum(baseline_y[i + 3, :] + (Zero_conditions["E0"] if i == 1 else Zero_conditions["M0"] if i == 2 else 0), 1))
        ax.plot(baseline_t, baseline, color="black", linewidth=2, label="Cohort")

    # Add specific axis labels
    axs[0].set_ylabel(f'Log$_{{10}}$ Virus (copies/mL)', fontsize=base_font_size)
    axs[1].set_ylabel(f'Log$_{{10}}$ CD8$^+$T Effector cells', fontsize=base_font_size)
    axs[2].set_ylabel(f'Log$_{{10}}$ CD8$^+$T Memory cells', fontsize=base_font_size)

    # Only set xlabel for all subplots
    for ax in axs:
        ax.set_xlabel("Time (d)", fontsize=base_font_size)
        ax.tick_params(labelsize=base_font_size)

    # Add a centered title for T Effector plot
    axs[1].set_title(f"{format_fit_params(sweep_param)}", fontsize=base_font_size, pad=20)

    # Create legend with gradients
    fig.subplots_adjust(right=0.8)
    plt.subplots_adjust(hspace=0.3, wspace=0.3)
    cbar_ax_decrease = fig.add_axes([0.85, 0.1, 0.02, 0.35])
    cbar_ax_increase = fig.add_axes([0.85, 0.55, 0.02, 0.35])

    mpl.colorbar.ColorbarBase(cbar_ax_decrease, cmap=cmap_decrease, norm=norm_decrease, orientation='vertical')
    cbar_ax_decrease.set_title("Decrease", fontsize=base_font_size)

    mpl.colorbar.ColorbarBase(cbar_ax_increase, cmap=cmap_increase, norm=norm_increase, orientation='vertical')
    cbar_ax_increase.set_title("Increase", fontsize=base_font_size)

    plt.show()

# Define problem parameters
t_span = (0.0, 8)  
y0 = [4E7, 75.0, 0.0, 0.0, 0.0, 0.0] + [0.0] * 13 
base_params = {
    "beta": 0.000049,
    "k": 4.0,
    "p": 2.83,
    "c": 141.0,
    "delta": 1.33,
    "xi": 0.18,
    "a": 7.43,
    "d_E": 0.17,
    "delta_E": 0.27,
    "K_delta_E": 100000,
    "zeta": 0.0046,
    "eta": 0,
    "K_I1": 1.0,
    "tau_memory": 0.26
}
Zero_conditions =  {"E0": 1.26e5, "M0": 7497}  
n_steps = 50  
percent = 50
log = 2

for param in base_params:
    if param not in ["eta", "K_I1"]:
        sweep_param = param
        sweep(t_span, y0, Zero_conditions, base_params, sweep_param, n_steps, percent=percent)
        sweep(t_span, y0, Zero_conditions, base_params, sweep_param, n_steps, log=log)

KeyboardInterrupt: 

In [5]:
from SALib.sample import fast_sampler
from SALib.analyze import fast
from scipy.integrate import trapezoid  # Import trapezoid from scipy

def efast_sensitivity_analysis(
    param_names,
    param_bounds,
    base_params,
    t_span,
    y0,
    Zero_conditions,
    n_samples=65,
    M=4
):

    def run_model_and_metrics(params_dict):

        sim_params = base_params.copy()
        for k, v in params_dict.items():
            sim_params[k] = v

        # Run the model
        t_array, y_array = Main.tmap_LCTModel(t_span, y0, list(sim_params.values()))
        t_array = np.array(t_array, dtype=np.float64)
        y_array = np.array(y_array, dtype=np.float64)  # shape: (n_states, len(t_array))


        # Virus
        V = y_array[3, :]
        # Effector CD8T
        E = y_array[4, :]
        # Immune compartments I1, I2 (adjust if needed)
        I1 = y_array[6, :]
        I2 = y_array[7, :]

        # 1) Peak V
        peak_V = np.max(V)
        # 2) CAUC V (cumulative area under V(t))
        cauc_V = trapezoid(V, x=t_array)  # Use trapezoid instead of trapz
        # 3) Time of Peak V
        time_peak_V = t_array[np.argmax(V)]

        # 4) Peak CD8TE
        peak_E = np.max(E)
        # 5) CAUC CD8TE
        cauc_E = trapezoid(E, x=t_array)  # Use trapezoid instead of trapz
        # 6) Time of Peak CD8TE
        time_peak_E = t_array[np.argmax(E)]

        # 7) CAUC I1+I2
        cauc_I1I2 = trapezoid(I1 + I2, x=t_array)  # Use trapezoid instead of trapz

        # 8) Immune Delay = Time of Peak CD8TE - Time of Peak V
        immune_delay = time_peak_E - time_peak_V

        return [
            peak_V,    # 0
            cauc_V,    # 1
            time_peak_V, # 2
            peak_E,    # 3
            cauc_E,    # 4
            time_peak_E, # 5
            cauc_I1I2, # 6
            immune_delay # 7
        ]

    # -- 2. Define the eFAST Problem --
    problem = {
        "num_vars": len(param_names),
        "names": param_names,
        "bounds": param_bounds 
    }

    # -- 3. Generate eFAST samples using SALib --
    # n_samples is the number of samples per frequency, so total samples = n_samples * len(param_names).
    param_values = fast_sampler.sample(problem, N=n_samples, M=M)

    # -- 4. Evaluate the model for all parameter samples --
    Y = np.zeros((len(param_values), 8))  # we have 8 metrics
    for i, row in enumerate(param_values):
        # row is an array of parameter values in the same order as param_names
        trial_params = {param_names[j]: row[j] for j in range(len(param_names))}
        Y[i, :] = run_model_and_metrics(trial_params)

    # -- 5. Perform eFAST (FAST) analysis on each of the 8 outputs --
    # We'll store results in a dictionary: results[metric_index] = (Si dict)
    results = {}
    for metric_idx in range(8):
        Si = fast.analyze(
            problem,
            Y[:, metric_idx],
            M=M,
            print_to_console=False  # Turn off console printing
        )
        results[metric_idx] = Si

    # -- 6. Plot the Sensitivity (Si) for each metric --
    # We'll do two bar plots side by side: first-order and total-order indices for each metric
    fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(14, 16))
    fig.suptitle("eFAST Sensitivity Indices for 8 Metrics", fontsize=18)

    metric_labels = [
        "Peak V",
        "CAUC V",
        "Time of Peak V",
        "Peak CD8TE",
        "CAUC CD8TE",
        "Time of Peak CD8TE",
        "CAUC I1+I2",
        "Immune Delay (Peak E - Peak V)"
    ]

    row_idx = 0
    col_idx = 0

    for metric_idx in range(8):
        ax = axes[row_idx, col_idx]
        Si = results[metric_idx]

        # Bar positions
        indices = np.arange(len(param_names))
        
        # We can plot first-order (S1) and total-order (ST) side by side
        width = 0.35
        ax.bar(indices - width/2, Si["S1"], width, yerr=Si["S1_conf"], label="First-order")
        ax.bar(indices + width/2, Si["ST"], width, yerr=Si["ST_conf"], label="Total-order")
        
        ax.set_title(metric_labels[metric_idx], fontsize=14)
        ax.set_xticks(indices)
        ax.set_xticklabels(param_names, rotation=45, ha="right")
        ax.set_ylabel("Sensitivity index")
        ax.legend()

        col_idx += 1
        if col_idx > 1:
            col_idx = 0
            row_idx += 1

    plt.tight_layout(rect=[0, 0, 1, 0.96])  # leave space for suptitle
    plt.show()

    return results, Y

# Define problem parameters
t_span = (0.0, 8)  
y0 = [4E7, 75.0, 0.0, 0.0, 0.0, 0.0] + [0.0] * 13 
base_params = {
    "beta": 0.000049,
    "k": 4.0,
    "p": 2.83,
    "c": 141.0,
    "delta": 1.33,
    "xi": 0.18,
    "a": 7.43,
    "d_E": 0.17,
    "delta_E": 0.27,
    "K_delta_E": 100000,
    "zeta": 0.0046,
    "eta": 0,
    "K_I1": 1.0,
    "tau_memory": 0.26
}
Zero_conditions =  {"E0": 1.26e5, "M0": 7497}  
sens_percent = 50
param_names = ['beta', 'p', 'c', 'delta', 'xi', 'a', 'delta_E', 'K_delta_E']

# Construct param_bounds programmatically
param_bounds = []
for param in param_names:
    base_value = base_params[param]
    lower_bound = max(0, base_value * (1 - sens_percent / 100))
    upper_bound = base_value * (1 + sens_percent / 100)
    param_bounds.append((lower_bound, upper_bound))
results, all_outputs = efast_sensitivity_analysis(
    param_names,
    param_bounds,
    base_params,
    t_span,
    y0,
    Zero_conditions,
    n_samples=65,
    M=4         
)