In [None]:
from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder, scaled_barenco_data
from reggae.utilities import get_rbf_dist, discretise, logit, logistic, LogisticNormal
from reggae.mcmc import create_chains, Parameter
from reggae.plot import plotters
from reggae.models.results import GenericResults
from scipy.interpolate import interp1d
from sklearn import preprocessing
from reggae.models.kernels import DeltaKernel


import tensorflow as tf
from tensorflow import math as tfm
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
num_genes = 5
t = np.arange(10)
τ, common_indices = discretise(t)
time = (t, τ, tf.constant(common_indices))
opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              delays=True, 
              latent_function_metropolis=True)

N_p = τ.shape[0]
N_m = t.shape[0]

num_tfs = 3



In [None]:
# Transcription factor

A = np.array([0.01, 0.3, 0.47, 0.51, 0.4, 0.37, 0.47, 0.32, 0.16, 0.1, 0.05, 0.025, 0.025, 0.01, 0.005, 0.005])
B = np.array([0.01, 0.1, 0.22, 0.44, 0.53, 0.41, 0.23, 0.13, 0.09, 0.035, 0.022, 0.02, 0.015, 0.01, 0.005, 0.005])
C = np.array([0.01, 0.02, 0.03, 0.05, 0.08, 0.16, 0.4, 0.36, 0.23, 0.12, 0.05, 0.025, 0.025, 0.01, 0.005, 0.005])
interp = interp1d(np.arange(A.shape[0]), A, kind='cubic')
A = interp(np.linspace(0,14, τ.shape[0]))
interp = interp1d(np.arange(B.shape[0]), B, kind='cubic')
B = interp(np.linspace(0,14, τ.shape[0]))
interp = interp1d(np.arange(C.shape[0]), C, kind='quadratic')
C = interp(np.linspace(0,14, τ.shape[0]))

k_fbar = logistic(f64(np.array([[0.1, 0.1, 0.1], [2, 2, 2]]).T))
print(k_fbar)
fbar = np.array([A, B, C])
fbar = 8*preprocessing.normalize(fbar)
print(fbar.shape)
tf_labels = ['A', 'B', 'C']
plt.title('TFs')

#Take observations
f_observed = tf.stack([fbar[i][common_indices] for i in range(num_tfs)])

for i in range(num_tfs):
    plt.plot(np.arange(τ.shape[0]), fbar[i], label=f'TF {i}')
    plt.scatter(np.arange(N_p)[common_indices], f_observed[i], marker='x')

plt.legend()

fbar = tfm.log((tfm.exp(fbar)-1))

f_i = tfm.log(1+tfm.exp(fbar))


In [None]:
w = 1*tf.ones((num_genes, num_tfs), dtype='float64') # TODO
w_0 = tf.zeros(num_genes, dtype='float64') # TODO
true_kbar = logistic(np.array([[0.2061, 0.2475, 0.8222, 4.0416],
                               [0.4091, 0.7305, 0.9486, 2.2348],
                               [0.1304, 0.3921, 1.3116, 7.1835],
                               [0.3789, 0.2861, 1.2456, 0.9928],
                               [0.2906, 0.6604, 0.8742, 4.1688]]))
true_kbar = (np.array([[0.50563, 0.66, 0.893, 0.9273],
                       [0.6402, 0.6335, 0.7390, 0.7714],
                       [0.5328, 0.5603, 0.6498, 0.9244],
                       [0.5939, 0.5821, 0.77716, 0.8387],
                       [0.58, 0.67, 0.57, 0.95]]))

print(logit(true_kbar))

temp_data = DataHolder((np.ones((num_genes, N_m)), np.ones((num_tfs, N_m))), None, time)
temp_lik = TranscriptionLikelihood(temp_data, opt)
Δ_nodelay = tf.constant([0, 0, 0], dtype='float64')

m_pred_nodelay = temp_lik.predict_m(true_kbar, k_fbar, logistic(w), fbar, logistic(w_0), Δ_nodelay)
Δ = tf.constant([1, 4, 10], dtype='float64')
m_pred = temp_lik.predict_m(true_kbar, k_fbar, logistic(w), fbar, logistic(w_0), Δ)

m_observed = tf.stack([m_pred.numpy()[i][common_indices] for i in range(num_genes)])
m_observed_nodelay = tf.stack([m_pred_nodelay.numpy()[i][common_indices] for i in range(num_genes)])

def plot_genes(tup1, tup2):
    #Take observations

    for j in range(num_genes):
        ax = plt.subplot(600+21+j)
        plt.title(f'Gene {j}')
        plt.scatter(np.arange(N_p)[common_indices], tup1[1][j], marker='x', label=tup1[2])
        if tup1[0] is not None:
            plt.plot(tup1[0][j], color='grey')
        plt.scatter(np.arange(N_p)[common_indices], tup2[1][j], marker='x', label=tup2[2])
        plt.plot(tup2[0][j], color='grey')
        plt.xticks(np.arange(0, 101, 11))
        ax.set_xticklabels(np.arange(10));
        plt.xlabel('Time / hr')
        plt.legend()

    plt.tight_layout()
    
plt.figure(figsize=(8, 14))

plot_genes((m_pred_nodelay, m_observed_nodelay, 'no delay'), (m_pred, m_observed, 'delay'))

data = (m_observed, f_observed)

data = DataHolder(data, None, time)

lik = TranscriptionLikelihood(data, opt)
# plt.figure(figsize=(5.5, 4))

ax = plt.subplot(626)
plt.title('TF Proteins')
p_nodelay = lik.calculate_protein(fbar, k_fbar, Δ_nodelay)
p = lik.calculate_protein(fbar, k_fbar, Δ)

colors = ['black', 'darkslateblue', 'orangered']
delay_colors = ['grey', 'cadetblue', 'lightcoral']
for i in range(num_tfs):
    plt.plot(p[i], color=delay_colors[i])
    plt.plot(p_nodelay[i], label=f'Protein {i}', color=colors[i], alpha=0.8)
    plt.xticks(np.arange(0, 101, 11))
    ax.set_xticklabels(np.arange(10));
    plt.xlabel('Time / hr')
plt.legend();

In [None]:
f64(max(np.var(data.f_obs, axis=1)))


In [None]:
model = TranscriptionMixedSampler(data, opt)


In [None]:
model.params.Δ.value = tf.constant([0, 4, 10], dtype='float64')
samples, is_accepted = model.sample(T=600, burn_in=0)


In [None]:
kinetics = model.samples[model.state_indices['kinetics']]
kbar = kinetics[0]
k_fbar = kinetics[1]
fbar = model.samples[model.state_indices['fbar']]
σ2_m = model.samples[model.state_indices['σ2_m']]
rbf_params = model.samples[model.state_indices['rbf_params']]
Δ = model.samples[model.state_indices['Δ']]
print(tf.round(Δ[-1]))
w = model.samples[model.state_indices['weights']][0]
w_0 = model.samples[model.state_indices['weights']][1]
σ2_f = model.samples[model.state_indices['σ2_f']]

# w = [1*tf.ones((num_genes, num_tfs), dtype='float64')] # TODO
# w_0 = [tf.zeros(num_genes, dtype='float64')] # TODO

pcs = list()
for i, param in enumerate(model.state_indices):
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)))


def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n


plt.title('Moving Average of Deltas')
for i in range(num_tfs):
    plt.plot(moving_average(Δ[:, i], 5), label=i)
#     plt.plot(Δ[:, i], label=i)
plt.legend()
plt.figure()
plt.plot(fbar[:, 0, 0])



In [None]:
m_preds = list()
for i in range(1, 100):
    m_preds.append(model.likelihood.predict_m(kbar[-i], k_fbar[-i], w[-1], fbar[-i], w_0[-1], Δ[-i])) #todo w[-1]
m_preds = np.array(m_preds)

f_samples = np.log(1+np.exp(fbar))
k_f_samples = logit(k_fbar).numpy()
k_samples = logit(kbar).numpy()
rbf_params_samples = [logit(rbf_params[0]), logit(rbf_params[1])] 
σ2_m_samples = logit(σ2_m)
σ2_f_samples = σ2_f.numpy()
# plt.style.use('seaborn')
true_k_f = f64(np.array([[0.1, 0.1, 0.1], [2, 2, 2]]).T)
true_k = logit(true_kbar[:,:]).numpy()
plotters.generate_report(data, k_samples, k_f_samples, f_samples[-300::10], σ2_m_samples,σ2_f_samples, rbf_params_samples, m_preds, 
                         plot_barenco=False, num_hpd=50, true_k=true_k, true_k_f=true_k_f)


In [None]:
fig = plt.figure(figsize=(6, 4))
Δ_other = tf.constant([10, 4, 10], dtype='float64')
m_pred = m_preds[-1]
m_pred_ = model.likelihood.predict_m(kbar[-1], k_fbar[-1], w[-1], fbar[-1], w_0[-1], Δ_other)
m_pred = model.likelihood.predict_m(kbar[-1], k_fbar[-1], w[-1], fbar[-1], w_0[-1], Δ[-1])

import arviz
plt.scatter(np.arange(N_p)[common_indices], m_observed_nodelay[3], marker='x', label='no delay')
plt.scatter(np.arange(N_p)[common_indices], m_observed[3], marker='x', s=70, linewidth=3, label='observations')
plt.plot(np.mean(m_preds[-5:, 3], axis=0), color='darkslateblue', label='prediction 1')
plt.plot(m_pred_[3], color='orangered', label='prediction 2')

print(m_preds.shape)
bounds = arviz.hpd(m_preds[:, 3, :], credible_interval=0.95)
# plot_genes((m_pred_, m_observed_nodelay, 'no delay'), (m_pred, m_observed, ))
plt.fill_between(np.arange(N_p), bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.legend(loc=2)
plt.xticks(np.arange(0, 101, 11))
fig.axes[0].set_xticklabels(np.arange(10));
plt.xlabel('Time / hr')

In [None]:
params = model.params
def compute_prob(delta):
    prob = tf.reduce_sum(model.likelihood.genes(
            k_fbar=params.k_fbar.value,
            kbar=params.kbar.value, 
            fbar=params.fbar.value, 
            w=w[-1],
            w_0=w_0[-1],
            σ2_m=params.σ2_m.value,
            Δ=delta,
    )) +  tf.reduce_sum(tfd.Exponential(f64(0.3)).log_prob(delta))
#     prior_prob = model.params.Δ.prior.log_prob(logit(delta))
#     print(logit(delta), prior_prob)
    print(prob)# + tf.reduce_sum(prior_prob))

print(params.Δ.value)
compute_prob(params.Δ.value)
compute_prob(tf.constant([0, 4, 8], dtype='float64'))
compute_prob(tf.constant([0, 0, 8], dtype='float64'))
compute_prob(tf.constant([0, 10, 8], dtype='float64'))

print(model.likelihood.genes(
            k_fbar=params.k_fbar.value,
            kbar=params.kbar.value, 
            fbar=params.fbar.value, 
            w=w[-1],
            w_0=w_0[-1],
            σ2_m=logistic(10*params.σ2_m.value),
            Δ=tf.constant([0, 10, 8], dtype='float64'),
        ))
print(params.k_fbar.value.shape)
k_latest = np.mean(k_samples[-200:], axis=0)
k_f_latest = np.mean(k_f_samples[-200:], axis=0)
print(k_latest.shape, k_f_latest.shape)
d = DeltaKernel(model.likelihood, 0, 10, model.state_indices, None)

print(params.kinetics.value)
current_state = [
    [logistic(k_latest), logistic(k_f_latest)],
#     params.kinetics.value,
    params.fbar.value, 
    [*params.V.value],
    params.σ2_m.value,
    params.Δ.value,
    params.weights.value,
]
ds = list()

for i in range(100):
    ds.append(d.one_step(params.Δ.value, GenericResults([500], True), current_state)[0].numpy())
ds = np.array(ds)

plt.figure(figsize=(13, 6))
for i in range(num_tfs):
    plt.subplot(331+i)
    plt.hist(tf.cast(Δ[-300:, i], 'int32'))
    plt.title(f'TF {i}')

    plt.figure(figsize=(13, 6))
for i in range(num_tfs):
    plt.subplot(331+i)
    plt.hist(tf.cast(ds[:, i], 'int32'))
    plt.axvline(np.mean(ds[:, i]))
    plt.title(f'TF {i}')


In [None]:
num_tfs = 3
new_state = tf.constant([7, 4, 8], dtype='float64')
Δrange = np.arange(0, 10+1, dtype='float64')
Δrange_tf = tf.range(0, 10+1, dtype='float64')
for i in range(3):
    # Generate normalised cumulative distribution
    probs = list()
    mask = np.zeros((3, ), dtype='float64')
    mask[i] = 1

    for j, Δ in enumerate(Δrange):
        test_state = (1-mask) * new_state + mask * Δ
        print(test_state)
        probs.append(tf.reduce_sum(model.likelihood.genes(
            k_fbar=params.k_fbar.value,
            kbar=params.kbar.value, 
            fbar=params.fbar.value, 
            w=w[-1],
            w_0=w_0[-1],
            σ2_m=params.σ2_m.value,
            Δ=test_state,
        )))# + tf.reduce_sum(tfd.Exponential(f64(0.3)).log_prob(Δ)))
    
    print(tf.stack(probs))
    probs =  tf.stack(probs) - max(probs)
    probs = tfm.exp(probs)

    probs = probs / sum(probs)
    cumsum = tfm.cumsum(probs)
    print('cumsum', cumsum)
#             tf.print('noramlised', cumsum)
    u = np.random.uniform()
    index = tf.where(cumsum == tf.reduce_min(cumsum[(cumsum - u) > 0]))
    chosen = Δrange_tf[index[0][0]]
    new_state = (1-mask) * new_state + mask * chosen
    print(chosen)

In [None]:
fig = plt.figure(figsize=(6, 4.2))
horizontal_subplots = 21 if num_tfs > 1 else 11

kwargs = {'label':'Samples'}
plt.plot(τ, np.mean(f_samples[-400:, i], axis=0), c='grey', alpha=1, **kwargs)


plt.scatter(τ[common_indices], data.f_obs[i], marker='x', s=60, linewidth=2, color='tab:blue', label='Observed')

# HPD:
bounds = arviz.hpd(f_samples[-400:,i,:], credible_interval=0.95)
#     plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')


plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
plt.xticks(t)
fig.axes[0].set_xticklabels(t)
plt.xlabel('Time / hr')
plt.legend()
plt.tight_layout()
