In [23]:
import numpy as np
from numba import njit, prange

def simulate_CLE(C_u0, V_injection, params, dt, sqrt_dt, T, rand_array):
    C_u = C_u0
    C_i = 0.
    V = 0.
    time = 0
    while time < T and (C_u + C_i) > 0:
        if abs(time - 1) < dt / 2:
            V = V_injection

        a1 = params['rho'] * C_u * (1 - (C_u + C_i) / params['kappa'])
        a2 = params['psi'] * V * C_u
        a3 = params['alpha'] * C_i
        a4 = params['delta'] * V

        a2_w2 = np.sqrt(a2) * sqrt_dt * rand_array[int(time/dt), 1]
        a3_w3 = np.sqrt(a3) * sqrt_dt * rand_array[int(time/dt), 2]

        dC_u = (a1 - a2) * dt + np.sqrt(np.max(a1, 0)) * sqrt_dt * rand_array[int(time/dt), 0] - a2_w2
        dC_i = (a2 - a3) * dt + a2_w2 - a3_w3
        dV = (params['beta'] * a3 - a2 - a4) * dt + np.sqrt(params['beta']) * a3_w3 - a2_w2 - np.sqrt(a4) * sqrt_dt * rand_array[int(time/dt), 3]
        
        C_u += dC_u
        C_u = max(C_u, 0)
        C_i += dC_i
        C_i = max(C_i, 0)
        V += dV
        V = max(V, 0)

        time += dt

    return [C_u, C_i, V, time] # the small dt difference for time?

def run_simulation(N_simulations, C_u0, V_injection, params, dt, T, seed = 42):

    sqrt_dt = np.sqrt(dt)

    # pre-calculate the random numbers
    np.random.seed(seed)
    random_numbers = np.random.normal(size=(N_simulations, int(T/dt), 4))

    results = np.zeros((N_simulations, 4))
    for i in range(N_simulations):
        results[i] = simulate_CLE(C_u0, V_injection, params, dt, sqrt_dt, T, random_numbers[i])

    return results

In [24]:
params = {'rho': 0.5379080098179797,
    'kappa': 777.6924217880852,
    'psi': 1.0142995960744846e-15,
    'beta': 999997793935659.6,
    'alpha': 0.5308450453281762,
    'delta': 13.937249271829442}


In [25]:
N_simulations = 1000
C_u0 = 400
V_injection = 3e9
dt = 0.001
T = 100

In [26]:
results = run_simulation(N_simulations, C_u0, V_injection, params, dt, T)

In [None]:
# TODO: calculate the extinction probability and extinction timepoint
# TODO: optimal control problem

In [27]:
results

array([[0.00000000e+00, 0.00000000e+00, 1.33510155e+12, 1.43330000e+01],
       [0.00000000e+00, 0.00000000e+00, 1.16818506e+12, 1.75580000e+01],
       [0.00000000e+00, 0.00000000e+00, 1.51237595e+12, 1.90870000e+01],
       ...,
       [0.00000000e+00, 0.00000000e+00, 1.39849064e+12, 1.73060000e+01],
       [0.00000000e+00, 0.00000000e+00, 2.42891667e+12, 1.56100000e+01],
       [0.00000000e+00, 0.00000000e+00, 7.41631962e+11, 1.56240000e+01]])