In [1]:
import matplotlib.pyplot as plt
import h5py
import numpy as np
import pandas as pd

import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import pytensor as pt
import pytensor.tensor as ptt

import icomo
from icomo import jax2pytensor

plt.rcParams.update({'font.size': 8})


In [2]:
def bootstrap_confidence_interval(data, n=1000, func=np.mean, alpha=0.05):
    resamples = np.random.choice(data.dropna(), size=(n, len(data.dropna())), replace=True)
    perc = np.percentile([func(r) for r in resamples], [100*alpha/2, 100*(1-alpha/2)])
    return [func(data.dropna()) - perc[0], perc[1] - func(data.dropna())]

def percentile_confidence_interval(data, alpha=0.05):
    perc = np.percentile(data.dropna(), [100*alpha/2, 100*(1-alpha/2)])
    return [data.mean() - perc[0], perc[1] - data.mean()]

# System definition

In [None]:
pps = 30
dt = 1.0 / pps

# standard parameters
# activation is b + sum kappa t_spike with kappa = kernel0 * dt^-alpha
# parameters from model fit
_b = -0.7
_alpha = 0.3
_kappa0 = 0.1
_cutoff = 120.0 # [s]

def create_kernel(alpha, kappa0, cutoff, pps):
    dt = 1.0 / pps
    t = np.arange(dt, cutoff*4, dt)
    # kappa = kappa0 * np.power(t, -alpha) / pps
    kappa = np.zeros_like(t)
    kappa = kappa0 * np.power(t, -alpha) / pps
    kappa *= np.exp(-t/cutoff)
    return kappa, t

def _kappa(dt, pps):
    return _kappa0 * np.power(dt, -_alpha) / pps

kappa, t = create_kernel(_alpha, _kappa0, _cutoff, pps)

fig = plt.figure(figsize=(6.5, 3.5))
ax = fig.add_subplot(111)
ax.set_xscale('log')
ax.set_yscale('log')
dts = np.logspace(-2, 3, 100)
ax.plot(dts, [_kappa(dt, pps) for dt in dts], label=r'${:2f} t^{{-{:2f}}}$'.format(_kappa0,_alpha), color='blue')
ax.plot(t, kappa, label='kernel', color='red')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Kappa')
ax.legend()
plt.tight_layout()
plt.show()

In [4]:
# Define the differential equations
def get_release_rate(ts, args):

    t_args, const_args = args

    B = const_args["B"]
    ALPHA = const_args["ALPHA"]
    KAPPA0 = const_args["KAPPA0"]
    CUTOFF = const_args["CUTOFF"]

    kappa, _ = create_kernel(ALPHA, KAPPA0, CUTOFF, pps)
    kappa = np.flip(kappa)
    print(kappa)

    stim_rate = t_args

    releases = np.zeros(len(ts)+len(kappa))
    activations = np.zeros(len(ts))
    k_length = len(kappa)
    for ind_t, t in enumerate(ts):

        # recovery
        recovery = np.sum(releases[ind_t:ind_t+k_length] * kappa)

        # release
        a = B - recovery
        activations[ind_t] = a

        p = np.exp(a)/(1 + np.exp(a))
        release_rate = - p**4 + 4*p**3 - 6*p**2 + 4*p
        releases[ind_t+k_length] = release_rate * stim_rate[ind_t]
    releases = releases[k_length:]
    return releases, activations

# Depression test

In [None]:
len_sim = 10 * 60 # s
num_points = int(len_sim * pps) 

### First set the time variables
t = jnp.linspace(0, len_sim, num_points) # timepoints at which the ODE is solved
t_stim = t # timepoints at which the stimulus is defined
t_out = t # timepoints at which the output is saved

### Set parameters
stim = 50 * jnp.ones(len(t_stim))
# stim = np.logical_or(t_stim < 2, t_stim > 22) * stim
t_args = stim

const_args = {
    "ALPHA": _alpha,
    "B": _b,
    "KAPPA0": _kappa0,
    "CUTOFF": _cutoff,
}

# get the release rate
release_rate, activations = get_release_rate(t, (t_args, const_args))
print(release_rate)

In [None]:
# pool sizes
f = plt.figure(figsize=(3*2.5,1.8))
axs = f.subplots(1,3)

# compare discrete and continuous release
filename = "../fig2/depression_test_50Hz.h5"
dic = h5py.File(filename, 'r')

ves_per_spike = dic.get('ves_per_spike')[()] 
ves_per_spike_mean = np.mean(ves_per_spike, axis=1)
ps = dic.get('ps')[()]
sizes = dic.get('sizes')[()]
test_times = dic.get('test_times')[()]

i_p = np.where(ps == 0.2)[0][0]
i_size = np.where(sizes == 1.0)[0][0]
sim_data = ves_per_spike_mean[:,i_p,i_size]

ves_per_spike = release_rate / stim

axs[0].plot(test_times, sim_data,
    linestyle='None', marker='o', zorder=2, markersize=2.0, label="discrete model", 
    color='cornflowerblue')
axs[0].plot(t_out, ves_per_spike, label="continuous model")
axs[0].set_yscale("log")
axs[0].set_xscale("log")
axs[0].set_xlabel("Time [s]")
axs[0].set_ylabel("Ves. per spike");
# remove top and right spines
axs[0].spines['top'].set_visible(False)
axs[0].spines['right'].set_visible(False)

# compare discrete and continuous release regular scale
axs[1].plot(test_times, sim_data,
    linestyle='None', marker='o', zorder=2, markersize=2.0, label="full model", 
    color='cornflowerblue')
inds = np.where(t_out <= 200)[0]
axs[1].plot(t_out[inds], ves_per_spike[inds], label="power-law GLM")
axs[1].legend(frameon=False)
axs[1].set_xlabel("Time [s]")
axs[1].set_ylabel("Ves. per spike");
axs[1].set_xlim([-2,50])
# remove top and right spines
axs[1].spines['top'].set_visible(False)
axs[1].spines['right'].set_visible(False)

# show recovery
axs[2].plot(t_out, activations)
# axs[2].plot(t_out, y)
axs[2].set_xlabel("Time [s]")
axs[2].set_ylabel("Activation")
axs[2].set_xlim([-2,50])
# axs[2].set_yscale("log")
# remove top and right spines
axs[2].spines['top'].set_visible(False)
axs[2].spines['right'].set_visible(False)

f.tight_layout()
f.savefig("discrete_GLM_comparison.pdf")

# Fit to experiment

In [None]:
# get experimental data
excel_file = "exp1_data.xlsx" # change path to the location of your file
df1 = pd.read_excel(excel_file)

# load exhaustion experiment data from xls
excel_file = "exp2_data.xlsx" 
df2 = pd.read_excel(excel_file)
data2 = df2[df2.columns[1:]].values.T
time = df2["seconds"].values
print(data2.shape)
print(time.shape)
experimental_fusion_rate = np.diff(data2) / np.diff(time)
experimental_fusion_rate = np.concatenate([experimental_fusion_rate, np.zeros((experimental_fusion_rate.shape[0],1))], axis=1)
print(experimental_fusion_rate.shape)

# stim starts at 32.35 s
stim_start = 32.35
inds = time > stim_start
time = time[inds] - time[inds][0]
data2 = data2[:,inds]
# start at 0 not 1
data2 = data2 - data2[:,[0]]
experimental_fusion_rate = experimental_fusion_rate[:,inds]
print(data2.shape)
print(time.shape)

data2_means = np.mean(data2, axis=0)
data2_errs = np.std(data2, axis=0) / np.sqrt(data2.shape[0])

In [8]:
def ptt_create_kernel(alpha, kappa0, cutoff, pps):
    dt = 1.0 / pps
    t = np.arange(dt, cutoff, dt)
    kappa = kappa0 * ptt.power(t, -alpha) / pps
    # for efficiency we ignore the exp. cutoff and just shorten the kernel
    kappa *= np.exp(-t/cutoff) 
    return kappa

# Define the differential equations
def get_release_rate_pymc(pps, args):

    t_args, const_args = args

    B = const_args["B"]
    ALPHA = const_args["ALPHA"]
    KAPPA0 = const_args["KAPPA0"]
    CUTOFF = const_args["CUTOFF"]

    kappa = ptt_create_kernel(ALPHA, KAPPA0, CUTOFF, pps)
    dt = 1.0 / pps
    len_kappa = len(np.arange(dt, CUTOFF, dt))

    stim_rate = t_args

    releases = ptt.zeros(len_kappa)
    def update_releases(stim_rate, releases, kappa, B):
        # recovery
        recovery = ptt.sum(releases * kappa)
        # release
        a = B - recovery
        p = ptt.exp(a)/(1 + ptt.exp(a)) 
        release_rate = - p**4 + 4*p**3 - 6*p**2 + 4*p
        release = release_rate * stim_rate

        rolled_releases = ptt.roll(releases, 1)
        updated_releases = ptt.set_subtensor(rolled_releases[0], release)
        return updated_releases
    
    outputs, _ = pt.scan(
        fn=update_releases, 
        sequences=[stim_rate], 
        non_sequences=[kappa, B],
        outputs_info=releases)

    releases = outputs[:,0]

    return releases

def run_experiment1(stim_rate, depletiontime, pausetime, testtime, const_args):
    pps = 10
    len_sim = depletiontime + pausetime + testtime
    num_points = int(len_sim * pps) + 1

    ### First set the time variables
    t = np.linspace(0, len_sim, num_points) 
    t_stim = t # timepoints at which the stimulus is defined
    t_out_inds = np.logical_and((t > depletiontime + pausetime), (t < depletiontime + pausetime + testtime))

    # define stimulus
    stim = t_stim < depletiontime
    stim = stim + (t_stim > depletiontime + pausetime)
    stim = stim * stim_rate
    t_args = stim

    # get the release rate
    release_rate = get_release_rate_pymc(pps, (t_args, const_args))
    return release_rate[t_out_inds]

def run_experiment2(stim_rate, const_args):
    len_sim = 70 # s
    pps_factor = 10
    pps = pps_factor*1.7 # points per second later scaled to 1.7 by taking mod 10
    num_points = int(len_sim * pps) + 1

    ### First set the time variables
    t = np.linspace(0, len_sim, num_points) 
    t_stim = t # timepoints at which the stimulus is defined
    t_out_inds = (np.arange(len(t)) % pps_factor == 0) & (t <= time[-1])
    
    # define stimulus
    stim = stim_rate * np.ones(len(t_stim))
    t_args = stim

    # get the release rate
    release_rate = get_release_rate_pymc(pps, (t_args, const_args))
    return release_rate[t_out_inds]

def get_lognormal_params(mean, std):
    sigma = np.sqrt(np.log(std**2 / mean**2 + 1))
    mu = np.log(mean) - 0.5 * sigma**2
    return mu, sigma

with pm.Model() as model:

    # Priors on the model parameters
    alpha = 0.3
    cutoff = _cutoff

    # informative priors for better convergence
    b1 = pm.Normal("b1", mu=20, sigma=20)
    mu, sigma = get_lognormal_params(12, 12)
    kappa01 = pm.LogNormal("kappa01", mu=mu, sigma=sigma)
    observation_factor1 = pm.HalfCauchy("observation_factor1", beta=0.01)#pm.HalfFlat("observation_factor1") #

    b2 = pm.Normal("b2", mu=-1.5, sigma=3)
    mu, sigma = get_lognormal_params(3e-08, 3e-08)
    kappa02 = pm.LogNormal("kappa02", mu=mu, sigma=sigma)
    observation_factor2 = pm.HalfCauchy("observation_factor2", beta=0.1) 

    error_model = pm.HalfCauchy("error_model", beta=0.2)
    mu, sigma = get_lognormal_params(error_model, error_model/4)
    error_model1 = pm.LogNormal("error_model1", mu=mu, sigma=sigma)
    error_model2 = pm.LogNormal("error_model2", mu=mu, sigma=sigma)


    const_args_var1 = {
        "B": b1,
        "ALPHA": alpha,
        "KAPPA0": kappa01,
        "CUTOFF": cutoff,
    }

    testtime = 2.0 # s
    depletiontimes = [0.4, 4.0, 40.0]
    pausetimes = [1.0, 10.0, 40.0, 100.0]
    stim_rate = 20.0 # Hz
    conditions = [(depletiontime, pausetime) for pausetime in pausetimes for depletiontime in depletiontimes]
    conditions.append((40.0, 200.0))

    sim = []
    sim_error = []
    obs = []

    for condition in conditions:
        depletiontime, pausetime = condition
        # get condition name
        fm = lambda x: round(x, 1) if x % 1 else int(x)
        condition_name = "{}_{}".format(fm(depletiontime), fm(pausetime))
        
        release_rate = run_experiment1(stim_rate, depletiontime, pausetime, testtime, const_args_var1)
        released = pm.math.mean(release_rate)

        # save sim data under condition name
        pm.Deterministic(condition_name, observation_factor1 * released)

        # get data and simulation results
        data = df1[condition_name]
        error_of_the_mean = np.std(data) / np.sqrt(len(data))
        data_mean = np.mean(data)
        sigma_error = pm.Deterministic("scaled_sigma_error_" + condition_name, error_model1 + error_of_the_mean)
        
        sim.append(observation_factor1 * released)
        sim_error.append(sigma_error)
        obs.append(data_mean)

    sim = pm.math.stack(sim)
    sim_error = pm.math.stack(sim_error)
    
    mu1 = sim
    sigma1 = sim_error
    observed1 = obs

    ###### experiment 2

    const_args_var2 = {
        "B": b2,
        "ALPHA": alpha,
        "KAPPA0": kappa02,
        "CUTOFF": cutoff,
    }
    
    stim_rate = 5.0 # Hz
    max_vesicles = 100

    release_rate = run_experiment2(stim_rate, const_args_var2)
    pps_experiment = 1.7
    fused = release_rate / pps_experiment
    # define ft = 1 - #tagged vesicles / #total vesicles
    inv_rel_fraction = 1 - fused / max_vesicles
    # then we can compute with cumprod the fraction of tagged vesicles
    ft = pm.math.cumprod(inv_rel_fraction)
    tagged = max_vesicles * (1 - ft)
    tagged = tagged - tagged[0]

    out = pm.Deterministic("experiment2", observation_factor2 * tagged)
    sim2_errors = pm.Deterministic("scaled_sigma_error_experiment2", error_model2 + data2_errs)

    mu2 = out
    sigma2 = sim2_errors
    observed2 = data2_means

    observed = np.concatenate([observed1, observed2])
    mu = ptt.concatenate([mu1, mu2])
    sigma = ptt.concatenate([sigma1, sigma2])

    def logp(observed, mu, sigma):
        n = 13

        observed1 = observed[:n]
        observed2 = observed[n:]
        mu1 = mu[:n]
        mu2 = mu[n:]
        sigma1 = sigma[:n]
        sigma2 = sigma[n:]
        p1 = pm.logp(pm.Normal.dist(mu1, sigma1),observed1)
        p2 = pm.logp(pm.Normal.dist(mu2, sigma2),observed2)
        return ptt.concatenate([p1, p2])

    pm.CustomDist("likelihood",
                    mu, sigma,
                    logp = logp,
                    observed = observed)
    
    

In [None]:
map_estimate = pm.find_MAP(model=model)
print(map_estimate)

In [None]:
# get the posterior
sim_df = pd.DataFrame()
for condition in conditions:
    depletiontime, pausetime = condition
    condition_name = "{}_{}".format(fm(depletiontime), fm(pausetime))
    sim_df[condition_name] = np.array([map_estimate[condition_name]])

fig = plt.figure(figsize=(3.5,2.8))
ax = fig.subplots(1,1)

# offset points
import matplotlib.transforms as transforms
offset = lambda p: transforms.ScaledTranslation(p/72.,0, fig.dpi_scale_trans)
trans = ax.transData

# remove the right and top and bottom spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)

errs_df = df1.apply(lambda x: bootstrap_confidence_interval(x), axis=0)
plt.errorbar(df1.columns, df1.mean(), yerr=errs_df, linestyle='None', marker='None', markersize=2, elinewidth=9, color='black', alpha=0.2)
plt.errorbar(df1.columns, df1.mean(), yerr=errs_df, linestyle='None', marker='_', markersize=9, elinewidth=0, color='black', alpha=0.2)

errs_sim_df = sim_df.apply(lambda x: percentile_confidence_interval(x), axis=0)
plt.plot(sim_df.columns, sim_df.mean(), linestyle='None', marker='o', color='cornflowerblue')

plt.xticks(rotation=45)
plt.ylabel("change in dF/F")
plt.tight_layout()
plt.savefig("posterior_full_comparison.pdf", bbox_inches='tight')

In [None]:
# get the posteriors
simulation_results = np.array([map_estimate["experiment2"]])

simulation_results_mean = np.mean(simulation_results, axis=0)
# simulation_results_conf = np.percentile(simulation_results, [2.5, 97.5], axis=0)

deviation = np.std(data2, axis=0) / np.sqrt(data2.shape[0])
experiment_conf = data2_means + 1.96 * np.array([-deviation, deviation])

fig, axs = plt.subplots(1,1, figsize=(2.5,2))
axs.plot(time, data2_means, label="Experimental data", alpha=0.2, color='black')
axs.fill_between(time, experiment_conf[0], experiment_conf[1], alpha=0.2, color='black')
axs.plot(time, simulation_results_mean, label="Model prediction", color='cornflowerblue')
# axs.fill_between(time, simulation_results_conf[0], simulation_results_conf[1], alpha=0.2, color='cornflowerblue')
axs.set_xlabel("Time [s]")
axs.set_ylabel("Flourescence increase")
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
axs.legend()


plt.tight_layout() 
plt.savefig("exhaustion_comparison.pdf", bbox_inches='tight')
plt.show()

In [None]:
# # sample the model
trace = pm.sample(
    model=model,
    tune=400,
    draws=1000,
    cores=2,
    nuts_sampler_kwargs={"nuts_kwargs": {"max_tree_depth": 6}},
    nuts_sampler="numpyro",
    target_accept=0.8,
)

# save the trace
az.to_netcdf(trace, "trace_powerlaw.nc")


In [None]:
# load the trace`
trace = az.from_netcdf("trace_powerlaw.nc")
# trace_all = trace
trace = trace.sel(chain=[1])
az.rhat(trace)

In [None]:
# get the posterior
sim_df = pd.DataFrame()
for condition in conditions:
    depletiontime, pausetime = condition
    condition_name = "{}_{}".format(fm(depletiontime), fm(pausetime))
    sim_df[condition_name] = trace.posterior[condition_name].to_numpy().flatten()

fig = plt.figure(figsize=(3.5,2.8))
ax = fig.subplots(1,1)

# offset points
import matplotlib.transforms as transforms
offset = lambda p: transforms.ScaledTranslation(p/72.,0, fig.dpi_scale_trans)
trans = ax.transData

# remove the right and top and bottom spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)

errs_df = df1.apply(lambda x: bootstrap_confidence_interval(x), axis=0)
plt.errorbar(df1.columns, df1.mean(), yerr=errs_df, linestyle='None', marker='None', markersize=2, elinewidth=9, color='black', alpha=0.2)
plt.errorbar(df1.columns, df1.mean(), yerr=errs_df, linestyle='None', marker='_', markersize=9, elinewidth=0, color='black', alpha=0.2)

errs_sim_df = sim_df.apply(lambda x: percentile_confidence_interval(x), axis=0)
plt.errorbar(sim_df.columns, sim_df.mean(), yerr=errs_sim_df, linestyle='None', marker='o', color='cornflowerblue')

plt.xticks(rotation=45)
plt.ylabel("change in dF/F")
plt.tight_layout()
plt.savefig("posterior_GLM_exp1.pdf", bbox_inches='tight')


In [None]:
# get the posteriors
simulation_results = trace.posterior["experiment2"].to_numpy().reshape(-1, data2.shape[1])

simulation_results_mean = np.mean(simulation_results, axis=0)
simulation_results_conf = np.percentile(simulation_results, [2.5, 97.5], axis=0)

deviation = np.std(data2, axis=0) / np.sqrt(data2.shape[0])
experiment_conf = data2_means + 1.96 * np.array([-deviation, deviation])

fig, axs = plt.subplots(1,1, figsize=(2.5,2))
axs.plot(time, data2_means, label="Experimental data", alpha=0.2, color='black')
axs.fill_between(time, experiment_conf[0], experiment_conf[1], alpha=0.2, color='black')
axs.plot(time, simulation_results_mean, label="Model prediction", color='cornflowerblue')
axs.fill_between(time, simulation_results_conf[0], simulation_results_conf[1], alpha=0.2, color='cornflowerblue')
axs.set_xlabel("Time [s]")
axs.set_ylabel("Flourescence increase")
axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
# axs.legend()


plt.tight_layout() 
plt.savefig("posterior_GLM_exp2.pdf", bbox_inches='tight')
plt.show()

In [None]:
def plot_cont(prior, ax=None):
    if ax is None:
        _, ax = plt.subplots()
    samples = pm.draw(prior, draws=1000)
    x = np.linspace(np.min(samples), np.max(samples), 1000)
    ax.plot(x, np.exp(pm.logp(prior,x)).eval(), color='gray')
    return ax

condition_names = ["{}_{}".format(fm(condition[0]), fm(condition[1])) for condition in conditions] + ["experiment2"]

name_dict = {
    "alpha1": r"$\alpha$",
    "b1": r"$b^{(1)}$",
    "b2": r"$b^{(2)}$",
    "kappa01": r"$\log\kappa_0^{(1)}$",
    "kappa02": r"$\log\kappa_0^{(2)}$",
    "observation_factor1": r"$\alpha_{obs}^{(1)}$",
    "error_model": r"$\sigma_{obs}$",
    "error_model1": r"$\sigma_{obs}^{(1)}$",
    "error_model2": r"$\sigma_{obs}^{(2)}$",
    "observation_factor2": r"$\alpha_{obs}^{(2)}$",
}

# for each variable plot prior and posterior
names = [var.name for var in model.unobserved_RVs]
fl = lambda name: not any([condition_name in name for condition_name in condition_names])
names = list(filter(fl, names))

sidelength = int(np.ceil(np.sqrt(len(names))))
fig, axs = plt.subplots(sidelength, sidelength, figsize=(sidelength*2.5, sidelength*2));
axs = axs.flatten()

for name, i in zip(names, range(len(names))):
    ax = axs[i]

    pm.plot_posterior(trace, var_names=name, color='cornflowerblue', ax=ax)
    ax.set_title(name_dict[name], fontdict={'fontsize': 15})
    # plot prior
    prior = model[name]
    try:
        plot_cont(prior, ax=ax)
    except:
        pass
    # scale to posterior
    ax.set_xlim((np.min(trace.posterior[name]), np.max(trace.posterior[name])))

plt.tight_layout()
plt.savefig("posterior_GLM.pdf", bbox_inches='tight')


In [None]:
print(az.summary(trace))
# keys = list(const_args_var.keys())
# keys = [k.lower() for k in keys]
# az.plot_forest(trace, var_names=keys, combined=True, hdi_prob=0.95, transform=lambda x: np.log10(x))
# az.plot_forest(trace, var_names=["error_model"], combined=True, hdi_prob=0.95)

In [None]:
with model:
    if not "log_likelihood" in trace._groups:
        log_likelihood = pm.compute_log_likelihood(trace)

In [None]:
az.to_netcdf(trace, "trace_powerlaw_with_likelihood.nc")