In [None]:
# Math, numpy and torch
import math
import numpy as np
import torch

# SBI
from sbi.inference import prepare_for_sbi, simulate_for_sbi
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.utils import RestrictionEstimator

# SciPy
from scipy.integrate import solve_ivp
from scipy.signal import find_peaks

In [None]:
# Constants
torch.manual_seed(0)
np.random.seed(0)
tf = 50
dt = 0.01
N = round(tf/dt)
t = np.arange(0, tf, dt)

tspan = [0, tf]
c0 = [0, 0]

In [None]:
# Functions
def system(t, z, k1, k2, k3, k4, Ac, xi):
    """
    This function returns the ODE system.

    Args:
        t (paramter): time (optional)
        z (list): state variables
        k1 (float): k1 constant
        k2 (float): k2 constant
        k3 (float): k3 constant
        k4 (float): k4 constant
        Ac (float): Ac constant
        xi (float): xi constant
    
    Returns:
        list: ODE system
    """
    x, y = z
    return [- k1*x - k2*y + Ac, k3*x - k4*y + xi*Ac]

def valid_simulation(x, dt):
    """Returns True or False depending on whether the simulation is valid or not.

    Args:
        x (array): The system state variable.

    Returns:
        A numpy array with the summary statistics (attributes) of the system
    """
    
    if np.any(x < 0): # Not a valid simulation
        return False
    else:
        # Check if the system presents oscillations
        peaks, _ = find_peaks(x, height=0.1, prominence=0.1)
        if len(peaks) > 1:
            # The system has oscillations
            return False
        
        # Peak
        # I don't want to use find_peaks for this since I have the min height and min prominence
        Xpeak = np.max(x)
        Tpeak = np.where(x == Xpeak)[0][0] * dt
        
        # Check that if the last value of x is close to the last five second value in the tolerance
        # If it is, then the system has (probably) reached the steady state
        tol = Xpeak * 1e-1
        if math.isclose(x[-1], x[int(-5/dt)], rel_tol=tol): # five second
            x_ss = x[-1]
        else:
            return False
        
        tol = 1e-3
        tau2_index = np.where(np.isclose(x, x_ss + (Xpeak - x_ss)*0.37, atol=tol))
        
        # Grab first index where tau2 * dt is greater than Tpeak
        tol_tau2 = 1e-2;
        try:
            tau2_index = tau2_index[0][tau2_index[0] > Tpeak/dt][0]
            tau2 = tau2_index*dt
            if Xpeak < x_ss+tol_tau2 and Xpeak > x_ss-tol_tau2 or math.isnan(tau2): # This is if there isn't a peak, therefor no tau2
                return False
        except:
            return False
        
        # Check that after t peak, there isn't a value of x that is lower to the steady state value - a tolerance
        if np.any(x[int(Tpeak/dt):] < x_ss - 3.25e-2*Xpeak):
            return False
        
        return True
    
def calculate_summary_statistics(x, y, dt):
    """Calculates the summary statistics (attributes) of the system.

    Args:
        x (array): The system state variable.

    Returns:
        A numpy array with the summary statistics (attributes) of the system
    """
    
    if np.any(x < 0): # Not a valid simulation
        return np.array([math.nan, math.nan, math.nan, math.nan, math.nan, math.nan])
    else:
        # Check if the system presents oscillations
        peaks, _ = find_peaks(x, height=0.1, prominence=0.1)
        if len(peaks) > 1:
            # The system has oscillations
            return np.array([math.nan, math.nan, math.nan, math.nan, math.nan, math.nan])
        
        # Peak
        # I don't want to use find_peaks for this since I have the min height and min prominence
        Xpeak = np.max(x)
        Tpeak = np.where(x == Xpeak)[0][0] * dt
        
        tau1 = np.where(x >= Xpeak*0.63)[0][0] * dt
        
        # Check that if the last value of x is close to the last five second value in the tolerance
        # If it is, then the system has (probably) reached the steady state
        tol = Xpeak * 1e-1
        if math.isclose(x[-1], x[int(-5/dt)], rel_tol=tol): # five seconds
            x_ss = x[-1]
        else:
            x_ss = math.nan
        
        tol = 1e-3
        tau2_index = np.where(np.isclose(x, x_ss + (Xpeak - x_ss)*0.37, atol=tol))
        
        # Grab first index where tau2 * dt is greater than Tpeak
        tol_tau2 = 1e-2;
        try:
            tau2_index = tau2_index[0][tau2_index[0] > Tpeak/dt][0]
            tau2 = tau2_index*dt
            if Xpeak < x_ss+tol_tau2 and Xpeak > x_ss-tol_tau2 or math.isnan(tau2): # This is if there isn't a peak, therefor no tau2
                tau2 = math.nan
        except:
            tau2 = math.nan
        
        # Check that after t peak, there isn't a value of x that is lower to the steady state value - a tolerance
        if np.any(x[int(Tpeak/dt):] < x_ss - 3.25e-2*Xpeak):
            return np.array([math.nan, math.nan, math.nan, math.nan, math.nan, math.nan])
        
        y_fin = y[-1]
        
        Ypeak = np.max(y)
        tau1_Y = np.where(y >= Ypeak*0.63)[0][0] * dt
        
        return np.array([tau1, tau2, x_ss, Xpeak, y_fin, tau1_Y])

def simulate_system(params):
    """Solves the ODE system using the SciPy solver.

    Args:
        params (array): array of parameters. [k1, k2, k3, k4, Ac, Xi] 

    Returns:
        The solution of the ODE system.
    """
    ode = solve_ivp(system, tspan, c0, args = (params[0], params[1], params[2], params[3], params[4], params[5]), t_eval = t) # Paso explicito
    return np.array([ode.y[0], ode.y[1]])

def simulate_system_stats(params):
    """Solves the ODE system using the SciPy solver.

    Args:
        params (array): array of parameters. [k1, k2, k3, k4, Ac, Xi] 

    Returns:
        The solution of the ODE system.
    """
    ode = solve_ivp(system, tspan, c0, args = (params[0], params[1], params[2], params[3], params[4], params[5]), t_eval = t) # Paso explicito
    stats = torch.as_tensor(calculate_summary_statistics(ode.y[0], ode.y[1], dt))
    return stats

In [None]:
prior_min = [0, 0, 0, 0, 0, 0]
prior_max = [2, 2, 0.5, 0.5, 2, 1]

prior = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))

In [None]:
simulator, prior = prepare_for_sbi(simulate_system, prior)
simulator_statistics, prior = prepare_for_sbi(simulate_system_stats, prior)
restriction_estimator = RestrictionEstimator(prior=prior)

In [None]:
restriction_estimator = RestrictionEstimator(prior=prior)
proposals = [prior]

num_rounds = 5
for r in range(num_rounds):
    params, x = simulate_for_sbi(simulator_statistics, proposals[-1], 5000, num_workers=4, show_progress_bar=False)
    # j = 0
    # for i in range(torch.Tensor.size(x)[0]):
    #     if not valid_simulation((x[i][0]).numpy(), dt):
    #         if j == 0:
    #             invalid_x = x[i].unsqueeze(0)
    #             invalid_params = params[i].clone().detach().unsqueeze(0)
    #             j = 1
    #         else:
    #             invalid_params = torch.cat((invalid_params, params[i].clone().detach().unsqueeze(0)), 0)
    #             invalid_x = torch.cat((invalid_x, x[i].clone().detach().unsqueeze(0)), 0)
    
    # print(f"Invalid simulations: {torch.Tensor.size(invalid_params)[0]}")
    # restriction_estimator.append_simulations(invalid_params, invalid_x)
    restriction_estimator.append_simulations(params, x)
    
    if (r < num_rounds - 1):
        classifier = restriction_estimator.train()
    proposals.append(restriction_estimator.restrict_prior())

all_theta, all_x, _ = restriction_estimator.get_simulations()
restricted_prior = restriction_estimator.restrict_prior()

In [None]:
new_params, new_x = simulate_for_sbi(simulator, restricted_prior, 50000) # We want curves now

In [None]:
# Plot 6 graphs with 10 curves each
import matplotlib.pyplot as plt 

fig, axs = plt.subplots(3, 2, figsize=(15, 15))
plt.suptitle("Simulations")
for i in range(6):
    plt.subplot(2, 3, i+1)
    for j in range(10):
        random_index = np.random.randint(0, len(new_x))
        plt.plot(t, new_x[random_index][0], label = f"{random_index}")
        plt.ylim(0, 2)
        plt.legend(loc = "upper right")

In [None]:
params, x = simulate_for_sbi(simulator, proposal=restricted_prior, num_simulations=500000, num_workers=6, simulation_batch_size=10000, seed=0)

In [None]:
torch.save((params, x), 'simulations_500k_restricted.pt')

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(15, 15))
plt.suptitle("Simulations")
for i in range(6):
    plt.subplot(2, 3, i+1)
    for j in range(10):
        random_index = np.random.randint(0, len(x))
        plt.plot(t, x[random_index][0], label = f"{random_index}")
        plt.ylim(0, 2)
        plt.legend(loc = "upper right")