In [17]:
import numpy as np
from numba import njit, prange
import time

@njit
def simulate_CLE(C_u0, V_injection, params_rho, params_psi, params_alpha, params_delta, params_kappa, params_beta, dt, sqrt_dt, rand_array):
    C_u = C_u0
    C_i = 0.
    V = 0.
    
    N_steps = rand_array.shape[0]
    injection_step = int(1/dt)

    for step in range(N_steps):
        if (C_u + C_i) <= 0.0:
            break

        if step == injection_step:
            V = V_injection

        a1 = params_rho * C_u * (1. - (C_u + C_i) / params_kappa)
        a2 = params_psi * V * C_u
        a3 = params_alpha * C_i
        a4 = params_delta * V

        a2_w2 = np.sqrt(a2) * sqrt_dt * rand_array[step, 1]
        a3_w3 = np.sqrt(a3) * sqrt_dt * rand_array[step, 2]

        dC_u = (a1 - a2) * dt + np.sqrt(np.maximum(a1, 0)) * sqrt_dt * rand_array[step, 0] - a2_w2
        dC_i = (a2 - a3) * dt + a2_w2 - a3_w3
        dV = (params_beta * a3 - a2 - a4) * dt + np.sqrt(params_beta) * a3_w3 - a2_w2 - np.sqrt(a4) * sqrt_dt * rand_array[step, 3]
        
        C_u = np.maximum(C_u + dC_u, 0)
        C_i = np.maximum(C_i + dC_i, 0)
        V = np.maximum(V + dV, 0)

    final_time = step * dt

    return np.array([C_u, C_i, V, step])

def run_simulation(N_simulations, C_u0, V_injection, params, dt, T, seed = 42):

    sqrt_dt = np.sqrt(dt)

    # pre-calculate the random numbers
    np.random.seed(seed)
    random_numbers = np.random.normal(size=(N_simulations, int(T/dt), 4))

    results = np.zeros((N_simulations, 4))
    
    start_run = time.time()
    checkpoints = np.linspace(0, N_simulations, 11, dtype=int)[1:]
    for i in range(N_simulations):
        results[i] = simulate_CLE(C_u0, V_injection, params['rho'], params['psi'], params['alpha'], params['delta'], params['kappa'], params['beta'], dt, sqrt_dt, random_numbers[i])

        # Check if we've reached a checkpoint
        if i+1 in checkpoints:
            elapsed_time = time.time() - start_run
            percent_complete = ((i+1) / N_simulations) * 100
            estimated_total_time = (elapsed_time / (i+1)) * N_simulations
            time_remaining = estimated_total_time - elapsed_time
            print(f"{percent_complete:.0f}% complete. Estimated time remaining: {time_remaining:.2f} seconds.")

    return results

In [25]:
params = {'rho': 0.5379080098179797,
    'kappa': 777.6924217880852,
    'psi': 1.0142995960744846e-15,
    'beta': 9999.6,
    'alpha': 0.5308450453281762,
    'delta': 13.937249271829442}


In [26]:
N_simulations = 1000
C_u0 = 400
V_injection = 3e9
dt = 0.001
T = 100

In [27]:
results = run_simulation(N_simulations, C_u0, V_injection, params, dt, T)

10% complete. Estimated time remaining: 1.80 seconds.
20% complete. Estimated time remaining: 1.60 seconds.
30% complete. Estimated time remaining: 1.40 seconds.
40% complete. Estimated time remaining: 1.20 seconds.
50% complete. Estimated time remaining: 1.00 seconds.
60% complete. Estimated time remaining: 0.80 seconds.
70% complete. Estimated time remaining: 0.60 seconds.
80% complete. Estimated time remaining: 0.40 seconds.
90% complete. Estimated time remaining: 0.20 seconds.
100% complete. Estimated time remaining: 0.00 seconds.


In [None]:
# TODO: optimize the loop - DONE
# TODO: calculate the extinction probability and extinction timepoint
# TODO: optimal control problem

In [31]:
results[:,-1].mean()

99999.0