In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')

from tensorflow import math as tfm
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp
from reggae.data_loaders import scaled_barenco_data
from scipy.interpolate import interp1d
from sklearn import preprocessing
from reggae.utilities import discretise, logit, logistic, inverse_positivity
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.data_loaders import load_barenco_puma, DataHolder
from reggae.data_loaders.artificial import artificial_dataset

np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
plt.style.use('ggplot')
%matplotlib inline
f64 = np.float64


In [None]:
# Transcription factor
A = np.array([0, 0.9, 1.9, 2.6, 2.7, 2.15, 1.3, 0.5, 0.15, 0.2, 0.6, 0.95, 1.1])
interp = interp1d(np.arange(A.shape[0]), A, kind='cubic')
A = interp(np.linspace(0,A.shape[0]-1, 100))
f_b, _ = scaled_barenco_data(A)
print(f_b.shape,np.arange(0, 12, 2).shape)
plt.figure(figsize=(6, 4))
plt.plot(np.linspace(0, 12, 100), A, c='grey', label='Gao et al.')
plt.scatter(np.arange(0, 13, 2), 0.85*f_b, marker='x', color='cadetblue', label='Barenco et al.')
plt.xlabel('Time (h)')
plt.ylim(-0.2, 3.1)
plt.legend()

In [None]:
# Transcription factor
A = np.array([0, 0.15, 0.5, 0.95, 1.25, 1.15, 0.9, 0.61,0.5, 0.45, 0.44, 0.48, 0.5])
interp = interp1d(np.arange(A.shape[0]), A, kind='cubic')
A = interp(np.linspace(0,A.shape[0]-1, 100))
f_b, _ = scaled_barenco_data(A)
print(f_b.shape,np.arange(0, 12, 2).shape)
plt.figure(figsize=(6, 4))
plt.plot(np.linspace(0, 12, 100), A, c='grey', label='Gao et al.')
plt.scatter(np.arange(0, 13, 2), f_b, marker='x', color='cadetblue', label='Barenco et al.')
plt.xlabel('Time (h)')
plt.ylim(-0.2, 1.5)
plt.ylabel('Abundance (AU)')
# plt.legend()
plt.tight_layout()

In [None]:
print()
ktest = logistic(np.concatenate([
    logit(np.array([[0.50563, 0.66, 0.893, 0.9273],
[0.6402, 0.6335, 0.7390, 0.7714],
[0.6202, 0.6935, 0.7990, 0.7114],
[0.5328, 0.5603, 0.6498, 0.9244],
[0.5328, 0.6603, 0.6798, 0.8244],
[0.5939, 0.5821, 0.77716, 0.8387],
[0.50, 0.68, 0.75716, 0.8587],
[0.58, 0.67, 0.57, 0.95],
[0.5553, 0.5734, 0.6462, 0.9068],
[0.5750, 0.5548, 0.6380, 0.7347],
[0.5373, 0.5277, 0.6319, 0.8608],
[0.5372, 0.5131, 0.8000, 0.9004],
[0.5145, 0.5818, 0.6801, 0.9129]])),
   np.log(1+np.array([
    [1.319434062, 1.3962113525, 0.8245041865, 2.2684353378],
    [1.3080045137, 3.3992868747, 2.0189033658, 3.7460822389],
    [2.0189525448, 1.8480506624, 0.6805040228, 3.1039094120],
    [1.7758426875, 0.1907625023, 0.1925539427, 1.8306885751],
    [1.7207442227, 0.1252089546, 0.6297333943, 3.2567248923],
    [1.4878806850, 3.8623843570, 2.4816128746, 4.3931294404],
    [2.0853079514, 2.5115446790, 0.6560607356, 3.0945313562],
    [1.6144843688, 1.8651409657, 0.7785363895, 2.6845058360],
    [1.4858223122, 0.5396687493, 0.5842698019, 3.0026805243],
    [1.6610647522, 2.0486340884, 0.9863876546, 1.4300094581],
    [1.6027276189, 1.4320302060, 0.7175033248, 3.2151637970],
    [2.4912882714, 2.7935526605, 1.2438786874, 4.3944794204],
    [2.894114279, 1.4726280947, 0.7356719860, 2.2316019158]])) 
], axis=0))

In [None]:
num_genes = 20
num_tfs = 3
tf.random.set_seed(1)
w = tf.random.normal([num_genes, num_tfs], mean=0.5, stddev=0.71, seed=42, dtype='float64')

Δ_delay = tf.constant([0, 4, 10], dtype='float64')

w_0 = tf.zeros(num_genes, dtype='float64')

opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              kinetic_exponential=True,
              weights=True,
              initial_step_sizes={'logistic': 1e-8, 'latents': 10},
              delays=True)


data, fbar, kinetics = artificial_dataset(opt, TranscriptionLikelihood, num_genes=num_genes, 
                                          weights=(w, w_0), delays=Δ_delay.numpy(), t_end=10,
                                          true_kbar=true_kbar[:num_genes])
true_kbar, true_k_fbar = kinetics
f_i = inverse_positivity(fbar)
t, τ, common_indices = data.t, data.τ, data.common_indices
print(τ.shape)
common_indices = common_indices.numpy()
N_p = τ.shape[0]
N_m = t.shape[0]

model = TranscriptionMixedSampler(data, opt)


In [None]:
plt.figure(figsize=(10, 15))

lik = model.likelihood
Δ_nodelay = tf.constant([0, 0, 0], dtype='float64')

m_pred = lik.predict_m(true_kbar, true_k_fbar, (w), fbar, (w_0), Δ_nodelay).numpy()[0, [3,4,9,11, 13,19]]
print(m_pred.shape)
m_observed = m_pred[:,common_indices]
print(m_observed.shape)
num=6
for j in range(num):
    ax = plt.subplot(num,2, 1+j)
    plt.title(f'Gene {j}')
    plt.scatter(t, m_observed[j], marker='x')
    plt.plot(τ, m_pred[j], color='grey')
    plt.xlabel('Time (h)')
    plt.xticks(np.arange(10))
    plt.ylabel('Abundance / (AU)')

plt.tight_layout()

fig = plt.figure(figsize=(6, 4))
# plt.title('TF Proteins')
p_nodelay = lik.calculate_protein(fbar, true_k_fbar, Δ_nodelay)
p = lik.calculate_protein(fbar, true_k_fbar, Δ_delay)
A = np.array([0.01, 0.2, 0.47, 0.51, 0.35, 0.19, 0.17, 0.24, 0.37, 0.47, 0.39, 0.2, 0.05, 0.025, 0.025, 0.01, 0.005, 0.005])
interp = interp1d(np.arange(A.shape[0]), A, kind='cubic')
A = interp(np.linspace(0,12, 100))
f_b, _ = scaled_barenco_data(A)
print(f_b.shape,np.arange(0, 13, 2).shape)
mul = 1.5
plt.figure(figsize=(mul*3.49, mul*1.62))
plt.scatter(t, (2*A)[common_indices], marker='x')
plt.plot(τ, 2*A, c='tab:blue')
plt.xlabel('Time (h)')
plt.xticks(np.arange(10))
plt.xlabel('Time (h)')
plt.ylabel('Abundance / (AU)')

plt.legend();
plt.tight_layout()

In [None]:
var = 2*np.array([[0.4, 0.7, 0.1, 0.3, 0.6, 0.4],
                  [0.3, 0.2, 0.6, 0.4, 0.6, 0.5],
                  [1, 1.1, 0.9, 0.85, 0.5, 1.2]])
errors = np.array([0.1, 0.4, 0.05, 0.2, 0.15, 0.35])
width = 0.15
plt.figure(figsize=(7, 4.2))
plt.bar(np.arange(num)-2*width, var[0, :], width=2*width, tick_label=np.arange(num), color='firebrick', label='Basal rate')
plt.bar(np.arange(num), var[1, :], width=2*width, tick_label=np.arange(num), color='tab:blue', label='Decay rate')
plt.bar(np.arange(num)+2*width, var[2, :], width=2*width, color='chocolate', label='Sensitivity')
plt.errorbar(np.arange(num)-2*width, var[0, :], errors, fmt='none', capsize=5, color='black')
plt.errorbar(np.arange(num), var[1, :], 0.5*errors, fmt='none', capsize=5, color='black')
plt.errorbar(np.arange(num)+2*width, var[2, :], 1.4*errors, fmt='none', capsize=5, color='black')
plt.xlabel('Gene')
plt.legend()
plt.tight_layout()

In [None]:
import networkx as nx
import random
G = nx.DiGraph()
random.seed(42)
np.random.seed(42)
nodes, node_colors, sizes, edges, colors = list(), list(), list(), list(), list()

nodes.append('TF')
node_colors.append('slategrey')
sizes.append(1000)

for j in range(num):
    nodes.append(f'g{j}')
    node_colors.append('chocolate')

b = var[0, :]
s = var[2, :]
s_min = np.min(s)
s_diff = max(s) - s_min
b_min = np.min(b)
b_diff = max(b) - b_min
plt.figure(figsize=(3.8, 3.8))
for j in range(num):
    for i in range(1):
        edge = ('TF', f'g{j}')
        weight = (s[j]-s_min) / s_diff
        if weight == 0:
            weight = 0.2
        colors.append(f'{1-weight}')
        edges.append(edge)
        print(weight)
    sizes.append(int(700 + 1700 * (b[j]-b_min)/b_diff))
G.add_nodes_from(nodes)
G.add_edges_from(edges)
node_size = sizes
pos=nx.spring_layout(G)
# pos = nx.nx_pydot.graphviz_layout(G, prog='dot')
nx.draw(G, pos=pos, edge_color=colors, node_color=node_colors, node_size=node_size, with_labels=True)


In [None]:
# Copied from matlab
x = np.array([
[0.0969 , 0.0785 ,  0.0451 , 0.0001 , 0.0800],
[0.0998 , 0.0815 ,  0.0479 , 0.0014 , 0.0828],
[0.1088 , 0.0907 ,  0.0568 , 0.0054 , 0.0917],
[0.1245 , 0.1065 ,  0.0720 , 0.0123 , 0.1068],
[0.1470 , 0.1290 ,  0.0938 , 0.0220 , 0.1286],
[0.1766 , 0.1583 ,  0.1222 , 0.0346 , 0.1571],
[0.2133 , 0.1941 ,  0.1571 , 0.0498 , 0.1921],
[0.2568 , 0.2362 ,  0.1982 , 0.0676 , 0.2337],
[0.3071 , 0.2843 ,  0.2454 , 0.0879 , 0.2814],
[0.3638 , 0.3379 ,  0.2981 , 0.1103 , 0.3350],
[0.4265 , 0.3967 ,  0.3561 , 0.1348 , 0.3940],
[0.4950 , 0.4603 ,  0.4191 , 0.1611 , 0.4583],
[0.5691 , 0.5283 ,  0.4866 , 0.1890 , 0.5274],
[0.6486 , 0.6006 ,  0.5586 , 0.2186 , 0.6013],
[0.7335 , 0.6771 ,  0.6350 , 0.2497 , 0.6799],
[0.8239 , 0.7578 ,  0.7159 , 0.2825 , 0.7633],
[0.9199 , 0.8430 ,  0.8013 , 0.3169 , 0.8517],
[1.0218 , 0.9328 ,  0.8917 , 0.3530 , 0.9452],
[1.1300 , 1.0276 ,  0.9871 , 0.3910 , 1.0442],
[1.2445 , 1.1273 ,  1.0877 , 0.4310 , 1.1487],
[1.3655 , 1.2322 ,  1.1936 , 0.4728 , 1.2589],
[1.4928 , 1.3419 ,  1.3046 , 0.5165 , 1.3746],
[1.6260 , 1.4559 ,  1.4203 , 0.5618 , 1.4952],
[1.7642 , 1.5735 ,  1.5397 , 0.6083 , 1.6201],
[1.9064 , 1.6934 ,  1.6618 , 0.6555 , 1.7481],
[2.0508 , 1.8140 ,  1.7850 , 0.7028 , 1.8776],
[2.1956 , 1.9333 ,  1.9074 , 0.7492 , 2.0067],
[2.3385 , 2.0493 ,  2.0268 , 0.7939 , 2.1333],
[2.4770 , 2.1594 ,  2.1411 , 0.8358 , 2.2550],
[2.6086 , 2.2613 ,  2.2477 , 0.8740 , 2.3695],
[2.7308 , 2.3527 ,  2.3444 , 0.9075 , 2.4743],
[2.8412 , 2.4315 ,  2.4290 , 0.9355 , 2.5672],
[2.9378 , 2.4958 ,  2.4998 , 0.9572 , 2.6465],
[3.0189 , 2.5445 ,  2.5554 , 0.9722 , 2.7107],
[3.0836 , 2.5767 ,  2.5951 , 0.9802 , 2.7589],
[3.1312 , 2.5925 ,  2.6186 , 0.9814 , 2.7908],
[3.1620 , 2.5922 ,  2.6263 , 0.9759 , 2.8066],
[3.1766 , 2.5770 ,  2.6191 , 0.9644 , 2.8072],
[3.1762 , 2.5483 ,  2.5983 , 0.9475 , 2.7939],
[3.1625 , 2.5082 ,  2.5660 , 0.9262 , 2.7684],
[3.1377 , 2.4589 ,  2.5241 , 0.9014 , 2.7328],
[3.1039 , 2.4029 ,  2.4750 , 0.8744 , 2.6895],
[3.0635 , 2.3427 ,  2.4211 , 0.8460 , 2.6407],
[3.0190 , 2.2807 ,  2.3648 , 0.8175 , 2.5888],
[2.9727 , 2.2193 ,  2.3084 , 0.7898 , 2.5362],
[2.9268 , 2.1606 ,  2.2539 , 0.7637 , 2.4849],
[2.8833 , 2.1064 ,  2.2031 , 0.7400 , 2.4367],
[2.8438 , 2.0582 ,  2.1577 , 0.7193 , 2.3933],
[2.8099 , 2.0174 ,  2.1188 , 0.7021 , 2.3560],
[2.7827 , 1.9848 ,  2.0876 , 0.6888 , 2.3259],
[2.7632 , 1.9612 ,  2.0646 , 0.6795 , 2.3038],
[2.7519 , 1.9468 ,  2.0503 , 0.6743 , 2.2900],
[2.7491 , 1.9417 ,  2.0449 , 0.6732 , 2.2850],
[2.7549 , 1.9456 ,  2.0481 , 0.6760 , 2.2884],
[2.7689 , 1.9580 ,  2.0595 , 0.6824 , 2.3000],
[2.7904 , 1.9779 ,  2.0783 , 0.6919 , 2.3189],
[2.8185 , 2.0042 ,  2.1033 , 0.7040 , 2.3441],
[2.8517 , 2.0352 ,  2.1331 , 0.7180 , 2.3742],
[2.8884 , 2.0692 ,  2.1659 , 0.7330 , 2.4075],
[2.9265 , 2.1039 ,  2.1997 , 0.7481 , 2.4420],
[2.9637 , 2.1371 ,  2.2323 , 0.7623 , 2.4754],
[2.9977 , 2.1664 ,  2.2613 , 0.7747 , 2.5055],
[3.0258 , 2.1893 ,  2.2843 , 0.7840 , 2.5297],
[3.0457 , 2.2033 ,  2.2991 , 0.7894 , 2.5458],
[3.0551 , 2.2065 ,  2.3034 , 0.7900 , 2.5516],
[3.0519 , 2.1971 ,  2.2955 , 0.7851 , 2.5452],
[3.0345 , 2.1736 ,  2.2740 , 0.7742 , 2.5252],
[3.0020 , 2.1354 ,  2.2382 , 0.7570 , 2.4907],
[2.9537 , 2.0823 ,  2.1878 , 0.7336 , 2.4414],
[2.8900 , 2.0148 ,  2.1231 , 0.7041 , 2.3776],
[2.8117 , 1.9340 ,  2.0451 , 0.6693 , 2.3002],
[2.7200 , 1.8417 ,  1.9555 , 0.6299 , 2.2106],
[2.6171 , 1.7399 ,  1.8562 , 0.5868 , 2.1109],
[2.5053 , 1.6314 ,  1.7499 , 0.5414 , 2.0035],
[2.3874 , 1.5191 ,  1.6391 , 0.4947 , 1.8912],
[2.2665 , 1.4060 ,  1.5270 , 0.4482 , 1.7769],
[2.1455 , 1.2950 ,  1.4164 , 0.4031 , 1.6634],
[2.0274 , 1.1893 ,  1.3101 , 0.3607 , 1.5537],
[1.9150 , 1.0912 ,  1.2107 , 0.3220 , 1.4504],
[1.8107 , 1.0030 ,  1.1205 , 0.2879 , 1.3558],
[1.7166 , 0.9265 ,  1.0411 , 0.2591 , 1.2716],
[1.6342 , 0.8629 ,  0.9740 , 0.2360 , 1.1994],
[1.5645 , 0.8129 ,  0.9198 , 0.2189 , 1.1398],
[1.5079 , 0.7766 ,  0.8788 , 0.2077 , 1.0933],
[1.4645 , 0.7535 ,  0.8506 , 0.2021 , 1.0596],
[1.4337 , 0.7429 ,  0.8346 , 0.2017 , 1.0382],
[1.4145 , 0.7434 ,  0.8296 , 0.2058 , 1.0278],
[1.4056 , 0.7534 ,  0.8341 , 0.2138 , 1.0272],
[1.4054 , 0.7711 ,  0.8464 , 0.2247 , 1.0347],
[1.4121 , 0.7945 ,  0.8646 , 0.2378 , 1.0485],
[1.4238 , 0.8217 ,  0.8869 , 0.2520 , 1.0666],
[1.4387 , 0.8505 ,  0.9113 , 0.2667 , 1.0872],
[1.4548 , 0.8793 ,  0.9360 , 0.2809 , 1.1086],
[1.4705 , 0.9063 ,  0.9594 , 0.2940 , 1.1291],
[1.4844 , 0.9300 ,  0.9802 , 0.3054 , 1.1473],
[1.4951 , 0.9495 ,  0.9971 , 0.3146 , 1.1619],
[1.5016 , 0.9637 ,  1.0093 , 0.3214 , 1.1721],
[1.5032 , 0.9721 ,  1.0162 , 0.3254 , 1.1772],
[1.4994 , 0.9744 ,  1.0173 , 0.3267 , 1.1768],
[1.4900 , 0.9706 ,  1.0126 , 0.3252 , 1.1708]])

f = [0.0000, 0.0230,  0.0492,  0.0782,  0.1094,  0.1424,  0.1765,  0.2114,  0.2466,  0.2819,  0.3172,  0.3527,  0.3886,  0.4253,  0.4632,  0.5027,  0.5443,  0.5882,  0.6342,  0.6821,  0.7312,  0.7805,  0.8287,  0.8741,  0.9151,  0.9498,  0.9765,  0.9938,  1.0006,  0.9961,  0.9804,  0.9539,  0.9176,  0.8730,  0.8220,  0.7668,  0.7097,  0.6530,  0.5990,  0.5494,  0.5058,  0.4694,  0.4408,  0.4206,  0.4086,  0.4046,  0.4082,  0.4187,  0.4353,  0.4573,  0.4837,  0.5136,  0.5458,  0.5791,  0.6122,  0.6436,  0.6718,  0.6951,  0.7120,  0.7210,  0.7208,  0.7104,  0.6893,  0.6573,  0.6150,  0.5632,  0.5035,  0.4377,  0.3682,  0.2974,  0.2278,  0.1621,  0.1026,  0.0512,  0.0097, -0.0209, -0.0400, -0.0475, -0.0438, -0.0300, -0.0072,  0.0228,  0.0582,  0.0971,  0.1376,  0.1777,  0.2157,  0.2502,  0.2799,  0.3040,  0.3219,  0.3332,  0.3381,  0.3368,  0.3299,  0.3179,  0.3018,  0.2824,  0.2605,  0.2371]


b=[0.0020,0.0830,0.1508,0.2024,0.2373,0.2559,0.2591,0.2488,0.2283,0.2028,0.1804,0.1719,0.1851,0.2179,0.2612,0.3069,0.3493,0.3846,0.4102,0.4247,0.4277,0.4196,0.4021,0.3778,0.3510,0.3271,0.3123,0.3119,0.3270,0.3545,0.3887,0.4237,0.4548,0.4784,0.4920,0.4944,0.4852,0.4650,0.4355,0.3998,0.3622,0.3290,0.3077,0.3047,0.3214,0.3534,0.3935,0.4346,0.4716,0.5006,0.5191,0.5256,0.5197,0.5016,0.4729,0.4358,0.3943,0.3535,0.3207,0.3037,0.3075,0.3307,0.3668,0.4079,0.4472,0.4801,0.5033,0.5144,0.5125,0.4971,0.4689,0.4294,0.3811,0.3286,0.2787,0.2426,0.2334,0.2564,0.3031,0.3606,0.4194,0.4732,0.5179,0.5506,0.5693,0.5728,0.5606,0.5327,0.4899,0.4340,0.3683,0.2993,0.2407,0.2194,0.2578,0.3429,0.4521,0.5724,0.6971,0.8222]
fig = plt.figure(figsize=(14, 17))
for j in range(5):
    ax = plt.subplot(531+j)
    plt.plot(x[:,j], color='grey')
    plt.title(genes[0].index[j])
    plt.scatter(np.arange(0, 100, 100/7), Y[7*j:(j+1)*7], marker='x', label='Observed')
    plt.xticks([100/6 * i for i in range(7)])
    ax.axes.set_xticklabels(t)
    plt.xlabel('Time (h)')
plt.tight_layout()
    


# if plot_barenco:
#     barenco_f, _ = scaled_barenco_data(np.mean(f_samples[-10:], axis=0))
#     plt.scatter(τ[common_indices], barenco_f, marker='x', s=60, linewidth=3, label='Barenco et al.')
f = np.array(f)
b = np.array(b)
# plt.fill_between(τ, bounds[:, 0], bounds[:, 1], color='grey', alpha=0.3, label='95% credibility interval')
fig = plt.figure(figsize=(7, 5))
plt.xticks([100/6 * i for i in range(7)])
fig.axes[0].set_xticklabels(t)
plt.ylim((-0.4,1.4))
plt.xlabel('Time (h)')
plt.plot(f, c='cadetblue')
plt.fill_between(np.arange(0, 100), f+b, f-b, color='grey', alpha=0.3, label='2 x st. dev')

# print(genes[1])
print(np.arange(len(f))[::15].shape, tfs[1][0].shape)
tf = tfs[1][0]/np.mean(tfs[1][0]) * np.mean(f)
plt.scatter([100/6 * i for i in range(7)], tf, marker='x', s=70, label='Observed')
plt.legend()