## Metropolis Hastings Custom HMC MC Algorithm

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder, scaled_barenco_data
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.utilities import get_rbf_dist, discretise, logit, LogisticNormal
from reggae.plot import plotters
from reggae.models import TranscriptionLikelihood, Options
from reggae.models import transcription_nuts_merge
from reggae.models.results import GenericResults
from reggae.models.kernels import MixedKernel, FKernel, KbarKernel

import tensorflow as tf
from tensorflow import math as tfm
import tensorflow_probability as tfp
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

import numpy as np
import pandas as pd
import arviz
from ipywidgets import IntProgress

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
l = LogisticNormal(-1, 1)

xs = np.linspace(-2, 2, 100)
ll = [l.log_prob(x).numpy() for x in xs]
print(ll)
plt.plot(xs, ll)
# plt.ylim(0, 0.5)
# plt.ylim(0, 5)

In [None]:
#df, genes, genes_se, m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()
m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()

# m_observed, f_observed, t = load_3day_dros()

replicate = 0

m_df, m_observed = m_observed 
f_df, f_observed = f_observed

# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed[replicate]
f_observed = np.atleast_2d(f_observed[replicate])
σ2_m_pre = σ2_m_pre[0]
σ2_f_pre = σ2_f_pre[0]

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
noise_data = (σ2_m_pre, σ2_f_pre)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, noise_data, time)
N_p = τ.shape[0]


In [None]:
opt = Options(preprocessing_variance=True, tf_mrna_present=True)
lik = transcription_nuts_merge.TranscriptionLikelihood(data, opt)

print(max(np.var(data.f_obs, axis=1)))

In [None]:
class GibbsKernel(tfp.mcmc.TransitionKernel):
    
    def one_step(self, current_state, previous_kernel_results, all_states):

        if self.options.preprocessing_variance:
            pass
        else: # Use Gibbs sampling
            # Prior parameters
            α = params.σ2_m.prior.concentration
            β = params.σ2_m.prior.scale
            # Conditional posterior of inv gamma parameters:
            α_post = α + 0.5*self.N_m
            β_post = β + 0.5*np.sum(sq_diff_m)
            # print(α.shape, sq_diff.shape)
            # print('val', β_post.shape, params.σ2_m.value)
            params.σ2_m.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_genes)
            self.acceptance_rates['σ2_m'] += 1
            
            if self.options.tf_mrna_present: # (Step 5)
                # Prior parameters
                α = params.σ2_f.prior.concentration
                β = params.σ2_f.prior.scale
                # Conditional posterior of inv gamma parameters:
                α_post = α + 0.5*self.N_m
                β_post = β + 0.5*np.sum(sq_diff_f)
                # print(α.shape, sq_diff.shape)
                # print('val', β_post.shape, params.σ2_m.value)
                params.σ2_f.value = np.repeat(tfd.InverseGamma(α_post, β_post).sample(), self.num_tfs)
                self.acceptance_rates['σ2_f'] += 1


In [None]:
import collections
TupleParams_pre = collections.namedtuple('TupleParams_pre', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V', 'kinetics', 'σ2_f',
])
TupleParams = collections.namedtuple('TupleParams', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V', 'kinetics'
])

class TranscriptionCustom():
    '''
    Data is a tuple (m, f) of shapes (num, time)
    time is a tuple (t, τ, common_indices)
    '''
    def __init__(self, data: DataHolder, options: Options):
        self.data = data
        self.samples = None
        min_dist = min(data.t[1:]-data.t[:-1])
        self.N_p = data.τ.shape[0]
        self.N_m = data.t.shape[0]      # Number of observations

        self.num_tfs = data.f_obs.shape[0] # Number of TFs
        self.num_genes = data.m_obs.shape[0]

        self.likelihood = transcription_nuts_merge.TranscriptionLikelihood(data, options)
        self.options = options
        # Adaptable variances
        a = tf.constant(-0.5, dtype='float64')
        b2 = tf.constant(2., dtype='float64')
        self.state_indices = {
            'kinetics': 0,
            'fbar': 1, 
            'rbf_params': 2,
            'σ2_m': 3,
            'w': 4,
        }
        logistic_step_size = 0.001

        # Interaction weights
        def w_log_prob(all_states):
            def w_log_prob_fn(wstar, w_0star):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                     w=wstar))
                new_prob += tf.reduce_sum(self.params.w.prior.log_prob(wstar)) 
                new_prob += tf.reduce_sum(self.params.w_0.prior.log_prob(w_0star))
                return tf.reduce_sum(new_prob)
            return w_log_prob_fn
        w = Parameter('w', tfd.Normal(f64(0), f64(2)), 
                      1*tf.ones((self.num_genes, self.num_tfs), dtype='float64'), 
                      step_size=0.04, hmc_log_prob=w_log_prob, requires_all_states=True)
        w_0 = Parameter('w_0', tfd.Normal(f64(0), f64(2)), tf.zeros(self.num_genes, dtype='float64'))

        # Latent function
        fbar_kernel = FKernel(self.likelihood, 
                              self.fbar_prior_params, 
                              self.num_tfs, self.num_genes, 
                              self.options.tf_mrna_present, 
                              self.state_indices,
                              0.2*tf.ones(N_p, dtype='float64'))
        fbar = Parameter('fbar', self.fbar_prior, 0.5*tf.ones((self.num_tfs, self.N_p), dtype='float64'),
                         kernel=fbar_kernel, requires_all_states=False)

        # GP hyperparameters
        def rbf_params_log_prob(all_states):
            def rbf_params_log_prob(vbar, l2bar):
                v = logit(vbar, nan_replace=self.params.V.prior.b)
                l2 = logit(l2bar, nan_replace=self.params.L.prior.b)

                new_prob = self.params.fbar.prior(all_states[self.state_indices['fbar']], vbar, l2bar)
#                 tf.print(new_prob)
                new_prob += self.params.V.prior.log_prob(v)
                new_prob += self.params.L.prior.log_prob(l2)
#                 tf.print('new prob', new_prob)
#                 if new_prob < -1e3:
#                     tf.print(all_states[self.state_indices['fbar']], v, l2)
                return tf.reduce_sum(new_prob)
            return rbf_params_log_prob

        V = Parameter('rbf_params', LogisticNormal(f64(1e-4), f64(1+max(np.var(data.f_obs, axis=1))),allow_nan_stats=False), 
                      [tf.constant(f64(0.8)), tf.constant(f64(0.98))], step_size=logistic_step_size, 
                      fixed=not options.tf_mrna_present, hmc_log_prob=rbf_params_log_prob, requires_all_states=True)
        L = Parameter('L', LogisticNormal(f64(min_dist**2-0.2), f64(data.t[-1]**2), allow_nan_stats=False), None)

        self.t_dist = get_rbf_dist(data.τ, self.N_p)

        # Translation kinetic parameters
        def δbar_log_prob(all_states):
            def δbar_log_prob_fn(state):
#                 chain_probs = list()
#                 new_prob = None
#                 for chain in range(state.shape[0]):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    δbar=state
                ))
                new_prob += self.params.δbar.prior.log_prob(logit(state))
#                 chain_probs.append(new_prob)
                    
#                 tf.print(tf.constant(chain_probs))
                return new_prob
            return δbar_log_prob_fn
        δbar = Parameter('δbar', LogisticNormal(0.1, 8), tf.reshape(f64(0.6), (self.num_tfs,)), step_size=logistic_step_size, 
                         hmc_log_prob=δbar_log_prob, requires_all_states=True)

        # White noise for genes
        def σ2_m_log_prob(all_states):
            def σ2_m_log_prob_fn(σ2_mstar):
#                 tf.print('star:',σ2_mstar)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    σ2_m=σ2_mstar 
                ) + self.params.σ2_m.prior.log_prob(logit(σ2_mstar))
#                 tf.print('prob', tf.reduce_sum(new_prob))
                return tf.reduce_sum(new_prob)                
            return σ2_m_log_prob_fn
        σ2_m = Parameter('σ2_m', LogisticNormal(f64(1e-5), f64(max(np.var(data.f_obs, axis=1)))), 1e-4*tf.ones(self.num_genes, dtype='float64'), 
                         hmc_log_prob=σ2_m_log_prob, requires_all_states=True, step_size=logistic_step_size)
        # Transcription kinetic parameters
        def constrain_kbar(kbar, gene):
            '''Constrains a given row in kbar'''
#             if gene == 3:
#                 kbar[2] = np.log(0.8)
#                 kbar[3] = np.log(1.0)
            kbar[kbar < -10] = -10
            kbar[kbar > 3] = 3
            return kbar
        kbar_initial = 0.6*np.float64(np.c_[ # was -0.1
            np.ones(self.num_genes), # a_j
            np.ones(self.num_genes), # b_j
            np.ones(self.num_genes), # d_j
            np.ones(self.num_genes)  # s_j
        ])
        def kbar_log_prob(all_states):
            def kbar_log_prob_fn(kbar, δbar):
#                 tf.print(kstar)
                k = logit(kbar)
#                 tf.print(k)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    kbar=kbar,
                    δbar=δbar
                )
                new_prob += tf.reduce_sum(self.params.kbar.prior.log_prob(k))
    
                new_prob += self.params.δbar.prior.log_prob(logit(δbar))

#                 tf.print(new_prob)
                return tf.reduce_sum(new_prob)
            return kbar_log_prob_fn
        for j, k in enumerate(kbar_initial):
            kbar_initial[j] = constrain_kbar(k, j)
        kbar = Parameter('kbar', LogisticNormal(0.01, 8), 
                         kbar_initial, constraint=constrain_kbar)
        
        kinetics = Parameter('kinetics', None, 
                         [kbar.value, δbar.value],
                         hmc_log_prob=kbar_log_prob,
                         constraint=constrain_kbar, step_size=logistic_step_size, requires_all_states=True)
        
        if not options.preprocessing_variance:
            σ2_f = Parameter('σ2_f', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*np.ones(self.num_tfs), step_size=tf.constant(0.5, dtype='float64'))
            self.params = TupleParams_pre(fbar, δbar, kbar, σ2_m, w, w_0, L, V, kinetics, σ2_f)
        else:
            self.params = TupleParams(fbar, δbar, kbar, σ2_m, w, w_0, L, V, kinetics)
            
    def fbar_prior_params(self, vbar, l2bar):
        v = logit(vbar, nan_replace=self.params.V.prior.b)
        l2 = logit(l2bar, nan_replace=self.params.L.prior.b)

#         tf.print('vl2', v, l2)
        jitter = tf.linalg.diag(1e-10 * tf.ones(self.N_p, dtype='float64'))
        K = tfm.multiply(v, tfm.exp(-tfm.square(self.t_dist)/(2*l2))) + jitter
        m = tf.zeros((self.N_p), dtype='float64')
        return m, K

    def fbar_prior(self, fbar, v, l2):
        m, K = self.fbar_prior_params(v, l2)
        jitter = tf.linalg.diag(1e-6 *tf.ones(self.N_p, dtype='float64'))

#         try:
        return tfd.MultivariateNormalTriL(loc=m, scale_tril=tf.linalg.cholesky(K+jitter)).log_prob(fbar)
#         except:
#             jitter = tf.linalg.diag(1e-4 *tf.ones(self.N_p, dtype='float64'))
#             try:
#                 return tfd.MultivariateNormalFullCovariance(m, K+jitter).log_prob(fbar)
#             except Exception as e:
#                 tf.print("Fbar prior error", e)
#                 raise e
#                 return tf.constant(-np.inf, dtype='float64')


    def sample(self, T=2000, store_every=10, burn_in=1000, report_every=100, num_chains=4):
        print('----- Sampling Begins -----')
        
        f = IntProgress(description='Running', min=0, max=T) # instantiate the bar
        display(f)
        params = self.params
        progbar = tf.keras.utils.Progbar(
            100, width=30, verbose=1, interval=0.05, stateful_metrics=None,
            unit_name='step'
        )

        active_params = [
            params.kinetics,
            params.fbar,
            params.V,
            params.σ2_m,
            #params.w,
        ]
        kernels = [param.kernel for param in active_params]
#         if self.options.tf_mrna_present:
        send_all_states = [param.requires_all_states for param in active_params]

        current_state = [
#             tf.stack([params.δbar.value for _ in range(num_chains)], axis=0),
            params.kinetics.value, 
            params.fbar.value, 
            [*params.V.value],
            params.σ2_m.value,
            #[params.w.value, params.w_0.value]
        ]
        mixed_kern = MixedKernel(kernels, send_all_states)
        
        def trace_fn(a, previous_kernel_results):
            return previous_kernel_results.is_accepted

        # Run the chain (with burn-in).
        @tf.function
        def run_chain():
            # Run the chain (with burn-in).
            samples, is_accepted = tfp.mcmc.sample_chain(
                  num_results=T,
                  num_burnin_steps=burn_in,
                  current_state=current_state,
                  kernel=mixed_kern,
                  trace_fn=trace_fn)

            return samples, is_accepted

        samples, is_accepted = run_chain()

        add_to_previous = (self.samples is not None)
        for param in active_params:
            index = self.state_indices[param.name]
            param_samples = samples[index]
            if type(param_samples) is list:
                param_samples = [[param_samples[i][-1] for i in range(len(param_samples))]]
            
            param.value = param_samples[-1]

            if add_to_previous:
                self.samples[index] = tf.concat([self.samples[index], samples[index]], axis=0)
        
        if not add_to_previous:
            self.samples = samples     
        self.is_accepted = is_accepted
        f.value = T
        print('----- Finished -----')
        return samples, is_accepted
        
            
    @staticmethod
    def initialise_from_state(args, state):
        model = TranscriptionMCMC(*args)
        model.acceptance_rates = state.acceptance_rates
        model.samples = state.samples
        return model

    def predict_m(self, kbar, δbar, w, fbar, w_0):
        return self.likelihood.predict_m(kbar, δbar, w, fbar, w_0)

    def predict_m_with_current(self):
        return self.likelihood.predict_m(self.params.kbar.value, 
                                         self.params.δbar.value, 
                                         self.params.w.value, 
                                         self.params.fbar.value,
                                         self.params.w_0.value)


In [None]:
np.set_printoptions(threshold=np.inf)
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
T = 1000
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50

model = TranscriptionCustom(data, opt)

samples, is_accepted = model.sample(T=200, burn_in=0)


In [None]:
pcs = list()
for i, param in enumerate(model.state_indices):
    print(i)
    if i == 4:
        break
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)[:-1]))

In [None]:
kbar = model.samples[model.state_indices['kinetics']][0]
δbar = model.samples[model.state_indices['kinetics']][1]
fbar = model.samples[model.state_indices['fbar']]
print(fbar.shape, δbar.shape)
σ2_m = model.samples[model.state_indices['σ2_m']]
rbf_params = model.samples[model.state_indices['rbf_params']]

# w = model.samples[model.state_indices['w']][0]
# w_0 = model.samples[model.state_indices['w']][1]

w = [1*tf.ones((num_genes, 1), dtype='float64')] # TODO
w_0 = [tf.zeros(num_genes, dtype='float64')] # TODO

m_preds = list()
for i in range(1, 20):
    m_preds.append(model.likelihood.predict_m(kbar[-i], δbar[-i], w[-1], fbar[-i][0], w_0[-1])) #todo w[-1]
m_preds = np.array(m_preds)

f_samples = np.log(1+np.exp(fbar))
δ_samples = logit(δbar)
k_samples = logit(kbar)
rbf_params_samples = [logit(rbf_params[0]), logit(rbf_params[1])] 



plotters.generate_report(data, k_samples, δ_samples, f_samples, 
                         σ2_m, rbf_params_samples, m_preds, 
                         plot_barenco=True, gene_names=m_df.index)


In [None]:
plot_barenco = True
plt.plot(fbar[:, 0, 0])
fig = plt.figure(figsize=(13, 7))
f_samples = np.log(1+np.exp(fbar[-100:, 0,:]))
print(f_samples.shape)
if 'σ2_f' in model.params._fields:
    σ2_f = model.params.σ2_f.value
    plt.errorbar(τ[common_indices], f_observed[0], 2*np.sqrt(σ2_f[0]), 
                 fmt='none', capsize=5, color='blue')
else:
    σ2_f = σ2_f_pre
    
bounds = arviz.hpd(f_samples, credible_interval=0.95)
for i in range(1,20):
    f_i = f_samples[-i]
#     plt.plot(f_i)
#     f_i[0] = 0
    kwargs = {}
    if i == 1:
        kwargs = {'label':'Samples'}
    plt.plot(τ, f_i, c='cadetblue', alpha=0.5, **kwargs)

# if plot_barenco:
#     barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
#     plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco et al.')

plt.scatter(τ[common_indices], f_observed[0], marker='x', s=70, linewidth=3, label='Observed')

plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.xticks(t)
fig.axes[0].set_xticklabels(t)
plt.ylim((-1,5))
plt.xlabel('Time (h)')
plt.legend();

In [None]:

plt.figure()
num_genes = kbar.shape[1]
k_latest = np.mean(logit(kbar[-10:]), axis=0)
print(k_latest)
B = k_latest[:,1]
D = k_latest[:,2]
S = k_latest[:,3]

plt.bar(np.arange(num_genes)-0.2, B, width=0.2, tick_label=m_df.index, label='Basal rate')
plt.bar(np.arange(num_genes), D, width=0.2, tick_label=m_df.index, label='Sensitivity')
plt.bar(np.arange(num_genes)+0.2, S, width=0.2, tick_label=m_df.index, label='Decay rate')
plt.yscale('log')
plt.title('Mechanistic Parameters')
plt.legend()


In [None]:
plt.figure()

for j in range(num_genes):
    plt.plot(w[:, j], label=m_df.index[j])
plt.legend()
plt.title('Interaction weights')

plt.figure()
for j in range(num_genes):
    plt.plot(w_0[:,j])
plt.title('Interaction bias')


In [None]:
plt.figure(figsize=(12, 10))
plt.title('Noise variances')
for i, j in enumerate(range(num_genes)):
    ax = plt.subplot(num_genes, num_genes-2, i+1)
    plt.title(m_df.index[j])
    plt.plot(σ2_m[:,j])
    
plt.tight_layout()

## Convergence Plots

In [None]:
keys = job[0].acceptance_rates.keys()

variables = {key : np.empty((0, T, *job[0].samples[key].get().shape[1:])) for key in keys}

for res in job:
    for key in keys:
        variables[key] = np.append(variables[key], np.expand_dims(res.samples[key].get(), 0), axis=0)

plt.plot(variables['L'][:,-100:].T)

mixes = {key: arviz.convert_to_inference_data(variables[key]) for key in keys}

#### Rhat
Rhat is the ratio of posterior variance and within-chain variance. If the ratio exceeds 1.1 then we consider the chains have not mixed well. As the between-chain variance tends to the within-chain then R tends to 1.

In [None]:
Rhat = arviz.rhat(mixes['fbar'])

Rhats = np.array([np.mean(arviz.rhat(mixes[key]).x.values) for key in keys])

rhat_df = pd.DataFrame([[*Rhats], [*(Rhats < 1.1)]], columns=keys)

display(rhat_df)

#### Rank plots

Rank plots are histograms of the ranked posterior draws (ranked over all
    chains) plotted separately for each chain.
    If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
    uniform, whereas if one chain has a different location or scale parameter, this will be
    reflected in the deviation from uniformity. If rank plots of all chains look similar, this
    indicates good mixing of the chains.

Rank-normalization, folding, and localization: An improved R-hat
    for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008

In [None]:
arviz.plot_rank(L_mix)

#### Effective sample sizes

Plot quantile, local or evolution of effective sample sizes (ESS).

In [None]:
arviz.plot_ess(L_mix)

#### Monte-Carlo Standard Error

In [None]:
arviz.plot_mcse(L_mix)


#### Parallel Plot
Plot parallel coordinates plot showing posterior points with and without divergences.

Described by https://arxiv.org/abs/1709.01449, suggested by Ari Hartikainen


In [None]:
arviz.plot_parallel(azl)


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples

In [None]:
model = TranscriptionCustom(data, opt)
fbar = tf.constant([[0.54880, 0.56062, 0.54973, 0.53923, 0.53923, 0.53892, 0.54077, 0.54334, 0.54586, 0.54953,
  0.55446, 0.55897, 0.56173, 0.56379, 0.56612, 0.56790, 0.56781, 0.56570, 0.56276, 0.5914,
  0.55507, 0.55068, 0.533, 0.54217, 0.53848, 0.53549, 0.53362, 0.53301, 0.53338, 0.53438,
  0.53593, 0.53800, 0.54037, 0.54267, 0.54481, 0.54645, 0.54760, 0.54818, 0.54817, 0.54774,
  0.54694, 0.54583, 0.5455, 0.54336, 0.54219, 0.54136, 0.54083, 0.54063, 0.54097, 0.54148,
  0.54237, 0.54342, 0.54465, 0.510, 0.54804, 0.557, 0.5285, 0.5545, 0.55770, 0.5932,
  0.56021, 0.56030, 0.55959, 0.534, 0.5654, 0.5417, 0.5162]], dtype='float64')
print(fbar.shape)
#  1.4018160215788704 4.4271686734681195
print(model.params.V.prior.log_prob(1.3818811998078957))
print(model.params.L.prior.log_prob(3.9525443383480283))

# fbar = model.params.fbar.value
print(model.params.fbar.prior(fbar, f64(0.6), f64(0.99)))
#  3.9525443383480283