In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

from tensorflow import math as tfm
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp

from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.data_loaders import DataHolder
from reggae.data_loaders.artificial import artificial_dataset
from reggae.utilities import logit, logistic, inverse_positivity
from reggae.plot import plotters
from reggae.models.results import SampleResults, GenericResults

from matplotlib import pyplot as plt
from IPython.display import HTML
plt.rcParams['animation.ffmpeg_path'] = 'C:\\Users\\Jacob\\Documents\\ffmpeg-static\\bin\\ffmpeg.exe'
import arviz

import pandas as pd
import numpy as np

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
num_genes = 13
num_tfs = 3
tf.random.set_seed(1)
w = tf.random.normal([num_genes, num_tfs], mean=0.5, stddev=0.71, seed=42, dtype='float64')

Δ_delay = tf.constant([0, 4, 10], dtype='float64')

w_0 = tf.zeros(num_genes, dtype='float64')

true_kbar = logistic((np.array([
    [1.319434062, 1.3962113525, 0.8245041865, 2.2684353378],
    [1.3080045137, 3.3992868747, 2.0189033658, 3.7460822389],
    [2.0189525448, 1.8480506624, 0.6805040228, 3.1039094120],
    [1.7758426875, 0.1907625023, 0.1925539427, 1.8306885751],
    [1.7207442227, 0.1252089546, 0.6297333943, 3.2567248923],
    [1.4878806850, 3.8623843570, 2.4816128746, 4.3931294404],
    [2.0853079514, 2.5115446790, 0.6560607356, 3.0945313562],
    [1.6144843688, 1.8651409657, 0.7785363895, 2.6845058360],
    [1.4858223122, 0.5396687493, 0.5842698019, 3.0026805243],
    [1.6610647522, 2.0486340884, 0.9863876546, 1.4300094581],
    [1.6027276189, 1.4320302060, 0.7175033248, 3.2151637970],
    [2.4912882714, 2.7935526605, 1.2438786874, 4.3944794204],
    [2.894114279, 1.4726280947, 0.7356719860, 2.2316019158],
 [1.7927833839, 1.0405867396, 0.4055775218, 2.9888350247],
 [1.0429721112, 0.1011544950, 0.7330443670, 3.1936843755],
 [1.2519286771, 2.0617880701, 1.0759649567, 3.9406060364],
 [1.4297185709, 1.3578824015, 0.6037986912, 2.6512418604],
 [1.9344878813, 1.4235867760, 0.8226320338, 4.2847217252],
 [1.4325562449, 1.1940752177, 1.0556928599, 4.1850449557],
 [0.8911103971, 1.3560009300, 0.5643954823, 3.4300182328],
 [1.0269654997, 1.0788097511, 0.5268448648, 4.4793299593],
 [0.8378220502, 1.8148234459, 1.0167440138, 4.4903387696]]
)))
true_kbar = true_kbar[:num_genes]
opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              kinetic_exponential=True,
              weights=True,
              initial_step_sizes={'logistic': 1e-8, 'latents': 10},
              delays=True)


data, fbar, kinetics = artificial_dataset(opt, TranscriptionLikelihood, num_genes=num_genes, 
                                          weights=(w, w_0), delays=Δ_delay.numpy(), t_end=10
                                          )
true_kbar, true_k_fbar = kinetics
f_i = inverse_positivity(fbar)
t, τ, common_indices = data.t, data.τ, data.common_indices
common_indices = common_indices.numpy()
N_p = τ.shape[0]
N_m = t.shape[0]

def expand(x):
    return np.expand_dims(x, 0)
true_results = SampleResults(opt, expand(fbar), expand(true_kbar), expand(true_k_fbar), Δ_delay, 
                             None, expand((w)), expand((w_0)), None, None)
model = TranscriptionMixedSampler(data, opt)


In [None]:
# Transcription factor
tf_labels = ['A', 'B', 'C']
plt.title('TFs')
for i in range(num_tfs):
    plt.plot(τ, f_i[0, i], label=f'TF {i}')
    plt.scatter(t, data.f_obs[0, i], marker='x')
plt.xticks(np.arange(0, 12))
plt.legend()
print(τ.shape)

In [None]:
def plot_genes(num, tup1, tup2):
    for j in range(num):
        ax = plt.subplot(num,2, 1+j)
        plt.title(f'Gene {j}')
        plt.scatter(np.arange(N_p)[common_indices], tup1[1][j], marker='x', label=tup1[2])
        if tup1[0] is not None:
            plt.plot(tup1[0][j], color='grey')
        plt.scatter(np.arange(N_p)[common_indices], tup2[1][j], marker='x', label=tup2[2])
        plt.plot(tup2[0][j], color='grey')
        plt.xticks(np.arange(0, 123, 11))
        ax.set_xticklabels(np.arange(12));
        plt.xlabel('Time / hr')
        plt.legend()

    plt.tight_layout()
    
plt.figure(figsize=(10, 28))

lik = model.likelihood
Δ_nodelay = tf.constant([0, 0, 0], dtype='float64')
m_pred = lik.predict_m(true_kbar, true_k_fbar, (w), fbar, (w_0), Δ_delay)
m_pred_nodelay = lik.predict_m(true_kbar, true_k_fbar, (w), fbar, (w_0), Δ_nodelay)
m_observed_nodelay = tf.stack([m_pred_nodelay.numpy()[:,i,common_indices] for i in range(num_genes)], axis=1)


plot_genes(num_genes, (m_pred_nodelay[0][:], m_observed_nodelay[0][:], 'no delay'), 
              (m_pred[0][:], data.m_obs[0][:], 'delay'))

fig = plt.figure(figsize=(6, 4))
# plt.title('TF Proteins')
p_nodelay = lik.calculate_protein(fbar, true_k_fbar, Δ_nodelay)
p = lik.calculate_protein(fbar, true_k_fbar, Δ_delay)

colors = ['black', 'darkgreen', 'orangered']
delay_colors = ['grey', 'green', 'lightcoral']
for i in range(num_tfs):
    plt.plot(np.arange(N_p), p[0,i], color=delay_colors[i])
    plt.plot(np.arange(N_p), p_nodelay[0,i], label=f'Protein {i}', color=colors[i], alpha=0.8)
    plt.xticks(np.arange(0, 123, 11))
    fig.axes[0].set_xticklabels(np.arange(12));
    plt.xlabel('Time / hr')
plt.legend();
plt.tight_layout()

In [None]:
model = TranscriptionMixedSampler.load('w-fix', [data, opt]) #delay-w
is_accepted = model.is_accepted
samples = model.samples

In [None]:
burnin = 0
if samples[0][0].shape[0] > 5000:
    burnin = 1500
results = model.results(burnin=burnin)


pcs = list()
for i, param in enumerate(model.state_indices):
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)))

def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

print(tf.round(results.Δ[-1]))
plt.title('Moving Average of Deltas')

for i in range(num_tfs):
#     plt.plot(moving_average(results.Δ[:, i], 5), label=i)
    plt.plot(results.Δ[-3000:, i], label=i)
plt.legend()


In [None]:
m_preds = model.sample_latents(results, 50)

In [None]:
true_k = np.exp(logit(true_kbar[:,:]).numpy())
true_k_f = np.expand_dims(logit(true_k_fbar).numpy(), 0)

plot_opt = plotters.PlotOptions(
    num_plot_genes=10, num_plot_tfs=10,
    for_report=False, kernel_names=model.kernel_selector.names()
)
plotter = plotters.Plotter(data, plot_opt)

In [None]:
f_pred = results.f
sq_diff = tfm.square(data.f_obs - tf.transpose(tf.gather(tf.transpose(f_pred),data.common_indices)))
print(sq_diff.shape)
plt.subplot(2, 2, 1)
plt.plot(tf.reduce_sum(sq_diff, axis=[1, 2, 3]))
plt.title('latent error')
k = results.k
error = tf.square(k-tf.expand_dims(true_k, 0))
plt.subplot(2, 2, 2)
plt.plot(tf.reduce_sum(error, axis=[1, 2]))
plt.title('kinetic error')

In [None]:
plt.plot(results.weights[0][:, 0])
plt.figure()
plt.plot(results.weights[0][-20:,1])

# m_samples = model.sample_latents(results, 1000, step=10)
# HTML(plotter.anim_latent(m_samples, index=0, interval=1))

In [None]:
f_samples = inverse_positivity(samples[1][0][:, 0]).numpy()#results.f[:, 0]
print(f_samples.shape)
plt.figure(figsize=(9, 5))
colors = ['slategrey', 'orchid', 'tab:blue']
for i in range(3):
    bounds = arviz.hpd(f_samples[-500:,i,:], credible_interval=0.95)
    label = '95% credibility interval' if i == 0 else None
    plt.fill_between(data.τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label=label)
    plt.scatter(data.t, data.f_obs[0, i], marker='x', color=colors[i])
    plt.plot(data.τ, np.mean(f_samples[-500:, i], axis=0), color=colors[i], label=f'TF {i}')
plt.legend()
plt.ylabel('Abundance (AU)')
plt.xlabel('Time (h)');

In [None]:
plotter.summary(results, m_preds, true_k=true_k, true_k_f=true_k_f)

In [None]:
plotter.convergence_summary(results)


In [None]:
print(results.weights[0].shape)
print(plotter.opt.gene_names)
plt.figure(figsize=(12, 5))
plotter.plot_grn(results, use_sensitivities=False, log=True)
plt.figure(figsize=(12, 5))
print(true_results.k.shape)
plotter.plot_grn(true_results, use_sensitivities=False, log=True)

In [None]:
f_samples = results.f
HTML(plotter.anim_latent(f_samples, index=2))

In [None]:
params = model.params
all_states = [param.value for param in model.active_params]

def compute_prob(delta):
    prob = tf.reduce_sum(model.likelihood.genes(
            all_states=all_states, 
            state_indices=model.state_indices,
            σ2_m=params.σ2_m.value,
            Δ=delta,
    )) +  tf.reduce_sum(tfd.Exponential(f64(0.3)).log_prob(delta))
#     prior_prob = model.params.Δ.prior.log_prob(logit(delta))
#     print(logit(delta), prior_prob)
    print(prob)# + tf.reduce_sum(prior_prob))

print(params.Δ.value)
compute_prob(params.Δ.value)
compute_prob(tf.constant([0, 4, 8], dtype='float64'))
compute_prob(tf.constant([0, 0, 8], dtype='float64'))
compute_prob(tf.constant([0, 10, 8], dtype='float64'))

print(model.likelihood.genes(
            all_states=all_states, 
            state_indices=model.state_indices,
            σ2_m=logistic(10*params.σ2_m.value),
            Δ=tf.constant([0, 10, 8], dtype='float64'),
        ))

In [None]:
fig = plt.figure(figsize=(6, 4))
Δ_other = tf.constant([10, 4, 10], dtype='float64')
m_pred = m_preds[-1]
m_pred_ = model.likelihood.predict_m(model.params.kinetics.value[0], 
                                    model.params.kinetics.value[1], 
                                    model.params.weights.value[0], 
                                    model.params.latents.value[0],
                                    model.params.weights.value[1],
                                    Δ_other)
m_pred = model.likelihood.predict_m(model.params.kinetics.value[0], 
                                    model.params.kinetics.value[1], 
                                    model.params.weights.value[0], 
                                    model.params.latents.value[0],
                                    model.params.weights.value[1],
                                    model.params.Δ.value[-1])

plt.scatter(t, m_observed_nodelay[0, 3], marker='x', label='no delay')
plt.scatter(t, data.m_obs[0, 3], marker='x', s=70, linewidth=3, label='observations')
plt.plot(τ, np.mean(m_preds[-5:, 0, 3], axis=0), color='darkslateblue', label='prediction 1')
plt.plot(τ, m_pred_[0,3], color='orangered', label='prediction 2')

print(m_preds.shape)
bounds = arviz.hpd(m_preds[:, 0, 3, :], credible_interval=0.95)
# plot_genes((m_pred_, m_observed_nodelay, 'no delay'), (m_pred, m_observed, ))
plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.legend(loc=2)
# plt.xticks(np.arange(0, 101, 11))
# fig.axes[0].set_xticklabels(np.arange(10));
plt.xlabel('Time / hr')

In [None]:
print(model.params.Δ.value)
delta_model = model.params.Δ.value
delta_true = Δ_delay
deltas = [delta_model, delta_true]
plt.figure(figsize=(7, 4))
labels=['Model', 'True']
colors = ['chocolate', 'slategrey']
j = 11

for i in range(2):
    pred = model.likelihood.predict_m(model.params.kinetics.value[0], 
                               model.params.kinetics.value[1], 
                               model.params.weights.value[0], 
                               model.params.latents.value[0],
                               model.params.weights.value[1],
                               deltas[i])
    plt.plot(data.τ, pred[0, j], label=labels[i], color=colors[i])
    plt.scatter(data.t, data.m_obs[0, j], marker='x')

bounds = arviz.hpd(m_preds[:,0, j,:], credible_interval=0.99)
plt.fill_between(data.τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.xlabel('Time (h)')
plt.xticks(np.arange(0, 10))
plt.legend()

In [None]:
p_samples = model.sample_proteins(results, 20)
plt.figure(figsize=(10, 5))
plotter.plot_samples(p_samples[:,0], ['', '', ''], 4, color='orangered')

In [None]:
fig = plt.figure(figsize=(6, 4.2))
horizontal_subplots = 21 if num_tfs > 1 else 11

kwargs = {'label':'Samples'}
plt.plot(τ, np.mean(results.f[-400:, 0, i], axis=0), c='grey', alpha=1, **kwargs)


plt.scatter(τ[common_indices], data.f_obs[0, i], marker='x', s=60, linewidth=2, color='tab:blue', label='Observed')

# HPD:
bounds = arviz.hpd(results.f[-2000:,0, i], credible_interval=0.95)
#     plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')


plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.xticks(t)
fig.axes[0].set_xticklabels(t)
plt.xlabel('Time / hr')
plt.legend()
plt.tight_layout()
