## Repressilator: synthetic biological clock

### Importing packages

In [1]:
import numpy as np

import pyro.contrib.gp as gp
from SBIRR.SDE_solver import solve_sde_RK
from SBIRR.utils import plot_trajectories_2
import matplotlib.pyplot as plt
from SBIRR.Schrodinger import gpIPFP, SchrodingerTrain, gpdrift, nndrift
from SBIRR.GP import spatMultitaskGPModel
from SBIRR.gradient_field_NN import train_nn_gradient, GradFieldMLP2
import torch
import torch.nn as nn
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import math
import time
import csv

from SBIRR.experiments_helper import get_settings # in ../package/SBIRR/experiments_helper.py

## set up and load data

In [5]:
task_name = "repres" 
# load all data
Xs_val = np.load("./data/"+task_name+"_data.npy", allow_pickle = True)
Xs_val = [torch.tensor(X).to(device) for X in Xs_val]
n_snap = len(Xs_val)
Xs = [Xs_val[2 * i] for i in range(int(n_snap/2)+1)]
N_steps = len(Xs)
# take SDE seetings and reference family
sigma, dt, N, dts, oursnndrift, vanillanndrift = get_settings(task_name, N_steps)

In [6]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fd2aac64b30>

In [12]:
def gpIPFPmaker(ref_drift = None, N = N, device = device):
        return gpIPFP(
                  ref_drift = ref_drift, 
                  N = N, device = device)
oursSchrodinger_train = SchrodingerTrain(Xs, dts, sigma)

# IRR
start_time = time.time()
_, backnforth = oursSchrodinger_train.iter_drift_fit(gpIPFPmaker,
                                                        oursnndrift,
                                                        ipfpiter = 10, 
                                                        iteration = 10)
end_time = time.time()
elapsed_time_irr = end_time - start_time

In [13]:
# Naive
start_time = time.time()
_, backnforth_vSB = oursSchrodinger_train.iter_drift_fit(gpIPFPmaker,
                                                        vanillanndrift,
                                                        ipfpiter = 10, 
                                                        iteration = 1)
end_time = time.time()
elapsed_time_naive = end_time - start_time

### earth mover distance

In [None]:
from TrajectoryNet.optimal_transport.emd import earth_mover_distance
i = 1
[earth_mover_distance(Xs_val[2 * i + 1], backnforth[:,int((backnforth.shape[1]-1)/(len(Xs_val)-1)) * (2 * i + 1),:-1]), \
earth_mover_distance(Xs_val[2 * i + 1], backnforth_vSB[:,int((backnforth.shape[1]-1)/(len(Xs_val)-1)) * (2 * i + 1),:-1])]

### visualizing

Visualizing ours

In [None]:
plt.rcParams['figure.figsize'] = [15, 10]
fig, axs = plt.subplots(1, 1)
axs.axis('off')
axs = fig.add_subplot(111, projection='3d')

# Define colormap
cmap = cm.get_cmap('coolwarm')

# plot the sample points
for i in range(len(Xs_val)):
    color = cmap(i/len(Xs_val))
    axs.scatter(Xs_val[i][:,0], Xs_val[i][:,1], Xs_val[i][:,2], label='validation samples at t='+str(i), s=50, c=color)

# add general legend to the figure
fig.legend(loc='center right', fontsize=15)

for i in range(len(backnforth)):
    axs.plot(backnforth[i,:,0], 
             backnforth[i,:,1], 
             backnforth[i,:,2], label='interpolation sb-sde at t='+str(i), 
             color='black', alpha=0.1)

# add title to each subplot
axs.set_title(f'Validation samples and interpolated trajectories sb-sde', fontsize=20)

# Add axis labels to each subplot
axs.set_xlabel('Gene 1', labelpad=10, fontsize=20)
axs.set_ylabel('Gene 2', labelpad=10, fontsize=20)
axs.set_zlabel('Gene 3', labelpad=10, fontsize=20)

# Rotate the initial angle
axs.view_init(-140, 60)

# show the figure
plt.show()

In [None]:
plt.rcParams['figure.figsize'] = [15, 10]
fig, axs = plt.subplots(1, 1)
axs.axis('off')
axs = fig.add_subplot(111, projection='3d')

# Define colormap
cmap = cm.get_cmap('coolwarm')

# plot the sample points
for i in range(len(Xs_val)):
    color = cmap(i/len(Xs_val))
    axs.scatter(Xs_val[i][:,0], Xs_val[i][:,1], Xs_val[i][:,2], s=50, c=color)

# add general legend to the figure
fig.legend(loc='center right', fontsize=15)

for i in range(len(backnforth_vSB)):
    axs.plot(backnforth_vSB[i,:,0], 
             backnforth_vSB[i,:,1], 
             backnforth_vSB[i,:,2], label='interpolation vanilla-sb at t='+str(i), 
             color='black', alpha=0.1)


# Add axis labels to each subplot
axs.set_xlabel('Gene 1', labelpad=10, fontsize=20)
axs.set_ylabel('Gene 2', labelpad=10, fontsize=20)
axs.set_zlabel('Gene 3', labelpad=10, fontsize=20)

# Rotate the initial angle
axs.view_init(-140, 60)

# show the figure
plt.show()