## Import packages

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import random
import torch.optim as optim

## Desired response curve

In [None]:
filename = "loop_model_forward_exact_curve.txt"

# Check if the file already exists, and remove it if it does
if os.path.exists(filename):
    os.remove(filename)

# Function to write data to a file
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

# Function to perform exact Gillespie simulation for the dorsal concentration c 
def gillespie_simulation(c, num_simulations=600, sim_time=20.0):
    mrna_prod_rates = torch.tensor([])

    # Perform simulations
    for i in range(num_simulations):
        time_points = torch.tensor([0])  # Initialize time points
        states = torch.tensor([0])  # Initialize the sequence of promoter states
        state = torch.tensor(0)  # Initialize current promoter state
        active_time = torch.tensor(0)  # Initialize active time (time duration for which the promoter is in either state 2 or 3)
        current_time = 0.0  # Initialize current time
        
        # Perform simulation until sim_time is reached
        while current_time < sim_time:
            # Calculate rates based on the current state
            forward_rate = k_for(c, state)   # Rate at which promoter state moves forward
            backward_rate = k_back(c, state) # Rate at which promoter state moves backward
            total_rate = forward_rate + backward_rate 

            # Generate time step based on total rate
            dt = (1/total_rate) * torch.log(1/torch.rand(1))

            # Update time
            time_points = torch.cat((time_points, time_points[-1] + dt), dim=0)

            # Update state
            if torch.rand(1) < forward_rate / total_rate:
                state = (state + 1) % 4 
            else:
                state = (state - 1) % 4 
            states = torch.cat((states, state.view(1)), dim=0)

            # Accumulate active time if in states 2 or 3
            if states[-2] == 2 or states[-2] == 3:
                active_time=active_time+dt
            current_time = time_points[-1]

        # Calculate the mRNA production rate for each simulation and store
        mrna_prod_rates = torch.cat((mrna_prod_rates, (active_time/current_time).view(1)), dim=0)
        
    # Calculate the mean mRNA production rate over all simulations
    mean = torch.mean(mrna_prod_rates)

    return mean

# Function to calculate forward rates
def k_for(c, state):
    if state == 0:
        k = c * k_b
    elif state == 1:
        k = n_ab * k_a
    elif state == 2:
        k = n_ua * k_u
    elif state == 3:
        k = k_i
    return k

# Function to calculate backward rates
def k_back(c, state):
    if state == 0:
        k = k_a
    elif state == 1:
        k = k_u
    elif state == 2:
        k = n_ib * k_i
    elif state == 3:
        k = c * n_ba * k_b
    return k

# Set the value of parameters
c_range = torch.logspace(np.log10(10), np.log10(5000), 10)  # Dorsal protein concentration range
k_b = 0.02
k_u = 2.0
k_a = 0.3
k_i = 2.5
n_ib = 0.3
n_ab = 3.0
n_ba = 1.5
n_ua = 1.5

torch.manual_seed(42)
# Perform simulations for each concentration in the range
for c in c_range:
    mean = gillespie_simulation(c)
    write_to_file(filename, c.item(), mean.item())

## Define differentiable Gillespie algorithm

In [None]:
# Define the stoichiometric matrix
stoic_matrix= torch.tensor([[1.0], # Reaction1: Promoter state moves forward 
                          [-1.0]]) # Reaction2: Promoter state moves backward 

# Function to calculate state jump
def state_jump(reaction_index, stoic_matrix):
    """
    Calculate state jump vector based on the selected reaction index and stoichiometry matrix, where, 
    state vector -> state vector + state jump vector.

    Arguments:
        reaction_index: Selected reaction index
        stoic_matrix: Stoichiometry matrix

    Returns:
        State jump vector
    """
    return torch.sum(stoic_matrix * (torch.exp(-b_inv * (reaction_index - torch.arange(stoic_matrix.shape[0]))**2)).view(-1, 1), dim=0)

# Function to select reaction
def reaction_selection(breaks, random_num):
    """
    Select reaction based on the transition points and a random number. Transition points are 
    given by the ratio of cumulative sum of rates and the total rate.

    Arguments:
        breaks: Transition points between [0,1]
        random_num: Random number in [0,1]

    Returns:
        Index of the next reaction
    """
    return torch.sum(torch.sigmoid(a_inv * (random_num - breaks)))

# Function for Gillespie simulation
def gillespie_simulation(k_b, k_u, k_a, k_i, n_ib, n_ab, n_ba, n_ua, num_simulations, sim_time, a_inv, b_inv, cc):
    """
    Perform differentiable Gillespie simulation for the 4-state loop promoter model.
    
    Arguments:
        k_b, k_u, k_a, k_i, n_ib, n_ab, n_ba, n_ua: Rates involved in the loop model.
        num_simulations: Number of simulations to run.
        sim_time: Simulation time.
        a_inv: Inverse parameter for reaction selection.
        b_inv : Inverse parameter for state jump calculation.
        cc: Sigmoid slope parameter for propensities (rates).
        
    Returns:
        mean_values: Mean of the mRNA levels at the end of the simulation.
    """
    random_seed = torch.randint(1, 10000000, (1,))
    torch.manual_seed(random_seed)
    mean_values = torch.tensor([]) # Tensor to store the mean mRNA production rate for each concentration c
    
    # Loop over the concentration c 
    for c in c_range:
        mrna_prod_rates = torch.tensor([]) # Empty tensor to store the mRNA production rates after each simulation
        for j in range(num_simulations):
            state = torch.tensor(0.0)  # Initialize the current state of the promoter to 0
            states = torch.tensor([0]) # Tensor to store the sequence of promoter states
            active_time = 0.0          # Time spent in state 2 or 3
            current_time = 0.0         # Total time
            
            while current_time < sim_time:
                
                # Calculate reaction propensities (rates)
                forward_rate = k_for(c, k_b, k_u, k_a, k_i, n_ab, n_ba, n_ua, state) # Rate at which promoter state moves forward
                backward_rate = k_back(c, k_b, k_u, k_a, k_i, n_ib, n_ba, state)     # Rate at which promoter state moves backward
                propensities = torch.stack([forward_rate, backward_rate])

                # Calculate total propensity
                total_propensity = propensities.sum()

                # Generate a random number to determine time to next reaction
                dt = -torch.log(torch.rand(1)) / total_propensity
                current_time += dt.item()

                # Check if the simulation exceeds sim_time
                if current_time >= sim_time:
                    break

                # Update state
                breaks = (propensities[:-1] / total_propensity).cumsum(dim=0) # Transition points between [0,1]
                reaction_index = reaction_selection(breaks, torch.rand(1)) # Choose index of next promoter state jump
                state = state + state_jump(reaction_index, stoic_matrix) # Update promoter state
                state = state % 4 # Implement periodicity 
                state = torch.relu(state) - torch.relu(state - 3) - 3.0 * torch.sigmoid(200 * (state - 3.5))
                states = torch.cat((states, state), dim=0)
                active_time = active_time + dt * (torch.sigmoid(200 * (states[-2] - 1.5))) # Add the waiting time to active_time if the promoter was in state > 1.5
                
            # Accumulate mrna production rates after each sumulation
            mrna_prod_rates = torch.cat((mrna_prod_rates, active_time / current_time), dim=0)
            
        # Calculate mean mRNA production rate (from the accumulated mrna production rates)
        mean = torch.mean(mrna_prod_rates)
        
        # Store the mean mRNA production rate for each concentration c
        mean_values = torch.cat((mean_values, mean.view(1)), dim=0)

    return mean_values

# Function for forward rates
def k_for(c, k_b, k_u, k_a, k_i, n_ab, n_ba, n_ua, state):
    out = (n_ab * k_a - c * k_b) * torch.sigmoid(cc * (state - 0.5)) + \
          (n_ua * k_u - n_ab * k_a) * torch.sigmoid(cc * (state - 1.5)) + \
          (k_i - n_ua * k_u) * torch.sigmoid(cc * (state - 2.5)) + \
          (c * k_b)
    return out

# Function for backward rates
def k_back(c, k_b, k_u, k_a, k_i, n_ib, n_ba, state):
    out = (k_u - k_a) * torch.sigmoid(cc * (state - 0.5)) + \
          (n_ib * k_i - k_u) * torch.sigmoid(cc * (state - 1.5)) + \
          (c * n_ba * k_b - n_ib * k_i) * torch.sigmoid(cc * (state - 2.5)) + \
          (k_a)
    return out

# Define the loss function 
def loss_function(mean_values, target_mean_values):
    """
    Loss function that calculates the mean squared error of the simulation results against data.
    """
    return torch.mean((mean_values - target_mean_values)**2)

# Function to write results to file
def write_to_file(filename, *args):
    with open(filename, 'a') as file:
        file.write(' '.join(map(str, args)) + '\n')

## Gradient descent on the 7-dimensional parameter space 

In [None]:
# Load the desired mean mRNA production rates from a file
data = torch.tensor(np.loadtxt("loop_model_forward_exact_curve.txt"))
c_values = data[:, 0] # Dorsal concentration 
target_mean_values = data[:, 1] # mean mRNA production rate

filename = "loop_model_learning_results.txt" # File to store the learning progress

# Check if the file exists and remove it if it does
if os.path.exists(filename):
    os.remove(filename)
        
# Seed for reproducibility
random.seed(1) 

# Initializing the parameter values
k_b_tensor = torch.tensor(2.0)
k_u_tensor =  torch.tensor(random.uniform(0.1, 10.0), requires_grad=True)
k_a_tensor = torch.tensor(random.uniform(0.1, 1.0), requires_grad=True)
k_i_tensor = torch.tensor(random.uniform(0.1, 10.0), requires_grad=True)
n_ib_tensor = torch.tensor(random.uniform(0.1, 1.0), requires_grad=True)
n_ab_tensor =  torch.tensor(random.uniform(1.0, 10.0), requires_grad=True)
n_ba_tensor = torch.tensor(random.uniform(0.1, 10.0), requires_grad=True)
n_ua_tensor = torch.tensor(random.uniform(0.1, 10.0), requires_grad=True)

# Define the optimizer
optimizer = optim.Adam([k_u_tensor, k_a_tensor, k_i_tensor, n_ib_tensor, n_ab_tensor
        , n_ba_tensor, n_ua_tensor], lr=0.1)

# Hyperparameters
num_iterations = 2000
num_simulations=20
sim_time=20.0
a_inv=200.0
b_inv=20.0
cc=20.0

# Dorsal protein concentration range
c_range = 0.01*torch.logspace(np.log10(10), np.log10(5000), 10) 

torch.manual_seed(42)
for iteration in range(num_iterations):
    
    # Perform Gillespie simulation to compute mean values
    mean_values = gillespie_simulation(k_b_tensor, k_u_tensor, k_a_tensor, k_i_tensor, n_ib_tensor, n_ab_tensor
        , n_ba_tensor, n_ua_tensor)

    # Compute the loss for the current iteration
    loss = loss_function(mean_values, target_mean_values)

    # Zero the gradients to prepare for backward pass
    optimizer.zero_grad()

    # Compute the gradient of the loss with respect to parameters
    loss.backward()

    # Update the parameters using the optimizer
    optimizer.step()
    
    # Enforce valid parameter constraints for the parameters 
    k_u_tensor.data = torch.clamp(k_u_tensor.data, min=0.01)
    k_a_tensor.data = torch.clamp(k_a_tensor.data, min=0.01)
    k_i_tensor.data = torch.clamp(k_i_tensor.data, min=0.01)
    n_ib_tensor.data = torch.clamp(n_ib_tensor.data, min=0.01)
    n_ab_tensor.data = torch.clamp(n_ab_tensor.data, min=1.0)
    n_ba_tensor.data = torch.clamp(n_ba_tensor.data, min=0.01)
    n_ua_tensor.data = torch.clamp(n_ua_tensor.data, min=0.01)

    
    # Write the results to the file after each iteration
    write_to_file(filename, iteration, k_b_tensor.item(), k_u_tensor.item(), k_a_tensor.item(), k_i_tensor.item(), n_ib_tensor.item(), n_ab_tensor.item(), n_ba_tensor.item(), n_ua_tensor.item(),  loss.item())
    print (k_u_tensor.item(), k_a_tensor.item(), k_i_tensor.item(), loss.item())