In [1]:
import numpy as np
from numba import njit, prange
import time
import matplotlib.pyplot as plt

@njit
def simulate_CLE(C_u0, n_hidden, V_injection, params_rho, params_psi, params_phi, params_alpha, params_delta, params_kappa, params_beta, dt, sqrt_dt, rand_array, extinction_threshold):
    C_u = C_u0
    C_i = 0.0
    C_is = np.zeros(n_hidden, dtype=np.float64)
    dC_is = np.zeros(n_hidden, dtype=np.float64)
    C_l = 0.0
    V = 0.0
    
    N_steps = rand_array.shape[0]
    injection_step = int(1./dt)

    for step in range(N_steps):

        sum_C = C_u + C_i + C_is.sum() + C_l

        if (sum_C) <= extinction_threshold:
            break

        if step == injection_step:
            V = V_injection

        a1_ = params_rho * (1. - sum_C / params_kappa)
        a2 = params_psi * V * C_u
        a3 = params_phi * C_i
        a4s = params_phi * C_is
        a5 = params_alpha * C_l
        a6 = params_delta * V

        a2_w2 = np.sqrt(a2) * sqrt_dt * rand_array[step, 1]
        a3_w3 = np.sqrt(a3) * sqrt_dt * rand_array[step, 2]

        dC_u = (a1_ * C_u - a2) * dt + np.sqrt(np.maximum(a1_ * C_u, 0)) * sqrt_dt * rand_array[step, 0] - a2_w2
        dC_i = (a1_ * C_i + a2 - a3) * dt + np.sqrt(np.maximum(a1_ * C_i, 0)) * sqrt_dt * rand_array[step, 0] + a2_w2 - a3_w3
        for i in range(n_hidden):
            if i == 0:
                dC_is[i] = (a1_ * C_is[i] + a3 - a4s[i]) * dt + np.sqrt(np.maximum(a1_ * C_is[i], 0)) * sqrt_dt * rand_array[step, 0] \
                    + a3_w3 - np.sqrt(a4s[i]) * sqrt_dt * rand_array[step, 3+i]
            else:
                dC_is[i] = (a1_ * C_is[i] + a4s[i-1] - a4s[i]) * dt + np.sqrt(np.maximum(a1_ * C_is[i], 0)) * sqrt_dt * rand_array[step, 0] \
                    + np.sqrt(a4s[i-1]) * sqrt_dt * rand_array[step, 3+i-1] - np.sqrt(a4s[i]) * sqrt_dt * rand_array[step, 3+i]
        dC_l = (a1_ * C_l + a4s[-1] - a5) * dt + np.sqrt(np.maximum(a1_ * C_l, 0)) * sqrt_dt * rand_array[step, 0] \
            + np.sqrt(a4s[-1]) * sqrt_dt * rand_array[step, 3+n_hidden-1] - np.sqrt(a5) * sqrt_dt * rand_array[step, 3+n_hidden]
        dV = (params_beta * a5 - a2 - a6) * dt + np.sqrt(params_beta * a5) * sqrt_dt * rand_array[step, 3+n_hidden] \
            - a2_w2 - np.sqrt(a6) * sqrt_dt * rand_array[step, 4+n_hidden]

        C_u = np.maximum(C_u + dC_u, 0)
        C_i = np.maximum(C_i + dC_i, 0)
        for i in range(n_hidden):
            C_is[i] = np.maximum(C_is[i] + dC_is[i], 0)
        C_l = np.maximum(C_l + dC_l, 0)
        V = np.maximum(V + dV, 0)

    final_time = step * dt

    return np.array([C_u, C_i, *C_is, C_l, V, final_time])

def run_simulation(N_simulations, C_u0, n_hidden, V_injection, params, dt, T, extinction_threshold, seed = 42):

    sqrt_dt = np.sqrt(dt)

    # pre-calculate the random numbers
    np.random.seed(seed)
    random_numbers = np.random.normal(size=(N_simulations, int(T/dt), n_hidden+5))

    results = np.zeros((N_simulations, n_hidden+5))
    
    start_run = time.time()
    checkpoints = [N_simulations // 2]
    for i in range(N_simulations):
        results[i] = simulate_CLE(C_u0, n_hidden, V_injection, params['rho'], params['psi'], params['phi'], params['alpha'], params['delta'], params['kappa'], params['beta'], dt, sqrt_dt, random_numbers[i], extinction_threshold)

        # Check if we've reached the 50% checkpoint
        if i+1 in checkpoints:
            elapsed_time = time.time() - start_run
            percent_complete = ((i+1) / N_simulations) * 100
            estimated_total_time = (elapsed_time / (i+1)) * N_simulations
            time_remaining = estimated_total_time - elapsed_time
            print(f"{percent_complete:.0f}% complete. Estimated time remaining: {time_remaining:.2f} seconds.")

    return results

# Test the simulation with optimal parameters

In [15]:
params_og = {'rho': 0.5778525462995705,
'kappa': 772.7473978753562,
'psi': 0.0001000147678614947,
'phi': 3.6546551247446124,
'beta': 1864.3310577533284,
'alpha': 1.0894545550320296,
'delta': 12.14514283598}

In [16]:
N_simulations = 1000
C_u0 = 400
n_hidden = 5
V_injection = 3e9
dt = 0.001
T = 100

extinction_threshold = 1e-6

In [None]:
results = run_simulation(N_simulations, C_u0, 5, V_injection, params_og, dt, T, extinction_threshold)

# Sensitivity analysis for extinction schenarios, virus related

In [18]:
def check_extinction_prob(results, extinction_threshold):
    return np.sum(results[:,:-1].sum(axis=1) < extinction_threshold) / results.shape[0]

def check_extinction_params(param, param_range, extinction_threshold, seed=42):
    extinction_probs = np.zeros(len(param_range))
    params = params_og.copy()
    for i, p in enumerate(param_range):
        params[param] = p
        results = run_simulation(N_simulations, C_u0, 5, V_injection, params, dt, T, extinction_threshold, seed)
        extinction_probs[i] = check_extinction_prob(results, extinction_threshold)

        print(f"Parameter {param} = {p}: calculation complete.")
    return extinction_probs

In [None]:
# vary infection rate, psi, from 1e-8 to 1e8
psi_values = np.logspace(-8, 8, 17)
extinction_probs_psi = check_extinction_params('psi', psi_values, extinction_threshold)
extinction_probs_psi

In [None]:
# visualize the probability of extinction as a function of psi
plt.plot(psi_values, extinction_probs_psi)
plt.xscale('log')
plt.xlabel('psi')
plt.ylabel('extinction probability')
plt.title('Extinction probability as a function of psi')
plt.show()

In [None]:
# vary infection rate, phi, from 1e-8 to 1e8
phi_values = np.logspace(-8, 8, 17)
extinction_probs_phi = check_extinction_params('phi', phi_values, extinction_threshold)
extinction_probs_phi

In [None]:
# visualize the probability of extinction as a function of phi
plt.plot(phi_values, extinction_probs_phi)
plt.xscale('log')
plt.xlabel('phi')
plt.ylabel('extinction probability')
plt.title('Extinction probability as a function of phi')
plt.show()

In [None]:
# vary virus production rate, beta, from 1e-8 to 1e8
beta_values = np.logspace(-8, 8, 17)
extinction_probs_beta = check_extinction_params('beta', beta_values, extinction_threshold)
extinction_probs_beta

In [None]:
# visualize the probability of extinction as a function of beta
plt.plot(beta_values, extinction_probs_beta)
plt.xscale('log')
plt.xlabel('beta')
plt.ylabel('extinction probability')
plt.title('Extinction probability as a function of beta')
plt.show()