## Import packages

In [3]:
import numpy as np
import torch
import torch.optim as optim
import os

## Load science paper 5DL1 promoter data points 

In [4]:
# Load the data from file for 5DL1 promoter 
data=np.load("science_data_5DL1.npy")

# Extract x and y data
x_data = data[:, 0]
y_data = data[:, 1]

# Get unique x values and their corresponding indices
unique_x, unique_indices = np.unique(x_data, return_index=True)

# Extract unique y values based on unique indices
unique_y = y_data[unique_indices]

# Create unique data array with x and y values
unique_data = np.column_stack((unique_x, unique_y))

# Convert mean and variance data to torch tensors
mean_data = torch.from_numpy(unique_data[:, 0]).double()
var_data = mean_data * torch.from_numpy(unique_data[:, 1]).double()

## Differentiable Gillespie algorithm

In [5]:
# Stoichiometric matrix defining the effect of each reaction on the system
stoic_matrix = torch.tensor([[2.0, 0.0],    # Reaction 1: Promoter state goes from -1 to +1
                             [0.0, 1.0],    # Reaction 2: mRNA is produced
                             [-2.0, 0.0],   # Reaction 3: Promoter state goes from +1 to -1
                             [0.0, -1.0]])  # Reaction 4: Degradation of mRNA

def state_jump(reaction_index, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector.

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """ 
    return torch.sum(stoic_matrix * (torch.exp(-b_inv * (reaction_index - torch.arange(stoic_matrix.shape[0]))**2)).view(-1, 1), dim=0)

def reaction_selection(breaks, random_num):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.

    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_num - breaks)))

def gillespie_simulation(poff_values, r, g):
    """
    Perform differentiable Gillespie simulation for a 2-state promoter model.
    
    Arguments:
        poff_values: Array of probabilities for promoter to be in OFF state. poff=koff/(kon+koff)
        r: Rate of mRNA production.
        g: Rate of mRNA degradation.
        
    Returns:
        mean_final_states: Mean of the mRNA levels at the end of the simulation.
        variances: Variance of the mRNA levels at the end of the simulation.
    """
    # Initialize random seed for reproducibility
    random_seed = torch.randint(1, 10000000, (1,))
    torch.manual_seed(random_seed)

    mean_final_states = torch.empty(len(unique_data))
    variances = torch.empty(len(unique_data))
    
    for n in range(len(unique_data)):
        poff = poff_values[n].unsqueeze(0)
        
        final_states = 0.0
        final_states_squared = 0.0

        for j in range(num_simulations):
            # Initial 'levels':
            # The first component of 'levels' is the promoter state, initialized to -1
            # The second component of 'levels' is the mRNA level, initialized to 0.
            levels = torch.stack([torch.tensor(-1.0), torch.tensor(0.0)])
            current_time = 0.0

            while current_time < sim_time:
                # Calculate reaction propensities
                propensities = torch.stack([(1/poff-1.0) * torch.sigmoid(-c*levels[0]), 
                                            r * torch.sigmoid(-c*levels[0]), 
                                            torch.tensor([1.0]) * torch.sigmoid(c * levels[0]), 
                                            g * levels[1]])
                propensities = torch.relu(propensities)

                # Sum of all propensities
                total_propensity = propensities.sum()

                # Time until next reaction
                dt = -torch.log(torch.rand(1)) / total_propensity
                current_time += dt.item()

                if current_time >= sim_time:
                    break

                # Determine which reaction occurs and update the system state
                breaks = (propensities[:-1] / total_propensity).cumsum(dim=0)
                reaction_index = reaction_selection(breaks, torch.rand(1))
                levels = levels + state_jump(reaction_index, stoic_matrix)
                levels[1] = torch.relu(levels[1])  # Ensure non-negative values for the mRNA number

            # Accumulate mRNA level and its square
            final_states += levels[1]
            final_states_squared += levels[1] ** 2

        # Calculate mean and variance of mRNA levels
        mean_final_state = final_states / num_simulations
        variance = final_states_squared / num_simulations - mean_final_state ** 2
        
        mean_final_states[n] = mean_final_state
        variances[n] = variance

    return mean_final_states, variances

def loss_function(mean_final_states, variances):
    """
    Loss function that calculates the mean squared error of the simulation results against data.

    Arguments:
        mean_final_states: Mean of the mRNA levels at the end of the simulation.
        variances: Variance of the mRNA levels at the end of the simulation.
        
    Returns:
        Loss value
    """
    return torch.mean((mean_final_states - mean_data)**2 + (variances**0.5 - var_data**0.5)**2)

# Define a function to write data to a file
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')


## Gradient descent


In [None]:
# Set seed for reproducibility 
torch.manual_seed(42)

# Define simulation hyperparameters
num_iterations = 1000
num_simulations = 200
sim_time = 0.2
a_inv = 200.0
b_inv =20.0
c = 20.0

# Initialize parameters with random values
poff_values = torch.nn.Parameter(torch.linspace(0.03, 0.97, len(unique_data)))
r = torch.nn.Parameter((1e+2) * torch.rand(1))
g = torch.nn.Parameter((1e+1) * torch.rand(1))

# Define the Adam optimizer and include all parameters that require gradients
optimizer = optim.Adam([poff_values, r, g], lr=0.1)

# Define filenames for saving results
filename1 = "learning_science_5DL1_poff.txt"
if os.path.exists(filename1):
    os.remove(filename1)   
filename2 = "learning_science_5DL1.txt"
if os.path.exists(filename2):
    os.remove(filename2)

# Main optimization loop
for iteration in range(num_iterations):
    
    # Forward differentiable Gillespie simulation
    mean_final_states, variances = gillespie_simulation(poff_values, r, g)
    
    # Compute the loss for the current iteration
    loss = loss_function(mean_final_states, variances)

    # Zero the gradients to prepare for backward pass
    optimizer.zero_grad()

    # Compute the gradient of the loss with respect to parameters
    loss.backward()

    # Clip gradients to prevent exploding gradients problem
    torch.nn.utils.clip_grad_norm_([poff_values, r, g], max_norm=1.0)

    # Update the parameters using the optimizer
    optimizer.step()

    # Clamp r and g to ensure they are within valid range
    r.data = torch.clamp(r.data, min=1.0)
    g.data = torch.clamp(g.data, min=1.0, max=r.item())
    poff_values.data = torch.clamp(poff_values, min=0.01, max=0.98)
    poff_values.data, _ = torch.sort(poff_values.data)

    # Save the values of the parameters after each iteration
    if iteration % 1 == 0:
        write_to_file(filename2, iteration, r.item(), g.item(), r.item() / g.item(), loss.item())
        write_to_file(filename1, poff_values.tolist(), loss.item())