In [None]:
import re #regex
import pickle
from multiprocessing import Pool, cpu_count
from git import Repo #for directory convenience

import numpy as np
from scipy.stats import nbinom as neg_binom
from mpmath import hyp2f1
from mpmath import ln as arbprec_ln
from scipy.special import gammaln
import pandas as pd

import emcee
import arviz as az

import matplotlib.pyplot as plt
import bokeh
import bokeh.io
import datashader as ds
# import datashader.bokeh_ext
import bebi103.viz

from srep.data_loader import load_FISH_by_promoter

bokeh.io.output_notebook()

In [None]:
repo = Repo("./", search_parent_directories=True)
# repo_rootdir holds the absolute path to the top-level of our repo
repo_rootdir = repo.working_tree_dir

In [None]:
# first load data using module util
df_unreg, df_reg = load_FISH_by_promoter(("unreg", "reg"))
# pull out one specific promoter for convenience for prior pred check & SBC
df_UV5 = df_unreg[df_unreg["experiment"] == "UV5"]
df_O2_1ngml = df_reg[df_reg["experiment"] == "O2_1ngmL"]


In [None]:
# first make mpmath's hypgeom vectorizable for convenience
np_log_hyp = np.frompyfunc(
    lambda a, b, c, z: arbprec_ln(hyp2f1(a, b, c, z)), 4, 1
)

def log_p_m_bursty_rep(mRNA, k_burst, mean_burst, kR_on, kR_off):
    """log prob of given mRNA counts
    
    note that S&S's alpha & beta are strictly positive, so every factor
    in their SI eq 75 & 76 must be >= 0, including 2F1. So we can take
    the log of the whole thing without worrying about signs.
    
    Also note that this function can't handle repressor rates = 0,
    code that separately!
    """
    # note that S&S's alpha & beta are strictly positive, so every factor
    # in their SI eq 75 & 76 must be >= 0, including 2F1. So we can take
    # the log of the whole thing without worrying about signs.
    # some args of 2F1 _might_ be negative, but none of the args of
    # any of the Gam fcns can possibly be < 0

    # first compute a, b, c, the effective parameters in the 2F1 gen fcn
    # recall the gen fcn is 2F1(a, b, c, b(z-1))
    rate_sum = k_burst + kR_off + kR_on
    sqrt_discrim = np.sqrt((rate_sum)**2 - 4*k_burst*kR_off)
    a = (rate_sum + sqrt_discrim) / 2.0
    b = (rate_sum - sqrt_discrim) / 2.0
    c = kR_on + kR_off

    prefac = (gammaln(a+mRNA) - gammaln(a) + gammaln(b+mRNA) - gammaln(b)
            - gammaln(c+mRNA) + gammaln(c)
            - gammaln(1+mRNA) + mRNA*np.log(mean_burst))
    log_prob = prefac + np_log_hyp(a+mRNA, b+mRNA, c+mRNA, - mean_burst)
    return log_prob.astype(float)

In [None]:
# %%timeit
np.exp(log_p_m_bursty_rep(np.arange(40), 5.2,3.6,1.5,1))

In [None]:
log_like_repressed((5.2,3.6,1.5,1), data_rep)

In [None]:
def log_like_repressed(params, data_rep):
    """Conv wrapper for log likelihood for 2-state promoter w/
    transcription bursts and repression.
    
    data : array-like. n x 2
        data[:, 0] = unique mRNA counts
        data[:, 1] = frequency of each mRNA count

    Note the data pre-processing here, credit to Manuel for this observation:
    'NOTE: The likelihood asks for unique mRNA entries and their corresponding 
    counts to speed up the process of computing the probability distribution. 
    Instead of computing the probability of 3 mRNAs n times, it computes it 
    once and multiplies the value by n.'
    """
    mRNA, counts = data_rep[0], data_rep[1]
#     return np.sum(counts * log_p_m_bursty_rep(mRNA, *params))
    return np.sum(data_rep[1] * log_p_m_bursty_rep(data_rep[0], *params))

In [None]:
def log_like_constitutive(params, data_uv5):
    k_burst, mean_burst, _, _ = params
#     k_burst, mean_burst = params
#     mRNA, counts = data_constit[0], data_constit[1]
    # change vars for scipy's goofy parametrization
    p = (1 + mean_burst)**(-1)
    return np.sum(data_uv5[1] * neg_binom._logpmf(data_uv5[0], k_burst, p))

In [None]:
def log_prior(params):
    k_burst, mean_burst, kR_on, kR_off = params
#     k_burst, mean_burst = params
    if (0 < k_burst < 20 and 0 < mean_burst < 20 and
        0 < kR_on < 40 and 0 < kR_off < 20):
        return 0.0
    return -np.inf

In [None]:
def log_posterior(params, data_uv5, data_rep, log_sampling=False):
    """log posterior fcn. check prior and then
    farm out data to the respective likelihoods."""
    # Boolean logic to sample in linear or in log scale
    # Credit to Manuel for this
    if log_sampling:
        params = 10**params
    lp = log_prior(params)
    if lp == -np.inf:
        return -np.inf
    return (lp + log_like_constitutive(params, data_uv5)
            + log_like_repressed(params, data_rep))

In [None]:
n_dim = 4
n_walkers = 20
n_burn = 10
n_steps = 400

# slice data for the sampler
data_uv5 = np.unique(df_UV5['mRNA_cell'], return_counts=True)
data_rep = np.unique(df_O2_1ngml['mRNA_cell'], return_counts=True)
# init walkers
p0 = np.zeros([n_walkers, n_dim])
p0[:,0] = np.random.uniform(5,6, n_walkers) # k_burst
p0[:,1] = np.random.uniform(3,4, n_walkers) # mean_burst
p0[:,2] = np.random.uniform(0,10, n_walkers) # kR_on
p0[:,3] = np.random.uniform(0,10, n_walkers) # kR_off

In [None]:
with Pool() as pool:
    # instantiate sampler
    sampler = emcee.EnsembleSampler(
        n_walkers, n_dim, log_posterior, pool=pool, args=(data_uv5, data_rep),
    )
    print("starting burn-in...")
    pos, prob, state = sampler.run_mcmc(p0, n_burn, store=False, progress=True)
    print("starting actual sampling...")
    _ = sampler.run_mcmc(pos, n_steps, progress=True, thin_by=10);

In [None]:
fig, axes = plt.subplots(4, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = ["k_burst", "b", "kR_on", "kR_off"]
for i in range(n_dim):
    ax = axes[i]
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes[-1].set_xlabel("step number");

In [None]:
sampler.get_autocorr_time()

In [None]:
emcee_output = az.from_emcee(sampler, var_names=['k_burst', 'b', 'kR_on', 'kR_off'])

In [None]:
bokeh.io.show(bebi103.viz.corner(emcee_output, plot_ecdf=True))