In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
from tensorflow import math as tfm
import tensorflow_probability as tfp
from timeit import default_timer as timer

from sklearn.datasets import make_friedman2
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel, RBF
from scipy.spatial.distance import pdist, cdist, squareform
import numpy as np
import tensorflow as tf
import gpflow
from reggae.data_loaders import load_covid, DataHolder, scaled_barenco_data
from reggae.data_loaders.artificial import artificial_dataset
from reggae.mcmc import create_chains, MetropolisHastings, Parameter
from reggae.utilities import discretise, logit, LogisticNormal, jitter_cholesky, inverse_positivity, logistic
from reggae.plot import plotters
from reggae.models import TranscriptionLikelihood, Options, TranscriptionMixedSampler
from reggae.models.results import GenericResults, SampleResults
from tensorflow_probability import distributions as tfd

import matplotlib.pyplot as plt
from IPython.display import HTML
plt.rcParams['animation.ffmpeg_path'] = 'C:\\Users\\Jacob\\Documents\\ffmpeg-static\\bin\\ffmpeg.exe'
f64 = np.float64
# %load_ext tensorboard
# %tensorboard --logdir logs/reggae

In [None]:
num_genes = 12
num_tfs = 3
tf.random.set_seed(1)
w = tf.random.normal([num_genes, num_tfs], mean=0.5, stddev=0.71, seed=42, dtype='float64')

Δ_delay = tf.constant([0, 0, 0], dtype='float64')

w_0 = tf.zeros(num_genes, dtype='float64')

true_kbar = logistic((np.array([
    [1.6319434062, 1.3962113525, 0.8245041865, 2.2684353378],
    [1.7080045137, 3.3992868747, 2.0189033658, 3.7460822389],
    [2.4189525448, 1.8480506624, 0.6805040228, 3.1039094120],
    [1.7758426875, 0.1907625023, 0.1925539427, 1.8306885751],
    [1.7207442227, 0.1252089546, 0.6297333943, 3.2567248923],
    [1.4878806850, 3.8623843570, 2.4816128746, 4.3931294404],
    [2.5853079514, 2.5115446790, 0.6560607356, 3.0945313562],
    [1.6144843688, 1.8651409657, 0.7785363895, 2.6845058360],
    [1.4858223122, 0.5396687493, 0.5842698019, 3.0026805243],
    [1.6610647522, 2.0486340884, 0.9863876546, 1.4300094581],
    [1.6027276189, 1.4320302060, 0.7175033248, 3.2151637970],
    [2.1912882714, 2.7935526605, 1.2438786874, 4.3944794204],
    [1.3894114279, 1.4726280947, 0.7356719860, 2.2316019158],
 [1.7927833839, 1.0405867396, 0.4055775218, 2.9888350247],
 [1.0429721112, 0.1011544950, 0.7330443670, 3.1936843755],
 [1.2519286771, 2.0617880701, 1.0759649567, 3.9406060364],
 [1.4297185709, 1.3578824015, 0.6037986912, 2.6512418604],
 [1.9344878813, 1.4235867760, 0.8226320338, 4.2847217252],
 [1.4325562449, 1.1940752177, 1.0556928599, 4.1850449557],
 [0.8911103971, 1.3560009300, 0.5643954823, 3.4300182328],
 [1.0269654997, 1.0788097511, 0.5268448648, 4.4793299593],
 [0.8378220502, 1.8148234459, 1.0167440138, 4.4903387696]]
)))
true_kbar = true_kbar[:num_genes]
opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              kinetic_exponential=True,
              weights=True,
              initial_step_sizes={'logistic': 0.00001, 'latents': 6},
              delays=True)


data, fbar, kinetics = artificial_dataset(opt, num_genes=num_genes, weights=(w, w_0), delays=Δ_delay.numpy(), true_kbar=true_kbar)
true_kbar, true_k_fbar = kinetics
f_i = inverse_positivity(fbar)
t, τ, common_indices = data.t, data.τ, data.common_indices
common_indices = common_indices.numpy()
N_p = τ.shape[0]
N_m = t.shape[0]

def expand(x):
    return np.expand_dims(x, 0)
true_results = SampleResults(opt, expand(fbar), expand(true_kbar), expand(true_k_fbar), Δ_delay, 
                             None, expand(logistic(w)), expand(logistic(w_0)), None, None)
model = TranscriptionMixedSampler(data, opt)


In [None]:
Δ_nodelay = tf.constant([0, 0, 0], dtype='float64')
w_ = w.numpy()
w_[0] = np.array([0.0065790198, 0.00473748, 1.00084])
print(w[0])
m_pred = model.likelihood.predict_m(true_kbar, true_k_fbar, logistic(w), fbar, logistic(w_0), Δ_delay)
m_pred_ = model.likelihood.predict_m(true_kbar, true_k_fbar, logistic(w_), fbar, logistic(w_0), Δ_delay)
print(m_pred.shape)

plt.plot(m_pred[0, 0])
plt.plot(m_pred_[0, 0])

In [None]:
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
tf.enable_v2_behavior()

tfd = tfp.distributions

dtype = np.float32
true_mean = dtype([0, 0])
true_cov = dtype([[1, 0.5],
                  [0.5, 1]])
num_results = 100
num_chains = 2

# Target distribution is defined through the Cholesky decomposition `L`:
L = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=L)

# Initial state of the chain
init_state = np.ones([num_chains, 2], dtype=dtype)

# Run Random Walk Metropolis with normal proposal for `num_results`
# iterations for `num_chains` independent chains:
nuts = tfp.mcmc.NoUTurnSampler(
        target_log_prob_fn=target.log_prob,
        step_size=1,
        seed=54)
rwm = tfp.mcmc.RandomWalkMetropolis(
        target_log_prob_fn=target.log_prob,
        seed=54)

sampless = list()

start = timer()

for sampler in [rwm, nuts]:
    print(sampler)
    samples, _ = tfp.mcmc.sample_chain(
        num_results=num_results,
        current_state=init_state,
        kernel=sampler,
        num_burnin_steps=50,
        trace_fn=lambda _, pkr: pkr.is_accepted,
        num_steps_between_results=1,  # Thinning.
        parallel_iterations=1)

    sampless.append(samples)
end = timer()
print(f'Time taken: {(end - start):.04f}s')

In [None]:
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import numpy.ma as ma
from numpy.random import uniform, seed
from matplotlib import cm
def gauss(x,y,Sigma,mu):
    X=np.vstack((x,y)).T
#     print(X.shape, target.prob(X).shape)
    mat_multi=np.dot((X-mu[None,...]).dot(np.linalg.inv(Sigma)),(X-mu[None,...]).T)
#     print(mat_multi.shape)
    return  np.diag(np.exp(-1*(mat_multi)))

def plot_countour(x,y,z, title):
    # define grid.
    xi = np.linspace(-2.1, 2.1, 100)
    yi = np.linspace(-2.1, 2.1, 100)
    ## grid the data.
    zi = griddata((x, y), z, (xi[None,:], yi[:,None]), method='cubic')
    levels = [0.2, 0.4, 0.6, 0.8, 1.0]
    # contour the gridded data, plotting dots at the randomly spaced data points.
    CS = plt.contour(xi,yi,zi,len(levels),linewidths=0.5,colors='k', levels=levels)
    plt.xlim(-2, 2)
    plt.ylim(-2, 2)
    plt.title(title)
plt.style.use('ggplot')
rwm = sampless[0][::4, :, :]
nuts = sampless[1]
samplist = [rwm, nuts]
titles = ['', '']
print(rwm.shape, nuts.shape)
plt.figure(figsize=(5, 3))
for i in range(2):
#     plt.subplot(2, 2, i+1)
    plt.figure(figsize=(5, 3))
    seed(1234)
    npts = 1000
    x = uniform(-2, 2, npts)
    y = uniform(-2, 2, npts)
    z = gauss(x, y, Sigma=np.asarray(true_cov), mu=np.asarray(true_mean))

    samp = tf.reduce_mean(samplist[i], axis=1)
    plt.plot(samp[:, 0], samp[:, 1], c='slategrey')
    plot_countour(x, y, z, titles[i])


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

# Our 2-dimensional distribution will be over variables X and Y
N = 60
X = np.linspace(-3, 3, N)
Y = np.linspace(-3, 4, N)
X, Y = np.meshgrid(X, Y)

# Mean vector and covariance matrix
mu = np.array([0., 1.])
Sigma = np.array([[ 1. , -0.5], [-0.5,  1.5]])

# Pack X and Y into a single 3-dimensional array
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y

def multivariate_gaussian(pos, mu, Sigma):
    """Return the multivariate Gaussian distribution on array pos.

    pos is an array constructed by packing the meshed arrays of variables
    x_1, x_2, x_3, ..., x_k into its _last_ dimension.

    """

    n = mu.shape[0]
    Sigma_det = np.linalg.det(Sigma)
    Sigma_inv = np.linalg.inv(Sigma)
    N = np.sqrt((2*np.pi)**n * Sigma_det)
    # This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
    # way across all the input variables.
    fac = np.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu)

    return np.exp(-fac / 2) / N

# The distribution on the variables X, Y packed into pos.
Z = multivariate_gaussian(pos, mu, Sigma)

# Create a surface plot and projected filled contour plot under it.
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True,
                cmap=cm.viridis)

cset = ax.contourf(X, Y, Z, zdir='z', offset=-0.15, cmap=cm.viridis)
plt.axhline(0.3, color='red', linewidth=1000)
# Adjust the limits, ticks and view angle
ax.set_zlim(-0.15,0.2)
ax.set_zticks(np.linspace(0,0.2,5))
ax.view_init(27, -21)

plt.show()


In [None]:
current_state = tf.ones(3, dtype='float64')
num_tfs = current_state.shape[0]
new_state = current_state
Δrange = np.arange(0, 10+1, dtype='float64')
Δrange_tf = tf.range(0, 10+1, dtype='float64')

for i in range(num_tfs):
    # Generate normalised cumulative distribution
    probs = list()
    mask = np.zeros((num_tfs, ), dtype='float64')
    mask[i] = 1

    for Δ in Δrange:
        test_state = (1-mask) * new_state + mask * Δ

        probs.append(f64(1+i*Δ)) #+ tf.reduce_sum(self.prior.log_prob(Δ)))

    probs =  tf.stack(probs) - tfm.reduce_max(probs)
    probs = tfm.exp(probs)
    probs = probs / tfm.reduce_sum(probs)
    cumsum = tfm.cumsum(probs)
    # tf.print(cumsum)
    u = np.random.uniform()
    index = tf.where(cumsum == tf.reduce_min(cumsum[(cumsum - u) > 0]))
    chosen = Δrange_tf[index[0][0]]
    new_state = (1-mask) * new_state + mask * chosen
    print(cumsum)
print()


In [None]:
weight_prior = tfd.Normal(0.0, 0.1)
bias_prior = tfd.Normal(0.0, 1.0)  # near-uniform

def get_initial_state(weight_prior, bias_prior, num_features, layers=None):
    """generate starting point for creating Markov chain
        of weights and biases for fully connected NN
    Keyword Arguments:
        layers {tuple} -- number of nodes in each layer of the network
    Returns:
        list -- architecture of FCNN with weigths and bias tensors for each layer
    """
    # make sure the last layer has two nodes, so that output can be split into
    # predictive mean and learned loss attenuation (see https://arxiv.org/abs/1703.04977)
    # which the network learns individually
    if layers is not None:
        assert layers[-1] == 2
    if layers is None:
        layers = (
            num_features,
            num_features // 2,
            num_features // 5,
            num_features // 10,
            2,
        )
    else:
        layers.insert(0, num_features)

    architecture = []
    for idx in range(len(layers) - 1):
        weigths = weight_prior.sample((layers[idx], layers[idx + 1]))
        biases = bias_prior.sample((layers[idx + 1]))
        # weigths = tf.zeros((layers[idx], layers[idx + 1]))
        # biases = tf.zeros((layers[idx + 1]))
        architecture.extend((weigths, biases))
    return architecture

initial_state = get_initial_state(weight_prior, bias_prior, 4)

print(len(state))
print(state[0].shape)

In [None]:
import gpflow
def plotkernelsample(k, ax, xmin=-15, xmax=15):
    xx = np.linspace(xmin, xmax, 100)[:, None]
    K = k(xx)
    ax.plot(xx, np.random.multivariate_normal(np.zeros(100), K, 3).T)
    ax.set_title(k.__class__.__name__)


np.random.seed(27)
f, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True, sharey=True)
plotkernelsample(gpflow.kernels.RBF(), axes[0])
plotkernelsample(gpflow.kernels.ArcCosine(0, weight_variances=1.0, bias_variance=0.5), axes[1])


## Sample from F kernel

In [None]:
m_observed, f_observed, t = load_covid()

m_df, m_observed = m_observed 
f_df, f_observed = f_observed
# Shape of m_observed = (replicates, genes, times)
m_observed = m_observed
f_observed = f_observed

num_genes = m_observed.shape[0]
τ, common_indices = discretise(t, num_disc=10)
N_p = τ.shape[0]
N_m = m_observed.shape[1]

data = (m_observed, f_observed)
time = (t, τ, tf.constant(common_indices))

data = DataHolder(data, None, time)
N_p = τ.shape[0]

opt = Options(preprocessing_variance=False, 
              tf_mrna_present=True, 
              delays=False, 
              initial_step_sizes={'nuts': 0.00005, 'fbar': 0.01},
              kernel='rbf')

model = TranscriptionMixedSampler(data, opt)
# np.set_printoptions(formatter={'float': lambda x: "{0:0.5f}".format(x)})
print(N_p)

In [None]:
X = τ.reshape((-1,1))
X_norm = np.sum(X ** 2, axis = -1)
def kf(v, l):
    return v * np.exp(-(f64(1)/l) * (X_norm[:,None] + X_norm[None,:] - 2 * np.dot(X, X.T)))
K = kf(1, 10)
plt.imshow(K)
iK = tf.linalg.inv(K+tf.linalg.diag(1*np.ones(N_p, dtype='float64')))
tf.linalg.cholesky(iK)

def add_diag(A, B):
    C= A + tf.linalg.diag(tf.linalg.diag_part(B))
    return C


In [None]:
kernel = model.kernel_selector()

@tf.function
def run():
    step_size = 0.1 *tf.ones((N_p), dtype='float64')

    current_state =  0*tf.ones((1, 2, N_p), dtype='float64')

    S = tf.linalg.diag(step_size)

    kernel_params = (f64(1)*tf.ones(2, dtype='float64'), f64(5)*tf.ones(2, dtype='float64'))
    _, K = kernel(*kernel_params)#kf(*kernel_params)

    K = K+tf.linalg.diag(1e-7*tf.ones(N_p, dtype='float64'))
    print(K.shape)
    # Propose new params
    v = model.kernel_selector.proposal(0, tf.ones(2, dtype='float64')).sample()
    l2 = model.kernel_selector.proposal(1, 2*tf.ones(2, dtype='float64')).sample()
    m_, K_ = kernel(v, l2)
    K_ = K_+tf.linalg.diag(1e-7*tf.ones(N_p, dtype='float64'))

    # current_state = tf.zeros(N_p, dtype='float64')#tfd.MultivariateNormalFullCovariance(0, K).sample()
    new_state = tf.identity(current_state)

    # plt.plot(current_state)
    # plt.figure()
    fbar = current_state
    fstar = tf.zeros_like(fbar)
    for r in range(1):
        fbar = new_state[r]

        iK = tf.linalg.inv(K)
        iK_ = tf.linalg.inv(K_)

        U_invR = tf.linalg.cholesky(add_diag(iK, 1/S))
        U_invR = tf.transpose(U_invR, [0, 2, 1])
        U_invR_ = jitter_cholesky(add_diag(iK_, 1/S))
        U_invR_ = tf.transpose(U_invR_, [0, 2, 1])

        gg = tfd.MultivariateNormalDiag(fbar, step_size).sample()
        print('gg', gg.shape)
        print('U_invR', U_invR.shape)

        Sinv_g = gg / step_size

        # f = tf.zeros((self.num_tfs, 1), dtype='float64')
        nu = tf.linalg.matvec(U_invR, fbar) - tf.squeeze(tf.linalg.solve(tf.transpose(U_invR, [0, 2, 1]), tf.expand_dims(Sinv_g, -1)), -1)
        f = tf.linalg.solve(U_invR_, tf.expand_dims(nu, -1)) + tf.linalg.cholesky_solve(tf.transpose(U_invR_, [0, 2, 1]), tf.expand_dims(Sinv_g, -1))
        f = tf.squeeze(f, -1)
            # mask = np.zeros((self.num_tfs, 1), dtype='float64')
            # mask[i] = 1
            # f = (1-mask) * f + mask * f_i

        mask = np.zeros((3, 1, 1), dtype='float64')
        mask[r] = 1
        new_state = (1-mask) * new_state + mask * f
    hyp = [v, l2]
    for i in range(2):
        tf.print(f.shape)
        prob = tf.reduce_sum(new_state[i])
        other_prob = tf.reduce_sum(nu)
        tf.print(prob, other_prob)
        is_accepted = tf.less(other_prob, prob)
        return_value = tf.zeros((1, 2, N_p), dtype='float64')
        tf.print(is_accepted)
        is_accepted = tf.random.uniform((1,), dtype='float64') < tf.math.minimum(f64(1), prob)
        tf.print('here', is_accepted[0])
        if not is_accepted[0]:
            hyp[0] = f64(3)
            tf.print('hello')
            mask = np.zeros((1, 2, 1), dtype='float64')
            mask[:, i] = 1
            test_state = (1-mask) * current_state + mask * new_state
            return_value = tf.ones(10, dtype='float64')
        tf.print(return_value)
run()

plt.figure()
print(new_state.shape)
plt.plot(τ, new_state[0, 0])
plt.figure()
plt.plot(τ, new_state[0, 1])


In [None]:
hyp_mask = np.zeros((2,), dtype='float64')
hyp_mask[1] =1
(1-hyp_mask) * v

In [None]:
step_size = 1 *tf.ones((155), dtype='float64')

# Untransformed tf mRNA vectors F (Step 1)
old_probs = list()
new_state = tf.ones((1, 1, 155), dtype='float64')
current_state = 0.1*tf.ones((1, 1, 155), dtype='float64')

S = tf.linalg.diag(step_size)

kernel_params = (f64(10), f64(5))
m, K = model.kernel_selector()(*kernel_params)
for r in range(1):
    # Gibbs step
    fbar = current_state[r]
    z_i = tfd.MultivariateNormalDiag(fbar, step_size).sample()
    fstar = tf.zeros_like(fbar)

    for i in range(1):
        invKsigmaK = tf.matmul(tf.linalg.inv(K[i]+tf.linalg.diag(step_size)), K[i]) # (C_i + hI)C_i
        L = jitter_cholesky(K[i]-tf.matmul(K[i], invKsigmaK))
        c_mu = tf.matmul(z_i[i, None], invKsigmaK)
        nu = tf.random.normal((1, L.shape[0]), dtype='float64')
        fstar_i = tf.linalg.matvec(L, nu) + c_mu # 0.5 1.5
        mask = np.zeros((1, 1), dtype='float64')
        mask[i] = 1
        fstar = (1-mask) * fstar + mask * fstar_i

    mask = np.zeros((1, 1, 1), dtype='float64')
    mask[r] = 1
    test_state = (1-mask) * new_state + mask * fstar
    plt.plot(inverse_positivity(fstar[0]))



### LogisticNormal distribution

In [None]:
from reggae.utilities import LogisticNormal
class LogisticNormal():
    def __init__(self, a, b, loc=f64(0), scale=f64(1), allow_nan_stats=True):
        self.a = a
        self.b = b
        self.dist = tfd.LogitNormal(loc, scale, allow_nan_stats=allow_nan_stats)
#         super().__init__(loc, scale, allow_nan_stats=allow_nan_stats)
    def log_prob(self, x):
        x = (x-self.a)/(self.b-self.a)
        log_prob = self.dist.log_prob(x)
        log_prob = tf.where(
            tf.math.is_nan(log_prob),
            -1e2*tf.ones([], log_prob.dtype),
            log_prob)

        return log_prob
    def prob(self, x):
        x = (x-self.a)/(self.b-self.a)
        log_prob = self.dist.prob(x)
        log_prob = tf.where(
            tf.math.is_nan(log_prob),
            -1e2*tf.ones([], log_prob.dtype),
            log_prob)

        return log_prob
x = np.linspace(-2, 2, 100)
# x = logit(x)
dist = LogisticNormal(-1, 1)
y = dist.prob(x)
plt.ylim(-4, 4)
plt.plot(x, y)

### Inverse transform sampling

In [None]:
from scipy.stats import norm
import numpy as np
from matplotlib import pyplot as plt

x = np.linspace(-3, 3, 100)
y = norm.cdf(x)
u = 0.68

fig = plt.figure()
plt.plot(x, y)
plt.scatter(0.47, 0.68, color='red', marker='.', s=120)
plt.axhline(u, color='grey')
plt.xlabel('x')
plt.ylabel('CDF')
# plt.xticks(x[np.linspace(-3, 3, 7)])
# fig.axes[0].set_xticklabels(np.arange(-3, 4))



In [None]:
y = y.reshape(-1, 1)
k = gpflow.kernels.Matern52()
m = gpflow.models.GPR(data=(X, y), kernel=k, mean_function=None)
opt = gpflow.optimizers.Scipy()
def objective_closure():
    return - m.log_marginal_likelihood()

opt_logs = opt.minimize(objective_closure,
                        m.trainable_variables,
                        options=dict(maxiter=20))



In [None]:
m.predict_f(X)

In [None]:
print(gpr.kernel_(X).shape, gpr.kernel_)#K = self.kernel_(self.X_train_)
print (X.shape, gpr.X_train_.shape)

dists = pdist(X / 1, metric='sqeuclidean')
K = np.exp(-.5 * dists)
print(K.shape)
# convert from upper-triangular matrix to square matrix
K = squareform(K)
print(K.shape)
np.fill_diagonal(K, 1)


In [None]:
t = tf.cast(tf.range(4)*2, tf.float64)
t_dist = tf.expand_dims(t, axis=0) - tf.expand_dims(t, axis=1)
t_ = tf.transpose(tf.reshape(tf.tile(t, [4]), [ 4, tf.shape(t)[0]]))
t_prime = tf.reshape(tf.tile(t, [4]), [ 4, tf.shape(t)[0]])

D = tf.ones(4)

print(t)
print(t_dist)
print(t_)
print(t_prime-t_)

m = [1,2,3,4]
# Compute m[i] * (t'-t) + t' for all i, t, t'
result = np.zeros((4, 4))
for i, t_ in enumerate(t):
    for j, t_prime in enumerate(t):
        for mk in m:
            result[i, j] += mk * (t_prime - t_) + t_prime
    
print('Result 1:')
print(result)
print()
print('Result 2:')

add = tf.transpose(tf.reshape(tf.tile(t, [4]), [ 4, tf.shape(t)[0]]))
result = np.zeros((4, 4))
result += m*t_dist + add
print(result)

In [None]:
import numpy as np
t = 3
tprime = 3
l = 2
np.exp(-((t-tprime)**2)/(l**2))

times   =   np.array([2.0,4.0, 6.0, 8.0])[:,None]
times.shape
times[2]

In [None]:
num_times=3
num_genes=2
from gpflow.utilities import print_summary, positive
from tensorflow_probability import bijectors as tfb
from tensorflow import math as tm
import math
PI = tf.constant(math.pi, dtype='float64')

class Kern(gpflow.kernels.Kernel):
    
    def __init__(self):
        super().__init__(active_dims=[0])
        self.lengthscale = gpflow.Parameter(1.0, transform=positive())
#         B = tf.ones(5)
#         self.B = gpflow.Parameter(B)
#         self.D = gpflow.Parameter(np.random.uniform(0.5, 1, 5), transform=positive())
#         S = tf.ones(5)
        affine = tfb.AffineScalar(shift=tf.cast(0., tf.float64),
                                  scale=tf.cast(3.-0., tf.float64))
        sigmoid = tfb.Sigmoid()
        logistic = tfb.Chain([affine, sigmoid])

        self.D = [gpflow.Parameter(0.7, transform=logistic, dtype=tf.float64) for _ in range(num_genes)]
        self.D[0].trainable = False
        self.D[0].assign(0.8)

        self.S = [gpflow.Parameter(0.7, transform=logistic, dtype=tf.float64) for _ in range(num_genes)]
        self.S[0].trainable = False
        self.S[0].assign(1)

    def K(self, X, X2=None):
        block_size = num_times
        if X2 is None:
            shape = [X.shape[0],X.shape[0]]
            K_xx = tf.zeros(shape, dtype='float64')
            for j in range(num_genes):
                for k in range(num_genes):
                    mask = np.ones(shape)
                    other = np.zeros(shape)
                    mask[j*block_size:(j+1)*block_size, 
                         k*block_size:(k+1)*block_size] = 0

                    pad_top = j*block_size
                    pad_left = k*block_size
                    pad_right = 0 if k == num_genes-1 else shape[0]-block_size-pad_left
                    pad_bottom = 0 if j == num_genes-1 else shape[0]-block_size-pad_top
                    other = tf.pad(self.k_xx(j, k),
                                   tf.constant([
                                       [pad_top,pad_bottom],
                                       [pad_left,pad_right]
                                   ]), 'CONSTANT'
                                  )
    #                     print(j, k, pad_right, pad_bottom, other.shape)
                    K_xx = K_xx * mask + other * (1 - mask)


            return K_xx
        else:
            print('K not none K_xf\n')
            shape = [X.shape[0],num_times]
            K_xf = tf.zeros(shape, dtype='float64')
            for j in range(num_genes):
                mask = np.ones(shape)
                other = np.zeros(shape)
                mask[j*block_size:(j+1)*block_size] = 0
                other[j*block_size:(j+1)*block_size] = self.k_xf(j, X)

                K_xf = K_xf * mask + other * (1-mask) 
            return K_xf


    def gamma(self, k):
        return self.D[k]*self.lengthscale/2

    def h(self, k, j, tprime, t):
        l = self.lengthscale

        multiplier = tm.exp(self.gamma(k))**2 / (self.D[j]+self.D[k])
        first_erf_term = tm.erf((tprime-t)/l - self.gamma(k)) + tm.erf(t/l + self.gamma(k))
        second_erf_term = tm.erf(tprime/l - self.gamma(k)) + tm.erf(self.gamma(k))
        return multiplier * (tm.exp(-self.D[k]*(tprime-t)) * first_erf_term - \
                             tm.exp(-self.D[k]*tprime-self.D[j]) * second_erf_term)


    def h_quick(self, k, j, primefirst=True):
        l = self.lengthscale
        t_prime, t_, t_dist = self.get_distance_matrix(primefirst=primefirst, size=num_times)
            
        multiplier = tm.exp(self.gamma(k))**2 / (self.D[j]+self.D[k])
        first_erf_term = tm.erf(t_dist/l - self.gamma(k)) + tm.erf(t_/l + self.gamma(k))
        second_erf_term = tm.erf(t_prime/l - self.gamma(k)) + tm.erf(self.gamma(k))
        
        return multiplier * (tf.multiply(tm.exp(-tm.multiply(self.D[k],t_dist)) , first_erf_term) - \
                             tf.multiply(tm.exp(-tm.multiply(self.D[k],t_prime)-self.D[j]) , second_erf_term))
    

    def k_xx(self, j, k):
        '''k_xx(t, tprime)'''
        mult = self.S[j]*self.S[k]*self.lengthscale*0.5*tm.sqrt(PI)
        return mult*(self.h_quick(k, j) + self.h_quick(j, k, primefirst=False))

    def k_xx_(self, j, k):
        '''k_xx(t, tprime)'''
        k_xx = np.zeros((num_times, num_times))
        for tprime in range(num_times):
            for t in range(num_times):
                mult = self.S[j]*self.S[k]*self.lengthscale*0.5*tm.sqrt(PI)
                k_xx[t,tprime] = mult*(self.h(k, j, tprime*2, t*2) + self.h(j, k, t*2, tprime*2))
        print(k_xx)
        return k_xx

    def get_distance_matrix(self, primefirst=True, size=7):
        t = tf.cast(tf.range(size)*2, tf.float64)
        t_ = tf.transpose(tf.reshape(tf.tile(t, [size]), [ size, tf.shape(t)[0]]))
        t_prime = tf.reshape(tf.tile(t, [size]), [ size, tf.shape(t)[0]])
        if not primefirst:
            t_prime = tf.transpose(tf.reshape(tf.tile(t, [size]), [ size, tf.shape(t)[0]]))
            t_ = tf.reshape(tf.tile(t, [size]), [ size, tf.shape(t)[0]])

        return t_prime, t_, t_prime-t_


In [None]:
X = np.arange(num_times, dtype='float64')*2
X = np.c_[[X for _ in range(num_genes)]].reshape(-1)
print(X)
k = Kern()
display(k.K(X))

In [None]:
array([[0.        , 0.19252085, 0.03952372, 0.        , 0.15377981, 0.03846892],
       [0.19252085, 1.09337319, 0.3065648 , 0.13141296, 0.78650799, 0.24865445],
       [0.03952372, 0.3065648 , 1.34475164, 0.02697853, 0.21804878, 0.98498977],
       [0.        , 0.13141296, 0.02697853, 0.        , 0.10543729, 0.02637576],
       [0.15377981, 0.78650799, 0.21804878, 0.10543729, 0.56618178, 0.17712899],
       [0.03846892, 0.24865445, 0.98498977, 0.02637576, 0.17712899, 0.72327064]])>
6, 6