## Metropolis Hastings Custom MCMC Algorithm

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

from reggae.data_loaders import load_barenco_puma, load_3day_dros, DataHolder
from reggae.mcmc import create_chains, MetropolisHastings
from reggae.models import transcription
from reggae.utilities import get_rbf_dist, exp, mult, discretise

import numpy as np
import arviz
from multiprocessing import Pool

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline

In [None]:
#df, genes, genes_se, m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()
m_observed, f_observed, σ2_m_pre, σ2_f_pre, t = load_barenco_puma()

# m_observed, f_observed, t = load_3day_dros()

replicate = 0

m_df, m_observed = m_observed 
f_df, f_observed = f_observed

# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed[replicate]
f_observed = np.atleast_2d(f_observed[replicate])
σ2_m_pre = σ2_m_pre[0]
σ2_f_pre = σ2_f_pre[0]

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
noise_data = (σ2_m_pre, σ2_f_pre)
time = (t, τ, common_indices)

data = DataHolder(data, noise_data, time)
opt = transcription.Options(preprocessing_variance=True, tf_mrna_present=True)

# transcription_model = transcription.TranscriptionMCMC(data, opt)
T = 50
store_every = 1
burn_in = 0
report_every = 20
num_chains = 4
tune_every = 50

In [None]:
opt = transcription.Options(preprocessing_variance=True, tf_mrna_present=False)

job = create_chains(transcription.TranscriptionMCMC, [data, opt], {
              'T': T, 
              'store_every': store_every, 
              'burn_in': burn_in,
              'report_every': report_every,
              'tune_every':tune_every}, num_chains=4)

    
print('Done')

In [None]:
for res in job:
    print(res)

In [None]:
# Begin MCMC

transcription_model.sample(T, store_every, burn_in, report_every)

print(transcription_model.acceptance_rates)

samples = transcription_model.samples


In [None]:
print(job[0].get())


Step size is standard dev, too small means it takes long time to reach high density areas. too long means we reject many of samples

## Plots

In [None]:
## samples = transcription_model.samples
plt.figure(figsize=(10,14))
parameter_names = transcription_model.acceptance_rates.keys()
acc_rates = samples['acc_rates']

for i, name in enumerate(parameter_names):
    plt.subplot(len(parameter_names), 3, i+1)
    deltas = acc_rates[i]
    plt.plot(deltas)
    plt.title(name)
plt.tight_layout()

In [None]:
# Plot decay
plt.figure(figsize=(10, 8))
for i, param in enumerate(['δbar', 'L', 'V']):
    ax = plt.subplot(331+i)
    plt.plot(samples[param].get())
    ax.set_title(param)
#'σ', 'w']):

plt.figure()
for j in range(num_genes):
    plt.plot(samples['w'].get()[:, j], label=m_df.index[j])
plt.legend()
plt.title('Interaction weights')

plt.figure()
for j in range(num_genes):
    plt.plot(samples['w_0'].get()[:,j])
plt.title('Interaction bias')


### Plot transcription ODE kinetic params


In [None]:
plt.figure(figsize=(14, 14))
plt.title('Transcription ODE kinetic parameters')
labels = ['a', 'b', 'd', 's']
for j in range(num_genes):
    ax = plt.subplot(num_genes, 2, j+1)
    k_param = samples['kbar'].get()[:, j]
#     print(k_param)
    
    for k in range(4):
        plt.plot(k_param[-20000:, k], label=labels[k])
    plt.axhline(np.mean(k_param[-20000:, 3]))
    plt.legend()
    ax.set_title(f'Gene {j}')

plt.tight_layout()


In [None]:
plot_barenco = True
def plot_kinetics(kbar, plot_barenco=False):
    plt.figure(figsize=(14, 14))
    k_latest = np.exp(np.mean(kbar[-100:], axis=0))
    B = k_latest[:,1]
    D = k_latest[:,2]
    S = k_latest[:,3]
    data = [B, S, D]
    barenco_data = [None, None, None]

    if plot_barenco:
        # From Martino paper ... do a rough rescaling so that the scales match.
        B_barenco = np.array([2.6, 1.5, 0.5, 0.2, 1.35])[[0, 4, 2, 3, 1]]
        B_barenco = B_barenco/np.mean(B_barenco)*np.mean(B)
        S_barenco = (np.array([3, 0.8, 0.7, 1.8, 0.7])/1.8)[[0, 4, 2, 3, 1]]
        S_barenco = S_barenco/np.mean(S_barenco)*np.mean(S)
        D_barenco = (np.array([1.2, 1.6, 1.75, 3.2, 2.3])*0.8/3.2)[[0, 4, 2, 3, 1]]
        D_barenco = D_barenco/np.mean(D_barenco)*np.mean(D)
        barenco_data = [B_barenco, S_barenco, D_barenco]

    labels = ['Basal rates', 'Sensitivities', 'Decay rates']

    plotnum = 331
    for A, B, label in zip(data, barenco_data, labels):
        plt.subplot(plotnum)
        plotnum+=1
        plt.bar(np.arange(num_genes)-0.2, A, width=0.4, tick_label=m_df.index)
        if B is not None:
            plt.bar(np.arange(num_genes)+0.2, B, width=0.4, color='blue', align='center')
        plt.title(label)

kbar = samples['kbar'].get()
plot_kinetics(kbar, plot_barenco)

In [None]:
# Plot genes
plt.figure(figsize=(14, 17))
m_pred = transcription_model.predict_m()
print(np.arange(N_p)[common_indices])
for j in range(num_genes):
    ax = plt.subplot(531+j)
    plt.title(m_df.index[j])
    plt.scatter(common_indices, m_observed[j], marker='x')
    # plt.errorbar([n*10+n for n in range(7)], Y[j], 2*np.sqrt(Y_var[j]), fmt='none', capsize=5)
    plt.plot(m_pred[j,:], color='grey')
    plt.xticks(np.arange(N_p)[common_indices])
    ax.set_xticklabels(np.arange(t[-1]))
    plt.xlabel('Time (h)')
    
plt.tight_layout()

In [None]:
plt.figure(figsize=(12, 10))
plt.title('Noise variances')
for i, j in enumerate(range(num_genes)):
    ax = plt.subplot(num_genes, num_genes-2, i+1)
    plt.title(m_df.index[j])
    plt.plot(samples['σ2_m'].get()[:,j])
    
plt.tight_layout()

In [None]:
def scaled_barenco_data(f):
    scale_pred = np.sqrt(np.var(f))
    barencof = np.array([[0.0, 200.52011, 355.5216125, 205.7574913, 135.0911372, 145.1080997, 130.7046969],
                         [0.0, 184.0994134, 308.47592, 232.1775328, 153.6595161, 85.7272235, 168.0910562],
                         [0.0, 230.2262511, 337.5994811, 276.941654, 164.5044287, 127.8653452, 173.6112139]])

    barencof = barencof[0]/(np.sqrt(np.var(barencof[0])))*scale_pred
    # measured_p53 = df[df.index.isin(['211300_s_at', '201746_at'])]
    # measured_p53 = measured_p53.mean(0)
    # measured_p53 = measured_p53*scale_pred
    measured_p53 = 0
    
    return barencof, measured_p53

def plot_f(f):
    fig = plt.figure(figsize=(13, 7))

    barencof = scaled_barenco_data(f)
    lb = len(barencof)
    plt.plot(np.arange(N_p), f, color='grey')
    plt.scatter(np.arange(0, N_p)[common_indices], barencof, marker='x')
    plt.xticks(np.arange(N_p)[common_indices])
    fig.axes[0].set_xticklabels(np.arange(N_m)*2)
    plt.xlabel('Time (h)')
    


In [None]:
fig = plt.figure(figsize=(13, 7))
f_samples = np.log(1+np.exp(np.array(samples['fbar'].get()[-50:])))
if 'σ2_f' in transcription_model.params._fields:
    σ2_f = transcription_model.params.σ2_f.value
else:
    σ2_f = σ2_f_pre
    
bounds = arviz.hpd(f_samples, credible_interval=0.95)
for i in range(1,20):
    f_i = f_samples[-i]
#     plt.plot(f_i)
#     f_i[0] = 0
    plt.plot(τ, f_i, c='blue', alpha=0.5)

if plot_barenco:
    barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
    plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco')

plt.scatter(τ[common_indices], f_observed[0], marker='x', s=60, linewidth=3, label='Observed')
plt.errorbar(τ[common_indices], f_observed[0], 2*np.sqrt(σ2_f[0]), fmt='none', capsize=5, color='blue')

plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.5)
plt.xticks(np.arange(N_m)*2)
fig.axes[0].set_xticklabels(t)
plt.xlabel('Time (h)')
plt.legend();