In [None]:
from matplotlib import pyplot as plt
import matplotlib as mpl
plt.style.use("ggplot")
mpl.rcParams['font.size'] = 16
mpl.rcParams['legend.fontsize'] = 14
from scipy.integrate import solve_ivp
from scipy.integrate._ivp.ivp import OdeResult
from scipy.optimize import curve_fit
import seaborn as sns
import numpy as np
import tqdm, os, glob
import itertools

from ars_247_2025_final_project.tram_model import TRAM_Model
from ars_247_2025_final_project.tools import pulse_model
from ars_247_2025_final_project.analysis import analyze_single_var_axis, analyze_multiple_var_axis
from ars_247_2025_final_project.vis import visualize_single_var_axis, visualize_two_var_grid

In [None]:
OUTPUT_DIR = os.path.abspath("..")
OUTPUT_PNG_DIR = os.path.join(OUTPUT_DIR, "Final_Project_PNG", "Fig2")
print(f"Saving files to {OUTPUT_DIR}...")
OUTPUT_NPZ_DIR = os.path.join(OUTPUT_DIR, "npz")
USE_EXISTING_NPZ = False
if len(glob.glob(os.path.join(OUTPUT_NPZ_DIR, "*.npz"))) > 0:
    answer = ""
    while answer != 'y' and answer != 'n':
        answer = input(f"Detected '.npz' file in {OUTPUT_NPZ_DIR}. Would you like to use it as your data? (y/n)")
    if answer == 'y': USE_EXISTING_NPZ = True

In [None]:
HOUR = 3600
MINUTE = 60
# transcriptional pulse
transcriptional_pulse = {
    'peak_time': 1.5 * HOUR,
    'peak_value': 1,
    'width': 0.5 * HOUR
}

# phosphorylation pulse
phosphorylation_pulse = {
    'peak_time': 1.5 * HOUR // 4,
    'peak_value': 1,
    'width': 0.5 * HOUR // 4 # 45 mins peak phosphorylation pulse
}

# optogenetic pulse
optogenetic_pulse = {
    'peak_time': 1.5 * HOUR // 10, # peak at 9 mins
    'peak_value': 1,
    'width': 0.5 * HOUR // 10 # 3 mins peak activity of pulse
}

og_system_params = {
    'PROM': 1,
    'V_A': 0.47, #um/s - mean for dynein / kinesin
    'R_M': 2.5, #um - mean for mammalian cells is 2.5; best results w/ 20
    'K_TEV': 0.015, #/s - physiological is ~0.15
    'K_TVMV': 0.015, #/s - physiological is ~0.15
    'K_TRAM_DEG': 0.00001, #/s - got interesting results with this being 0.01; physiological is ~0.00001
    'K_TF_DEG': 0.01 #/s - got interesting results with this being 0.1; physiological is ~0.01 w/ fast degron
}

system_params = og_system_params.copy()

## Figure 2

In [None]:
# Creating the linear parameter space for each parameter of interest
linear_parameter_space = np.concat([1 / np.arange(2, 8, 2)[::-1], np.ndarray([1]), np.arange(2, 8, 2)])
speed_ranges = linear_parameter_space * og_system_params['V_A']
protease_ranges = linear_parameter_space * og_system_params['K_TEV']
deg_tram_ranges = linear_parameter_space * og_system_params['K_TRAM_DEG']
deg_tf_ranges = linear_parameter_space * og_system_params['K_TF_DEG']

if not USE_EXISTING_NPZ:
    # Combinatorially constructing all possible parameter combinations based on linear parameter space
    all_combinations = np.array(list(itertools.product(speed_ranges, protease_ranges, deg_tram_ranges, deg_tf_ranges)))
    models_comb = []
    T = int(24 * HOUR)
    N_DOMAINS = 6
    solns_comb = np.zeros(shape=(len(all_combinations), 4 * N_DOMAINS + 1, T))
    system_params = og_system_params.copy()
    for idx, (speedx, protx, degrx, degtx) in enumerate(tqdm.tqdm(all_combinations)):
        system_params['V_A'] = speedx
        system_params['K_TEV'], system_params['K_TVMV'] = protx, protx
        system_params['K_TRAM_DEG'], system_params['K_TF_DEG'] = degrx, degtx

        this_model = TRAM_Model(
            n_domains=N_DOMAINS,
            pulse='gaussian',
            pulse_params=optogenetic_pulse,
            sys_config=system_params
        )

        models_comb.append(this_model)
        solns_comb[idx,...] = this_model.solve_tram_ivp(0, T)['y'][:,:T]

else:
    print("User elected to use existing NPZ as input data, skipping IVP solving step...")

Stacking results into a solution object and saving to NPZ file

In [None]:
linear_npz_path = os.path.join(OUTPUT_NPZ_DIR, "linear_parameter_space_results.npz")
if not USE_EXISTING_NPZ: 
    npz_save_dict = {
        'V_A': speed_ranges,
        'K_TEV': protease_ranges,
        'K_TRAM_DEG': deg_tram_ranges,
        'K_TF_DEG': deg_tf_ranges,
        'n_ranges': 4,
        'combinations': all_combinations,
        'solutions': solns_comb
    }
    
    np.savez(
        file=linear_npz_path,
        **npz_save_dict
    )
else:
    npz_save_dict = np.load(linear_npz_path)

### Figure 2a

In [None]:
# Getting pulse start time
p_model = pulse_model('gaussian', params=optogenetic_pulse)
p_peak_time = np.argmax(
    solve_ivp(
        fun = p_model,
        t_span = [0, 4 * 3600],
        y0 = [0],
        max_step = 1)
)

In [None]:
key, key_range, delays, amplitudes = analyze_single_var_axis \
(
    input_data=npz_save_dict,
    key='V_A',
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_single_var_axis(
    key = key,
    delays = delays,
    amplitudes = amplitudes,
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2a-V_A.png"))

In [None]:
key, key_range, delays, amplitudes = analyze_single_var_axis \
(
    input_data=npz_save_dict,
    key='K_TEV',
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_single_var_axis(
    key = key,
    delays = delays,
    amplitudes = amplitudes
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2a-K_TEV.png"))

In [None]:
key, key_range, delays, amplitudes = analyze_single_var_axis \
(
    input_data=npz_save_dict,
    key='K_TRAM_DEG',
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_single_var_axis(
    key = key,
    delays = delays,
    amplitudes = amplitudes
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2a-K_TRAM_DEG.png"))

In [None]:
key, key_range, delays, amplitudes = analyze_single_var_axis \
(
    input_data=npz_save_dict,
    key='K_TF_DEG',
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_single_var_axis(
    key = key,
    delays = delays,
    amplitudes = amplitudes
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2a-K_TF_DEG.png"))

### Multivariate analysis

In [None]:
keys = ["K_TF_DEG", "K_TEV"]
keys, key_ranges, delays, amplitudes = analyze_multiple_var_axis(
    input_data=npz_save_dict,
    keys=keys,
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_two_var_grid(
    data=amplitudes,
    keys=keys,
    plt_title="Output Peak Amplitude",
    cbar_step=50, # 60 second step for time minutes
    cbar_title="Expression (REU)"
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2b-Amp.png"))

In [None]:
keys = ["V_A", "K_TEV"]
keys, key_ranges, delays, amplitudes = analyze_multiple_var_axis(
    input_data=npz_save_dict,
    keys=keys,
    default_params=og_system_params,
    start_time=p_peak_time
)

In [None]:
visualize_two_var_grid(
    data=delays,
    keys=keys,
    plt_title="Output Peak Delay",
    cbar_step=HOUR / 4, # 60 second step for time minutes
    cbar_title="Time (hrs)",
    divide_cbar_label_by=HOUR
)
plt.savefig(os.path.join(OUTPUT_DIR, "Fig2a-Delay.png"))