Simulate data from the cell2fate model:

In [1]:
from typing import Optional
import numpy as np
import pandas as pd
import pyro
import pyro.distributions as dist
import torch
from pyro.nn import PyroModule
from scvi import REGISTRY_KEYS
import pandas as pd
from scvi.nn import one_hot
from cell2fate.utils import G_a, G_b, mu_mRNA_continousAlpha_globalTime_twoStates
from pyro.infer import config_enumerate
from pyro.ops.indexing import Vindex
from torch.distributions import constraints

Global seed set to 0


In [2]:
n_obs = 3000
n_vars = 3000
n_batch = 1
n_extra_categoricals=None
n_modules = 10
stochastic_v_ag_hyp_prior={"alpha": 6.0, "beta": 3.0}
factor_prior={"rate": 1.0, "alpha": 1.0, "states_per_gene": 3.0}
t_switch_alpha_prior = {"mean": 1000., "alpha": 1000.}
splicing_rate_hyp_prior={"mean": 1.0, "alpha": 5.0,
                        "mean_hyp_alpha": 10., "alpha_hyp_alpha": 20.}
degredation_rate_hyp_prior={"mean": 1.0, "alpha": 5.0,
                        "mean_hyp_alpha": 10., "alpha_hyp_alpha": 20.}
activation_rate_hyp_prior={"mean_hyp_prior_mean": 2., "mean_hyp_prior_sd": 0.33,
                            "sd_hyp_prior_mean": 0.33, "sd_hyp_prior_sd": 0.1}
s_overdispersion_factor_hyp_prior={'alpha_mean': 100., 'beta_mean': 1.,
                                   'alpha_sd': 1., 'beta_sd': 0.1}
detection_hyp_prior={"alpha": 10.0, "mean_alpha": 1.0, "mean_beta": 1.0}
detection_i_prior={"mean": 1, "alpha": 100}
detection_gi_prior={"mean": 1, "alpha": 200}
gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}
gene_add_mean_hyp_prior={"alpha": 1.0, "beta": 100.0}
Tmax_prior={"mean": 50., "sd": 50.}

Use parameters from dentate gyrus datasets as starting point:

In [3]:
import pickle
with open('/nfs/team283/aa16/data/fate_benchmarking/' + 'DentateGyrusPosterior.pickle', 'rb') as handle:
    posterior = pickle.load(handle)

In [4]:
n_modules = posterior['A_mgON'].shape[0]
n_obs = np.shape(posterior['t_c'])[0]
n_vars = np.shape(posterior['gamma_g'])[1]
zeros = torch.zeros(n_obs, n_vars)
zero = torch.tensor(0.)
ones = torch.ones((1, 1))
batch_size = np.shape(posterior['t_c'])[0]
batch_index = torch.zeros(n_obs,1)
obs2sample = one_hot(batch_index, n_batch)        
# ===================== Kinetic Rates ======================= #
# Splicing rate:
splicing_alpha = posterior['splicing_alpha']
splicing_mean = posterior['splicing_mean']
beta_g = torch.tensor(posterior['beta_g'])
# Degredation rate:
degredation_alpha = posterior['degredation_alpha']
degredation_mean = posterior['degredation_mean']
gamma_g = torch.tensor(posterior['gamma_g'])
# Transcription rate contribution of each module:
factor_level_g = posterior['factor_level_g']
g_fg = posterior['g_fg']
A_mgON = torch.tensor(posterior['A_mgON'])
A_mgOFF = 10**(-5)    
# Activation and Deactivation rate:
lam_mu = posterior['lam_mu']
lam_sd = posterior['lam_sd']
lam_m_mu = posterior['lam_m_mu']
lam_mi = torch.tensor(posterior['lam_mi'])
# =====================Time======================= #
# Global time for each cell:
T_max = posterior['Tmax']
t_c_loc = posterior['t_c_loc']
t_c_scale = posterior['t_c_scale']
t_c = posterior['t_c']
T_c = torch.tensor(posterior['T_c'])
# Global switch on time for each gene:
t_delta = torch.tensor(posterior['t_delta'])
t_mON = torch.cumsum(torch.concat([zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0)
T_mON = torch.tensor(posterior['T_mON'])
# Global switch off time for each gene:
t_mOFF = posterior['t_mOFF']
T_mOFF = torch.tensor(posterior['T_mOFF'])
# =========== Mean expression according to RNAvelocity model ======================= #
mu_total = torch.stack([zeros, zeros], axis = -1)
for m in range(n_modules):
    mu_total += mu_mRNA_continousAlpha_globalTime_twoStates(
        A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], zeros)
mu_expression = pyro.deterministic('mu_expression', mu_total)
# =============Detection efficiency of spliced and unspliced counts =============== #
# Cell specific relative detection efficiency with hierarchical prior across batches:
detection_mean_y_e = posterior['detection_mean_y_e']
detection_hyp_prior_alpha = posterior['detection_hyp_prior_alpha']
beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e)
detection_y_c = torch.tensor(posterior['detection_y_c'])    
detection_y_i = torch.tensor(posterior['detection_y_i'])
detection_y_gi = torch.tensor(posterior['detection_y_gi'])
s_g_gene_add_alpha_hyp = posterior['s_g_gene_add_alpha_hyp']
s_g_gene_add_mean = posterior['s_g_gene_add_mean']
s_g_gene_add_alpha_e_inv = posterior['s_g_gene_add_alpha_e_inv']
s_g_gene_add = torch.tensor(posterior['s_g_gene_add'])
stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
stochastic_v_ag_inv = torch.tensor(posterior['stochastic_v_ag_inv'])
stochastic_v_ag = (ones / stochastic_v_ag_inv.pow(2))       
# =====================Expected expression ======================= #
# biological expression
mu = (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \
detection_y_c * detection_y_i * detection_y_gi
# =====================DATA likelihood ======================= #
data = pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag,
               rate= stochastic_v_ag / mu))

Multiply splicing rates with various factors and save results:

In [7]:
multiplication_factor = (1., 0.25, 0.5, 2., 4.)
for i in range(len(multiplication_factor)):
    print(i)
    n_obs = np.shape(posterior['t_c'])[0]
    n_vars = np.shape(posterior['gamma_g'])[1]
    zeros = torch.zeros(n_obs, n_vars)
    zero = torch.tensor(0.)
    ones = torch.ones((1, 1))
    batch_size = np.shape(posterior['t_c'])[0]
    batch_index = torch.zeros(n_obs,1)
    obs2sample = one_hot(batch_index, n_batch)        
    # ===================== Kinetic Rates ======================= #
    # Splicing rate:
    splicing_alpha = posterior['splicing_alpha']
    splicing_mean = posterior['splicing_mean']
    beta_g = torch.tensor(posterior['beta_g'])*multiplication_factor[i]
    # Degredation rate:
    degredation_alpha = posterior['degredation_alpha']
    degredation_mean = posterior['degredation_mean']
    gamma_g = torch.tensor(posterior['gamma_g'])
    # Transcription rate contribution of each module:
    factor_level_g = posterior['factor_level_g']
    g_fg = posterior['g_fg']
    A_mgON = torch.tensor(posterior['A_mgON'])
    A_mgOFF = 10**(-5)    
    # Activation and Deactivation rate:
    lam_mu = posterior['lam_mu']
    lam_sd = posterior['lam_sd']
    lam_m_mu = posterior['lam_m_mu']
    lam_mi = torch.tensor(posterior['lam_mi'])
    # =====================Time======================= #
    # Global time for each cell:
    T_max = posterior['Tmax']
    t_c_loc = posterior['t_c_loc']
    t_c_scale = posterior['t_c_scale']
    t_c = posterior['t_c']
    T_c = torch.tensor(posterior['T_c'])
    # Global switch on time for each gene:
    t_delta = torch.tensor(posterior['t_delta'])
    t_mON = torch.cumsum(torch.concat([zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0)
    T_mON = torch.tensor(posterior['T_mON'])
    # Global switch off time for each gene:
    t_mOFF = posterior['t_mOFF']
    T_mOFF = torch.tensor(posterior['T_mOFF'])
    # =========== Mean expression according to RNAvelocity model ======================= #
    mu_total = torch.stack([zeros, zeros], axis = -1)
    for m in range(n_modules):
        mu_total += mu_mRNA_continousAlpha_globalTime_twoStates(
            A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], zeros)
    mu_expression = pyro.deterministic('mu_expression', mu_total)
    # =============Detection efficiency of spliced and unspliced counts =============== #
    # Cell specific relative detection efficiency with hierarchical prior across batches:
    detection_mean_y_e = posterior['detection_mean_y_e']
    detection_hyp_prior_alpha = posterior['detection_hyp_prior_alpha']
    beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e)
    detection_y_c = torch.tensor(posterior['detection_y_c'])    
    detection_y_i = torch.tensor(posterior['detection_y_i'])
    detection_y_gi = torch.tensor(posterior['detection_y_gi'])
    s_g_gene_add_alpha_hyp = posterior['s_g_gene_add_alpha_hyp']
    s_g_gene_add_mean = posterior['s_g_gene_add_mean']
    s_g_gene_add_alpha_e_inv = posterior['s_g_gene_add_alpha_e_inv']
    s_g_gene_add = torch.tensor(posterior['s_g_gene_add'])
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_inv = torch.tensor(posterior['stochastic_v_ag_inv'])
    stochastic_v_ag = (ones / stochastic_v_ag_inv.pow(2))       
    # =====================Expected expression ======================= #
    # biological expression
    mu = (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \
    detection_y_c * detection_y_i * detection_y_gi
    # =====================DATA likelihood ======================= #
    data = pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag,
                   rate= stochastic_v_ag / mu))
    filename = '/nfs/team283/aa16/data/fate_benchmarking/SimulatedData/Beta' + str(multiplication_factor[i]) + '.pickle'
    with open(filename, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)    

0
1
2
3
4


Degradation Rate:

In [8]:
multiplication_factor = (1., 0.25, 0.5, 2., 4.)
for i in range(len(multiplication_factor)):
    print(i)
    n_obs = np.shape(posterior['t_c'])[0]
    n_vars = np.shape(posterior['gamma_g'])[1]
    zeros = torch.zeros(n_obs, n_vars)
    zero = torch.tensor(0.)
    ones = torch.ones((1, 1))
    batch_size = np.shape(posterior['t_c'])[0]
    batch_index = torch.zeros(n_obs,1)
    obs2sample = one_hot(batch_index, n_batch)        
    # ===================== Kinetic Rates ======================= #
    # Splicing rate:
    splicing_alpha = posterior['splicing_alpha']
    splicing_mean = posterior['splicing_mean']
    beta_g = torch.tensor(posterior['beta_g'])
    # Degredation rate:
    degredation_alpha = posterior['degredation_alpha']
    degredation_mean = posterior['degredation_mean']
    gamma_g = torch.tensor(posterior['gamma_g'])*multiplication_factor[i]
    # Transcription rate contribution of each module:
    factor_level_g = posterior['factor_level_g']
    g_fg = posterior['g_fg']
    A_mgON = torch.tensor(posterior['A_mgON'])
    A_mgOFF = 10**(-5)    
    # Activation and Deactivation rate:
    lam_mu = posterior['lam_mu']
    lam_sd = posterior['lam_sd']
    lam_m_mu = posterior['lam_m_mu']
    lam_mi = torch.tensor(posterior['lam_mi'])
    # =====================Time======================= #
    # Global time for each cell:
    T_max = posterior['Tmax']
    t_c_loc = posterior['t_c_loc']
    t_c_scale = posterior['t_c_scale']
    t_c = posterior['t_c']
    T_c = torch.tensor(posterior['T_c'])
    # Global switch on time for each gene:
    t_delta = torch.tensor(posterior['t_delta'])
    t_mON = torch.cumsum(torch.concat([zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0)
    T_mON = torch.tensor(posterior['T_mON'])
    # Global switch off time for each gene:
    t_mOFF = posterior['t_mOFF']
    T_mOFF = torch.tensor(posterior['T_mOFF'])
    # =========== Mean expression according to RNAvelocity model ======================= #
    mu_total = torch.stack([zeros, zeros], axis = -1)
    for m in range(n_modules):
        mu_total += mu_mRNA_continousAlpha_globalTime_twoStates(
            A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], zeros)
    mu_expression = pyro.deterministic('mu_expression', mu_total)
    # =============Detection efficiency of spliced and unspliced counts =============== #
    # Cell specific relative detection efficiency with hierarchical prior across batches:
    detection_mean_y_e = posterior['detection_mean_y_e']
    detection_hyp_prior_alpha = posterior['detection_hyp_prior_alpha']
    beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e)
    detection_y_c = torch.tensor(posterior['detection_y_c'])    
    detection_y_i = torch.tensor(posterior['detection_y_i'])
    detection_y_gi = torch.tensor(posterior['detection_y_gi'])
    s_g_gene_add_alpha_hyp = posterior['s_g_gene_add_alpha_hyp']
    s_g_gene_add_mean = posterior['s_g_gene_add_mean']
    s_g_gene_add_alpha_e_inv = posterior['s_g_gene_add_alpha_e_inv']
    s_g_gene_add = torch.tensor(posterior['s_g_gene_add'])
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_inv = torch.tensor(posterior['stochastic_v_ag_inv'])
    stochastic_v_ag = (ones / stochastic_v_ag_inv.pow(2))       
    # =====================Expected expression ======================= #
    # biological expression
    mu = (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \
    detection_y_c * detection_y_i * detection_y_gi
    # =====================DATA likelihood ======================= #
    data = pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag,
                   rate= stochastic_v_ag / mu))
    filename = '/nfs/team283/aa16/data/fate_benchmarking/SimulatedData/Gamma' + str(multiplication_factor[i]) + '.pickle'
    with open(filename, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)    

0
1
2
3
4


Overdispersion

In [12]:
multiplication_factor = (1., 0.25, 0.5, 2., 4.)
for i in range(len(multiplication_factor)):
    print(i)
    n_obs = np.shape(posterior['t_c'])[0]
    n_vars = np.shape(posterior['gamma_g'])[1]
    zeros = torch.zeros(n_obs, n_vars)
    zero = torch.tensor(0.)
    ones = torch.ones((1, 1))
    batch_size = np.shape(posterior['t_c'])[0]
    batch_index = torch.zeros(n_obs,1)
    obs2sample = one_hot(batch_index, n_batch)        
    # ===================== Kinetic Rates ======================= #
    # Splicing rate:
    splicing_alpha = posterior['splicing_alpha']
    splicing_mean = posterior['splicing_mean']
    beta_g = torch.tensor(posterior['beta_g'])
    # Degredation rate:
    degredation_alpha = posterior['degredation_alpha']
    degredation_mean = posterior['degredation_mean']
    gamma_g = torch.tensor(posterior['gamma_g'])
    # Transcription rate contribution of each module:
    factor_level_g = posterior['factor_level_g']
    g_fg = posterior['g_fg']
    A_mgON = torch.tensor(posterior['A_mgON'])
    A_mgOFF = 10**(-5)    
    # Activation and Deactivation rate:
    lam_mu = posterior['lam_mu']
    lam_sd = posterior['lam_sd']
    lam_m_mu = posterior['lam_m_mu']
    lam_mi = torch.tensor(posterior['lam_mi'])
    # =====================Time======================= #
    # Global time for each cell:
    T_max = posterior['Tmax']
    t_c_loc = posterior['t_c_loc']
    t_c_scale = posterior['t_c_scale']
    t_c = posterior['t_c']
    T_c = torch.tensor(posterior['T_c'])
    # Global switch on time for each gene:
    t_delta = torch.tensor(posterior['t_delta'])
    t_mON = torch.cumsum(torch.concat([zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0)
    T_mON = torch.tensor(posterior['T_mON'])
    # Global switch off time for each gene:
    t_mOFF = posterior['t_mOFF']
    T_mOFF = torch.tensor(posterior['T_mOFF'])
    # =========== Mean expression according to RNAvelocity model ======================= #
    mu_total = torch.stack([zeros, zeros], axis = -1)
    for m in range(n_modules):
        mu_total += mu_mRNA_continousAlpha_globalTime_twoStates(
            A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], zeros)
    mu_expression = pyro.deterministic('mu_expression', mu_total)
    # =============Detection efficiency of spliced and unspliced counts =============== #
    # Cell specific relative detection efficiency with hierarchical prior across batches:
    detection_mean_y_e = posterior['detection_mean_y_e']
    detection_hyp_prior_alpha = posterior['detection_hyp_prior_alpha']
    beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e)
    detection_y_c = torch.tensor(posterior['detection_y_c'])    
    detection_y_i = torch.tensor(posterior['detection_y_i'])
    detection_y_gi = torch.tensor(posterior['detection_y_gi'])
    s_g_gene_add_alpha_hyp = posterior['s_g_gene_add_alpha_hyp']
    s_g_gene_add_mean = posterior['s_g_gene_add_mean']
    s_g_gene_add_alpha_e_inv = posterior['s_g_gene_add_alpha_e_inv']
    s_g_gene_add = torch.tensor(posterior['s_g_gene_add'])
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_inv = torch.tensor(posterior['stochastic_v_ag_inv'])
    stochastic_v_ag = (ones / stochastic_v_ag_inv.pow(2))*multiplication_factor[i]       
    # =====================Expected expression ======================= #
    # biological expression
    mu = (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \
    detection_y_c * detection_y_i * detection_y_gi
    # =====================DATA likelihood ======================= #
    data = pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag,
                   rate= stochastic_v_ag / mu))
    filename = '/nfs/team283/aa16/data/fate_benchmarking/SimulatedData/Overdispersion' + str(multiplication_factor[i]) + '.pickle'
    with open(filename, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)   

0
1
2
3
4


DetectionEfficiency

In [13]:
multiplication_factor = (1., 0.25, 0.5, 2., 4.)
for i in range(len(multiplication_factor)):
    print(i)
    n_obs = np.shape(posterior['t_c'])[0]
    n_vars = np.shape(posterior['gamma_g'])[1]
    zeros = torch.zeros(n_obs, n_vars)
    zero = torch.tensor(0.)
    ones = torch.ones((1, 1))
    batch_size = np.shape(posterior['t_c'])[0]
    batch_index = torch.zeros(n_obs,1)
    obs2sample = one_hot(batch_index, n_batch)        
    # ===================== Kinetic Rates ======================= #
    # Splicing rate:
    splicing_alpha = posterior['splicing_alpha']
    splicing_mean = posterior['splicing_mean']
    beta_g = torch.tensor(posterior['beta_g'])
    # Degredation rate:
    degredation_alpha = posterior['degredation_alpha']
    degredation_mean = posterior['degredation_mean']
    gamma_g = torch.tensor(posterior['gamma_g'])
    # Transcription rate contribution of each module:
    factor_level_g = posterior['factor_level_g']
    g_fg = posterior['g_fg']
    A_mgON = torch.tensor(posterior['A_mgON'])
    A_mgOFF = 10**(-5)    
    # Activation and Deactivation rate:
    lam_mu = posterior['lam_mu']
    lam_sd = posterior['lam_sd']
    lam_m_mu = posterior['lam_m_mu']
    lam_mi = torch.tensor(posterior['lam_mi'])
    # =====================Time======================= #
    # Global time for each cell:
    T_max = posterior['Tmax']
    t_c_loc = posterior['t_c_loc']
    t_c_scale = posterior['t_c_scale']
    t_c = posterior['t_c']
    T_c = torch.tensor(posterior['T_c'])
    # Global switch on time for each gene:
    t_delta = torch.tensor(posterior['t_delta'])
    t_mON = torch.cumsum(torch.concat([zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0)
    T_mON = torch.tensor(posterior['T_mON'])
    # Global switch off time for each gene:
    t_mOFF = posterior['t_mOFF']
    T_mOFF = torch.tensor(posterior['T_mOFF'])
    # =========== Mean expression according to RNAvelocity model ======================= #
    mu_total = torch.stack([zeros, zeros], axis = -1)
    for m in range(n_modules):
        mu_total += mu_mRNA_continousAlpha_globalTime_twoStates(
            A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], zeros)
    mu_expression = pyro.deterministic('mu_expression', mu_total)
    # =============Detection efficiency of spliced and unspliced counts =============== #
    # Cell specific relative detection efficiency with hierarchical prior across batches:
    detection_mean_y_e = posterior['detection_mean_y_e']
    detection_hyp_prior_alpha = posterior['detection_hyp_prior_alpha']
    beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e)
    detection_y_c = torch.tensor(posterior['detection_y_c'])*multiplication_factor[i]    
    detection_y_i = torch.tensor(posterior['detection_y_i'])
    detection_y_gi = torch.tensor(posterior['detection_y_gi'])
    s_g_gene_add_alpha_hyp = posterior['s_g_gene_add_alpha_hyp']
    s_g_gene_add_mean = posterior['s_g_gene_add_mean']
    s_g_gene_add_alpha_e_inv = posterior['s_g_gene_add_alpha_e_inv']
    s_g_gene_add = torch.tensor(posterior['s_g_gene_add'])
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_hyp = posterior['stochastic_v_ag_hyp']
    stochastic_v_ag_inv = torch.tensor(posterior['stochastic_v_ag_inv'])
    stochastic_v_ag = (ones / stochastic_v_ag_inv.pow(2))       
    # =====================Expected expression ======================= #
    # biological expression
    mu = (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \
    detection_y_c * detection_y_i * detection_y_gi
    # =====================DATA likelihood ======================= #
    data = pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag,
                   rate= stochastic_v_ag / mu))
    filename = '/nfs/team283/aa16/data/fate_benchmarking/SimulatedData/DetectionEfficiency' + str(multiplication_factor[i]) + '.pickle'
    with open(filename, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)    

0
1
2
3
4
