## Import packages

In [1]:
import numpy as np
import torch
import torch.optim as optim
import os

## Define differentiable Gillespie algorithm

In [2]:
# Define the stoichiometry matrix for the reactions
stoic_matrix = torch.tensor([[2.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [0.0, 1.0],    # Reaction 2: mRNA is produced
                             [-2.0, 0.0],   # Reaction 3: Promoter state goes from +1 to -1
                             [0.0, -1.0]])  # Reaction 4: Degradation of mRNA

# Define a function to compute the state jump
def state_jump(reaction_index, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector.

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """
    return torch.sum(stoic_matrix * (torch.exp(-b_inv* (reaction_index - torch.arange(stoic_matrix.shape[0]))**2)).view(-1, 1), dim=0)

# Define a function to select the reaction based on reaction selection thresholds
def reaction_selection(breaks, random_num):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.

    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_num - breaks)))

# Define the Gillespie simulation function
def gillespie_simulation(poff, r, g, num_simulations, sim_time, a_inv, b_inv, c):
    """
    Perform differentiable Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        poff: Promoter probability to be in OFF (-1) state. poff=koff/(koff+kon).
        r: Rate of mRNA production.
        g: Rate of mRNA degradation.
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        c: Sigmoid slope parameter for propensities.
        
    Returns:
        mean_final_state: Mean of the mRNA levels at the end of the simulation.
        variance: Variance of the mRNA levels at the end of the simulation.
    """
    # Initialize random seed for reproducibility
    random_seed = torch.randint(1, 10000000, (1,))
    torch.manual_seed(random_seed)
    final_states = 0.0
    final_states_squared = 0.0

    # Main simulation loop
    for j in range(num_simulations):
        # Initial 'levels':
        # The first component of 'levels' is the promoter state, initialized to -1
        # The second component of 'levels' is the mRNA level, initailized to 0.
        levels = torch.stack([torch.tensor(-1.0), torch.tensor(0.0)])
        current_time = 0.0

        # Main simulation loop
        while current_time < sim_time:
            # Calculate reaction propensities
            propensities = torch.stack([(1/poff-1.0) *torch.sigmoid(-c*levels[0]),       # Rate of promoter state switching from -1 to +1
                                        r*torch.sigmoid(-c*levels[0]),                   # Rate of mRNA production
                                        torch.tensor([1.0])*torch.sigmoid(c*levels[0]),  # Rate of promoter state switching from +1 to -1
                                        g*levels[1]])                                    # Rate of mRNA degradation

            # Calculate total propensity
            total_propensity = propensities.sum()

            # Generate a random number to determine time to next reaction
            dt = -torch.log(torch.rand(1)) / total_propensity
            current_time += dt.item()

            # Check if the simulation exceeds sim_time. If it exceeds, quit the simulation.
            if current_time >= sim_time:
                break

            # Update state
            breaks = (propensities[:-1] / total_propensity).cumsum(dim=0)
            reaction_index = reaction_selection(breaks, torch.rand(1))
            levels = levels + state_jump(reaction_index, stoic_matrix)
            levels[1] = torch.relu(levels[1]) 

        # Accumulate final states after each sumulation
        final_states += levels[1]
        final_states_squared += levels[1] ** 2

    # Calculate mean and variance of mRNA levels (from the accumulated final states)
    mean_final_state = final_states / num_simulations
    variance = final_states_squared / num_simulations - mean_final_state ** 2

    # Return mean mRNA level and variance
    return mean_final_state, variance

# Define a function to write data to a file
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')
        
# Define the loss function 
def loss_function(mean_final_state, variance, target_mean, target_std):
    """
    Calculates the mean squared error of the simulation results against data
    """
    return (mean_final_state - target_mean) ** 2 + (variance ** 0.5 - target_std) ** 2

## Gradient descent

In [None]:
# Set seed for reproducibility 
torch.manual_seed(40)

# Define simulation hyperparameters
num_simulations = 50
sim_time = 10.0
a_inv = 200.0
b_inv =20.0
c = 20.0
num_iterations = 130

# Load the sample random rates from a file
sample = np.load("random_rates.npy")

# Define the filename to write data
filename = "2state_model_with_deg.txt"

# Remove the file if it already exists
if os.path.exists(filename):
    os.remove(filename)

# Loop through each sample rate set
for i in range(sample.shape[0]):

    # Initialize parameters
    poff = torch.nn.Parameter((1.0 / (1.0 + sample[i, 0])) * (10 ** (-1 + 2 * torch.rand(1)))) # poff=(1/(1+kon))
    while True:
        r = torch.nn.Parameter(sample[i, 1] * (10 ** (-1 + 2 * torch.rand(1))))
        g = torch.nn.Parameter(sample[i, 2] * (10 ** (-1 + 2 * torch.rand(1))))
        if g < r:
            break

    # Define the Adam optimizer
    optimizer = optim.Adam([poff, r, g], lr=0.1)

    # Set target mean and standard deviation
    target_mean = sample[i, 3]
    target_std = sample[i, 4]

    # Loop through each iteration
    for iteration in range(num_iterations):

        # Forward differentiable Gillespie simulation
        mean_final_state, variance = gillespie_simulation(poff, r, g, num_simulations, sim_time, a_inv, b_inv, c)

        # Compute the loss for the current iteration
        loss = loss_function(mean_final_state, variance, target_mean, target_std)

        # Zero the gradients to prepare for backward pass
        optimizer.zero_grad()

        # Compute the gradient of the loss with respect to parameters
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_([poff, r, g], max_norm=0.2)

        # Update the parameters using the optimizer
        optimizer.step()

        # Clamp the parameter values to certain bounds
        poff.data = torch.clamp(poff.data, min=1 / (1 + 100), max=1 / (1 + 0.1))
        r.data = torch.clamp(r.data, min=0.1, max=100.0)
        g.data = torch.clamp(g.data, min=0.1, max=r.data.item())

        # Save the values of the parameters and the loss every iterations
        if iteration % 1 == 0:
            write_to_file(filename, i, iteration, (1 / poff.item()) - 1.0, r.item(), g.item(), mean_final_state.item(), (variance ** 0.5).item()
                         , (poff * r / g).item(), loss.item())
