## Covid dataset

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_covid, DataHolder, scaled_barenco_data
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.utilities import discretise, logit, LogisticNormal
from reggae.plot import plotters
from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.models.results import GenericResults

import tensorflow as tf
from tensorflow import math as tfm

import numpy as np
import pandas as pd
import arviz
from ipywidgets import IntProgress

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64
np.set_printoptions(threshold=np.inf)
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})


As reported by https://jvi.asm.org/content/84/17/8369:
tfs: IRF1, IRF7, STAT1, and STAT2
gns: OAS1, RIG-I, Mx-1, ISG15, and ISG20, CXCL10 
RIG-I is ENSMPUG00000004319

GO:0035457, GO:0035458, GO:0035455, GO:0035456, GO:0034340, and HALLMARK_INTERFERON_ALPHA_RESPONSE
subset of genes involved in the recruitment and activation of the adaptive immune response including: IL6, CXCL11, IFNG, IL7, CXCL9, and CXCL10


In [None]:
m_observed, f_observed, t = load_covid()

m_df, m_observed = m_observed 
f_df, f_observed = f_observed
# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed
f_observed = f_observed

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, None, time)
N_p = τ.shape[0]

print(m_observed.shape)

In [None]:
opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              delays=False, 
              latent_function_metropolis=True,
              kernel='rbf')
T = 1000
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50


model = TranscriptionMixedSampler(data, opt)


In [None]:
samples, is_accepted = model.sample(T=300, burn_in=0)


In [None]:
pcs = list()
for i, param in enumerate(model.state_indices):
    print(i)
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)))

In [None]:
σ2_f = None
kbar = model.samples[model.state_indices['kinetics']][0]
k_fbar = model.samples[model.state_indices['kinetics']][1]
fbar = model.samples[model.state_indices['fbar']]
σ2_m = model.samples[model.state_indices['σ2_m']]
# σ2_f = model.samples[model.state_indices['σ2_f']]
kernel_params = model.samples[model.state_indices['kernel_params']]
w = model.samples[model.state_indices['weights']][0]
w_0 = model.samples[model.state_indices['weights']][1]

# w = [1*tf.ones((num_genes, 1), dtype='float64')] # TODO
# w_0 = [tf.zeros(num_genes, dtype='float64')] # TODO
print(kbar[-1])

m_preds = list()
for i in range(1, 20):
    m_preds.append(model.likelihood.predict_m(kbar[-i], k_fbar[-i], w[-1], fbar[-i], w_0[-1])) #todo w[-1]
m_preds = np.array(m_preds)

f_samples = np.log(1+np.exp(fbar))
k_f_samples = logit(k_fbar).numpy()
k_samples = logit(kbar).numpy()
rbf_params_samples = [logit(kernel_params[0]), logit(kernel_params[1])] 
weights_samples = [logit(w), logit(w_0)]

plotters.generate_report(data, k_samples, k_f_samples, f_samples, 
                         σ2_m, σ2_f, rbf_params_samples, m_preds, weights_samples,
                         plot_barenco=True, gene_names=m_df.index, num_hpd=20, replicate=2)


In [None]:
plot_barenco = True
plt.plot(fbar[:, 0, 0])
fig = plt.figure(figsize=(13, 7))
f_samples = np.log(1+np.exp(fbar[-100:, 0,:]))
print(f_samples.shape)
if 'σ2_f' in model.params._fields:
    σ2_f = model.params.σ2_f.value
    plt.errorbar(τ[common_indices], f_observed[0], 2*np.sqrt(σ2_f[0]), 
                 fmt='none', capsize=5, color='blue')
else:
    σ2_f = σ2_f_pre
    
bounds = arviz.hpd(f_samples, credible_interval=0.95)
for i in range(1,20):
    f_i = f_samples[-i]
#     plt.plot(f_i)
#     f_i[0] = 0
    kwargs = {}
    if i == 1:
        kwargs = {'label':'Samples'}
    plt.plot(τ, f_i, c='cadetblue', alpha=0.5, **kwargs)

# if plot_barenco:
#     barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
#     plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco et al.')

plt.scatter(τ[common_indices], f_observed[0], marker='x', s=70, linewidth=3, label='Observed')

plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.xticks(t)
fig.axes[0].set_xticklabels(t)
plt.ylim((-1,5))
plt.xlabel('Time (h)')
plt.legend();

In [None]:

plt.figure()
num_genes = kbar.shape[1]
k_latest = np.mean(logit(kbar[-10:]), axis=0)
print(k_latest)
B = k_latest[:,1]
D = k_latest[:,2]
S = k_latest[:,3]

plt.bar(np.arange(num_genes)-0.2, B, width=0.2, tick_label=m_df.index, label='Basal rate')
plt.bar(np.arange(num_genes), D, width=0.2, tick_label=m_df.index, label='Sensitivity')
plt.bar(np.arange(num_genes)+0.2, S, width=0.2, tick_label=m_df.index, label='Decay rate')
plt.yscale('log')
plt.title('Mechanistic Parameters')
plt.legend()


In [None]:
plt.figure(figsize=(12, 10))
plt.title('Noise variances')
for i, j in enumerate(range(num_genes)):
    ax = plt.subplot(num_genes, num_genes-2, i+1)
    plt.title(m_df.index[j])
    plt.plot(σ2_m[:,j])
    
plt.tight_layout()

## Convergence Plots

In [None]:
keys = job[0].acceptance_rates.keys()

variables = {key : np.empty((0, T, *job[0].samples[key].get().shape[1:])) for key in keys}

for res in job:
    for key in keys:
        variables[key] = np.append(variables[key], np.expand_dims(res.samples[key].get(), 0), axis=0)

plt.plot(variables['L'][:,-100:].T)

mixes = {key: arviz.convert_to_inference_data(variables[key]) for key in keys}

#### Rhat
Rhat is the ratio of posterior variance and within-chain variance. If the ratio exceeds 1.1 then we consider the chains have not mixed well. As the between-chain variance tends to the within-chain then R tends to 1.

In [None]:
Rhat = arviz.rhat(mixes['fbar'])

Rhats = np.array([np.mean(arviz.rhat(mixes[key]).x.values) for key in keys])

rhat_df = pd.DataFrame([[*Rhats], [*(Rhats < 1.1)]], columns=keys)

display(rhat_df)

#### Rank plots

Rank plots are histograms of the ranked posterior draws (ranked over all
    chains) plotted separately for each chain.
    If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
    uniform, whereas if one chain has a different location or scale parameter, this will be
    reflected in the deviation from uniformity. If rank plots of all chains look similar, this
    indicates good mixing of the chains.

Rank-normalization, folding, and localization: An improved R-hat
    for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008

In [None]:
arviz.plot_rank(L_mix)

#### Effective sample sizes

Plot quantile, local or evolution of effective sample sizes (ESS).

In [None]:
arviz.plot_ess(L_mix)

#### Monte-Carlo Standard Error

In [None]:
arviz.plot_mcse(L_mix)


#### Parallel Plot
Plot parallel coordinates plot showing posterior points with and without divergences.

Described by https://arxiv.org/abs/1709.01449, suggested by Ari Hartikainen


In [None]:
arviz.plot_parallel(azl)


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples

In [None]:
model = TranscriptionCustom(data, opt)
fbar = tf.constant([[0.54880, 0.56062, 0.54973, 0.53923, 0.53923, 0.53892, 0.54077, 0.54334, 0.54586, 0.54953,
  0.55446, 0.55897, 0.56173, 0.56379, 0.56612, 0.56790, 0.56781, 0.56570, 0.56276, 0.5914,
  0.55507, 0.55068, 0.533, 0.54217, 0.53848, 0.53549, 0.53362, 0.53301, 0.53338, 0.53438,
  0.53593, 0.53800, 0.54037, 0.54267, 0.54481, 0.54645, 0.54760, 0.54818, 0.54817, 0.54774,
  0.54694, 0.54583, 0.5455, 0.54336, 0.54219, 0.54136, 0.54083, 0.54063, 0.54097, 0.54148,
  0.54237, 0.54342, 0.54465, 0.510, 0.54804, 0.557, 0.5285, 0.5545, 0.55770, 0.5932,
  0.56021, 0.56030, 0.55959, 0.534, 0.5654, 0.5417, 0.5162]], dtype='float64')
print(fbar.shape)
#  1.4018160215788704 4.4271686734681195
print(model.params.V.prior.log_prob(1.3818811998078957))
print(model.params.L.prior.log_prob(3.9525443383480283))

# fbar = model.params.fbar.value
print(model.params.fbar.prior(fbar, f64(0.6), f64(0.99)))
#  3.9525443383480283