## Metropolis Hastings Custom HMC MC Algorithm

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder, scaled_barenco_data
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.utilities import get_rbf_dist, discretise
from reggae.plot import plotters
from reggae.models import TranscriptionLikelihood, Options
from reggae.models.results import GenericResults
from reggae.models.kernels import MixedKernel, FKernel, KbarKernel

import tensorflow as tf
from tensorflow import math as tfm
import tensorflow_probability as tfp
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

import numpy as np
import pandas as pd
import arviz
from ipywidgets import IntProgress

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
#df, genes, genes_se, m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()
m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()

# m_observed, f_observed, t = load_3day_dros()

replicate = 0

m_df, m_observed = m_observed 
f_df, f_observed = f_observed

# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed[replicate]
f_observed = np.atleast_2d(f_observed[replicate])
σ2_m_pre = σ2_m_pre[0]
σ2_f_pre = σ2_f_pre[0]

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
noise_data = (σ2_m_pre, σ2_f_pre)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, noise_data, time)
N_p = τ.shape[0]


In [None]:
opt = Options(preprocessing_variance=True, tf_mrna_present=True)
lik = TranscriptionLikelihood(data, opt)


In [None]:
class GibbsKernel(tfp.mcmc.TransitionKernel):
    pass

In [None]:
import collections
TupleParams_pre = collections.namedtuple('TupleParams_pre', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V','σ2_f'
])
TupleParams = collections.namedtuple('TupleParams', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V'
])

class TranscriptionCustom():
    '''
    Data is a tuple (m, f) of shapes (num, time)
    time is a tuple (t, τ, common_indices)
    '''
    def __init__(self, data: DataHolder, options: Options):
        self.data = data
        min_dist = min(data.t[1:]-data.t[:-1])
        self.N_p = data.τ.shape[0]
        self.N_m = data.t.shape[0]      # Number of observations

        self.num_tfs = data.f_obs.shape[0] # Number of TFs
        self.num_genes = data.m_obs.shape[0]

        self.likelihood = TranscriptionLikelihood(data, options)
        self.options = options
        # Adaptable variances
        a = tf.constant(-0.5, dtype='float64')
        b2 = tf.constant(2., dtype='float64')
        self.state_indices = {
            'δbar': 0,
            'kbar': 1,
            'fbar': 2, 
            'rbf_params': 3,
            'σ2_m': 4,
            'w': 5,
        }

        # Interaction weights
        def w_log_prob(all_states):
            def w_log_prob_fn(wstar, w_0star):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                     w=wstar))
                new_prob += tf.reduce_sum(self.params.w.prior.log_prob(wstar)) 
                new_prob += tf.reduce_sum(self.params.w_0.prior.log_prob(w_0star))
                return tf.reduce_sum(new_prob)
            return w_log_prob_fn
        w = Parameter('w', tfd.Normal(f64(0), f64(2)), 
                      1*tf.ones((self.num_genes, self.num_tfs), dtype='float64'), 
                      step_size=0.04, hmc_log_prob=w_log_prob, requires_all_states=True)
        w_0 = Parameter('w_0', tfd.Normal(f64(0), f64(2)), tf.zeros(self.num_genes, dtype='float64'))

        # Latent function
        fbar_kernel = FKernel(self.likelihood, 
                              self.fbar_prior_params, 
                              self.num_tfs, self.num_genes, 
                              self.options.tf_mrna_present, 
                              self.state_indices,
                              0.3*tf.ones(N_p, dtype='float64'))
        fbar = Parameter('fbar', self.fbar_prior, 0.5*np.ones(self.N_p), kernel=fbar_kernel, requires_all_states=False)

        # GP hyperparameters
        def rbf_params_log_prob(all_states):
            def rbf_params_log_prob(vstar, l2star):
#                 tf.print(all_states[fbar_state_index])
                new_prob = self.params.fbar.prior(all_states[self.state_indices['fbar']], vstar, l2star)
                new_prob += self.params.V.prior.log_prob(vstar)
                new_prob += self.params.L.prior.log_prob(l2star)
#                 tf.print(new_prob)
                return tf.reduce_sum(new_prob)
            return rbf_params_log_prob

        V = Parameter('V', tfd.InverseGamma(f64(0.01), f64(0.01)), tf.constant(f64(1)), step_size=0.001, 
                      fixed=not options.tf_mrna_present, hmc_log_prob=rbf_params_log_prob, requires_all_states=True)
        L = Parameter('L', tfd.Uniform(f64(min_dist**2-0.2), f64(data.t[-1]**2)), tf.constant(f64(4)))

        self.t_dist = get_rbf_dist(data.τ, self.N_p)

        # Translation kinetic parameters
        def δbar_log_prob(all_states):
            def δbar_log_prob_fn(state):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    δbar=state
                ))
                new_prob += self.params.δbar.prior.log_prob(state)
                return new_prob
            return δbar_log_prob_fn
        δbar = Parameter('δbar', tfd.Normal(a, b2), f64(-0.3), step_size=0.1, 
                         hmc_log_prob=δbar_log_prob, requires_all_states=True)

        # White noise for genes
        def σ2_m_log_prob(all_states):
            def σ2_m_log_prob_fn(σ2_mstar):
#                 tf.print('star:',σ2_mstar)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    σ2_m=σ2_mstar 
                ) + self.params.σ2_m.prior.log_prob(σ2_mstar)
#                 tf.print('prob', tf.reduce_sum(new_prob))
                return tf.reduce_sum(new_prob)                
            return σ2_m_log_prob_fn
        σ2_m = Parameter('σ2_m', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*tf.ones(self.num_genes, dtype='float64'), 
                         hmc_log_prob=σ2_m_log_prob, requires_all_states=True, step_size=0.00001)
        # Transcription kinetic parameters
        def constrain_kbar(kbar, gene):
            '''Constrains a given row in kbar'''
#             if gene == 3:
#                 kbar[2] = np.log(0.8)
#                 kbar[3] = np.log(1.0)
            kbar[kbar < -10] = -10
            kbar[kbar > 3] = 3
            return kbar
        kbar_initial = -0.1*np.float64(np.c_[
            np.ones(self.num_genes), # a_j
            np.ones(self.num_genes), # b_j
            np.ones(self.num_genes), # d_j
            np.ones(self.num_genes)  # s_j
        ])
        def kbar_log_prob(all_states):
            def kbar_log_prob_fn(kstar):
#                 tf.print(kstar)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    kbar=kstar, 
                ) + tf.reduce_sum(self.params.kbar.prior.log_prob(kstar))
#                 tf.print(new_prob)
                return tf.reduce_sum(new_prob)
            return kbar_log_prob_fn
        for j, k in enumerate(kbar_initial):
            kbar_initial[j] = constrain_kbar(k, j)
        kbar = Parameter('kbar', tfd.Normal(a, b2), 
                         kbar_initial,
                         hmc_log_prob=kbar_log_prob,
                         constraint=constrain_kbar, step_size=0.05, requires_all_states=True)
        
        if not options.preprocessing_variance:
            σ2_f = Parameter('σ2_f', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*np.ones(self.num_tfs), step_size=tf.constant(0.5, dtype='float64'))
            self.params = TupleParams_pre(fbar, δbar, kbar, σ2_m, w, w_0, L, V, σ2_f)
        else:
            self.params = TupleParams(fbar, δbar, kbar, σ2_m, w, w_0, L, V)
            
    def fbar_prior_params(self, v, l2):
    #     print('vl2', v, l2)
        jitter = tf.linalg.diag(1e-5 * tf.ones(self.N_p, dtype='float64'))
        K = tfm.multiply(v, tfm.exp(-tfm.square(self.t_dist)/(2*l2))) + jitter
        m = np.zeros(self.N_p)
        return m, K

    def fbar_prior(self, fbar, v, l2):
#         tf.print(v, l2)
        m, K = self.fbar_prior_params(v, l2)
        try:
            return tfd.MultivariateNormalFullCovariance(m, K).log_prob(fbar)
        except:
            jitter = tf.linalg.diag(1e-4 *tf.ones(self.N_p))
            try:
                return np.float64(tfd.MultivariateNormalFullCovariance(m, K+jitter).log_prob(fbar))
            except:
                tf.print("error")
                return tf.constant(-np.inf)


    def sample(self, T=2000, store_every=10, burn_in=1000, report_every=100):
        print('----- Sampling Begins -----')
        
        f = IntProgress(description='Running', min=0, max=T) # instantiate the bar
        display(f)
        params = self.params
        progbar = tf.keras.utils.Progbar(
            100, width=30, verbose=1, interval=0.05, stateful_metrics=None,
            unit_name='step'
        )

        print(params)
        active_params = [
            params.δbar,
            params.kbar,
            params.fbar,
            params.V,
            params.σ2_m,
            #params.w,
        ]
        kernels = [param.kernel for param in active_params]
#         if self.options.tf_mrna_present:
        send_all_states = [param.requires_all_states for param in active_params]

        current_state = [
            params.δbar.value, 
            params.kbar.value, 
            params.fbar.value, 
            [params.V.value, params.L.value],
            params.σ2_m.value,
            #[params.w.value, params.w_0.value]
        ]
        mixed_kern = MixedKernel(kernels, send_all_states)

        def trace_fn(a, previous_kernel_results):
            if hasattr(previous_kernel_results.inner_results[0], 'is_accepted'):
                return previous_kernel_results.inner_results[0].is_accepted
            return previous_kernel_results.inner_results[0].inner_results.is_accepted

        # Run the chain (with burn-in).
        @tf.function
        def run_chain():
            # Run the chain (with burn-in).
            samples, is_accepted = tfp.mcmc.sample_chain(
                  num_results=T,
                  num_burnin_steps=burn_in,
                  current_state=current_state,
                  kernel=mixed_kern,
                  trace_fn=trace_fn)

            return samples, is_accepted

        samples, is_accepted = run_chain()
        
        self.is_accepted = is_accepted
        self.samples = samples
        
        is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))

        # for key in self.acceptance_rates:
        #     self.acceptance_rates[key] /= T
        # rates = np.array(self.samples['acc_rates']).T/np.arange(1, T-burn_in+1, store_every)
        # self.samples['acc_rates'] = rates
        f.value = T
        print('----- Finished -----')
        return samples, is_accepted
        
            
            

In [None]:
T = 1000
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50

model = TranscriptionCustom(data, opt)

samples, is_accepted = model.sample(T=700, burn_in=0)


In [None]:
kbar = model.samples[model.state_indices['kbar']]
δbar = model.samples[model.state_indices['δbar']]
fbar = model.samples[model.state_indices['fbar']]
σ2_m = model.samples[model.state_indices['σ2_m']]
rbf_params = model.samples[model.state_indices['rbf_params']]
# w = model.samples[model.state_indices['w']][0]
# w_0 = model.samples[model.state_indices['w']][1]

w = [1*tf.ones((num_genes, 1), dtype='float64')] # TODO
w_0 = [tf.zeros(num_genes, dtype='float64')] # TODO


In [None]:
plot_barenco = True

fig = plt.figure(figsize=(13, 7))
f_samples = np.log(1+np.exp(np.array(fbar[-50:])))
print(f_samples.shape)
if 'σ2_f' in model.params._fields:
    σ2_f = model.params.σ2_f.value
    plt.errorbar(τ[common_indices], f_observed[0], 2*np.sqrt(σ2_f[0]), 
                 fmt='none', capsize=5, color='blue')
else:
    σ2_f = σ2_f_pre
    
bounds = arviz.hpd(f_samples, credible_interval=0.95)
for i in range(1,20):
    f_i = f_samples[-i]
#     plt.plot(f_i)
#     f_i[0] = 0
    kwargs = {}
    if i == 1:
        kwargs = {'label':'Samples'}
    plt.plot(τ, f_i, c='blue', alpha=0.5, **kwargs)

if plot_barenco:
    barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
    plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco et al.')

plt.scatter(τ[common_indices], f_observed[0], marker='x', s=70, linewidth=4, label='Observed')

plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.5, label='95% credibility interval')
plt.xticks(t)
fig.axes[0].set_xticklabels(t)
plt.ylim((-1,5))
plt.xlabel('Time (h)')
plt.legend();

In [None]:
# Plot genes
plt.figure(figsize=(14, 17))

m_pred = model.likelihood.predict_m(kbar[-1], δbar[-1], w[-1], fbar[-1], w_0[-1])

for j in range(num_genes):
    ax = plt.subplot(531+j)
    plt.title(m_df.index[j])
    plt.scatter(common_indices, m_observed[j], marker='x')
    # plt.errorbar([n*10+n for n in range(7)], Y[j], 2*np.sqrt(Y_var[j]), fmt='none', capsize=5)
    plt.plot(m_pred[j,:], color='grey')
    plt.xticks(np.arange(N_p)[common_indices])
    ax.set_xticklabels(np.arange(t[-1]))
    plt.xlabel('Time (h)')
    
plt.tight_layout()

In [None]:
# Plot decay
plt.figure(figsize=(10, 8))
i = 0
for name, param in (zip(['δbar', 'L', 'V'], [δbar, *rbf_params])):
    ax = plt.subplot(331+i)
    plt.plot(param)
    ax.set_title(name)
    i+=1 

### Plot transcription ODE kinetic params


In [None]:
plotters.plot_kinetics_convergence(kbar)

In [None]:
plotters.plot_kinetics(m_df, kbar, plot_barenco)

In [None]:
plt.figure()
for j in range(num_genes):
    plt.plot(w[:, j], label=m_df.index[j])
plt.legend()
plt.title('Interaction weights')

plt.figure()
for j in range(num_genes):
    plt.plot(w_0[:,j])
plt.title('Interaction bias')


In [None]:
plt.figure(figsize=(12, 10))
plt.title('Noise variances')
for i, j in enumerate(range(num_genes)):
    ax = plt.subplot(num_genes, num_genes-2, i+1)
    plt.title(m_df.index[j])
    plt.plot(σ2_m[:,j])
    
plt.tight_layout()

In [None]:

class TranscriptionHMC(MetropolisHastings):

    def iterate(self):
        trace_fn = lambda _, pkr: pkr.is_accepted
        params = self.params
        # Compute likelihood for comparison
        old_m_likelihood, sq_diff_m  = self.likelihood.genes(params, return_sq_diff=True)
        old_f_likelihood = 0
        if self.options.tf_mrna_present:
            old_f_likelihood, sq_diff_f  = self.likelihood.tfs(params, params.fbar.value, return_sq_diff=True)

        # Noise variances
        if self.options.preprocessing_variance:
            σ2_m = params.σ2_m.value
            σ2_mstar = σ2_m.copy()
            for j in range(self.num_genes):
                sample = params.σ2_m.propose(σ2_m[j])
                σ2_mstar[j] = sample
                old_q = params.σ2_m.proposal_dist(σ2_mstar[j]).log_prob(σ2_m[j])
                new_prob = self.likelihood.genes(params, σ2_m=σ2_mstar)[j] +params.σ2_m.prior.log_prob(σ2_mstar[j])
                
                new_q = params.σ2_m.proposal_dist(σ2_m[j]).log_prob(σ2_mstar[j])
                old_prob = self.likelihood.genes(params, σ2_m=σ2_m)[j] + params.σ2_m.prior.log_prob(σ2_m[j])
                    
                if self.is_accepted(new_prob + old_q, old_prob + new_q):
                    params.σ2_m.value[j] = sample
                    self.acceptance_rates['σ2_m'] += 1/self.num_genes
                else:
                    σ2_mstar[j] = σ2_m[j]
        else: # Use Gibbs sampling
            # Prior parameters
            α = params.σ2_m.prior.concentration
            β = params.σ2_m.prior.scale
            # Conditional posterior of inv gamma parameters:
            α_post = α + 0.5*self.N_m
            β_post = β + 0.5*np.sum(sq_diff_m)
            # print(α.shape, sq_diff.shape)
            # print('val', β_post.shape, params.σ2_m.value)
            params.σ2_m.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_genes)
            self.acceptance_rates['σ2_m'] += 1
            
            if self.options.tf_mrna_present: # (Step 5)
                # Prior parameters
                α = params.σ2_f.prior.concentration
                β = params.σ2_f.prior.scale
                # Conditional posterior of inv gamma parameters:
                α_post = α + 0.5*self.N_m
                β_post = β + 0.5*np.sum(sq_diff_f)
                # print(α.shape, sq_diff.shape)
                # print('val', β_post.shape, params.σ2_m.value)
                params.σ2_f.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_tfs)
                self.acceptance_rates['σ2_f'] += 1

            # print('val', params.σ2_m.value)

    @staticmethod
    def initialise_from_state(args, state):
        model = TranscriptionMCMC(*args)
        model.acceptance_rates = state.acceptance_rates
        model.samples = state.samples
        return model

    def predict_m(self, kbar, δbar, w, fbar, w_0):
        return self.likelihood.predict_m(kbar, δbar, w, fbar, w_0)

    def predict_m_with_current(self):
        return self.likelihood.predict_m(self.params.kbar.value, 
                                         self.params.δbar.value, 
                                         self.params.w.value, 
                                         self.params.fbar.value,
                                         self.params.w_0.value)



In [None]:
opt = Options(preprocessing_variance=True, tf_mrna_present=True)
model = TranscriptionHMC(data, opt)




In [None]:
T = 600
model.sample(T, 1, 0, 1)

print(model.acceptance_rates)
samples = model.samples
acceptance_rates = model.acceptance_rates


In [None]:
T = 1000
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50


In [None]:
# Chains
job = create_chains(
    transcription.TranscriptionMCMC, 
    [data, opt], 
    {
        'T': T, 
        'store_every': store_every, 
        'burn_in': burn_in,
        'report_every': report_every,
        'tune_every':tune_every
    }, 
    num_chains=num_chains)

    
print('Done')

## Convergence Plots

In [None]:
keys = job[0].acceptance_rates.keys()

variables = {key : np.empty((0, T, *job[0].samples[key].get().shape[1:])) for key in keys}

for res in job:
    for key in keys:
        variables[key] = np.append(variables[key], np.expand_dims(res.samples[key].get(), 0), axis=0)

plt.plot(variables['L'][:,-100:].T)

mixes = {key: arviz.convert_to_inference_data(variables[key]) for key in keys}

#### Rhat
Rhat is the ratio of posterior variance and within-chain variance. If the ratio exceeds 1.1 then we consider the chains have not mixed well. As the between-chain variance tends to the within-chain then R tends to 1.

In [None]:
Rhat = arviz.rhat(mixes['fbar'])

Rhats = np.array([np.mean(arviz.rhat(mixes[key]).x.values) for key in keys])

rhat_df = pd.DataFrame([[*Rhats], [*(Rhats < 1.1)]], columns=keys)

display(rhat_df)

#### Rank plots

Rank plots are histograms of the ranked posterior draws (ranked over all
    chains) plotted separately for each chain.
    If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
    uniform, whereas if one chain has a different location or scale parameter, this will be
    reflected in the deviation from uniformity. If rank plots of all chains look similar, this
    indicates good mixing of the chains.

Rank-normalization, folding, and localization: An improved R-hat
    for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008

In [None]:
arviz.plot_rank(L_mix)

#### Effective sample sizes

Plot quantile, local or evolution of effective sample sizes (ESS).

In [None]:
arviz.plot_ess(L_mix)

#### Monte-Carlo Standard Error

In [None]:
arviz.plot_mcse(L_mix)


#### Parallel Plot
Plot parallel coordinates plot showing posterior points with and without divergences.

Described by https://arxiv.org/abs/1709.01449, suggested by Ari Hartikainen


In [None]:
arviz.plot_parallel(azl)


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples

## Plots

In [None]:
samples = job[0].samples
acceptance_rates = job[0].acceptance_rates

model = transcription.TranscriptionMCMC.initialise_from_state([data, opt], job[0])

In [None]:
# samples = transcription_model.samples
acceptance_rates = model.acceptance_rates
plt.figure(figsize=(10,14))
parameter_names = acceptance_rates.keys()
acc_rates = samples['acc_rates']

for i, name in enumerate(parameter_names):
    plt.subplot(len(parameter_names), 3, i+1)
    deltas = acc_rates[i]
    plt.plot(deltas)
    plt.title(name)
plt.tight_layout()