In [None]:
from reggae.models import TranscriptionLikelihood, Options
from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder, scaled_barenco_data
from reggae.utilities import get_rbf_dist, discretise, logit, logistic, LogisticNormal
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.plot import plotters
from reggae.models.kernels import MixedKernel, FKernel, KbarKernel

from scipy.interpolate import interp1d
from sklearn import preprocessing

import tensorflow as tf
from tensorflow import math as tfm
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
num_genes = 5
t = np.arange(10)
τ, common_indices = discretise(t)
time = (t, τ, tf.constant(common_indices))
opt = Options(preprocessing_variance=False, tf_mrna_present=True)

N_p = τ.shape[0]
N_m = t.shape[0]

num_tfs = 3

In [None]:
# Transcription factor

A = np.array([0.01, 0.3, 0.47, 0.51, 0.4, 0.37, 0.47, 0.32, 0.16, 0.025])
B = np.array([0.01, 0.1, 0.22, 0.44, 0.53, 0.41, 0.23, 0.13, 0.05, 0.013])
C = np.array([0.01, 0.02, 0.03, 0.05, 0.08, 0.16, 0.4, 0.36, 0.23, 0.02])
interp = interp1d(np.arange(A.shape[0]), A, kind='cubic')
A = interp(np.linspace(0,9, τ.shape[0]))
interp = interp1d(np.arange(B.shape[0]), B, kind='cubic')
B = interp(np.linspace(0,9, τ.shape[0]))
interp = interp1d(np.arange(C.shape[0]), C, kind='quadratic')
C = interp(np.linspace(0,9, τ.shape[0]))

δbar = logistic(f64(np.array([1.5, 1.5, 1.5])))

fbar = np.array([A, B, C])
fbar = 5*preprocessing.normalize(fbar)
print(fbar.shape)
tf_labels = ['A', 'B', 'C']
plt.title('TFs')

#Take observations
f_observed = tf.stack([fbar[i][common_indices] for i in range(num_tfs)])

for i in range(num_tfs):
    plt.plot(np.arange(τ.shape[0]), fbar[i], label=f'TF {i}')
    plt.scatter(np.arange(N_p)[common_indices], f_observed[i], marker='x')

plt.legend()

fbar = tfm.log((tfm.exp(fbar)-1))

f_i = tfm.log(1+tfm.exp(fbar))


In [None]:
w = 1*tf.ones((num_genes, num_tfs), dtype='float64') # TODO
w_0 = tf.zeros(num_genes, dtype='float64') # TODO
true_kbar = logistic(np.array([[0.2061, 0.2475, 0.8222, 4.0416],
                       [0.4091, 0.7305, 0.9486, 2.2348],
                       [0.1304, 0.3921, 2.3116, 7.1835],
                       [0.3789, 0.2861, 1.2456, 0.9928],
                       [0.2906, 0.6604, 0.8742, 4.1688]]))
true_kbar = logistic(np.array([[0.2148, 0.2192, 0.5675, 2.7892],
                         [0.3172, 0.8753, 0.9294, 1.0161],
                         [0.1133, 0.2964, 0.7675, 2.6464],
                         [0.4087, 0.4333, 1.2619, 0.9654],
                         [0.2490, 0.1860, 0.7592, 3.3507]]))
print(true_kbar)

temp_data = DataHolder((np.ones((num_genes, N_m)), np.ones((num_tfs, N_m))), None, time)
temp_lik = TranscriptionLikelihood(temp_data, opt)

m_pred = temp_lik.predict_m(true_kbar, δbar, w, fbar, w_0)

#Take observations
m_observed = tf.stack([m_pred.numpy()[i][common_indices] for i in range(num_genes)])

plt.figure(figsize=(10, 13))
for j in range(num_genes):
    plt.subplot(num_genes*100+21+j)
    plt.scatter(np.arange(N_p)[common_indices], m_observed[j], marker='x')
    plt.title(f'Gene {j}')
    plt.plot(m_pred[j], color='grey')
    
plt.tight_layout()



In [None]:
data = (m_observed, f_observed)

data = DataHolder(data, None, time)

opt = Options(preprocessing_variance=False, tf_mrna_present=True)
lik = TranscriptionLikelihood(data, opt)

plt.title('TF Proteins')

p = lik.calculate_protein(fbar, δbar)

for i in range(num_tfs):
    plt.plot(p[i], label=f'Protein {i}')
plt.legend();

In [None]:
import collections
TupleParams_pre = collections.namedtuple('TupleParams_pre', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V','σ2_f'
])
TupleParams = collections.namedtuple('TupleParams', [
    'fbar','δbar','kbar','σ2_m','w','w_0','L','V'
])

class TranscriptionCustom():
    '''
    Data is a tuple (m, f) of shapes (num, time)
    time is a tuple (t, τ, common_indices)
    '''
    def __init__(self, data: DataHolder, options: Options):
        self.data = data
        self.samples = None
        min_dist = min(data.t[1:]-data.t[:-1])
        self.N_p = data.τ.shape[0]
        self.N_m = data.t.shape[0]      # Number of observations

        self.num_tfs = data.f_obs.shape[0] # Number of TFs
        self.num_genes = data.m_obs.shape[0]

        self.likelihood = TranscriptionLikelihood(data, options)
        self.options = options
        # Adaptable variances
        a = tf.constant(-0.5, dtype='float64')
        b2 = tf.constant(2., dtype='float64')
        self.state_indices = {
            'δbar': 0,
            'kbar': 1,
            'fbar': 2, 
            'rbf_params': 3,
            'σ2_m': 4,
            'w': 5,
        }
        logistic_step_size = 0.00001

        # Interaction weights
        def w_log_prob(all_states):
            def w_log_prob_fn(wstar, w_0star):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                     w=wstar))
                new_prob += tf.reduce_sum(self.params.w.prior.log_prob(wstar)) 
                new_prob += tf.reduce_sum(self.params.w_0.prior.log_prob(w_0star))
                return tf.reduce_sum(new_prob)
            return w_log_prob_fn
        w = Parameter('w', tfd.Normal(f64(0), f64(2)), 
                      1*tf.ones((self.num_genes, self.num_tfs), dtype='float64'), 
                      step_size=0.04, hmc_log_prob=w_log_prob, requires_all_states=True)
        w_0 = Parameter('w_0', tfd.Normal(f64(0), f64(2)), tf.zeros(self.num_genes, dtype='float64'))

        # Latent function
        fbar_kernel = FKernel(self.likelihood, 
                              self.fbar_prior_params, 
                              self.num_tfs, self.num_genes, 
                              self.options.tf_mrna_present, 
                              self.state_indices,
                              0.1*tf.ones(N_p, dtype='float64'))
        fbar = Parameter('fbar', self.fbar_prior, 0.5*tf.ones((self.num_tfs, self.N_p), dtype='float64'),
                         kernel=fbar_kernel, requires_all_states=False)

        # GP hyperparameters
        def rbf_params_log_prob(all_states):
            def rbf_params_log_prob(vbar, l2bar):
                v = logit(vbar, nan_replace=self.params.V.prior.b)
                l2 = logit(l2bar, nan_replace=self.params.L.prior.b)

                new_prob = self.params.fbar.prior(all_states[self.state_indices['fbar']], vbar, l2bar)
#                 tf.print(new_prob)
                new_prob += self.params.V.prior.log_prob(v)
                new_prob += self.params.L.prior.log_prob(l2)
#                 tf.print('new prob', new_prob)
#                 if new_prob < -1e3:
#                     tf.print(all_states[self.state_indices['fbar']], v, l2)
                return tf.reduce_sum(new_prob)
            return rbf_params_log_prob

        V = Parameter('rbf_params', LogisticNormal(f64(1e-4), f64(1+max(np.var(data.f_obs, axis=1))),allow_nan_stats=False), 
                      [tf.constant(f64(0.8)), tf.constant(f64(0.98))], step_size=logistic_step_size, 
                      fixed=not options.tf_mrna_present, hmc_log_prob=rbf_params_log_prob, requires_all_states=True)
        L = Parameter('L', LogisticNormal(f64(min_dist**2-0.2), f64(data.t[-1]**2), allow_nan_stats=False), None)

        self.t_dist = get_rbf_dist(data.τ, self.N_p)

        # Translation kinetic parameters
        def δbar_log_prob(all_states):
            def δbar_log_prob_fn(state):
#                 chain_probs = list()
#                 new_prob = None
#                 for chain in range(state.shape[0]):
                new_prob = tf.reduce_sum(self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    δbar=state
                ))
                new_prob += self.params.δbar.prior.log_prob(logit(state))
#                 chain_probs.append(new_prob)
                    
#                 tf.print(tf.constant(chain_probs))
                return new_prob
            return δbar_log_prob_fn
        δbar = Parameter('δbar', LogisticNormal(0.1, 8), 0.6*tf.ones((self.num_tfs,), dtype='float64'), step_size=logistic_step_size, 
                         hmc_log_prob=δbar_log_prob, requires_all_states=True)

        # White noise for genes
        def σ2_m_log_prob(all_states):
            def σ2_m_log_prob_fn(σ2_mstar):
#                 tf.print('star:',σ2_mstar)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    σ2_m=σ2_mstar 
                ) + self.params.σ2_m.prior.log_prob(logit(σ2_mstar))
#                 tf.print('prob', tf.reduce_sum(new_prob))
                return tf.reduce_sum(new_prob)                
            return σ2_m_log_prob_fn
        σ2_m = Parameter('σ2_m', LogisticNormal(f64(1e-5), f64(max(np.var(data.f_obs, axis=1)))), 1e-4*tf.ones(self.num_genes, dtype='float64'), 
                         hmc_log_prob=σ2_m_log_prob, requires_all_states=True, step_size=logistic_step_size)
        # Transcription kinetic parameters
        def constrain_kbar(kbar, gene):
            '''Constrains a given row in kbar'''
#             if gene == 3:
#                 kbar[2] = np.log(0.8)
#                 kbar[3] = np.log(1.0)
            kbar[kbar < -10] = -10
            kbar[kbar > 3] = 3
            return kbar
        kbar_initial = 0.6*np.float64(np.c_[ # was -0.1
            np.ones(self.num_genes), # a_j
            np.ones(self.num_genes), # b_j
            np.ones(self.num_genes), # d_j
            np.ones(self.num_genes)  # s_j
        ])
        def kbar_log_prob(all_states):
            def kbar_log_prob_fn(kstar):
#                 tf.print(kstar)
                k = logit(kstar)
#                 tf.print(k)
                new_prob = self.likelihood.genes(
                    all_states=all_states, 
                    state_indices=self.state_indices,
                    kbar=kstar, 
                ) + tf.reduce_sum(self.params.kbar.prior.log_prob(k))
#                 tf.print(new_prob)
                return tf.reduce_sum(new_prob)
            return kbar_log_prob_fn
        for j, k in enumerate(kbar_initial):
            kbar_initial[j] = constrain_kbar(k, j)
        kbar = Parameter('kbar', LogisticNormal(0.01, 8), 
                         kbar_initial,
                         hmc_log_prob=kbar_log_prob,
                         constraint=constrain_kbar, step_size=logistic_step_size, requires_all_states=True)
        
        if not options.preprocessing_variance:
            σ2_f = Parameter('σ2_f', tfd.InverseGamma(f64(0.01), f64(0.01)), 1e-4*np.ones(self.num_tfs), step_size=tf.constant(0.5, dtype='float64'))
            self.params = TupleParams_pre(fbar, δbar, kbar, σ2_m, w, w_0, L, V, σ2_f)
        else:
            self.params = TupleParams(fbar, δbar, kbar, σ2_m, w, w_0, L, V)
            
    def fbar_prior_params(self, vbar, l2bar):
        v = logit(vbar, nan_replace=self.params.V.prior.b)
        l2 = logit(l2bar, nan_replace=self.params.L.prior.b)

#         tf.print('vl2', v, l2)
        jitter = tf.linalg.diag(1e-10 * tf.ones(self.N_p, dtype='float64'))
        K = tfm.multiply(v, tfm.exp(-tfm.square(self.t_dist)/(2*l2))) + jitter
        m = tf.zeros((self.N_p), dtype='float64')
        return m, K

    def fbar_prior(self, fbar, v, l2):
        m, K = self.fbar_prior_params(v, l2)
        jitter = tf.linalg.diag(1e-6 *tf.ones(self.N_p, dtype='float64'))

#         try:
        return tfd.MultivariateNormalTriL(loc=m, scale_tril=tf.linalg.cholesky(K+jitter)).log_prob(fbar)
#         except:
#             jitter = tf.linalg.diag(1e-4 *tf.ones(self.N_p, dtype='float64'))
#             try:
#                 return tfd.MultivariateNormalFullCovariance(m, K+jitter).log_prob(fbar)
#             except Exception as e:
#                 tf.print("Fbar prior error", e)
#                 raise e
#                 return tf.constant(-np.inf, dtype='float64')


    def sample(self, T=2000, store_every=10, burn_in=1000, report_every=100, num_chains=4):
        print('----- Sampling Begins -----')
        
        params = self.params
        progbar = tf.keras.utils.Progbar(
            100, width=30, verbose=1, interval=0.05, stateful_metrics=None,
            unit_name='step'
        )

        active_params = [
            params.δbar,
            params.kbar,
            params.fbar,
            params.V,
            params.σ2_m,
            #params.w,
        ]
        kernels = [param.kernel for param in active_params]
#         if self.options.tf_mrna_present:
        send_all_states = [param.requires_all_states for param in active_params]

        current_state = [
#             tf.stack([params.δbar.value for _ in range(num_chains)], axis=0),
            params.δbar.value,
            params.kbar.value, 
            params.fbar.value, 
            [*params.V.value],
            params.σ2_m.value,
            #[params.w.value, params.w_0.value]
        ]
        mixed_kern = MixedKernel(kernels, send_all_states)
        
        def trace_fn(a, previous_kernel_results):
            return previous_kernel_results.is_accepted

        # Run the chain (with burn-in).
        @tf.function
        def run_chain():
            # Run the chain (with burn-in).
            samples, is_accepted = tfp.mcmc.sample_chain(
                  num_results=T,
                  num_burnin_steps=burn_in,
                  current_state=current_state,
                  kernel=mixed_kern,
                  trace_fn=trace_fn)

            return samples, is_accepted

        samples, is_accepted = run_chain()

        add_to_previous = (self.samples is not None)
        for param in active_params:
            index = self.state_indices[param.name]
            param_samples = samples[index]
            if type(param_samples) is list:
                param_samples = [[param_samples[i][-1] for i in range(len(param_samples))]]
            
            param.value = param_samples[-1]

            if add_to_previous:
                self.samples[index] = tf.concat([self.samples[index], samples[index]], axis=0)
        
        if not add_to_previous:
            self.samples = samples     
        self.is_accepted = is_accepted
        print('----- Finished -----')
        return samples, is_accepted
        
            
    @staticmethod
    def initialise_from_state(args, state):
        model = TranscriptionMCMC(*args)
        model.acceptance_rates = state.acceptance_rates
        model.samples = state.samples
        return model

    def predict_m(self, kbar, δbar, w, fbar, w_0):
        return self.likelihood.predict_m(kbar, δbar, w, fbar, w_0)

    def predict_m_with_current(self):
        return self.likelihood.predict_m(self.params.kbar.value, 
                                         self.params.δbar.value, 
                                         self.params.w.value, 
                                         self.params.fbar.value,
                                         self.params.w_0.value)


In [None]:
model = TranscriptionCustom(data, opt)
samples, is_accepted = model.sample(T=400, burn_in=0)


In [None]:
kbar = model.samples[model.state_indices['kbar']]
δbar = model.samples[model.state_indices['δbar']]
fbar = model.samples[model.state_indices['fbar']]
print(logit(kbar[-1]))
σ2_m = model.samples[model.state_indices['σ2_m']]
rbf_params = model.samples[model.state_indices['rbf_params']]

# w = model.samples[model.state_indices['w']][0]
# w_0 = model.samples[model.state_indices['w']][1]

w = [1*tf.ones((num_genes, num_tfs), dtype='float64')] # TODO
w_0 = [tf.zeros(num_genes, dtype='float64')] # TODO

pcs = list()
for i, param in enumerate(model.state_indices):
    if i == 5:
        break
    pcs.append(tf.reduce_mean(tf.cast(is_accepted[i], dtype=tf.float32)).numpy())

display(pd.DataFrame([[f'{100*pc:.02f}%' for pc in pcs]], columns=list(model.state_indices)[:-1]))

plt.plot(fbar[:, 0, 0])

m_preds = list()
for i in range(1, 20):
    m_preds.append(model.likelihood.predict_m(kbar[-i], δbar[-i], w[-1], fbar[-i][0], w_0[-1])) #todo w[-1]
m_preds = np.array(m_preds)

f_samples = np.log(1+np.exp(fbar))
δ_samples = logit(δbar)
k_samples = logit(kbar)
rbf_params_samples = [logit(rbf_params[0]), logit(rbf_params[1])] 

print(true_kbar.shape)
true_k = [logit(true_kbar[:,1]).numpy(), logit(true_kbar[:,2]).numpy(), logit(true_kbar[:,3]).numpy()]
print(true_k)
plotters.generate_report(data, k_samples, δ_samples, f_samples, 
                         σ2_m, rbf_params_samples, m_preds, plot_barenco=False, true_k=true_k)


In [None]:
true_kbar