# Additional Simulations

In [122]:
from zebrafish_ms2_paper.gillespie_simulations_delay import Params, simulate_multiple_copies, sim_ms2, hill_function
from zebrafish_ms2_paper.trace_analysis import binarize_trace, get_on_and_off_times, get_burst_durations, get_burst_inactive_durations, extract_traces
from zebrafish_ms2_paper.utils import pboc_rc, style_axes, colors, fontsize
import matplotlib.pyplot as plt
from matplotlib import rc, rcParams
import numpy as np
from multiprocessing import Pool
import pandas as pd
import pickle
from copy import deepcopy
from scipy.interpolate import interp1d

In [2]:
rcParams.update(pboc_rc)
rcParams['pdf.fonttype'] = 42

In [3]:
colors

{'green': '#7AA974',
 'light_green': '#BFD598',
 'pale_green': '#DCECCB',
 'yellow': '#EAC264',
 'light_yellow': '#F3DAA9',
 'pale_yellow': '#FFEDCE',
 'blue': '#738FC1',
 'light_blue': '#A9BFE3',
 'pale_blue': '#C9D7EE',
 'red': '#D56C55',
 'light_red': '#E8B19D',
 'pale_red': '#F1D4C9',
 'purple': '#AB85AC',
 'light_purple': '#D4C2D9',
 'dark_green': '#7E9D90',
 'dark_brown': '#905426'}

In [204]:
fontsize = 9
linewidth = 2
run_sim = True
n_replicates = 15
markersize = 8
#bins = np.linspace(3, 55, 12)
bins = np.arange(3, 55)

In [5]:
%matplotlib qt

In [38]:
"""fixed model params"""
elongation_time = 0.29
delta_t = 1.0
w = elongation_time / delta_t
sigma = 0.2
n_replicates = 20


In [121]:
np.roll(np.arange(10), 3)

array([7, 8, 9, 0, 1, 2, 3, 4, 5, 6])

In [145]:
tgrid

array([0.000000e+00, 1.000000e-04, 2.000000e-04, ..., 2.401007e+02,
       2.401008e+02, 2.401009e+02])

In [146]:
tvec

array([0.00000000e+00, 9.68687847e-02, 1.14875085e-01, ...,
       2.39979869e+02, 2.39990290e+02, 2.40100865e+02])

In [159]:
X, tvec, p = simulate_multiple_copies(p)
# need to sample protein on a regular grid, shift it, then resample it on the given tvec
if p.delay > 0:
    fine_dt = 0.0001
    tgrid = np.arange(0, tvec[-1], fine_dt)
    tgrid[-1] = tvec[-1]
    protein_grid = interp1d(tvec, X[:, -1], kind='previous')(tgrid)
    delay_index = int(np.round(p.delay / fine_dt))
    delayed_protein = np.roll(protein_grid, delay_index)
    repressing_protein = interp1d(tgrid, delayed_protein)(tvec)
else:
    repressing_protein = X[:, -1]

burn_in_time = 45
X = X[tvec > burn_in_time]
repressing_protein = repressing_protein[tvec > burn_in_time]
tvec = tvec[tvec > burn_in_time] - burn_in_time
production_rate = p.transcription_rate_0 + p.transcription_rate_1 * hill_function(repressing_protein, p.KD_transcription_rate, p.n)
state = X[:, 0]



In [160]:
relevant_production_rate = production_rate[np.where(tvec > burn_in_time)[0]]
m = np.mean(np.log10(relevant_production_rate[relevant_production_rate > 0]))
s = np.std(np.log10(relevant_production_rate[relevant_production_rate > 0]))
detection_threshold = 0.22 * 10 ** (m + s) * w

ms2, uniform_times = sim_ms2(state, tvec, production_rate, w, delta_t, sigma, detection_threshold)


In [161]:
plt.figure()
plt.plot(uniform_times, ms2)

[<matplotlib.lines.Line2D at 0x7f7dde60a660>]

In [163]:
"""function for computing interval distributions"""
def compute_burst_intervals(p):
    X, tvec, p = simulate_multiple_copies(p)
    if p.delay > 0:
        # need to sample protein on a regular grid, shift it, then resample it on the given tvec
        fine_dt = 0.0001
        tgrid = np.arange(0, tvec[-1], fine_dt)
        tgrid[-1] = tvec[-1]
        protein_grid = interp1d(tvec, X[:, -1], kind='previous')(tgrid)
        delay_index = int(np.round(p.delay / fine_dt))
        delayed_protein = np.roll(protein_grid, delay_index)
        repressing_protein = interp1d(tgrid, delayed_protein)(tvec)
    else:
        repressing_protein = X[:, -1]

    burn_in_time = 45
    X = X[tvec > burn_in_time]
    repressing_protein = repressing_protein[tvec > burn_in_time]
    tvec = tvec[tvec > burn_in_time] - burn_in_time
    production_rate = p.transcription_rate_0 + p.transcription_rate_1 * hill_function(repressing_protein, p.KD_transcription_rate, p.n)
    state = X[:, 0]
    
    elongation_time = 0.29
    delta_t = 1.0
    w = elongation_time / delta_t
    sigma = 0.2
    relevant_production_rate = production_rate[np.where(tvec > burn_in_time)[0]]
    m = np.mean(np.log10(relevant_production_rate[relevant_production_rate > 0]))
    s = np.std(np.log10(relevant_production_rate[relevant_production_rate > 0]))
    detection_threshold = 0.22 * 10 ** (m + s) * w

    ms2, uniform_times = sim_ms2(state, tvec, production_rate, w, delta_t, sigma, detection_threshold)
    
    inferred_state = binarize_trace(ms2, uniform_times, thresh=1e-1, window_size=3)
    on_times, off_times = get_on_and_off_times(inferred_state, uniform_times)
    active_durations  = get_burst_durations(on_times, off_times)
    inactive_durations = get_burst_inactive_durations(on_times, off_times)
    periods = np.diff(on_times)
       
    return active_durations, inactive_durations, periods

def init_pool_processes():
    np.random.seed()
    

"""more functions for computing interval distributions. we also need func and init_pool_processes from the cell above."""
def compute_distributions(traces, bins):
    pulse_periods = []
    pulse_durations =[]
    pulse_quiets = []
    for i, trace in enumerate(traces):
        t_arr, inten_arr, nucleus = trace
        t_arr = non_blank_timepoints[t_arr.astype('int')]
        state = binarize_trace(inten_arr, t_arr, thresh=1.0, window_size=5)
        on_times, off_times = get_on_and_off_times(state, t_arr)
        if len(on_times) > 2:
            these_pulse_periods = np.diff(on_times)
            these_quiets = get_burst_inactive_durations(on_times, off_times)
            pulse_periods.extend([p for p in these_pulse_periods])
            these_pulse_durations = get_burst_durations(on_times, off_times)
            pulse_durations.extend([p for p in these_pulse_durations])
            pulse_quiets.extend([p for p in these_quiets])
            
    counts, bins = np.histogram(pulse_durations, bins=bins)
    prob_dens_durations = counts / np.sum(counts) / np.diff(bins)
    
    counts, bins = np.histogram(pulse_quiets, bins=bins)
    prob_dens_quiets = counts / np.sum(counts) / np.diff(bins)
            
    return prob_dens_durations, prob_dens_quiets


def bootstrap_simulated_distributions(intervals, bins, n_bootstraps):
    interval_dist_arr = np.zeros((n_bootstraps, len(bins) - 1))
    for i in range(int(n_bootstraps)):
        these_ids = np.random.randint(0, len(intervals), len(intervals), dtype='int')
        these_intervals = [intervals[j] for j in these_ids]
        
        counts, bins = np.histogram(these_intervals, bins=bins)
        interval_dist_arr[i] = counts / np.sum(counts) / np.diff(bins)
        
    return interval_dist_arr 

## Run simulations
Here's the code for running the simulations that produce the interval distributions. To just make the plots skip ahead to where the data is loaded in.

In [224]:
"""amplitude + duration regulation"""
save = True
if run_sim:
    p = Params()
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 0.0
    p.k_off1 = 0.4
    p.k_on0 = 0.055
    p.k_on1 = 0.0
    p.transcription_rate_0 = 0

    p.translation_rate = 4.5
    p.transcription_rate_1 = 10

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23
    p.delay = 0

    p.KD_k_on = 10
    p.KD_k_off = 1100
    p.n = 3

    p.delay = 7.5

    p.KD_transcription_rate = 100
    p.n = 3
    p_arr = [p] * n_replicates * 4
    decay_scale = 0.23 / 3
    for i in range(len(p_arr)):
        tmp_p = deepcopy(p_arr[i])
        tmp_p.protein_decay_rate = np.clip(np.random.normal(loc=0.23, scale=decay_scale), a_min=0, a_max=np.inf)
        p_arr[i] = tmp_p

    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_burst_intervals, p_arr)

    active_durations = [item for sublist in res for item in sublist[0]]
    inactive_durations = [item for sublist in res for item in sublist[1]]
    periods = [item for sublist in res for item in sublist[2]]

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_and_dur_reg_intervals.pkl', 'wb') as f:
        pickle.dump([active_durations, inactive_durations, periods], f)






In [65]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_and_dur_reg_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])


In [225]:
"""amp, frequency, and dur regulation"""
save = True
if run_sim:
    p = Params()
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 0.0
    p.k_off1 = 0.4
    p.k_on0 = 0.0
    p.k_on1 = 0.5
    p.transcription_rate_0 = 0

    p.translation_rate = 4.5
    p.transcription_rate_1 = 10

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23

    p.delay = 7.5

    p.KD_k_on = 80
    p.KD_k_off = 1300
    p.KD_transcription_rate = 100

    p.n = 3

    p_arr = [p] * n_replicates * 4
    decay_scale = 0
    for i in range(len(p_arr)):
        tmp_p = deepcopy(p_arr[i])
        tmp_p.protein_decay_rate = np.clip(np.random.normal(loc=0.23, scale=decay_scale), a_min=0, a_max=np.inf)
        p_arr[i] = tmp_p
        
    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_burst_intervals, p_arr)

    active_durations = [item for sublist in res for item in sublist[0]]
    inactive_durations = [item for sublist in res for item in sublist[1]]
    periods = [item for sublist in res for item in sublist[2]]

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_freq_dur_reg_intervals.pkl', 'wb') as f:
        pickle.dump([active_durations, inactive_durations, periods], f)



In [67]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_freq_dur_reg_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])

In [226]:
"""amp and freq regulation"""
save = True
if run_sim:
    p = Params()
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 0.08
    p.k_off1 = 0.0
    p.k_on0 = 0.0
    p.k_on1 = 0.5
    p.transcription_rate_0 = 0

    p.translation_rate = 4.5
    p.transcription_rate_1 = 10

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23

    p.delay = 7.5

    p.KD_k_on = 80
    p.KD_k_off = 1300
    p.KD_transcription_rate = 100

    p.n = 3

    p_arr = [p] * n_replicates * 4
    decay_scale = 0
    for i in range(len(p_arr)):
        tmp_p = deepcopy(p_arr[i])
        tmp_p.protein_decay_rate = np.clip(np.random.normal(loc=0.23, scale=decay_scale), a_min=0, a_max=np.inf)
        p_arr[i] = tmp_p
        
    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_burst_intervals, p_arr)

    active_durations = [item for sublist in res for item in sublist[0]]
    inactive_durations = [item for sublist in res for item in sublist[1]]
    periods = [item for sublist in res for item in sublist[2]]

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_freq_reg_intervals.pkl', 'wb') as f:
        pickle.dump([active_durations, inactive_durations, periods], f)



In [71]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_freq_reg_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])

In [227]:
"""no regulation"""
save = True
if run_sim:
    p = Params()
    p.number_of_random_numbers_to_pregenerate = 1e6
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 0.08
    p.k_off1 = 0.0
    p.k_on0 = 0.055
    p.k_on1 = 0.0
    p.transcription_rate_0 = 10

    p.translation_rate = 4.5
    p.transcription_rate_1 = 0

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23
    p.delay = 0

    p.KD_k_on = 80
    p.KD_k_off = 1300
    p.n = 3
    p_arr = [p] * n_replicates * 4
    decay_scale = 0
    for i in range(len(p_arr)):
        tmp_p = deepcopy(p_arr[i])
        tmp_p.protein_decay_rate = np.clip(np.random.normal(loc=0.23, scale=decay_scale), a_min=0, a_max=np.inf)
        p_arr[i] = tmp_p
        
    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_burst_intervals, p_arr)

    active_durations = [item for sublist in res for item in sublist[0]]
    inactive_durations = [item for sublist in res for item in sublist[1]]
    periods = [item for sublist in res for item in sublist[2]]

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/no_reg_intervals.pkl', 'wb') as f:
        pickle.dump([active_durations, inactive_durations, periods], f)



In [73]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/no_reg_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])

In [228]:
"""amp regulation with fast bursting"""
save = True
if run_sim:
    p = Params()
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 2
    p.k_off1 = 0.0
    p.k_on0 = 1
    p.k_on1 = 0.0
    p.transcription_rate_0 = 0

    p.translation_rate = 4.5
    p.transcription_rate_1 = 10

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23

    p.delay = 7.5

    p.KD_k_on = 80
    p.KD_k_off = 1300
    p.KD_transcription_rate = 100

    p.n = 3

    p_arr = [p] * n_replicates * 4
    decay_scale = 0
    for i in range(len(p_arr)):
        tmp_p = deepcopy(p_arr[i])
        tmp_p.protein_decay_rate = np.clip(np.random.normal(loc=0.23, scale=decay_scale), a_min=0, a_max=np.inf)
        p_arr[i] = tmp_p
        
    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_burst_intervals, p_arr)

    active_durations = [item for sublist in res for item in sublist[0]]
    inactive_durations = [item for sublist in res for item in sublist[1]]
    periods = [item for sublist in res for item in sublist[2]]

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_reg_fast_bursting_intervals.pkl', 'wb') as f:
        pickle.dump([active_durations, inactive_durations, periods], f)


In [102]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/amp_reg_fast_bursting_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])

In [99]:
p.Tmax = 240
X, tvec, p = simulate_multiple_copies(p)
fig, axs = plt.subplots(3, 1)
for i in range(3):
    axs[i].plot(tvec, X[:, i])

In [100]:
burn_in_time = 0
X = X[burn_in_time:]
production_rate = p.transcription_rate_0 + p.transcription_rate_1 * hill_function(X[:,-1], p.KD_transcription_rate, p.n)
state = X[:, 0]

elongation_time = 0.29
delta_t = 1.0
w = elongation_time / delta_t
sigma = 0.2
relevant_production_rate = production_rate[np.where(tvec > burn_in_time)[0]]
m = np.mean(np.log10(relevant_production_rate[relevant_production_rate > 0]))
s = np.std(np.log10(relevant_production_rate[relevant_production_rate > 0]))
detection_threshold = 0.22 * 10 ** (m + s) * w

ms2, uniform_times = sim_ms2(state, tvec, production_rate, w, delta_t, sigma, detection_threshold)
inferred_state = binarize_trace(ms2, uniform_times, thresh=1e-1, window_size=3)

plt.figure()
plt.plot(uniform_times, ms2 / np.max(ms2))
plt.plot(uniform_times, inferred_state)

[<matplotlib.lines.Line2D at 0x7f7de1eba270>]

In [116]:
'dt' in p.__dict__.keys()

True

In [177]:
"""solve the deterministic DDE using the Euler method"""

def dde_model_derivatives(mrna, protein, delayed_protein, p):
    dmrna_dt = p.transcription_rate_1 / (1 + (delayed_protein / p.KD_transcription_rate) ** p.n) - p.mrna_decay_rate * mrna
    dprotein_dt = p.translation_rate * mrna - p.protein_decay_rate * protein
    
    return dmrna_dt, dprotein_dt

def solve_dde(p):
    t_arr = np.arange(0, p.Tmax, p.dt)
    delay_index = int(np.round(p.delay / p.dt))
    mrna = np.zeros_like(t_arr)
    protein = np.zeros_like(t_arr)
    mrna[0] = p.initial_mrna
    protein[0] = p.initial_protein
    
    for i in range(1, len(t_arr)):
        delayed_index = int(np.clip(i - delay_index, a_min=0, a_max=np.inf))
        delayed_protein = protein[delayed_index]
        
        dmrna_dt, dprotein_dt = dde_model_derivatives(mrna[i - 1], protein[i - 1], delayed_protein, p)
        mrna[i] = mrna[i - 1] + p.dt * dmrna_dt
        protein[i] = protein[i - 1] + p.dt * dprotein_dt
        
    return mrna, protein, t_arr
        
def compute_dde_intervals(p):
    mrna, protein, t_arr = solve_dde(p)
    delay_index = int(np.round(p.delay) / p.dt)
    delayed_protein = np.roll(protein, delay_index)
    burn_in_time = 45
    mrna = mrna[t_arr > burn_in_time]
    protein = protein[t_arr > burn_in_time]
    delayed_protein = delayed_protein[t_arr > burn_in_time]
    t_arr = t_arr[t_arr > burn_in_time] - burn_in_time

    production_rate = p.transcription_rate_1 * hill_function(delayed_protein, p.KD_transcription_rate, p.n)
    state = np.ones_like(production_rate)

    elongation_time = 0.29
    delta_t = 1.0
    w = elongation_time / delta_t
    sigma = 0.2
    m = np.mean(np.log10(production_rate))
    s = np.std(np.log10(production_rate))
    detection_threshold = 0.22 * 10 ** (m + s) * w

    ms2, uniform_times = sim_ms2(state, t_arr, production_rate, w, delta_t, sigma, detection_threshold)
    
    inferred_state = binarize_trace(ms2, uniform_times, thresh=1e-1, window_size=3)
    on_times, off_times = get_on_and_off_times(inferred_state, uniform_times)
    active_durations  = get_burst_durations(on_times, off_times)
    inactive_durations = get_burst_inactive_durations(on_times, off_times)
    periods = np.diff(on_times)
       
    return active_durations, inactive_durations, periods


In [178]:
p = Params()
p.initial_state = np.array([1])
p.Tmax = 240
p.k_off0 = 0.0
p.k_off1 = 0.0
p.k_on0 = 0.0
p.k_on1 = 0.0
p.transcription_rate_0 = 0

p.translation_rate = 4.5
p.transcription_rate_1 = 10.0 / len(p.initial_state)

p.mrna_decay_rate = 0.23
p.protein_decay_rate = 0.23

p.delay = 7.5

p.KD_transcription_rate = 100
p.n = 3
p.dt = 0.01

active_durations, inactive_durations, periods = compute_dde_intervals(p)


In [182]:
save = True
if run_sim:
    p = Params()
    p.initial_state = np.array([1])
    p.Tmax = 240
    p.k_off0 = 0.0
    p.k_off1 = 0.0
    p.k_on0 = 0.0
    p.k_on1 = 0.0
    p.transcription_rate_0 = 0

    p.translation_rate = 4.5
    p.transcription_rate_1 = 10.0 / len(p.initial_state)

    p.mrna_decay_rate = 0.23
    p.protein_decay_rate = 0.23

    p.delay = 7.5

    p.KD_transcription_rate = 100
    p.n = 3
    p.dt = 0.01
    p_arr = [p] * n_replicates * 4

    with Pool(processes=15, initializer=init_pool_processes) as pool:
        res = pool.map(compute_dde_intervals, p_arr)

        active_durations = [item for sublist in res for item in sublist[0]]
        inactive_durations = [item for sublist in res for item in sublist[1]]
        periods = [item for sublist in res for item in sublist[2]]

    if save:
        with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/dde_intervals.pkl', 'wb') as f:
            pickle.dump([active_durations, inactive_durations, periods], f)

In [183]:
active_intervals, inactive_intervals, periods = load_intervals(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/dde_intervals.pkl')
fig, axs = plt.subplots(1, 2)
ax = plot_interval_dists(active_intervals, bins, ax=axs[0])
ax = plot_interval_dists(inactive_intervals, bins, ax=axs[1])

In [111]:
mrna, protein, t_arr = solve_dde(p)
plt.figure()
plt.subplot(121)
plt.plot(t_arr, mrna)
plt.subplot(122)
plt.plot(t_arr, protein)

[<matplotlib.lines.Line2D at 0x7f7de218f7a0>]

In [220]:
def plot_interval_dists(intervals, 
                        bins,
                        ax=None,
                        xlabel=None, 
                        ylabel=None, 
                        xticks=(0, 15, 30), 
                        yticks=(0, 0.05, 0.10),
                        n_bootstraps=100,
                        title=None,
                        color='k',
                       ):
    if ax is None:
        f, ax = plt.subplots()
    counts, bins = np.histogram(intervals, bins=bins)
    prob_dens = counts / np.sum(counts) / np.diff(bins)
    interval_dist_arr = bootstrap_simulated_distributions(intervals, bins, n_bootstraps=n_bootstraps)
    prob_dens_uncertainty = np.std(interval_dist_arr, axis=0)

    ax.fill_between(bins[:-1], prob_dens - prob_dens_uncertainty, prob_dens + prob_dens_uncertainty,
                   facecolor=color, alpha=0.5)
    ax.plot(bins[:-1], prob_dens, '-', linewidth=linewidth, color=color)
    
    if xlabel is None:
        xtick_labels = []
    else:
        xtick_labels = xticks
    if ylabel is None:
        ytick_labels = []
    else:
        ytick_labels = yticks
    ax.set_xticks(xticks, labels=xtick_labels)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_yticks(yticks, labels=ytick_labels)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    if title is not None:
        ax.set_title(title, fontsize=fontsize, fontweight='bold')
        
    ax = style_axes(ax, fontsize=fontsize)

    return ax, prob_dens

def load_intervals(file_name, idx=None):
    """load a .pkl with the lists of active intervals, inactive intervals, and periods.
    if loading a file with these intervals for multiple models, pass the index of the 
    model of interest, idx."""
    with open(file_name, 'rb') as f:
        intervals = pickle.load(f)
    if idx is not None:
        active_intervals, inactive_intervals, periods = intervals[idx]
    else:
        active_intervals, inactive_intervals, periods = intervals
    
    return active_intervals, inactive_intervals, periods
        
def plot_simulated_dists(axd, keys, data_dict, dict_keys, counter=0, titles=None, xlabel=None, ylabel=None, colors=None):
    for i in range(len(dict_keys)):
        if titles is not None:
            title = titles[i]
        else:
            title= None
                   
        active_intervals, inactive_intervals, _ = data_dict[dict_keys[i]]
        ax = axd[keys[counter]]
        if colors is not None:
            color = colors[counter]
        else:
            color = 'k'
        ax, prob_dens = plot_interval_dists(active_intervals, bins, ax, xlabel=xlabel, ylabel=ylabel, title=title, color=color, yticks=(0, 0.2, 0.4))
        #ax.set_ylim([0, 0.12])
        ax.set_ylim([0, 0.5])
        ax.set_xlim([0, 40])

        counter += 1
        
        ax = axd[keys[counter]]
        if colors is not None:
            color = colors[counter]
        else:
            color = 'k'
        ax, prob_dens = plot_interval_dists(inactive_intervals, bins, ax, xlabel=xlabel, ylabel=ylabel, title=title, color=color, yticks=(0, 0.2, 0.4))
        ax.set_xlim([0, 40])
        #ax.set_ylim([0, 0.12])
        ax.set_ylim([0, 0.5])
        
        counter += 1
    
    return counter

# Plot all active interval dists

In [230]:
"""plot grouped by active vs inactive interval"""
f, axd = plt.subplot_mosaic([['a', 'b', 'c', 'd', 'e'], 
                             ['f', 'g', 'h', 'i', 'j'],
                             ['k', 'l', 'm', 'n', 'o'], 
                             ['p', 'q', 'r', 's', 't']], figsize=(6.5, 8))
keys = ['a', 'k', 'b', 'l', 'c', 'm', 'd', 'n', 'e', 'o', 'f', 'p', 'g', 'q', 'h', 'r', 'i', 's', 'j', 't']
titles = ['deterministic txn.\nrate regulation', 
          'unregulated\nbursts', 
          'stochastic txn.\nrate regulation',
          'amplitude\nregulation (fast)',
          'frequency\nregulation', 
          'duration\nregulation', 
          'freq. + dur.\nregulation', 
          'amp. + freq.\nregulation',
          'amp. + dur.\nregulation',
          'amp. + freq. + dur.\nregulation']
dict_keys = ['dde', 
         'random', 
         'amp_no_bursts', 
         'amp_fast_bursts', 
         'freq', 
         'dur', 
         'freq_dur', 
         'amp_freq',
         'amp_dur',
         'amp_freq_dur']
sim_colors = (colors['blue'], colors['purple'],) * 10
counter = 0
xticks = [0, 30, 60]

path_to_data_dict = r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/2025_09_05_all_intervals.pkl'
with open(path_to_data_dict, 'rb') as file:
    data_dict = pickle.load(file)
    
counter = plot_simulated_dists(axd, keys, data_dict, dict_keys, counter=0, xlabel='time (min)', titles=titles, colors=sim_colors)


ylabel_keys = ['a', 'f', 'k', 'p']
for ykey in ylabel_keys:
    axd[ykey].set_ylabel('probability \ndensity (1/min)', fontsize=fontsize)
    axd[ykey].set_yticks((0, 0.2, 0.4), labels=(0, 0.2, 0.4))

plt.gcf().tight_layout(w_pad=-0.15)

In [231]:
plt.savefig(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/figures/2025_09_05_all_simulated_interval_dists_delay_sims_colors.pdf')

In [229]:
"""assemble all the burst intervals into a new dictionary"""
save = True
sim_dir = r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims'
file_names = ['dde_intervals.pkl', 
         'no_reg_intervals.pkl', 
         'amp_reg_intervals.pkl', 
         'amp_reg_fast_bursting_intervals.pkl', 
         'freq_reg_intervals.pkl', 
         'dur_reg_intervals.pkl', 
         'freq_and_dur_reg_intervals.pkl', 
         'amp_freq_reg_intervals.pkl',
         'amp_and_dur_reg_intervals.pkl',
         'amp_freq_dur_reg_intervals.pkl']

out_names = ['dde', 
         'random', 
         'amp_no_bursts', 
         'amp_fast_bursts', 
         'freq', 
         'dur', 
         'freq_dur', 
         'amp_freq',
         'amp_dur',
         'amp_freq_dur']

out_dict = {}
for i, name in enumerate(file_names):
    with open(sim_dir + '/' + name, 'rb') as file:
        intervals = pickle.load(file)
    out_dict[out_names[i]] = intervals

if save:
    with open(r'/home/brandon/Documents/Code/zebrafish-ms2-paper/data/delay_sims/2025_09_05_all_intervals.pkl', 'wb') as file:
        pickle.dump(out_dict, file)