## p53 network - REGGaE

In [None]:
## TensorFlow SETUP
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
# tf.debugging.set_log_device_placement(True)
from tensorflow import math as tfm

from timeit import default_timer as timer
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_barenco_puma, DataHolder, barenco_params
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.utilities import discretise, logit, LogisticNormal, inverse_positivity
from reggae.plot import plotters
from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.models.results import GenericResults

import numpy as np
import pandas as pd
import arviz
from ipywidgets import IntProgress
from IPython.display import HTML
plt.rcParams['animation.ffmpeg_path'] = 'C:\\Users\\Jacob\\Documents\\ffmpeg-static\\bin\\ffmpeg.exe'

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64
np.set_printoptions(threshold=np.inf)
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})


In [None]:
m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()

# m_observed, f_observed, t = load_3day_dros()

replicate = 0

m_df, m_observed = m_observed 
f_df, f_observed = f_observed
# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed
f_observed = f_observed
σ2_m_pre = f64(σ2_m_pre)
σ2_f_pre = f64(σ2_f_pre)

num_genes = m_observed.shape[1]
τ, common_indices = discretise(t, num_disc=13)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
noise_data = (σ2_m_pre, σ2_f_pre)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, noise_data, time)
N_p = τ.shape[0]


In [None]:
opt = Options(preprocessing_variance=True, 
              tf_mrna_present=False, 
              delays=False,
              weights=False,
              translation=False,
              initial_conditions=False,
              initial_step_sizes={'logistic': 0.00001, 'latents': 10},
              kernel='rbf')
name = 'p53' if opt.tf_mrna_present else 'p53-notf'
model = TranscriptionMixedSampler(data, opt)

In [None]:
# Initialise from saved model:
model = TranscriptionMixedSampler.load(name, [data, opt])
is_accepted = model.is_accepted

In [None]:
start = timer()
samples, is_accepted = model.sample(T=500, burn_in=0)
end = timer()
print(f'Time taken: {(end - start):.04f}s')


In [None]:
model.save(name)
print(name)

In [None]:
# first clear C:\Users\Jacob\AppData\Local\Temp\.tensorboard-info
%load_ext tensorboard
%tensorboard --logdir logs/reggae

In [None]:
samples, is_accepted = model.sample(T=1, burn_in=0, profile=True)

<table>
    <tr><th>Processor</th><th># Iterations</th><th>Time</th><th>Note</th></tr>
    <tr><td>CPU</td><td>20</td><td>54.7s</td><td></td></tr>
    <tr><td>CPU</td><td>20</td><td>38.4s</td><td>Merged weight and kinetics</td></tr>
    <tr><td>CPU</td><td>200</td><td>436.2s</td><td></td></tr>
    <tr><td>CPU</td><td>200</td><td>396.1s</td><td>No prob update</td></tr>
    <tr><td>CPU</td><td>800</td><td>901.9s</td><td>No prob update, merged w,k</td></tr>
    <tr><td>CPU</td><td>1000</td><td>900s</td><td>no intial cond, no protein</td></tr>
    <tr><td>GPU</td><td>2</td><td>40.8s</td><td></td></tr>
</table>

In [None]:
pcs = list()
for i, param in enumerate(model.state_indices):
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)))

σ2_f = None
results = model.results()
k_latest = np.mean(results.k[-100:], axis=0)


In [None]:
### BARENCO
barenco = barenco_params()
if opt.tf_mrna_present:
    k_true = barenco / np.mean(barenco, axis=0) * np.mean(k_latest[:,1:], axis=0)
    k_true = np.c_[np.zeros(num_genes), k_true]
else: 
    k_true = barenco / np.mean(barenco, axis=0) * np.mean(k_latest, axis=0)

In [None]:
### COMPARISON TO MH
k_true = np.array([[0.28157, 2.48264, 9.05267],
          [0.07394, 2.64863, 7.10769],
          [0.54263, 8.78634, 19.73215],
          [0.26399, 8.04371, 7.49752],
          [0.23321, 3.66627, 11.41177]])

k_true = k_true / np.mean(k_true, axis=0) * np.mean(k_latest, axis=0)


In [None]:
plot_opt = plotters.PlotOptions(
    num_plot_genes=10, num_plot_tfs=10,
    gene_names=m_df.index, tf_names=f_df.index, 
    for_report=True, protein_present=False, tf_present=False,
    kernel_names=model.kernel_selector.names(), 
    true_label='Barenco et al.', model_label='REGGaE', num_kinetic_avg=100
)
plotter = plotters.Plotter(data, plot_opt)


In [None]:
m_preds = model.sample_latents(results, 20)

In [None]:
plotter.summary(results, m_preds, replicate=0, true_k=k_true)

In [None]:
print(results.k.shape)
plotter.plot_convergence(results.k[-2000:, 0, 1], lims=(0, 3.5), fig_height=4.5, fig_width=6.5)

In [None]:
kbar = results.kbar[-1]
print(kbar[3, 0])
kbar[3, 0] = 0.57
m_pred = model.likelihood.predict_m(kbar, None, results.wbar[-1], results.fbar[-1],
                           results.w_0bar[-1], None)
plt.plot(τ, m_pred[0, 3])
plt.scatter(t, m_observed[0, 3])

In [None]:
print(hpds[:, 2].swapaxes(0,1).shape)

plotter.plot_kinetics(results.k, results.k_f, true_k=k_true, true_hpds=hpds);

In [None]:
plotter.convergence_summary(results)

In [None]:
plotter.plot_tfs(results.f, replicate=0, sample_gap=5)

In [None]:
HTML(plotter.anim_latent(results))

In [None]:
from sklearn import preprocessing 
k = results.k
k_latest = np.mean(results.k[-50:], axis=0)
print(k_latest)
k_latest = preprocessing.normalize(k_latest)
print(k_latest)
num_genes = results.k.shape[1]
true_data = None
plot_labels = ['Initial Conditions', 'Basal rates', 'Decay rates', 'Sensitivities']

hpds = list()
for j in range(num_genes):
    hpds.append(arviz.hpd(k[-50:, j,:], credible_interval=0.95))
hpds = np.array(hpds)
hpds = abs(hpds - np.expand_dims(k_latest, 2))

width = 18 if num_genes > 10 else 10
plt.figure(figsize=(width, 16))
comparison_label = 'Barenco et al.'
 
true_data = barenco / np.mean(barenco, axis=0) * np.mean(k_latest, axis=0)
plot_labels = plot_labels[1:]
# true_data = preprocessing.normalize(true_data)
plotnum = 421
for k in range(k_latest.shape[1]):
    plt.subplot(plotnum)
    plotnum+=1
    plt.bar(np.arange(num_genes)-0.2, k_latest[:, k], width=0.4, tick_label=m_df.index, label='Model')
    if true_data is not None:
        plt.bar(np.arange(num_genes)+0.2, true_data[:, k], width=0.4, color='blue', align='center', label=comparison_label)
    plt.title(plot_labels[k])
    plt.errorbar(np.arange(num_genes)-0.2, k_latest[:, k], hpds[:, k].swapaxes(0,1), fmt='none', capsize=5, color='black')
    plt.legend()
    plt.xticks(rotation=70)
plt.tight_layout(h_pad=2.0)


In [None]:
p_samples = model.sample_proteins(results, 20)
print(p_samples.shape)
plotter.plot_samples(p_samples[:,0], [''], 4, color='orangered')

# Run just the Latent sampler

In [None]:
from reggae.mcmc.kernels import MixedKernel, LatentKernel
import tensorflow_probability as tfp

all_states = [param.value for param in model.active_params]

def trace_fn(a, previous_kernel_results):
    return previous_kernel_results.is_accepted

iters = 50000
@tf.function
def run_chain():
    # Run the chain (with burn-in).
    samples, is_accepted = tfp.mcmc.sample_chain(
          num_results=iters,
          num_burnin_steps=0,
          current_state=all_states,
          kernel=mixed_kern,
          trace_fn=trace_fn)

    return samples, is_accepted

latent_kern = LatentKernel(model.data, model.options, model.likelihood, model.kernel_selector, 
                           model.state_indices, 2*tf.ones(N_p, dtype='float64'))
kernels = [model.active_params[0].kernel, latent_kern, model.active_params[2].kernel]
mixed_kern = MixedKernel(kernels, [True, False, False], iters, skip=[True, False, True])

chain_result = run_chain();



In [None]:
print(chain_result[0][1][0].shape)
f = chain_result[0][1][0]
print(f.shape)
print(data.f_obs[0][0])
plt.scatter(τ[common_indices], data.f_obs[0,0])
plt.plot(τ, inverse_positivity(f[-1,0,0]))

In [None]:

plt.figure()
num_genes = kbar.shape[1]
k_latest = np.mean(logit(kbar[-10:]), axis=0)
print(k_latest)
B = k_latest[:,1]
D = k_latest[:,2]
S = k_latest[:,3]

plt.bar(np.arange(num_genes)-0.2, B, width=0.2, tick_label=m_df.index, label='Basal rate')
plt.bar(np.arange(num_genes), D, width=0.2, tick_label=m_df.index, label='Sensitivity')
plt.bar(np.arange(num_genes)+0.2, S, width=0.2, tick_label=m_df.index, label='Decay rate')
plt.yscale('log')
plt.title('Mechanistic Parameters')
plt.legend()


## Convergence Plots

In [None]:
keys = job[0].acceptance_rates.keys()

variables = {key : np.empty((0, T, *job[0].samples[key].get().shape[1:])) for key in keys}

for res in job:
    for key in keys:
        variables[key] = np.append(variables[key], np.expand_dims(res.samples[key].get(), 0), axis=0)

plt.plot(variables['L'][:,-100:].T)

mixes = {key: arviz.convert_to_inference_data(variables[key]) for key in keys}

#### Rhat
Rhat is the ratio of posterior variance and within-chain variance. If the ratio exceeds 1.1 then we consider the chains have not mixed well. As the between-chain variance tends to the within-chain then R tends to 1.

In [None]:
Rhat = arviz.rhat(mixes['fbar'])

Rhats = np.array([np.mean(arviz.rhat(mixes[key]).x.values) for key in keys])

rhat_df = pd.DataFrame([[*Rhats], [*(Rhats < 1.1)]], columns=keys)

display(rhat_df)

#### Rank plots

Rank plots are histograms of the ranked posterior draws (ranked over all
    chains) plotted separately for each chain.
    If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
    uniform, whereas if one chain has a different location or scale parameter, this will be
    reflected in the deviation from uniformity. If rank plots of all chains look similar, this
    indicates good mixing of the chains.

Rank-normalization, folding, and localization: An improved R-hat
    for assessing convergence of MCMC. arXiv preprint https://arxiv.org/abs/1903.08008

In [None]:
arviz.plot_rank(L_mix)

#### Effective sample sizes

Plot quantile, local or evolution of effective sample sizes (ESS).

In [None]:
arviz.plot_ess(L_mix)

#### Monte-Carlo Standard Error

In [None]:
arviz.plot_mcse(L_mix)


#### Parallel Plot
Plot parallel coordinates plot showing posterior points with and without divergences.

Described by https://arxiv.org/abs/1709.01449, suggested by Ari Hartikainen


In [None]:
arviz.plot_parallel(azl)


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples