## Import packages

In [1]:
import numpy as np
import torch
import os
from scipy.stats import entropy

## Exact Gillespie simulation

In [2]:
# Define the stoichiometry matrix for the reactions
stoic_matrix = torch.tensor([[2.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [0.0, 1.0],    # Reaction 2: mRNA is produced
                             [-2.0, 0.0],   # Reaction 3: Promoter state goes from +1 to -1
                             [0.0, -1.0]])  # Reaction 4: Degradation of mRNA

def gillespie_simulation_exact(kon, r, g, num_simulations, sim_time):
    """
    Perform exact Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        kon: Rate of promoter switching from -1 to +1.
        r: Rate of mRNA production.
        g: Rate of mRNA degradation.
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        c: Sigmoid parameter for propensities.
        
    Returns:
        mean_final_state: Mean of the mRNA levels at the end of the simulation.
        variance: Variance of the mRNA levels at the end of the simulation.
        levels: List of mRNA levels from each simulation.
    """
    # Seed for reproducibility
    torch.manual_seed(42)

    # Initialize tensors to store simulation data
    levels = torch.stack([
        torch.full((num_simulations,), -1.0),  # The first component of 'levels' is the promoter state, initialized to -1
        torch.zeros(num_simulations)           # The second component of 'levels' is the mRNA level, initailized to 0.
    ], dim=1)
    current_time = torch.zeros(num_simulations)
    final_states = torch.zeros(num_simulations)
    final_states_squared = torch.zeros(num_simulations)

    # Main simulation loop
    while True:
        # Mask to keep track of active simulations
        active = current_time < sim_time
        if not active.any():
            break

        # Calculate propensities for each reaction
        propensities = torch.stack([
            kon * torch.heaviside(-levels[:, 0], torch.tensor(0.0)), # Rate of promoter state switching from -1 to +1
            r * torch.heaviside(-levels[:, 0], torch.tensor(0.0)),   # Rate of mRNA production
            torch.heaviside(levels[:, 0], torch.tensor(0.0)),        # Rate of promoter state switching from +1 to -1
            g * levels[:, 1]                                         # Rate of mRNA degradation
        ], dim=1)
        propensities = propensities[active]

        # Calculate total propensity for each simulation
        total_propensity = propensities.sum(dim=1)

        # Generate time until next reaction for each active simulation
        dt = -torch.log(torch.rand(active.sum())) / total_propensity

        # Update current time for active simulations
        current_time[active] += dt

        # Calculate cumulative propensities for each reaction
        cumulative_propensity = propensities.cumsum(dim=1)

        # Generate random numbers to select reactions
        random_nums = torch.rand(active.sum()) * total_propensity

        # Select reactions based on random numbers
        reactions = (cumulative_propensity > random_nums.unsqueeze(1)).int().argmax(dim=1)

        # Update levels based on selected reactions
        stoic_updates = stoic_matrix[reactions]
        levels[active] += stoic_updates

    # Calculate mean and variance of mRNA levels
    final_states += levels[:, 1]
    final_states_squared += levels[:, 1] ** 2
    mean_final_state = final_states.mean()
    variance = (final_states_squared.mean() - mean_final_state ** 2)

    # Return mean mRNA level, variance, and list of mRNA levels
    return mean_final_state.item(), variance.item(), levels[:, 1].tolist()


## Differentiable Gillespie simulation

In [3]:
# Define the stoichiometry matrix for the reactions
stoic_matrix = torch.tensor([[2.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [0.0, 1.0],    # Reaction 2: mRNA is produced
                             [-2.0, 0.0],   # Reaction 3: Promoter state goes from +1 to -1
                             [0.0, -1.0]])  # Reaction 4: Degradation of mRNA

# Define a function to compute the state jump
def state_jump(reaction_indices, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """
    return torch.matmul(torch.exp(-b_inv* (reaction_indices.unsqueeze(1) - torch.arange(stoic_matrix.shape[0]).unsqueeze(0).repeat(reaction_indices.shape[0], 1))**2), stoic_matrix)

# Define a function to select the reaction based on reaction selection thresholds
def reaction_selection(breaks, random_nums):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.

    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_nums.unsqueeze(1) - breaks)), dim=1)

def gillespie_simulation_diff(kon, r, g, num_simulations, sim_time, a_inv, b_inv, c):
    """
    Perform differentiable Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        kon: Rate of promoter switching from -1 to +1.
        r: Rate of mRNA production.
        g: Rate of mRNA degradation.
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        c: Sigmoid parameter for propensities.
        
    Returns:
        mean_final_state: Mean of the mRNA levels at the end of the simulation.
        variance: Variance of the mRNA levels at the end of the simulation.
        levels: List of mRNA levels from each simulation.
    """
    
    # Initialize random seed for reproducibility
    random_seed = torch.randint(1, 10000000, (1,))
    torch.manual_seed(random_seed)

    # Initialize tensors to store simulation data
    levels = torch.stack([torch.full((num_simulations,), -1.0), torch.zeros(num_simulations)], dim=1)
    current_time = torch.zeros(num_simulations)
    final_states = torch.zeros(num_simulations)
    final_states_squared = torch.zeros(num_simulations)

    # Main simulation loop
    while True:
        # Mask to keep track of active simulations
        active = current_time < sim_time
        if not active.any():
            break

        # Calculate propensities for each reaction
        propensities = torch.stack([
            kon * torch.sigmoid(-c * levels[:, 0]),  # Rate of promoter state switching from -1 to +1
            r * torch.sigmoid(-c * levels[:, 0]),    # Rate of mRNA production
            torch.sigmoid(c * levels[:, 0]),         # Rate of promoter state switching from +1 to -1
            g * levels[:, 1]                         # Rate of mRNA degradation
        ], dim=1)

        # Calculate total propensity for each simulation
        total_propensity = propensities.sum(dim=1)

        # Generate time until next reaction for each active simulation
        dt = -torch.log(torch.rand(num_simulations)) / total_propensity
        current_time[active] += dt[active]

        # Calculate reaction selection thresholds
        breaks = (propensities[active, :-1].cumsum(dim=1) / total_propensity[active].unsqueeze(1))

        # Generate random numbers for reaction selection
        random_nums = torch.rand(active.sum())

        # Select reactions based on random numbers
        reaction_index = reaction_selection(breaks, random_nums)

        # Calculate state jumps based on selected reactions
        stoic_updates = state_jump(reaction_index, stoic_matrix)

        # Update levels based on state jumps and ensure non-negative values
        levels[active] += stoic_updates
        levels[:, 1] = torch.relu(levels[:, 1])

    # Calculate mean and variance of mRNA levels
    final_states += levels[:, 1]
    final_states_squared += levels[:, 1] ** 2
    mean_final_state = final_states.mean()
    variance = (final_states_squared.mean() - mean_final_state ** 2)

    # Return mean mRNA level, variance, and list of mRNA levels
    return mean_final_state.item(), variance.item(), levels[:, 1].tolist()

## Compute relative JSD as a function of 1/a

In [4]:
# Define a function to write data to a file
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

# Define the filename to write data
filename = "a_inv_vs_jsd.txt"

# Check if the file already exists and remove it if it does
if os.path.exists(filename):
    os.remove(filename)

# Set random seed for reproducibility
torch.manual_seed(40)

# Define simulation hyperparameters
num_simulations = 5000
sim_time = 10.0
b_inv = 20.0
c = 20.0

# Define rate constants
kon = torch.tensor([0.5])
r = torch.tensor([10.0])
g = torch.tensor([1.0])

# Define a list of a_inv values
a_inv_list = np.logspace(np.log10(0.1), np.log10(10000), 50)

# Forward exact Gillespie simulation
a1, b1, c1 = gillespie_simulation_exact(kon, r, g, num_simulations, sim_time)
    
# Loop through each value of a_inv
for a_inv in a_inv_list:
    
    # Forward differentiable Gillespie simulation
    a2, b2, c2 = gillespie_simulation_diff(kon, r, g, num_simulations, sim_time, a_inv, b_inv, c)
    
    # Convert samples to histograms with bin size equal to 1
    bin_edges = np.arange(np.floor(min([min(c1), min(c2)])) - 0.5, np.ceil(max([max(c1), max(c2)])) + 1.5, 1)
    hist_c1, _ = np.histogram(c1, bins=bin_edges, density=True)
    hist_c2, _ = np.histogram(c2, bins=bin_edges, density=True)

    # Make sure there are no zeros in distributions (required for JSD)
    hist_c1 += 1e-10
    hist_c2 += 1e-10
    
    # Calculate relative Jensen–Shannon divergence value
    rel_jsd_value = entropy(hist_c1, hist_c2) / entropy(hist_c1)
    
    # Save relative Jensen–Shannon divergence value to a file
    write_to_file(filename, a_inv, rel_jsd_value)