## Metropolis Hastings Custom HMC MC Algorithm

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder
from reggae.mcmc import create_chains, MetropolisHastings
from reggae.models import transcription_mh
from reggae.utilities import get_rbf_dist, exp, mult, discretise

import tensorflow as tf
import numpy as np
import pandas as pd
import arviz
from multiprocessing import Pool

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline

In [None]:
#df, genes, genes_se, m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()
m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()

# m_observed, f_observed, t = load_3day_dros()

replicate = 0

m_df, m_observed = m_observed 
f_df, f_observed = f_observed

# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed[replicate]
f_observed = np.atleast_2d(f_observed[replicate])
σ2_m_pre = σ2_m_pre[0]
σ2_f_pre = σ2_f_pre[0]

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
noise_data = (σ2_m_pre, σ2_f_pre)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, noise_data, time)


In [None]:
from tensorflow import math as tfm
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp
from ipywidgets import IntProgress

from reggae.mcmc import MetropolisHastings, Parameter
from reggae.data_loaders import DataHolder
from reggae.utilities import get_rbf_dist, exp, mult, jitter_cholesky

import numpy as np
from scipy.special import expit

f64 = np.float64
class Options():
    def __init__(self, preprocessing_variance=True, tf_mrna_present=True):
        self.preprocessing_variance = preprocessing_variance
        self.tf_mrna_present = tf_mrna_present
        
class TranscriptionLikelihood():
    def __init__(self, data: DataHolder, options: Options):
        self.options = options
        self.data = data
        self.preprocessing_variance = options.preprocessing_variance
        self.num_genes = data.m_obs.shape[0]

    @tf.function
    def predict_m(self, kbar, δbar, w, fbar, w_0):
        # Take relevant parameters out of log-space
        print(kbar)
        a_j, b_j, d_j, s_j = (tf.reshape(tfm.exp(kbar[:, i]), (-1, 1)) for i in range(4))
        δ = tfm.exp(δbar)
        f_i = tfm.log(1+tfm.exp(fbar))
        τ = self.data.τ
        N_p = self.data.τ.shape[0]

        # Calculate p_i vector
        Δ = τ[1]-τ[0]
        sum_term = tfm.multiply(tfm.exp(δ*τ), f_i)
        p_i = tf.concat([[f64(0)], 0.5*Δ*tfm.cumsum(sum_term[:-1] + sum_term[1:])], axis=0) # Trapezoid rule
        p_i = tfm.multiply(tfm.exp(-δ*τ), p_i)

        # Calculate m_pred
        integrals = tf.zeros((self.num_genes, N_p))
        interactions = w[:, 0][:, None]*tfm.log(p_i+1e-100) + w_0[:, None]
        G = tfm.sigmoid(interactions) # TF Activation Function (sigmoid)
        sum_term = G * tfm.exp(d_j*τ)
        integrals = tf.concat([tf.zeros((5, 1), dtype='float64'), 0.5*Δ*tfm.cumsum(sum_term[:, :-1] + sum_term[:, 1:], axis=1)], axis=1) # Trapezoid rule

        exp_dt = tfm.exp(-d_j*τ)
        integrals = tfm.multiply(exp_dt, integrals)
        m_pred = b_j/d_j + tfm.multiply((a_j-b_j/d_j), exp_dt) + s_j*integrals

        return m_pred

    def genes(self, params=None, δbar=None,
                     fbar=None, 
                     kbar=None, 
                     w=None,
                     w_0=None,
                     σ2_m=None, return_sq_diff=False):
        '''
        Computes likelihood of the genes.
        If any of the optional args are None, they are replaced by their current value in params.
        '''
        if δbar is None:
            δbar = params.δbar.value
        if fbar is None:
            fbar = params.fbar.value
        if kbar is None:
            kbar = params.kbar.value
        w = params.w.value if w is None else w
        σ2_m = params.σ2_m.value if σ2_m is None else σ2_m

        w_0 = params.w_0.value if w_0 is None else w_0
        lik, sq_diff = self._genes(δbar, fbar, kbar, w, w_0, σ2_m)

        if return_sq_diff:
            return lik, sq_diff
        return lik

    @tf.function
    def _genes(self, δbar, fbar, kbar, w, w_0, σ2_m):
#         tf.print(δbar)
        m_pred = self.predict_m(kbar, δbar, w, fbar, w_0)

        sq_diff = tfm.square(self.data.m_obs - tf.transpose(tf.gather(tf.transpose(m_pred),self.data.common_indices)))
        variance = tf.reshape(σ2_m, (-1, 1))
        if self.preprocessing_variance:
            variance = variance + self.data.σ2_m_pre # add PUMA variance
#         print(variance.shape, sq_diff.shape)
        log_lik = -0.5*tfm.log(2*np.pi*(variance)) - 0.5*sq_diff/variance
        log_lik = tf.reduce_sum(log_lik, axis=1)
        return log_lik, sq_diff

    def tfs(self, params, fbar, return_sq_diff=False): 
        '''
        Computes log-likelihood of the transcription factors.
        TODO this should be for the i-th TF
        '''
        assert self.options.tf_mrna_present
        if not self.preprocessing_variance:
            σ2_f = params.σ2_f.value
            variance = σ2_f.reshape(-1, 1)
        else:
            variance = self.data.σ2_f_pre
        f_pred = tfm.log(1+np.exp(fbar))
        f_pred = tf.reshape(f_pred, (1, -1)) #np.atleast_2d f_pred[:, self.data.common_indices]
        sq_diff = tfm.square(self.data.f_obs - tf.transpose(tf.gather(tf.transpose(f_pred),self.data.common_indices)))

        log_lik = -0.5*tfm.log(2*np.pi*variance) - 0.5*sq_diff/variance
        log_lik = tf.reduce_sum(log_lik, axis=1)
        if return_sq_diff:
            return log_lik, sq_diff
        return log_lik
    
opt = Options(preprocessing_variance=True, tf_mrna_present=True)
lik = TranscriptionLikelihood(data, opt)

In [None]:
import tensorflow_probability as tfp
import collections
MixedKernelResults = collections.namedtuple('MixedKernelResults', [
    'inner_results',
#     'grads_target_log_prob',
#     'step_size',
#     'log_accept_ratio',
#     'is_accepted',
])

GenericResults = collections.namedtuple('GenericResults', [
    'target_log_prob',
    'is_accepted',
])

class MixedKernel(tfp.mcmc.TransitionKernel): # TODO simplify all states: just send all states and keep dict of indices
    def __init__(self, kernels, send_all_states):
        self.kernels = kernels
        self.send_all_states = send_all_states
        self.num_kernels = len(kernels)
        
    def one_step(self, current_state, previous_kernel_results):
#         print('running', current_state, previous_kernel_results)
        new_state = list()
        is_accepted = list()
        inner_results = list()

        for i in range(self.num_kernels):
            if self.send_all_states[i]:
                result_state, kernel_results = self.kernels[i].one_step(
                    current_state[i], previous_kernel_results.inner_results[i], current_state)
            else:
                result_state, kernel_results = self.kernels[i].one_step(
                    current_state[i], previous_kernel_results.inner_results[i])

            if i == 2:
                print(result_state, kernel_results)
            '''
            kernel_results: NUTSKernelResults(
                target_log_prob=0.248428136, grads_target_log_prob=[0.0792938471], 
                step_size=[1], log_accept_ratio=0, leapfrogs_taken=50, is_accepted=1, energy=0.0327872932)
            and more...
            '''
            new_state.append(result_state)
            inner_results.append(kernel_results)
        
        
        return new_state, MixedKernelResults(inner_results)

    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step(...)[1]`.
        Args:
        init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).
        Returns:
        kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        """
        inner_kernels_bootstraps = list()
        for i in range(self.num_kernels):
            if self.send_all_states[i]:
                inner_kernels_bootstraps.append(
                    self.kernels[i].bootstrap_results(init_state[i], init_state))
            else:
                inner_kernels_bootstraps.append(
                    self.kernels[i].bootstrap_results(init_state[i]))

        return MixedKernelResults(inner_kernels_bootstraps)

    def is_calibrated(self):
        return True


In [None]:

class FKernel(tfp.mcmc.TransitionKernel):
    def __init__(self, likelihood, fbar_prior_params, num_tfs):
        self.fbar_prior_params = fbar_prior_params
        self.num_tfs = num_tfs
        self.likelihood = likelihood
        self.h_f = 0.35*tf.ones(self.N_p, dtype='float64')
        
    def one_step(self, current_state, previous_kernel_results, all_states):
        # Untransformed tf mRNA vectors F (Step 1)
        fbar = current_state
        for i in range(self.num_tfs):
            # Gibbs step
            z_i = tf.reshape(tfd.MultivariateNormalDiag(fbar, self.h_f).sample(), (1, -1))
            # MH
            m, K = self.fbar_prior_params(params.V.value, params.L.value)
            invKsigmaK = tf.matmul(tf.linalg.inv(K+tf.linalg.diag(self.h_f)), K) # (C_i + hI)C_i
            L = jitter_cholesky(K-tf.matmul(K, invKsigmaK))
            c_mu = tf.matmul(z_i, invKsigmaK)
            fstar = tf.matmul(tf.random.normal((1, L.shape[0]), dtype='float64'), L) + c_mu
            fstar = tf.reshape(fstar, (-1, ))
            new_m_likelihood = self.likelihood.genes(params, fbar=fstar)
            new_f_likelihood = 0 
            if self.options.tf_mrna_present:
                new_f_likelihood = self.likelihood.tfs(params, fstar)
            new_prob = tf.reduce_sum(new_m_likelihood) + new_f_likelihood
            old_prob = previous_kernel_results.target_log_prob #tf.reduce_sum(old_m_likelihood) + old_f_likelihood
            
            if self.is_accepted(new_prob, old_prob):
                params.fbar.value = fstar
                old_m_likelihood = new_m_likelihood
                old_f_likelihood = new_f_likelihood
                self.acceptance_rates['fbar'] += 1/self.num_tfs

def metropolis_is_accepted(new_log_prob, old_log_prob):
    alpha = tfm.exp(new_log_prob - old_log_prob)
    return tf.random.uniform((1,), dtype='float64') < tfm.minimum(f64(1), alpha)
#     if is_tensor(alpha):
#         alpha = alpha.numpy()
#     return not np.isnan(alpha) and random.random() < min(1, alpha)

δbar_state_index = 1
N_p = τ.shape[0]

class KbarKernel(tfp.mcmc.TransitionKernel):
    def __init__(self, likelihood, prop_dist, prior_dist, num_genes):
        self.prop_dist = prop_dist
        self.prior_dist = prior_dist
        self.num_genes = num_genes
        self.likelihood = likelihood
        
    def one_step(self, current_state, previous_kernel_results, all_states):

        kbar = current_state
        kstar = tf.identity(kbar)
        old_probs = list()
        is_accepteds = list()
        for j in range(self.num_genes):
            sample = self.prop_dist(kstar[j]).sample()
#             sample = params.kbar.constrain(sample, j)
            kstar = tf.concat([kstar[:j], [sample], kstar[j+1:]], axis=0)
            
            new_prob = self.likelihood.genes(
                δbar=all_states[δbar_state_index],
                fbar=0.5*tf.ones(N_p, dtype='float64'),       # TODO
                kbar=kstar, 
                w=1*tf.ones((self.num_genes, 1), dtype='float64'), # TODO
                w_0=tf.zeros(self.num_genes, dtype='float64'),     # TODO
                σ2_m=1e-4*tf.ones(self.num_genes, dtype='float64') # TODO
            )[j] + tf.reduce_sum(self.prior_dist.log_prob(sample))
            
            old_prob = previous_kernel_results.target_log_prob[j] #old_m_likelihood[j] + sum(params.kbar.prior.log_prob(kbar[j]))

            is_accepted = metropolis_is_accepted(new_prob, old_prob)
            is_accepteds.append(is_accepted)
            
            prob = tf.cond(tf.equal(is_accepted, tf.constant(True)), lambda:new_prob, lambda:old_prob)
            kstar = tf.cond(tf.equal(is_accepted, tf.constant(False)), 
                                     lambda:tf.concat([kstar[:j], [current_state[j]], kstar[j+1:]], axis=0), lambda:kstar)
            old_probs.append(prob)
#                 
        return kstar, GenericResults(old_probs, True) #TODO not just return true
    
    def bootstrap_results(self, init_state, all_states):
        probs = list()
        for j in range(self.num_genes):
            prob = self.likelihood.genes(
                δbar=all_states[δbar_state_index],
                fbar=0.5*tf.ones(N_p, dtype='float64'),       # TODO
                kbar=init_state, 
                w=1*tf.ones((self.num_genes, 1), dtype='float64'), # TODO
                w_0=tf.zeros(self.num_genes, dtype='float64'),     # TODO
                σ2_m=1e-4*tf.ones(self.num_genes, dtype='float64') # TODO
            )[j] + tf.reduce_sum(self.prior_dist.log_prob(init_state[j]))
            probs.append(prob)

        return GenericResults(probs, True) #TODO automatically adjust
    
    def is_calibrated(self):
        return True



In [None]:
TupleParams_pre = collections.namedtuple('TupleParams_pre', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V','σ2_f'
])
TupleParams = collections.namedtuple('TupleParams', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V'
])

class TranscriptionCustom():
    '''
    Data is a tuple (m, f) of shapes (num, time)
    time is a tuple (t, τ, common_indices)
    '''
    def __init__(self, data: DataHolder, options: Options):
        self.data = data
        min_dist = min(data.t[1:]-data.t[:-1])
        self.N_p = data.τ.shape[0]
        self.N_m = data.t.shape[0]      # Number of observations

        self.num_tfs = data.f_obs.shape[0] # Number of TFs
        self.num_genes = data.m_obs.shape[0]

        self.likelihood = TranscriptionLikelihood(data, options)
        self.options = options
        # Adaptable variances
        a = tf.constant(-0.5, dtype='float64')
        b2 = tf.constant(2., dtype='float64')
        

        # Interaction weights
        w_0 = Parameter('w_0', tfd.Normal(f64(0), f64(2)), np.zeros(self.num_genes), step_size=0.5*tf.ones(self.num_genes, dtype='float64'))
        w_0.proposal_dist=lambda mu, j:tfd.Normal(mu, w_0.step_size[j])
        w = Parameter('w', tfd.Normal(f64(0), f64(2)), 1*np.ones((self.num_genes, self.num_tfs)), step_size=0.5*tf.ones(self.num_genes, dtype='float64'))
        w.proposal_dist=lambda mu, j:tfd.Normal(mu, w.step_size[j]) #) w_j) # At the moment this is the same as w_j0 (see pg.8)
#         def w_log_prob(wstar, w_0star):
#             new_prob = self.likelihood.genes(self.params, w=wstar, w_0=w_0star)[j]
#             new_prob += tf.reduce_sum(self.params.w.prior.log_prob(wstar), axis=1) + self.params.w_0.prior.log_prob(w_0star)
#             return tf.reduce_sum(new_prob)
#         w = Parameter('w', tfd.Normal(f64(0), f64(2)), 1*np.ones((self.num_genes, self.num_tfs)), 
#                       step_size=0.1, hmc_log_prob=w_log_prob)
        # Latent function
        fbar = Parameter('fbar', self.fbar_prior, 0.5*np.ones(self.N_p))

        # GP hyperparameters
        @tf.function
        def V_log_prob(vstar, l2star):
            new_prob = self.params.fbar.prior(self.params.fbar.value, vstar, l2star)
            new_prob += self.params.V.prior.log_prob(vstar)
            new_prob += self.params.L.prior.log_prob(l2star)
            return tf.reduce_sum(new_prob)
        V = Parameter('V', tfd.InverseGamma(f64(0.01), f64(0.01)), f64(1), step_size=0.05, 
                      fixed=not options.tf_mrna_present, hmc_log_prob=V_log_prob)
        L = Parameter('L', tfd.Uniform(f64(min_dist**2-0.5), f64(data.t[-1]**2)), f64(4), step_size=0.1)
        L.proposal_dist=lambda l2: tfd.TruncatedNormal(l2, L.step_size, low=0, high=100) #l2_i
        self.t_dist = get_rbf_dist(data.τ, self.N_p)

        # Translation kinetic parameters
        def δbar_log_prob(state):
            new_prob = tf.reduce_sum(self.likelihood.genes(self.params, δbar=state)) 
#             tf.print(new_prob)
            new_prob += self.params.δbar.prior.log_prob(state)
            return new_prob

        δbar = Parameter('δbar', tfd.Normal(a, b2), f64(-0.3), step_size=0.3, hmc_log_prob=δbar_log_prob)

        # White noise for genes
        σ2_m = Parameter('σ2_m', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*np.ones(self.num_genes), step_size=0.5)
        σ2_m.proposal_dist=lambda mu: tfd.TruncatedNormal(mu, σ2_m.step_size, low=0, high=5)
        # Transcription kinetic parameters
        def constrain_kbar(kbar, gene):
            '''Constrains a given row in kbar'''
#             if gene == 3:
#                 kbar[2] = np.log(0.8)
#                 kbar[3] = np.log(1.0)
            kbar[kbar < -10] = -10
            kbar[kbar > 3] = 3
            return kbar
        kbar_initial = -0.1*np.float64(np.c_[
            np.ones(self.num_genes), # a_j
            np.ones(self.num_genes), # b_j
            np.ones(self.num_genes), # d_j
            np.ones(self.num_genes)  # s_j
        ])
        def kbar_log_prob(*kstar):
            new_prob = 0
            for j in range(num_genes):
                new_prob += self.likelihood.genes(self.params, kbar=np.array(kstar))[j]
                new_prob += tf.reduce_sum(self.params.kbar.prior.log_prob(kstar[j]))
            return tf.reduce_sum(new_prob)

        for j, k in enumerate(kbar_initial):
            kbar_initial[j] = constrain_kbar(k, j)
        kbar = Parameter('kbar',
            tfd.Normal(a, b2), 
            kbar_initial,
            constraint=constrain_kbar, step_size=0.25*tf.ones(4, dtype='float64'))
        kbar.proposal_dist=lambda mu: tfd.MultivariateNormalDiag(mu, kbar.step_size)
        
        if not options.preprocessing_variance:
            σ2_f = Parameter('σ2_f', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*np.ones(self.num_tfs), step_size=tf.constant(0.5, dtype='float64'))
            self.params = TupleParams_pre(fbar, δbar, kbar, σ2_m, w, w_0, L, V, σ2_f)
        else:
            self.params = TupleParams(fbar, δbar, kbar, σ2_m, w, w_0, L, V)
            
    def fbar_prior_params(self, v, l2):
    #     print('vl2', v, l2)
        jitter = tf.linalg.diag(1e-5 * tf.ones(self.N_p, dtype='float64'))
        K = tfm.multiply(v, tfm.exp(-tfm.square(self.t_dist)/(2*l2))) + jitter
        m = np.zeros(self.N_p)
        return m, K

    def fbar_prior(self, fbar, v, l2):
        m, K = self.fbar_prior_params(v, l2)
        try:
            return tfd.MultivariateNormalFullCovariance(m, K).log_prob(fbar)
        except:
            jitter = tf.linalg.diag(1e-4 *tf.ones(self.N_p))
            try:
                return np.float64(tfd.MultivariateNormalFullCovariance(m, K+jitter).log_prob(fbar))
            except:
                return tf.constant(-np.inf)


    def sample(self, T=20000, store_every=10, burn_in=1000, report_every=100, tune_every=50):
        print('----- Sampling Begins -----')
        
        self.acceptance_rates = {param.name: 0. for param in self.params} # Reset acceptance rates
        f = IntProgress(description='Running', min=0, max=T) # instantiate the bar
        display(f)
        params = self.params

        a = tf.constant(-0.5, dtype='float64')
        b2 = tf.constant(2., dtype='float64')

        def unnormalized_log_prob(x):
            return -x - x**2.

        test_kernel = tfp.mcmc.NoUTurnSampler(target_log_prob_fn=unnormalized_log_prob, step_size=1.)

        kernels = [
            test_kernel,
            params.δbar.kernel, # Translation ODE parameters
            KbarKernel(self.likelihood, params.kbar.proposal_dist, tfd.Normal(a, b2), self.num_genes),
            FKernel(self.likelihood, self.fbar_prior_params, self.num_tfs):
        ]
        send_all_states = [
            False,
            False,
            True,
            True,
        ]

        mixed_kern = MixedKernel(kernels, send_all_states)
        def trace_fn(a, pkr):
#             print(pkr) #TODO
            return pkr.inner_results[0].is_accepted
        num_results = int(1e2)
        num_burnin_steps = int(1e1)

        # Run the chain (with burn-in).
        @tf.function
        def run_chain():
            # Run the chain (with burn-in).
            samples, is_accepted = tfp.mcmc.sample_chain(
                  num_results=num_results,
                  num_burnin_steps=num_burnin_steps,
                  current_state=[1., params.δbar.value, params.kbar.value],
                  kernel=mixed_kern,
                  trace_fn=trace_fn)

            return samples, is_accepted

        samples, is_accepted = run_chain()
        

        print(samples)
        sample_mean = tf.reduce_mean(samples[0])
        sample_stddev = tf.math.reduce_std(samples[0])
        is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))

        # for key in self.acceptance_rates:
        #     self.acceptance_rates[key] /= T
        # rates = np.array(self.samples['acc_rates']).T/np.arange(1, T-burn_in+1, store_every)
        # self.samples['acc_rates'] = rates
        f.value = T
        print('----- Finished -----')
        return sample_mean, sample_stddev, is_accepted
        
            
            


In [None]:
model = TranscriptionCustom(data, opt)

sample_mean, sample_stddev, is_accepted = model.sample()

print('mean:{:.4f}  stddev:{:.4f}  acceptance:{:.4f}'.format(
    sample_mean.numpy(), sample_stddev.numpy(), is_accepted.numpy()))
#mean:-0.4878  stddev:0.7141  acceptance:0.8933

In [None]:


class TranscriptionHMC(MetropolisHastings):


    @tf.function
    def iter_delta(self, δbar, kernel):
        trace_fn = lambda _, pkr: pkr.is_accepted
        for i in range(self.num_tfs):# TODO make for self.num_tfs > 1

            samples, is_accepted = tfp.mcmc.sample_chain(
                num_results=2,
                num_burnin_steps=0,
                current_state=δbar,
                kernel=kernel,
                trace_fn=trace_fn)

        return samples, is_accepted
    
    @tf.function
    def iter_rbf_params(self, v, l2, kernel):
        return tfp.mcmc.sample_chain(
            num_results=2,
            num_burnin_steps=0,
            current_state=[v, l2],
            kernel=kernel,
            trace_fn=lambda _, pkr: pkr.is_accepted)
        
    def iterate(self):
        trace_fn = lambda _, pkr: pkr.is_accepted
        params = self.params
        # Compute likelihood for comparison
        old_m_likelihood, sq_diff_m  = self.likelihood.genes(params, return_sq_diff=True)
        old_f_likelihood = 0
        if self.options.tf_mrna_present:
            old_f_likelihood, sq_diff_f  = self.likelihood.tfs(params, params.fbar.value, return_sq_diff=True)
        


        if self.options.tf_mrna_present: # (Step 2)
            # Log of translation ODE degradation rates
            δbar = params.δbar.value
            samples, is_accepted = self.iter_delta(δbar, params.δbar.kernel)
            if is_accepted[-1]:
                params.δbar.value = samples[-1] #δstar
                self.acceptance_rates['δbar'] += 1/self.num_tfs

#             δbar = params.δbar.value
#             for i in range(self.num_tfs):# TODO make for self.num_tfs > 1
#                 # Proposal distribution
#                 δstar = params.δbar.propose(δbar) # δstar is in log-space, i.e. δstar = δbar*
#                 new_prob = np.sum(self.likelihood.genes(params, δbar=δstar)) + params.δbar.prior.log_prob(δstar)
#                 old_prob = np.sum(old_m_likelihood) + params.δbar.prior.log_prob(δbar)
#                 if self.is_accepted(new_prob, old_prob):
#                     params.δbar.value = δstar
#                     self.acceptance_rates['δbar'] += 1/self.num_tfs


        # Log of transcription ODE kinetic params (Step 3)
#         kbar = params.kbar.value
#         samples, is_accepted = tfp.mcmc.sample_chain(
#             num_results=2,
#             num_burnin_steps=0,
#             current_state=[kbar[j] for j in range(self.num_genes)],
#             kernel=params.kbar.kernel,
#             trace_fn=trace_fn)

#         for j in range(self.num_genes):
#             if is_accepted[-1]:
#                 params.kbar.value[j] = samples[j][-1]
#                 self.acceptance_rates['kbar'] += 1



        # Interaction weights and biases (note: should work for self.num_tfs > 1) (Step 4)
#         w = params.w.value
#         w_0 = params.w_0.value
#         samples, is_accepted = tfp.mcmc.sample_chain(
#             num_results=2,
#             num_burnin_steps=0,
#             current_state=[w, w_0],
#             kernel=params.w.kernel,
#             trace_fn=trace_fn)

#         if is_accepted[-1]:
#             params.w.value = samples[0][-1]
#             params.w_0.value = samples[1][-1]
#             self.acceptance_rates['w'] += 1/self.num_genes
#             self.acceptance_rates['w_0'] += 1/self.num_genes
        w = params.w.value
        w_0 = params.w_0.value
        wstar = w.copy()
        w_0star = w_0.copy()
        for j in range(self.num_genes):
            sample_0 = params.w_0.propose(w_0[j], j)
            sample = params.w.propose(wstar[j], j)
            wstar[j] = sample
            w_0star[j] = sample_0
            new_prob = self.likelihood.genes(params, w=wstar, w_0=w_0star)[j] + sum(params.w.prior.log_prob(sample)) + params.w_0.prior.log_prob(sample_0)
            old_prob = old_m_likelihood[j] + sum(params.w.prior.log_prob(w[j,:])) + params.w_0.prior.log_prob(w_0[j])
            if self.is_accepted(new_prob, old_prob):
                params.w.value[j] = sample
                params.w_0.value[j] = sample_0
                self.acceptance_rates['w'] += 1/self.num_genes
                self.acceptance_rates['w_0'] += 1/self.num_genes
            else:
                wstar[j] = params.w.value[j]


        # Noise variances
        if self.options.preprocessing_variance:
            σ2_m = params.σ2_m.value
            σ2_mstar = σ2_m.copy()
            for j in range(self.num_genes):
                sample = params.σ2_m.propose(σ2_m[j])
                σ2_mstar[j] = sample
                old_q = params.σ2_m.proposal_dist(σ2_mstar[j]).log_prob(σ2_m[j])
                new_prob = self.likelihood.genes(params, σ2_m=σ2_mstar)[j] +params.σ2_m.prior.log_prob(σ2_mstar[j])
                
                new_q = params.σ2_m.proposal_dist(σ2_m[j]).log_prob(σ2_mstar[j])
                old_prob = self.likelihood.genes(params, σ2_m=σ2_m)[j] + params.σ2_m.prior.log_prob(σ2_m[j])
                    
                if self.is_accepted(new_prob + old_q, old_prob + new_q):
                    params.σ2_m.value[j] = sample
                    self.acceptance_rates['σ2_m'] += 1/self.num_genes
                else:
                    σ2_mstar[j] = σ2_m[j]
        else: # Use Gibbs sampling
            # Prior parameters
            α = params.σ2_m.prior.concentration
            β = params.σ2_m.prior.scale
            # Conditional posterior of inv gamma parameters:
            α_post = α + 0.5*self.N_m
            β_post = β + 0.5*np.sum(sq_diff_m)
            # print(α.shape, sq_diff.shape)
            # print('val', β_post.shape, params.σ2_m.value)
            params.σ2_m.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_genes)
            self.acceptance_rates['σ2_m'] += 1
            
            if self.options.tf_mrna_present: # (Step 5)
                # Prior parameters
                α = params.σ2_f.prior.concentration
                β = params.σ2_f.prior.scale
                # Conditional posterior of inv gamma parameters:
                α_post = α + 0.5*self.N_m
                β_post = β + 0.5*np.sum(sq_diff_f)
                # print(α.shape, sq_diff.shape)
                # print('val', β_post.shape, params.σ2_m.value)
                params.σ2_f.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_tfs)
                self.acceptance_rates['σ2_f'] += 1

            # print('val', params.σ2_m.value)
        # Length scales and variances of GP kernels
        l2 = params.L.value
        v = params.V.value
        for i in range(self.num_tfs):
            # Proposal distributions
            # Acceptance probabilities            
            samples, is_accepted = self.iter_rbf_params(v, l2, params.V.kernel)

            if is_accepted[-1]:
                params.V.value = samples[0][-1]
                params.L.value = samples[1][-1]
        
#             if accepted:
#                 params.L.value = l2star
                self.acceptance_rates['V'] += 1/self.num_tfs
                self.acceptance_rates['L'] += 1/self.num_tfs

    @staticmethod
    def initialise_from_state(args, state):
        model = TranscriptionMCMC(*args)
        model.acceptance_rates = state.acceptance_rates
        model.samples = state.samples
        return model

    def predict_m(self, kbar, δbar, w, fbar, w_0):
        return self.likelihood.predict_m(kbar, δbar, w, fbar, w_0)

    def predict_m_with_current(self):
        return self.likelihood.predict_m(self.params.kbar.value, 
                                         self.params.δbar.value, 
                                         self.params.w.value, 
                                         self.params.fbar.value,
                                         self.params.w_0.value)



In [None]:
opt = Options(preprocessing_variance=True, tf_mrna_present=True)
model = TranscriptionHMC(data, opt)




In [None]:
T = 600
model.sample(T, 1, 0, 1)

print(model.acceptance_rates)
samples = model.samples
acceptance_rates = model.acceptance_rates


In [None]:
T = 1000
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50


In [None]:
# Chains
job = create_chains(
    transcription.TranscriptionMCMC, 
    [data, opt], 
    {
        'T': T, 
        'store_every': store_every, 
        'burn_in': burn_in,
        'report_every': report_every,
        'tune_every':tune_every
    }, 
    num_chains=num_chains)

    
print('Done')

## Convergence Plots

In [None]:
keys = job[0].acceptance_rates.keys()

variables = {key : np.empty((0, T, *job[0].samples[key].get().shape[1:])) for key in keys}

for res in job:
    for key in keys:
        variables[key] = np.append(variables[key], np.expand_dims(res.samples[key].get(), 0), axis=0)

plt.plot(variables['L'][:,-100:].T)

mixes = {key: arviz.convert_to_inference_data(variables[key]) for key in keys}

#### Rhat
Rhat is the ratio of posterior variance and within-chain variance. If the ratio exceeds 1.1 then we consider the chains have not mixed well. As the between-chain variance tends to the within-chain then R tends to 1.

In [None]:
Rhat = arviz.rhat(mixes['fbar'])

Rhats = np.array([np.mean(arviz.rhat(mixes[key]).x.values) for key in keys])

rhat_df = pd.DataFrame([[*Rhats], [*(Rhats < 1.1)]], columns=keys)

display(rhat_df)

#### Rank plots

Rank plots are histograms of the ranked posterior draws (ranked over all
    chains) plotted separately for each chain.
    If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
    uniform, whereas if one chain has a different location or scale parameter, this will be
    reflected in the deviation from uniformity. If rank plots of all chains look similar, this
    indicates good mixing of the chains.

Rank-normalization, folding, and localization: An improved R-hat
    for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008

In [None]:
arviz.plot_rank(L_mix)

#### Effective sample sizes

Plot quantile, local or evolution of effective sample sizes (ESS).

In [None]:
arviz.plot_ess(L_mix)

#### Monte-Carlo Standard Error

In [None]:
arviz.plot_mcse(L_mix)


#### Parallel Plot
Plot parallel coordinates plot showing posterior points with and without divergences.

Described by https://arxiv.org/abs/1709.01449, suggested by Ari Hartikainen


In [None]:
arviz.plot_parallel(azl)


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples

## Plots

In [None]:
samples = job[0].samples
acceptance_rates = job[0].acceptance_rates

model = transcription.TranscriptionMCMC.initialise_from_state([data, opt], job[0])

In [None]:
# Begin MCMC
T = 1000
model.sample(T, 1, 0, 1)

print(model.acceptance_rates)
samples = model.samples

In [None]:
# samples = transcription_model.samples
acceptance_rates = model.acceptance_rates
plt.figure(figsize=(10,14))
parameter_names = acceptance_rates.keys()
acc_rates = samples['acc_rates']

for i, name in enumerate(parameter_names):
    plt.subplot(len(parameter_names), 3, i+1)
    deltas = acc_rates[i]
    plt.plot(deltas)
    plt.title(name)
plt.tight_layout()

In [None]:
# Plot decay
plt.figure(figsize=(10, 8))
for i, param in enumerate(['δbar', 'L', 'V']):
    ax = plt.subplot(331+i)
    plt.plot(samples[param].get())
    ax.set_title(param)
#'σ', 'w']):

plt.figure()
for j in range(num_genes):
    plt.plot(samples['w'].get()[:, j], label=m_df.index[j])
plt.legend()
plt.title('Interaction weights')

plt.figure()
for j in range(num_genes):
    plt.plot(samples['w_0'].get()[:,j])
plt.title('Interaction bias')


### Plot transcription ODE kinetic params


In [None]:
plt.figure(figsize=(14, 14))
plt.title('Transcription ODE kinetic parameters')
labels = ['a', 'b', 'd', 's']
for j in range(num_genes):
    ax = plt.subplot(num_genes, 2, j+1)
    k_param = samples['kbar'].get()[:, j]
#     print(k_param)
    
    for k in range(4):
        plt.plot(k_param[-20000:, k], label=labels[k])
    plt.axhline(np.mean(k_param[-200:, 3]))
    plt.legend()
    ax.set_title(f'Gene {j}')

plt.tight_layout()


In [None]:
plot_barenco = True
def plot_kinetics(kbar, plot_barenco=False):
    plt.figure(figsize=(14, 14))
    k_latest = np.exp(np.mean(kbar[-100:], axis=0))
    B = k_latest[:,1]
    D = k_latest[:,2]
    S = k_latest[:,3]
    data = [B, S, D]
    barenco_data = [None, None, None]

    if plot_barenco:
        # From Martino paper ... do a rough rescaling so that the scales match.
        B_barenco = np.array([2.6, 1.5, 0.5, 0.2, 1.35])[[0, 4, 2, 3, 1]]
        B_barenco = B_barenco/np.mean(B_barenco)*np.mean(B)
        S_barenco = (np.array([3, 0.8, 0.7, 1.8, 0.7])/1.8)[[0, 4, 2, 3, 1]]
        S_barenco = S_barenco/np.mean(S_barenco)*np.mean(S)
        D_barenco = (np.array([1.2, 1.6, 1.75, 3.2, 2.3])*0.8/3.2)[[0, 4, 2, 3, 1]]
        D_barenco = D_barenco/np.mean(D_barenco)*np.mean(D)
        barenco_data = [B_barenco, S_barenco, D_barenco]

    labels = ['Basal rates', 'Sensitivities', 'Decay rates']

    plotnum = 331
    for A, B, label in zip(data, barenco_data, labels):
        plt.subplot(plotnum)
        plotnum+=1
        plt.bar(np.arange(num_genes)-0.2, A, width=0.4, tick_label=m_df.index, label='Model')
        if B is not None:
            plt.bar(np.arange(num_genes)+0.2, B, width=0.4, color='blue', align='center', label='Barenco et al.')
        plt.title(label)
        plt.legend()

kbar = samples['kbar'].get()
plot_kinetics(kbar, plot_barenco)

In [None]:
# Plot genes
plt.figure(figsize=(14, 17))
kbar = samples['kbar'].get()[-1]
δbar = samples['δbar'].get()[-1]
w = samples['w'].get()[-1]
fbar = samples['fbar'].get()[-1]
w_0 = samples['w_0'].get()[-1]
m_pred = model.predict_m(kbar, δbar, w, fbar, w_0)
print(np.arange(N_p)[common_indices])
for j in range(num_genes):
    ax = plt.subplot(531+j)
    plt.title(m_df.index[j])
    plt.scatter(common_indices, m_observed[j], marker='x')
    # plt.errorbar([n*10+n for n in range(7)], Y[j], 2*np.sqrt(Y_var[j]), fmt='none', capsize=5)
    plt.plot(m_pred[j,:], color='grey')
    plt.xticks(np.arange(N_p)[common_indices])
    ax.set_xticklabels(np.arange(t[-1]))
    plt.xlabel('Time (h)')
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(12, 10))
plt.title('Noise variances')
for i, j in enumerate(range(num_genes)):
    ax = plt.subplot(num_genes, num_genes-2, i+1)
    plt.title(m_df.index[j])
    plt.plot(samples['σ2_m'].get()[:,j])
    
plt.tight_layout()

In [None]:
def scaled_barenco_data(f):
    scale_pred = np.sqrt(np.var(f))
    barencof = np.array([[0.0, 200.52011, 355.5216125, 205.7574913, 135.0911372, 145.1080997, 130.7046969],
                         [0.0, 184.0994134, 308.47592, 232.1775328, 153.6595161, 85.7272235, 168.0910562],
                         [0.0, 230.2262511, 337.5994811, 276.941654, 164.5044287, 127.8653452, 173.6112139]])

    barencof = barencof[0]/(np.sqrt(np.var(barencof[0])))*scale_pred
    # measured_p53 = df[df.index.isin(['211300_s_at', '201746_at'])]
    # measured_p53 = measured_p53.mean(0)
    # measured_p53 = measured_p53*scale_pred
    measured_p53 = 0
    
    return barencof, measured_p53

def plot_f(f):
    fig = plt.figure(figsize=(13, 7))

    barencof = scaled_barenco_data(f)
    lb = len(barencof)
    plt.plot(np.arange(N_p), f, color='grey')
    plt.scatter(np.arange(0, N_p)[common_indices], barencof, marker='x')
    plt.xticks(np.arange(N_p)[common_indices])
    fig.axes[0].set_xticklabels(np.arange(N_m)*2)
    plt.xlabel('Time (h)')
    


In [None]:
fig = plt.figure(figsize=(13, 7))
f_samples = np.log(1+np.exp(np.array(samples['fbar'].get()[-50:])))
if 'σ2_f' in model.params._fields:
    σ2_f = model.params.σ2_f.value
    plt.errorbar(τ[common_indices], f_observed[0], 2*np.sqrt(σ2_f[0]), 
                 fmt='none', capsize=5, color='blue')
else:
    σ2_f = σ2_f_pre
    
bounds = arviz.hpd(f_samples, credible_interval=0.95)
for i in range(1,20):
    f_i = f_samples[-i]
#     plt.plot(f_i)
#     f_i[0] = 0
    kwargs = {}
    if i == 1:
        kwargs = {'label':'Samples'}
    plt.plot(τ, f_i, c='blue', alpha=0.5, **kwargs)

if plot_barenco:
    barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
    plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco et al.')

plt.scatter(τ[common_indices], f_observed[0], marker='x', s=70, linewidth=4, label='Observed')

plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.5, label='95% credibility interval')
plt.xticks(np.arange(N_m)*2)
fig.axes[0].set_xticklabels(t)
plt.xlabel('Time (h)')
plt.legend();

In [None]:
np.linspace(0,12,100)