In [1]:
import pandas as pd
from CRN_Simulation_Inference.RB_method_for_model_identification.RBForModelIdentification import RBForModelIdentification
import numpy as np
import matplotlib.pyplot as plt  
from tqdm import tqdm
import torch
from ElenaDataManagement import count_samples_for_supersampling, get_X_Y_sampling_times, sample_trajectory_on_times, CRN_simulations_to_dataloaders, run_SSA_for_filtering


# Create the CRN model

sigma = 0.1

species_names = ['mRNA']
stoichiometric_matrix = [[1, -1]]
parameters_names = ['k','g']
reaction_names = ['mRNA prod.', 'mRNA deg.']
propensities = [
    lambda k: k,
    lambda g, mRNA: g*mRNA,
]

range_of_species = \
    pd.DataFrame([[0, 120]], index=species_names, columns=['min', 'max'])
range_of_parameters= \
    pd.DataFrame([[0, 150], [0, 150]],index=parameters_names,columns=['min', 'max'])
discretization_size_parameters = \
    pd.DataFrame([100, 100], index=parameters_names) #index=parameters_names

# The observation related information
# h_function = [
#     lambda Protein: Protein
# ]
h_function = [
    lambda mRNA: mRNA # np.where(mRNA > 5, mRNA, 0)
]
observation_noise_intensity = [
    lambda : sigma
]
#observation_noise_intensity = {'sigma1': 0.1}

maximum_size_of_each_follower_subsystem = 20000 #800 # 1000


MI = RBForModelIdentification(
    species_names=species_names,
    stoichiometric_matrix=stoichiometric_matrix,
    parameters_names=parameters_names,
    reaction_names=reaction_names,
    propensities=propensities,
    range_of_species=range_of_species,
    range_of_parameters=range_of_parameters,
    observation_noise_intensity=observation_noise_intensity,
    discretization_size_parameters=discretization_size_parameters,
    h_function=h_function,
    maximum_size_of_each_follower_subsystem=maximum_size_of_each_follower_subsystem)

In [2]:
# this can be done quickly by running 
import math
tf = 1.
# Get a trajectory of the system
parameter_values_sets = []

parameter_values_sets.append({'k': 1, 'g': 1})

parameter_set_index = 0
parameter_values = parameter_values_sets[parameter_set_index]
#initial_state = {'M': 0, 'P': 0}
initial_states = []

for i in range(3):
    initial_states.append({'mRNA': i})

n_samples = 1000  # p in the paper
batch_size = 1000 #int(math.sqrt(n_samples))

n_Y_measurements = 2 # n in the paper
n_X_measurements_between_Y_measurements = 100 # m_bar-2 in the paper

dataset = run_SSA_for_filtering(MI, initial_states, parameter_values, tf, n_Y_measurements, n_X_measurements_between_Y_measurements, n_samples=n_samples)
train_dataset, val_dataset, Xtimes, Ytimes = CRN_simulations_to_dataloaders(dataset, batch_size, test_split=0.2)

print("--- check batch sizes ---")
print("training : ", [x.shape for x in next(iter(train_dataset))])
print("validation : ", [x.shape for x in next(iter(val_dataset))])

100%|██████████| 1000/1000 [00:00<00:00, 1107.12it/s]


--- check batch sizes ---
training :  [torch.Size([800, 102, 1]), torch.Size([800, 2, 1]), torch.Size([800, 102, 2])]
validation :  [torch.Size([200, 102, 1]), torch.Size([200, 2, 1]), torch.Size([200, 102, 2])]


In [None]:
from OtherNetworks import MLP, RNNEncoder
from DeepCME import FilteringDeepCME, TemporalFeatureExtractor
from copy import deepcopy

def deepcopy_flatten(x):
    other = deepcopy(x)
    if type(x) == RNNEncoder:
        other.RNN.flatten_parameters()
    return other

r = 1                          # number of temporal features
n = MI.get_number_of_species() # number of species
O = 1                          # number of observed species

g_functions = [
    lambda x: x[:, 0],
    lambda x: x[:, 0]*0. + 1., # constant function
]

R = len(g_functions)
K = MI.get_number_of_reactions()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

temporal_feature_extractor = TemporalFeatureExtractor(r, tf, device=device)


x_encoder = None
y_encoder = None
backbone = None
baseline_net = None


tau_list = torch.linspace(0, tf, n_X_measurements_between_Y_measurements+2)
measurement_times = torch.linspace(0, tf, n_Y_measurements)
tau_times = torch.linspace(0, tf/(n_Y_measurements-1), n_X_measurements_between_Y_measurements+2)

# likelihoods and h functions
from ElenaLosses import likelihood_GaussianNoise_vmap_compatible

def h_fun(x):
    #out = x[:, 2:3]
    out = x # for BD
    return out

sigma_prior = 0.0001
likelihood = likelihood_GaussianNoise_vmap_compatible
likelihood_parameters = {'noise_covariance': torch.tensor([[sigma_prior]]).to(device)}
likelihood_parameters["noise_covariance_determinant"] = torch.det(likelihood_parameters["noise_covariance"])
likelihood_parameters["noise_covariance_inverse"] = torch.inverse(likelihood_parameters["noise_covariance"])


chain = []
for i in range(n_Y_measurements-2,-1,-1):
    print(i)
    if i == n_Y_measurements-2:
        chain.append(FilteringDeepCME(None, None, None, None, tau_times, measurement_times, g_functions, None, R, K, O, position_in_the_chain=i, n_NN_in_chain=n_Y_measurements-1, device=device, h_transform=h_fun, likelihood=likelihood, likelihood_parameters=likelihood_parameters, next_in_chain=None, use_exact_poisson_for_debugging=True))
    else:
        chain.append(FilteringDeepCME(None, None, None, None, tau_times, measurement_times, g_functions, None, R, K, O, position_in_the_chain=i, n_NN_in_chain=n_Y_measurements-1, device=device, h_transform=h_fun, likelihood=likelihood, likelihood_parameters=likelihood_parameters, next_in_chain=chain[-1], use_exact_poisson_for_debugging=True))


0


In [4]:
# run with the poisson martingale

def test_loop(model, train_loader):
    training_loss = []
    
    for i, (X, Y, R) in enumerate(train_loader):
        X = X.to(model.device)
        Y = Y.to(model.device)
        R = R.to(model.device)
        def closure():
            loss = model.poisson_loss(X, Y, R)
            training_loss.append((len(train_loader) + i, loss.item()))
            return loss
        closure()
        print(f'Testing, batch {i}, Loss {training_loss[-1][1]}')

    return training_loss

def chain_training_loop(chain, train_loader):
    for i, model in enumerate(chain):
        print(f"+++++ Testing model {model.position_in_the_chain} +++++")
        #optimizer = torch.optim.LBFGS(model.parameters(), lr=lr)
        training_loss = test_loop(model, train_loader)
        model.freeze()


chain_training_loop(chain, train_dataset)

+++++ Testing model 0 +++++
torch.Size([800, 1]) torch.Size([800, 2])
Testing, batch 0, Loss 0.0008820827933959663


In [5]:
results = chain[-1].analytical_martingale(torch.tensor(0.), torch.tensor(1.), torch.tensor([[6.]]).to(device), torch.tensor([[[3.]]]).to(device))
print(results)
print(results[0,0,0]/results[0,0,1])

tensor([[[1.1793, 0.3931]]], device='cuda:0')
tensor(3., device='cuda:0')


  out = torch.vmap(lambda hx, Sigma, s_det, Sigma_inv, Y : torch.exp(-1/2 * (Y - hx) @ Sigma_inv @ (Y - hx).T), in_dims=(0, None, None, None, None))(hx, Sigma, s_det, Sigma_inv, Y)
