In [2]:
import numpy as np
import random
import os
import pickle
from scipy.integrate import solve_ivp
from multiprocessing import Process, Queue
from concurrent.futures import ProcessPoolExecutor, TimeoutError
from tqdm import tqdm
import time
import warnings
warnings.filterwarnings("ignore")

np.seterr(over='ignore', under='ignore', invalid='ignore')

LOG_MIN = -5
LOG_MAX = 5
SIM_T_MAX = 100.0
SIM_N_POINTS = 100
TRAIN_SAMPLE_POINTS = 21
N_TRAIN_PER_MECH = 300
N_VAL_PER_MECH = 50
N_TEST_PER_MECH = 20
MECHANISMS = [f"M{i}" for i in range(1, 21)]
SEED = 42

np.random.seed(SEED)
random.seed(SEED)

def round_sig(x, sig=3):
    x = np.array(x)
    def _round_scalar(val):
        if val == 0:
            return 0
        digits = sig - int(np.floor(np.log10(abs(val)))) - 1
        return np.round(val, digits)
    vfunc = np.vectorize(_round_scalar)
    return vfunc(x)


def sample_theta_pool(n_total, log_min=LOG_MIN, log_max=LOG_MAX, dim=1, seed=None):
    rng = np.random.default_rng(seed)
    theta_pool = 10 ** rng.uniform(log_min, log_max, size=(n_total, dim))
    return round_sig(theta_pool)

def maximum_yield_criterion(ode_func, theta, y0=None, yield_lower=0.5, yield_upper=1.0, t_max=SIM_T_MAX):
    if y0 is None:
        S0 = 1.0
        y0 = [S0] + [0.0] * (len(theta) - 1)
    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, 10)
    try:
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA", rtol=1e-6, atol=1e-9)
        if not sol.success or np.any(np.isnan(sol.y)):
            return False
        S_traj = sol.y[0]
        S0 = S_traj[0]
        S_end = S_traj[-1]
        conv = (S0 - S_end) / (S0 + 1e-8)
        return yield_lower <= conv <= yield_upper
    except Exception:
        return False

def filter_theta_pool(n_required, filter_func, ode_func, dim, y0_func=None, max_round=50, batch_factor=2):  # 改小
    passed = []
    rounds = 0
    total_tested = 0
    while len(passed) < n_required and rounds < max_round:
        start_time = time.time()
        batch_size = batch_factor * (n_required - len(passed))
        batch_size = min(batch_size, 40) 
        thetas = 10 ** np.random.uniform(LOG_MIN, LOG_MAX, size=(batch_size, dim))
        print(f"[INFO] {ode_func.__name__} | Round {rounds+1} | Param range: [{10**LOG_MIN:.1e}, {10**LOG_MAX:.1e}] | Batch size: {batch_size}")
        print(f"[INFO] {ode_func.__name__} | Sample θ example: {np.array2string(thetas[0], precision=2, separator=', ')}")
        accepted_this_round = 0
        for theta in thetas:
            try:
                y0 = y0_func() if y0_func else [1.0] + [0.0] * (dim - 1)
                if filter_func is not None and not filter_func(theta, ode_func):
                    continue
                if not maximum_yield_criterion(ode_func, theta, y0):
                    continue
                passed.append(theta)
                accepted_this_round += 1
                if len(passed) >= n_required:
                    break
            except Exception:
                continue
        total_tested += batch_size
        elapsed = time.time() - start_time
        print(f"[INFO] {ode_func.__name__} | Round {rounds+1} | Accepted {len(passed)} (+{accepted_this_round} this round) / Tested {total_tested} (pass rate {len(passed)/total_tested:.3%}) | Time: {elapsed:.2f}s")
        rounds += 1
    if len(passed) < n_required:
        raise RuntimeError(f"Only {len(passed)} samples found after {rounds} rounds (need {n_required}) for {ode_func.__name__}.")
    return np.array(passed[:n_required])


def single_ode_job(ode_func_name, theta, y0, n_points):
    ode_func = ODE_FUNCS[ode_func_name]
    try:
        t_eval = np.linspace(0, SIM_T_MAX, n_points)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), [0, SIM_T_MAX], y0, t_eval=t_eval, method="LSODA", rtol=1e-6, atol=1e-9, max_step=1.0)
        if sol.success and np.all(np.isfinite(sol.y)):
            return (sol.t, sol.y)
    except Exception:
        return None
    return None

def simulate_trajectory(ode_func, theta, y0, n_points=SIM_N_POINTS, timeout=3):
    ode_func_name = [k for k,v in ODE_FUNCS.items() if v == ode_func][0]
    with ProcessPoolExecutor(max_workers=1) as executor:
        future = executor.submit(single_ode_job, ode_func_name, theta, y0, n_points)
        try:
            res = future.result(timeout=timeout)
            return res
        except TimeoutError:
            return None
        
def subsample_points(arr, t_full, n_points):
    idx = np.sort(np.random.choice(arr.shape[1], n_points, replace=False))
    return arr[:, idx], t_full[idx]

def extract_SPcat(traj):
    return traj[[0, 1, 2], :]

def generate_sample(theta, ode_func, init_func, train=True):
    inits = init_func()
    cat0s = [init[2] for init in inits]
    profiles = []
    t_all = []
    for y0 in inits:
        res = simulate_trajectory(ode_func, theta, y0)
        if res is None:
            return None, None, None 
        t_full, traj = res
        sub, t_sub = subsample_points(traj, t_full, TRAIN_SAMPLE_POINTS)
        if not train:
            noise_std = random.uniform(0.005, 0.05) * np.max(sub)
            sub = sub + np.random.normal(0.0, noise_std, size=sub.shape)
        profiles.append(extract_SPcat(sub).T)
        t_all.append(t_sub)
    sample_matrix = np.vstack(profiles)
    t_concat = np.concatenate(t_all)
    return np.array(cat0s), sample_matrix, t_concat


def generate_dataset_given_theta(theta_array, ode_func, train_init, test_init, mode='train', desc=""):
    x1_list, x2_list, t_list = [], [], []
    idx = 0
    for _ in tqdm(range(len(theta_array)), desc=desc, leave=False, ncols=80, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"):
        theta = theta_array[idx]
        try:
            if mode == 'test':
                cat0, mat, t_sub = generate_sample(theta, ode_func, test_init, train=False)
            else:
                cat0, mat, t_sub = generate_sample(theta, ode_func, train_init, train=True)
            if cat0 is None: 
                idx += 1
                continue
            x1_list.append(cat0)
            x2_list.append(mat)
            t_list.append(t_sub)
        except Exception:
            idx += 1
            continue
        idx += 1
        if idx >= len(theta_array):
            extra_needed = len(theta_array) - len(x1_list)
            if extra_needed > 0:
                extra = sample_theta_pool(extra_needed, LOG_MIN, LOG_MAX, theta_array.shape[1])
                theta_array = np.concatenate([theta_array, extra], axis=0)
    return np.array(x1_list), x2_list, t_list


ODE_FUNCS = {}
TRAIN_INIT_FUNCS = {}
TEST_INIT_FUNCS = {}
MECH_DIM = {}
POST_FILTER_FUNCS = {}

In [3]:
def ode_M1(t, y, theta):
    S, P, cat, catS = y
    k1, k_1, k2, k_2 = theta
    dS_dt = k_1 * catS - k1 * S * cat
    dP_dt = k2 * catS - k_2 * cat * P
    dcat_dt = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt]

def train_inits_M1():
    cat0s = np.random.uniform(0.01, 0.1, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0])
    return inits

def test_inits_M1():
    cat0s = [np.random.uniform(0.01, 0.02), np.random.uniform(0.045, 0.055), np.random.uniform(0.09, 0.10)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0])
    return inits

ODE_FUNCS["M1"] = ode_M1
TRAIN_INIT_FUNCS["M1"] = train_inits_M1
TEST_INIT_FUNCS["M1"] = test_inits_M1
MECH_DIM["M1"] = 4

In [4]:
def ode_M2(t, y, theta):
    S, P, cat, catS, cat2 = y
    k1, k_1, k2, k_2, k3, k_3 = theta
    dS_dt = k_1 * catS - k1 * S * cat
    dP_dt = k2 * catS - k_2 * cat * P
    dcat_dt = (k_1 + k2) * catS + 2 * k3 * cat2 - (k1 * S + k_2 * P + 2 * k_3 * cat) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dcat2_dt = k_3 * cat * cat - k_3 * cat2
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dcat2_dt]

def train_inits_M2():
    cat0s = np.random.uniform(0.01, 0.1, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M2():
    cat0s = [np.random.uniform(0.01, 0.02), np.random.uniform(0.045, 0.055), np.random.uniform(0.09, 0.10)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits


def M2_cat2_post_filter(theta, ode_func):
    S0 = 1.0
    P0 = 0.0
    cat0 = 0.05
    y0 = [S0, P0, cat0, 0.0, 0.0]
    t_span = (0, 100)
    t_eval = np.linspace(0, 100, 100)
    sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
    if (not sol.success) or np.any(np.isnan(sol.y)):
        return False
    S_traj = sol.y[0]
    cat2_traj = sol.y[4]
    S_start = S_traj[0]
    S_end = S_traj[-1]
    S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
    window = np.where((S_conv >= 0.2) & (S_conv <= 0.5))[0]
    if len(window) == 0:
        return False
    cat2_max = np.max(cat2_traj[window])
    cat_total = cat0
    if cat2_max > 0.1 * cat_total:
        return True
    return False

ODE_FUNCS["M2"] = ode_M2
TRAIN_INIT_FUNCS["M2"] = train_inits_M2
TEST_INIT_FUNCS["M2"] = test_inits_M2
MECH_DIM["M2"] = 6
POST_FILTER_FUNCS["M2"] = M2_cat2_post_filter

In [5]:
def ode_M3(t, y, theta):
    S, P, cat, cat2, cat2S = y
    k1, k_1, k2, k_2, k3, k_3 = theta
    dS_dt = k_1 * cat2S - k1 * S * cat2
    dP_dt = k2 * cat2S - k_2 * cat2 * P
    dcat2_dt = (k_1 + k2) * cat2S + k3 * cat * cat - (k1 * S + k_2 * P + k_3) * cat2
    dcat2S_dt = (k1 * S + k_2 * P) * cat2 - (k_1 + k2) * cat2S
    dcat_dt = 2 * k_3 * cat2 - 2 * k3 * cat * cat

    return [dS_dt, dP_dt, dcat_dt, dcat2_dt, dcat2S_dt]

def train_inits_M3():
    cat0s = np.random.uniform(0.01, 0.1, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M3():
    cat0s = [np.random.uniform(0.01, 0.02), np.random.uniform(0.045, 0.055), np.random.uniform(0.09, 0.10)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M3_cat_post_filter(theta, ode_func):
    S0 = 1.0
    P0 = 0.0
    cat0 = 0.05
    y0 = [S0, P0, cat0, 0.0, 0.0]
    t_span = (0, 100)
    t_eval = np.linspace(0, 100, 100)
    sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
    if (not sol.success) or np.any(np.isnan(sol.y)):
        return False
    S_traj = sol.y[0]
    cat_traj = sol.y[2]
    S_start = S_traj[0]
    S_end = S_traj[-1]
    S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
    window = np.where(S_conv < 0.1)[0]
    if len(window) == 0:
        return False
    cat_min = np.min(cat_traj[window])
    cat_total = cat0
    if cat_min > 0.05 * cat_total:
        return True
    return False

ODE_FUNCS["M3"] = ode_M3
TRAIN_INIT_FUNCS["M3"] = train_inits_M3
TEST_INIT_FUNCS["M3"] = test_inits_M3
MECH_DIM["M3"] = 6
POST_FILTER_FUNCS["M3"] = M3_cat_post_filter

In [6]:
def ode_M4(t, y, theta):
    S, P, cat, catS, X = y
    k1, k_1, k2, k_2 = theta
    dS_dt = k_1 * X * catS - k1 * S * cat
    dP_dt = k2 * X * catS - k_2 * cat * P
    dcat_dt = (k_1 + k2) * X * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * X * catS
    dX_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * X * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dX_dt]

def train_inits_M4():
    cat0s = np.random.uniform(0.01, 0.1, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0])
    return inits

def test_inits_M4():
    cat0s = [np.random.uniform(0.01, 0.02), np.random.uniform(0.045, 0.055), np.random.uniform(0.09, 0.10)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0])
    return inits

ODE_FUNCS["M4"] = ode_M4
TRAIN_INIT_FUNCS["M4"] = train_inits_M4
TEST_INIT_FUNCS["M4"] = test_inits_M4
MECH_DIM["M4"] = 4  

In [7]:
def ode_M5(t, y, theta):
    S, P, cat, catS, catP = y
    k1, k_1, k2, k_2, k3, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k3 * catP - k_3 * cat * P
    dcat_dt  = k_1 * catS + k3 * catP - (k1 * S + k_3 * P) * cat
    dcatS_dt = (k1 * S + k_2 * catP) * cat - (k_1 + k2 * cat) * catS
    dcatP_dt = k2 * catS * cat + k_3 * P * cat - (k3 + k_2 * cat) * catP
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dcatP_dt]

def train_inits_M5():
    cat0s = np.random.uniform(0.01, 0.1, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M5():
    cat0s = [np.random.uniform(0.01, 0.02), np.random.uniform(0.045, 0.055), np.random.uniform(0.09, 0.10)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

ODE_FUNCS["M5"] = ode_M5
TRAIN_INIT_FUNCS["M5"] = train_inits_M5
TEST_INIT_FUNCS["M5"] = test_inits_M5
MECH_DIM["M5"] = 6  

In [8]:
def ode_M6(t, y, theta):
    S, P, cat, cat_star, cat_starS = y
    k1, k_1, k2, k_2, k3 = theta
    dS_dt = k_1 * cat_starS - k1 * S * cat_star
    dP_dt = k2 * cat_starS - k_2 * P * cat_star
    dcat_dt = -k3 * cat
    dcat_star_dt = k3 * cat + (k_1 + k2) * cat_starS - (k1 * S + k_2 * P) * cat_star
    dcat_starS_dt = (k1 * S + k_2 * P) * cat_star - (k_1 + k2) * cat_starS
    return [dS_dt, dP_dt, dcat_dt, dcat_star_dt, dcat_starS_dt]

def train_inits_M6():
    cat0s = np.random.uniform(0.03, 0.07, size=3)  # 3~7mol%
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M6():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M6_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_star_traj = sol.y[3]
        cat_starS_traj = sol.y[4]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_20 = np.argmin(np.abs(S_conv - 0.2))
        active_cat = cat_star_traj[idx_20] + cat_starS_traj[idx_20]
        if 0.1 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M6"] = ode_M6
TRAIN_INIT_FUNCS["M6"] = train_inits_M6
TEST_INIT_FUNCS["M6"] = test_inits_M6
MECH_DIM["M6"] = 5
POST_FILTER_FUNCS["M6"] = M6_activecat_post_filter

In [9]:
def ode_M7(t, y, theta):
    S, P, cat, catS, catS2 = y
    k1, k_1, k2, k_2, k3, k_3 = theta
    dS_dt    = k_1 * catS2 - k1 * S * catS + k_3 * catS - k3 * S * cat
    dP_dt    = k2 * catS2 - k_2 * catS * P
    dcat_dt  = k_3 * catS - k3 * S * cat
    dcatS_dt = k3 * S * cat - k_3 * catS + (k_1 + k2) * catS2 - (k1 * S + k_2 * P) * catS
    dcatS2_dt= (k1 * S + k_2 * P) * catS - (k_1 + k2) * catS2
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dcatS2_dt]

def train_inits_M7():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M7():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M7_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        catS_traj = sol.y[3]
        catS2_traj = sol.y[4]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_20 = np.argmin(np.abs(S_conv - 0.2))
        active_cat = catS_traj[idx_20] + catS2_traj[idx_20]
        if 0.1 * cat0 <= active_cat <= 0.8 * cat0:
            return True
    return False

ODE_FUNCS["M7"] = ode_M7
TRAIN_INIT_FUNCS["M7"] = train_inits_M7
TEST_INIT_FUNCS["M7"] = test_inits_M7
MECH_DIM["M7"] = 6
POST_FILTER_FUNCS["M7"] = M7_activecat_post_filter

In [17]:
def ode_M8(t, y, theta):
    S, P, cat, cat_star, cat_starS, L = y
    k1, k_1, k2, k_2, k3, k_3 = theta
    dS_dt      = k_1 * cat_starS - k1 * S * cat_star
    dP_dt      = k2 * cat_starS - k_2 * cat_star * P
    dcat_dt    = k_3 * L * cat_star - k3 * cat
    dcat_star_dt  = k3 * cat - k_3 * L * cat_star + (k_1 + k2) * cat_starS - (k1 * S + k_2 * P) * cat_star
    dcat_starS_dt = (k1 * S + k_2 * P) * cat_star - (k_1 + k2) * cat_starS
    dL_dt      = k3 * cat - k_3 * L * cat_star
    return [dS_dt, dP_dt, dcat_dt, dcat_star_dt, dcat_starS_dt, dL_dt]

def train_inits_M8():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0, 1.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0, 1.0])
    return inits

def test_inits_M8():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0, 1.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0, 1.0])
    return inits

def M8_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0, 1.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_star_traj = sol.y[3]
        cat_starS_traj = sol.y[4]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_star_traj[idx_50] + cat_starS_traj[idx_50]
        if 0.1 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M8"] = ode_M8
TRAIN_INIT_FUNCS["M8"] = train_inits_M8
TEST_INIT_FUNCS["M8"] = test_inits_M8
MECH_DIM["M8"] = 6
POST_FILTER_FUNCS["M8"] = M8_activecat_post_filter

In [19]:
def ode_M9(t, y, theta):
    S, P, cat, catS, inact_cat = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dinact_dt = k_3 * cat
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_dt]

def train_inits_M9():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M9():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M9_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M9"] = ode_M9
TRAIN_INIT_FUNCS["M9"] = train_inits_M9
TEST_INIT_FUNCS["M9"] = test_inits_M9
MECH_DIM["M9"] = 5
POST_FILTER_FUNCS["M9"] = M9_activecat_post_filter

In [22]:
def ode_M10(t, y, theta):
    S, P, cat, catS, inhibitor, inact_catI = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt      = k_1 * catS - k1 * S * cat
    dP_dt      = k2 * catS - k_2 * cat * P
    dcat_dt    = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3 * inhibitor) * cat
    dcatS_dt   = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dinhibitor_dt = -k_3 * inhibitor * cat
    dinact_catI_dt = k_3 * inhibitor * cat
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinhibitor_dt, dinact_catI_dt]

def train_inits_M10():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0, 0.0])
    return inits

def test_inits_M10():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0, 0.0])
    return inits

def M10_inactcat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 1.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        inact_catI_traj = sol.y[5]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        inact_catI = inact_catI_traj[idx_50]
        if inact_catI >= 0.1 * cat0:
            return True
    return False

ODE_FUNCS["M10"] = ode_M10
TRAIN_INIT_FUNCS["M10"] = train_inits_M10
TEST_INIT_FUNCS["M10"] = test_inits_M10
MECH_DIM["M10"] = 5
POST_FILTER_FUNCS["M10"] = M10_inactcat_post_filter

In [24]:
def ode_M11(t, y, theta):
    S, P, cat, catS, inact_catS = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - (k1 + k_3) * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3 * S) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dinact_catS_dt = k_3 * S * cat
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_catS_dt]

def train_inits_M11():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M11():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M11_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M11"] = ode_M11
TRAIN_INIT_FUNCS["M11"] = train_inits_M11
TEST_INIT_FUNCS["M11"] = test_inits_M11
MECH_DIM["M11"] = 5
POST_FILTER_FUNCS["M11"] = M11_activecat_post_filter

In [26]:
def ode_M12(t, y, theta):
    S, P, cat, catS, inact_catP = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - (k_2 + k_3) * P * cat
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dinact_catP_dt = k_3 * P * cat
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_catP_dt]

def train_inits_M12():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M12():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M12_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M12"] = ode_M12
TRAIN_INIT_FUNCS["M12"] = train_inits_M12
TEST_INIT_FUNCS["M12"] = test_inits_M12
MECH_DIM["M12"] = 5
POST_FILTER_FUNCS["M12"] = M12_activecat_post_filter

In [28]:
def ode_M13(t, y, theta):
    S, P, cat, catS, inact_cat2 = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt      = k_1 * catS - k1 * S * cat
    dP_dt      = k2 * catS - k_2 * cat * P
    dcat_dt    = (k_1 + k2) * catS - (k1 * S + k_2 * P + 2 * k_3 * cat) * cat
    dcatS_dt   = (k1 * S + k_2 * P) * cat - (k_1 + k2) * catS
    dinact_cat2_dt = k_3 * cat**2
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_cat2_dt]
def train_inits_M13():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M13():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M13_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.8 * cat0:
            return True
    return False

ODE_FUNCS["M13"] = ode_M13
TRAIN_INIT_FUNCS["M13"] = train_inits_M13
TEST_INIT_FUNCS["M13"] = test_inits_M13
MECH_DIM["M13"] = 5
POST_FILTER_FUNCS["M13"] = M13_activecat_post_filter

In [30]:
def ode_M14(t, y, theta):
    S, P, cat, catS, inact_catS = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3) * catS
    dinact_catS_dt = k_3 * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_catS_dt]

def train_inits_M14():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M14():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M14_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M14"] = ode_M14
TRAIN_INIT_FUNCS["M14"] = train_inits_M14
TEST_INIT_FUNCS["M14"] = test_inits_M14
MECH_DIM["M14"] = 5
POST_FILTER_FUNCS["M14"] = M14_activecat_post_filter

In [32]:
def ode_M15(t, y, theta):
    S, P, cat, catS, inhibitor, inact_catSI = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3 * inhibitor) * catS
    dinhibitor_dt = -k_3 * inhibitor * catS
    dinact_catSI_dt = k_3 * inhibitor * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinhibitor_dt, dinact_catSI_dt]

def train_inits_M15():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0, 0.0])
    return inits

def test_inits_M15():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 1.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 1.0, 0.0])
    return inits

def M15_inactcat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 1.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        inact_catSI_traj = sol.y[5]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        inact_catSI = inact_catSI_traj[idx_50]
        if inact_catSI > 0.1 * cat0:
            return True
    return False

ODE_FUNCS["M15"] = ode_M15
TRAIN_INIT_FUNCS["M15"] = train_inits_M15
TEST_INIT_FUNCS["M15"] = test_inits_M15
MECH_DIM["M15"] = 5
POST_FILTER_FUNCS["M15"] = M15_inactcat_post_filter

In [34]:
def ode_M16(t, y, theta):
    S, P, cat, catS, inact_catS2 = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat - k_3 * S * catS
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3 * S) * catS
    dinact_catS2_dt = k_3 * S * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_catS2_dt]

def train_inits_M16():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M16():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M16_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M16"] = ode_M16
TRAIN_INIT_FUNCS["M16"] = train_inits_M16
TEST_INIT_FUNCS["M16"] = test_inits_M16
MECH_DIM["M16"] = 5
POST_FILTER_FUNCS["M16"] = M16_activecat_post_filter

In [36]:
def ode_M17(t, y, theta):
    S, P, cat, catS, inact_catSP = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P - k_3 * P * catS
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3 * P) * catS
    dinact_catSP_dt = k_3 * P * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_catSP_dt]

def train_inits_M17():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M17():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M17_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.9 * cat0:
            return True
    return False

ODE_FUNCS["M17"] = ode_M17
TRAIN_INIT_FUNCS["M17"] = train_inits_M17
TEST_INIT_FUNCS["M17"] = test_inits_M17
MECH_DIM["M17"] = 5
POST_FILTER_FUNCS["M17"] = M17_activecat_post_filter

In [38]:
def ode_M18(t, y, theta):
    S, P, cat, catS, inact_cat2S2 = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt    = k_1 * catS - k1 * S * cat
    dP_dt    = k2 * catS - k_2 * cat * P
    dcat_dt  = (k_1 + k2) * catS - (k1 * S + k_2 * P) * cat
    dcatS_dt = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3 * catS) * catS
    dinact_cat2S2_dt = k_3 * catS ** 2
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_cat2S2_dt]

def train_inits_M18():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M18():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M18_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.8 * cat0:
            return True
    return False

ODE_FUNCS["M18"] = ode_M18
TRAIN_INIT_FUNCS["M18"] = train_inits_M18
TEST_INIT_FUNCS["M18"] = test_inits_M18
MECH_DIM["M18"] = 5
POST_FILTER_FUNCS["M18"] = M18_activecat_post_filter

In [40]:
def ode_M19(t, y, theta):
    S, P, cat, catS, inact_cat2S = y
    k1, k_1, k2, k_2, k_3 = theta
    dS_dt      = k_1 * catS - k1 * S * cat
    dP_dt      = k2 * catS - k_2 * cat * P
    dcat_dt    = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3 * catS) * cat
    dcatS_dt   = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_3 * cat) * catS
    dinact_cat2S_dt = k_3 * cat * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_cat2S_dt]

def train_inits_M19():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def test_inits_M19():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0])
    return inits

def M19_activecat_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        if 0.5 * cat0 <= active_cat <= 0.8 * cat0:
            return True
    return False

ODE_FUNCS["M19"] = ode_M19
TRAIN_INIT_FUNCS["M19"] = train_inits_M19
TEST_INIT_FUNCS["M19"] = test_inits_M19
MECH_DIM["M19"] = 5
POST_FILTER_FUNCS["M19"] = M19_activecat_post_filter

In [42]:
def ode_M20(t, y, theta):
    S, P, cat, catS, inact_cat, inact_catS = y
    k1, k_1, k2, k_2, k_3, k_4 = theta
    dS_dt      = k_1 * catS - k1 * S * cat
    dP_dt      = k2 * catS - k_2 * cat * P
    dcat_dt    = (k_1 + k2) * catS - (k1 * S + k_2 * P + k_3) * cat
    dcatS_dt   = (k1 * S + k_2 * P) * cat - (k_1 + k2 + k_4) * catS
    dinact_cat_dt = k_3 * cat
    dinact_catS_dt = k_4 * catS
    return [dS_dt, dP_dt, dcat_dt, dcatS_dt, dinact_cat_dt, dinact_catS_dt]

def train_inits_M20():
    cat0s = np.random.uniform(0.03, 0.07, size=3)
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0, 0.0])
    return inits

def test_inits_M20():
    cat0s = [np.random.uniform(0.03, 0.04), np.random.uniform(0.05, 0.06), np.random.uniform(0.06, 0.07)]
    S0_4 = np.round(np.random.uniform(0.4, 0.8), 4)
    P0_4 = 1.0 - S0_4
    pick4 = np.random.randint(0, 3)
    inits = []
    for i in range(3):
        S0 = 1.0
        P0 = 0.0
        cat0 = np.round(cat0s[i] * S0, 4)
        inits.append([S0, P0, cat0, 0.0, 0.0, 0.0])
    inits.append([S0_4, P0_4, np.round(cat0s[pick4] * S0_4, 4), 0.0, 0.0, 0.0])
    return inits

def M20_post_filter(theta, ode_func):
    for cat0 in [0.03, 0.05, 0.07]:
        y0 = [1.0, 0.0, cat0, 0.0, 0.0, 0.0]
        t_span = (0, 100)
        t_eval = np.linspace(0, 100, 100)
        sol = solve_ivp(lambda t, y: ode_func(t, y, theta), t_span, y0, t_eval=t_eval, method="LSODA")
        if (not sol.success) or np.any(np.isnan(sol.y)):
            continue
        S_traj = sol.y[0]
        cat_traj = sol.y[2]
        catS_traj = sol.y[3]
        inact_cat_traj = sol.y[4]
        inact_catS_traj = sol.y[5]
        S_start = S_traj[0]
        S_end = S_traj[-1]
        S_conv = (S_start - S_traj) / (S_start - S_end + 1e-10)
        idx_50 = np.argmin(np.abs(S_conv - 0.5))
        active_cat = cat_traj[idx_50] + catS_traj[idx_50]
        inact_cat = inact_cat_traj[idx_50]
        inact_catS = inact_catS_traj[idx_50]
        if (0.5 * cat0 <= active_cat <= 0.9 * cat0) and (inact_cat > 0.05 * cat0) and (inact_catS > 0.05 * cat0):
            return True
    return False

ODE_FUNCS["M20"] = ode_M20
TRAIN_INIT_FUNCS["M20"] = train_inits_M20
TEST_INIT_FUNCS["M20"] = test_inits_M20
MECH_DIM["M20"] = 6
POST_FILTER_FUNCS["M20"] = M20_post_filter

In [None]:
def reformat_x2_for_paper(x2_array, t_list, n_points=21):
    x2_reformat = []
    for i in range(len(x2_array)):
        profiles = []
        for j in range(4):
            start = j * n_points
            stop = (j + 1) * n_points
            time = t_list[i][start:stop].reshape(-1, 1)
            SP = x2_array[i][start:stop, :2]
            profiles.append(np.hstack([time, SP]))
        block = np.hstack(profiles)
        x2_reformat.append(block)
    return np.stack(x2_reformat)

def generate_simulation_dataset(output_dir='simulation_data'):
    os.makedirs(output_dir, exist_ok=True)
    x1_train_all, x1_val_all, x1_test_all = [], [], []
    x2_train_all, x2_val_all, x2_test_all = [], [], []
    t_train_all, t_val_all, t_test_all = [], [], []
    y_train_all, y_val_all, y_test_all = [], [], []

    for label, mech in enumerate(tqdm(MECHANISMS, desc="Generating mechanisms", ncols=80, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")):
        if mech not in ODE_FUNCS:
            continue
        ode_func = ODE_FUNCS[mech]
        train_init_func = TRAIN_INIT_FUNCS[mech]
        test_init_func = TEST_INIT_FUNCS[mech]
        dim = MECH_DIM[mech]
        filter_func = POST_FILTER_FUNCS.get(mech)
        n_total = N_TRAIN_PER_MECH + N_VAL_PER_MECH + N_TEST_PER_MECH

        if filter_func:
            theta_pool = filter_theta_pool(n_total, filter_func, ode_func, dim, max_round=40, batch_factor=20)
        else:
            theta_pool = sample_theta_pool(n_total, log_min=LOG_MIN, log_max=LOG_MAX, dim=dim, seed=SEED + label)

        theta_train, theta_val, theta_test = np.split(theta_pool, [N_TRAIN_PER_MECH, N_TRAIN_PER_MECH + N_VAL_PER_MECH])
        x1_train, x2_train, t_train = generate_dataset_given_theta(theta_train, ode_func, train_init_func, test_init_func, mode='train', desc=f"Train set ({mech})")
        x1_val, x2_val, t_val = generate_dataset_given_theta(theta_val, ode_func, train_init_func, test_init_func, mode='train', desc=f"Validation set ({mech})")
        x1_test, x2_test, t_test = generate_dataset_given_theta(theta_test, ode_func, train_init_func, test_init_func, mode='test', desc=f"Test set ({mech})")

        x1_train_all.append(x1_train)
        x1_val_all.append(x1_val)
        x1_test_all.append(x1_test)
        x2_train_all.extend(x2_train)
        x2_val_all.extend(x2_val)
        x2_test_all.extend(x2_test)
        t_train_all.extend(t_train)
        t_val_all.extend(t_val)
        t_test_all.extend(t_test)
        y_train_all.append(np.full(len(x1_train), label))
        y_val_all.append(np.full(len(x1_val), label))
        y_test_all.append(np.full(len(x1_test), label))

    x1_train = np.concatenate(x1_train_all)
    x1_val = np.concatenate(x1_val_all)
    x1_test = np.concatenate(x1_test_all)
    y_train = np.concatenate(y_train_all).reshape(-1, 1)
    y_val = np.concatenate(y_val_all).reshape(-1, 1)
    y_test = np.concatenate(y_test_all).reshape(-1, 1)
    x2_train = np.stack(x2_train_all)
    x2_val = np.stack(x2_val_all)
    x2_test = np.stack(x2_test_all)
    x2_train_final = reformat_x2_for_paper(x2_train, t_train_all, n_points=21)
    x2_val_final = reformat_x2_for_paper(x2_val, t_val_all, n_points=21)

    with open(os.path.join(output_dir, 'x2_train_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x2_train_final, f)
    with open(os.path.join(output_dir, 'x2_val_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x2_val_final, f)
    n_test_samples = x2_test.shape[0]
    x2_test_dict = {}
    for n_pts in [2, 6, 20]:
        x2_test_dict[n_pts] = {}
        for noise_pct in [0, 1, 5]:
            test_profiles = []
            for i in range(n_test_samples):
                profiles = []
                for j in range(4):
                    idx_full = np.arange(j * 21, (j + 1) * 21)
                    idx = np.sort(np.random.choice(idx_full, n_pts, replace=False))
                    time = np.array(t_test_all[i])[idx].reshape(-1, 1)
                    SP = x2_test[i][idx, :2]
                    if noise_pct > 0:
                        std = noise_pct / 100 * np.max(SP)
                        SP = SP + np.random.normal(0, std, size=SP.shape)
                    profiles.append(np.hstack([time, SP]))
                block = np.hstack(profiles)
                test_profiles.append(block)
            x2_test_dict[n_pts][noise_pct] = np.stack(test_profiles)

    with open(os.path.join(output_dir, 'x2_test_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x2_test_dict, f)
    with open(os.path.join(output_dir, 'x1_train_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x1_train, f)
    with open(os.path.join(output_dir, 'x1_val_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x1_val, f)
    with open(os.path.join(output_dir, 'x1_test_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(x1_test, f)
    with open(os.path.join(output_dir, 'y_train_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(y_train, f)
    with open(os.path.join(output_dir, 'y_val_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(y_val, f)
    with open(os.path.join(output_dir, 'y_test_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(y_test, f)
    with open(os.path.join(output_dir, 't_train_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(t_train_all, f)
    with open(os.path.join(output_dir, 't_val_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(t_val_all, f)
    with open(os.path.join(output_dir, 't_test_M1_M20_train_val_test_set_final.pkl'), 'wb') as f:
        pickle.dump(t_test_all, f)

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn')
    try:
        generate_simulation_dataset()
    except Exception as e:
        print("Dataset generation failed:", e)
        import sys
        sys.exit(1)

Generating mechanisms:   0%|                                    | 0/20 [00:00<?]
Train set (M1):   0%|                                          | 0/300 [00:00<?][A
Train set (M1):   1%|▎                                     | 2/300 [00:00<00:20][A
Train set (M1):   1%|▌                                     | 4/300 [00:00<00:20][A
Train set (M1):   2%|▊                                     | 6/300 [00:00<00:19][A
Train set (M1):   3%|█                                     | 8/300 [00:00<00:19][A
Train set (M1):   3%|█▏                                   | 10/300 [00:00<00:19][A
Train set (M1):   4%|█▍                                   | 12/300 [00:00<00:18][A
Train set (M1):   5%|█▋                                   | 14/300 [00:00<00:19][A
Train set (M1):   5%|█▉                                   | 16/300 [00:01<00:19][A
Train set (M1):   6%|██▏                                  | 18/300 [00:01<00:19][A
Train set (M1):   7%|██▍                                  | 20/300 [00:01<00:19

[INFO] ode_M2 | Round 1 | Param range: [1.0e-05, 1.0e+05] | Batch size: 40
[INFO] ode_M2 | Sample θ example: [1.24e+00, 9.27e+04, 1.38e+03, 1.42e+01, 1.15e-02, 1.73e+01]
[INFO] ode_M2 | Round 1 | Accepted 0 (+0 this round) / Tested 40 (pass rate 0.000%) | Time: 0.21s
[INFO] ode_M2 | Round 2 | Param range: [1.0e-05, 1.0e+05] | Batch size: 40
[INFO] ode_M2 | Sample θ example: [6.99e-04, 4.19e-02, 2.21e-02, 1.83e-05, 1.75e-05, 2.07e+03]
[INFO] ode_M2 | Round 2 | Accepted 0 (+0 this round) / Tested 80 (pass rate 0.000%) | Time: 0.24s
[INFO] ode_M2 | Round 3 | Param range: [1.0e-05, 1.0e+05] | Batch size: 40
[INFO] ode_M2 | Sample θ example: [6.31e-03, 3.83e+03, 8.36e-03, 1.28e+04, 3.46e+02, 1.12e+03]
[INFO] ode_M2 | Round 3 | Accepted 0 (+0 this round) / Tested 120 (pass rate 0.000%) | Time: 0.19s
[INFO] ode_M2 | Round 4 | Param range: [1.0e-05, 1.0e+05] | Batch size: 40
[INFO] ode_M2 | Sample θ example: [3.92e-03, 8.84e-04, 3.77e-05, 2.66e-02, 5.04e-03, 4.62e+04]
[INFO] ode_M2 | Round 4 |