## Import packages

In [1]:
import numpy as np
import torch
import os
import scipy.stats 

## Define differentiabale Gillespie algorithm

In [2]:
# Define the stoichiometry matrix for the reactions
stoic_matrix = torch.tensor([[2.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [0.0, 1.0],    # Reaction 2: mRNA is produced
                             [-2.0, 0.0],   # Reaction 3: Promoter state goes from +1 to -1
                             [0.0, -1.0]])  # Reaction 4: Degradation of mRNA

# Define a function to compute the state jump
def state_jump(reaction_index, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector.

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """
    return torch.sum(stoic_matrix * (torch.exp(-b_inv* (reaction_index - torch.arange(stoic_matrix.shape[0]))**2)).view(-1, 1), dim=0)

# Define a function to select the reaction based on reaction selection thresholds
def reaction_selection(breaks, random_num):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.
    
    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_num - breaks)))

# Define the Gillespie simulation function
def gillespie_simulation(kon, r, g, num_simulations, sim_time, a_inv, b_inv, c):
    """
    Perform differentiable Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        kon: Rate of promoter switching from -1 to +1.
        r: Rate of mRNA production.
        g: Rate of mRNA degradation.
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        c: Sigmoid slope parameter for propensities.
        
    Returns:
        mean_final_state: Mean of the mRNA levels at the end of the simulation.
        variance: Variance of the mRNA levels at the end of the simulation.
    """
    # Initialize random seed for reproducibility
    random_seed = torch.randint(1, 10000000, (1,))
    print (random_seed)
    torch.manual_seed(random_seed)
    final_states = 0.0
    final_states_squared = 0.0

    # Main simulation loop
    for j in range(num_simulations):
        # Initialize 'levels':
        # The first component of 'levels' is the promoter state, initialized to -1
        # The second component of 'levels' is the mRNA count, initailized to 0.
        levels = torch.stack([torch.tensor(-1.0), torch.tensor(0.0)])
        current_time = 0.0

        # Main simulation loop
        while current_time < sim_time:
            # Calculate reaction propensities
            propensities = torch.stack([kon*torch.sigmoid(-c*levels[0]),                 # Rate of promoter state switching from -1 to +1
                                        r*torch.sigmoid(-c*levels[0]),                   # Rate of mRNA production
                                        torch.tensor([1.0])*torch.sigmoid(c*levels[0]),  # Rate of promoter state switching from +1 to -1
                                        g*levels[1]])                                    # Rate of mRNA degradation

            # Calculate total propensity
            total_propensity = propensities.sum()

            # Generate a random number to determine time to next reaction
            dt = -torch.log(torch.rand(1)) / total_propensity
            current_time += dt.item()

            # Check if the simulation exceeds sim_time. If it exceeds, quit the simulation.
            if current_time >= sim_time:
                break

            # Update state vector
            breaks = (propensities[:-1] / total_propensity).cumsum(dim=0)
            reaction_index = reaction_selection(breaks, torch.rand(1))
            levels = levels + state_jump(reaction_index, stoic_matrix)
            levels[1] = torch.relu(levels[1]) 

        # Accumulate final states after each sumulation
        final_states += levels[1]
        final_states_squared += levels[1] ** 2

    # Calculate mean and variance of mRNA levels (from the accumulated final states)
    mean_final_state = final_states / num_simulations
    variance = final_states_squared / num_simulations - mean_final_state ** 2

    # Return mean mRNA level and variance
    return mean_final_state, variance

## Non-degenerate case
### Load the learned parameters

In [None]:
# Load the data from learning progress file
file_path = '2state_model_no_deg.txt'  

data = np.loadtxt(file_path) 

# Initialize a dictionary to store the parameter values for the minimum loss
min_rows = {}

# Iterate through each row in the data
for row in data:
    index = int(row[0])
    # Update the minimum row for the index as needed
    if index not in min_rows or row[-1] < min_rows[index][-1]:
        min_rows[index] = row

# Extract the parameter values for minimum loss
diff_gillespie_min = [row[[2, 3, 4, 5,6]] for row in min_rows.values()]
diff_gillespie_min = np.array(diff_gillespie_min)

### Obtain the error bars

In [None]:
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')
 
# File to store errors
filename = "errors_on_mean_and_std_no_deg.txt"
if os.path.exists(filename):
    os.remove(filename)

# Load true parameter values
true_values=np.load("random_rates.npy")

# Hyperparameters
num_simulations=50
a_inv=200.0
b_inv=20.0
c=20.0
sim_time=10.0

# Set random seed for reproducibility
torch.manual_seed(42)

# Do forward Gillespie for 100 different random seeds
for n in range (20):
    means=[]
    variances=[]
    for i in range (100):
        a,b= gillespie_simulation(torch.tensor([diff_gillespie_min[n, 0]]), torch.tensor([diff_gillespie_min[n, 1]]), 
                                  torch.tensor([diff_gillespie_min[n,2]]), num_simulations, sim_time, a_inv, b_inv, c)
        means.append(a.item())
        variances.append((b.item())**0.5)
      
    # Write the errors on mean and std deviation to file
    write_to_file(filename, true_values[n,-2], diff_gillespie_min[n, -2], np.var(means)**0.5, true_values[n,-1], diff_gillespie_min[n, -1], np.var(variances)**0.5)

## Degenerate case
### Load the learned parameters

In [None]:
# Load the data from learning progress file
file_path = '2state_model_with_deg.txt'  

data = np.loadtxt(file_path) 

# Initialize a dictionary to store the parameter values for the minimum loss
min_rows = {}

# Iterate through each row in the data
for row in data:
    index = int(row[0])
    # Update the minimum row for the index as needed
    if index not in min_rows or row[-1] < min_rows[index][-1]:
        min_rows[index] = row

# Extract the parameter values for minimum loss
diff_gillespie_min = [row[[2, 3, 4, 5,6]] for row in min_rows.values()]
diff_gillespie_min = np.array(diff_gillespie_min)

### Obtain the error bars


In [None]:
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')
 
# File to store errors
filename = "errors_on_mean_and_std_with_deg.txt"
if os.path.exists(filename):
    os.remove(filename)

# Load true parameter values
true_values=np.load("random_rates.npy")

# Hyperparameters
num_simulations=50
a_inv=200.0
b_inv=20.0
c=20.0
sim_time=10.0

# Set random seed for reproducibility
torch.manual_seed(42)

# Do forward Gillespie for 100 different random seeds
for n in range (20):
    means=[]
    variances=[]
    for i in range (100):
        a,b= gillespie_simulation(torch.tensor([diff_gillespie_min[n, 0]]), torch.tensor([diff_gillespie_min[n, 1]]), 
                                  torch.tensor([diff_gillespie_min[n,2]]), num_simulations, sim_time, a_inv, b_inv, c)
        means.append(a.item())
        variances.append((b.item())**0.5)
        
    # Write the errors on mean and std deviation to file
    write_to_file(filename, true_values[n,-2], diff_gillespie_min[n, -2], np.var(means)**0.5, true_values[n,-1], diff_gillespie_min[n, -1], np.var(variances)**0.5)