# Into to NNCME with a Gene Expression Example - Simple Version

This is a simplified step-by-step guideline of implementing a gene expression example with NNCME. \
Follow this notebook, you can conveniently use the VAN to gain a gene expression system data which contains the joint distribution between species. \
We also provide an optional code for generating the Gillespie trajectories and plotting the result.\
A more detailed guideline can be seen in `Detailed Gene Expression Example.ipynb`.

## 1. Gene Expression System

The gene expression system involves two species: mRNA ($r$) , protein ($p$).

<table>
<td> 
<img src="https://github.com/jiadeyu0602/CheatSheet/raw/master/GE_panela.svg" width="120"/> <br>
</td> 
</table>

The model can be written as the chemical reactions:
\begin{align}
\begin{split}
DNA\stackrel{k_{r}}{\longrightarrow}r,\quad
r\stackrel{k_{p}}{\longrightarrow}r+p,\quad
r\stackrel{\gamma_{r}}{\longrightarrow}\emptyset,\quad
p\stackrel{\gamma_{p}}{\longrightarrow}\emptyset, 
\end{split}
\end{align}
where $k_r$, $k_p$, $\gamma_r$ and $\gamma_p$ are rate constants.

The concentrations of mRNA and protein are given by the following ODEs: 

\begin{align}
\frac{dr}{dt}=k_r-\gamma_{r}r\\
\frac{dp}{dt}=k_{p}r-\gamma_{p}p
\end{align}

The stoichiometric matrix of this chemical reaction is written as

$$V=
\begin{bmatrix} 
    1 & 0 & -1 & 0 \\ 
    0 & 1 & 0 & -1 
\end{bmatrix}.
$$

## 2. Create a GeneExp.py
Input gene expression system as a `.py` file. You can change the initial conditions, species number constraint, reaction rates and the stoichiometric matrix in the function 'rates' . **For a new system, copy GeneExp.py and change the below parameters: Nothing needs to be changed in function 'init' and 'Propensity'.**

In [None]:
import numpy as np
import torch

class GeneExp:
    
    def rates(self):  
        
        self.L=2
        IniDistri='delta'
        initialD=np.array([0,0]).reshape(1,self.L) # the parameter for the initial delta distribution
        r=torch.zeros(4) #Reaction rates
        r[0] = 0.1 #kr
        r[1] = 0.1 #kp
        r[2] = 0.1 #yr
        r[3] = 0.002 #yp
        
        # Stoichiometric matrix = ReactionMatRight - ReactionMatLeft #SpeciesXReactions    
        ReactionMatLeft=torch.as_tensor([(0, 1,1,0), (0,0,0,1)]).to(self.device)#SpeciesXReactions
        ReactionMatRight=torch.as_tensor([(1, 1,0,0), (0,1,0,0)]).to(self.device)#SpeciesXReactions

        MConstrain=np.zeros(1,dtype=int)
        conservation=np.ones(1,dtype=int)
        
        return IniDistri,initialD,r,ReactionMatLeft,ReactionMatRight,MConstrain,conservation

## 3. Implement the code after providing the GeneExp.py file

PC users (Windows): Run `MasterEq.py` after necessary changes according to the gene expression example. Adjust the hyperparameters if you need. After this step, you can get the system data by the VAN, which contains the joint distribution information between species.

In [None]:
###MasterEq.py
from args import args
import numpy as np
from main import Test

###Add models----------------------------------
from GeneExp import GeneExp #import the GeneExp class

##Set parameters-------------------------------
args.Model='GeneExp' #Model name
args.L=2 #Species number
args.M=int(100) #Upper limit of the molecule number
args.batch_size=1000 #Number of batch samples
args.Tstep=1001# Time step of iterating the chemical master equation
args.delta_t=0.1 #Time step length

args.net ='rnn'
args.max_stepAll=5000 #Maximum number of steps first time ste
args.max_stepLater=100 #Maximum number of steps of later time steps
args.net_depth=1 # including output layer and not input data layer
args.net_width=16
args.print_step=20

###Add model command------------------------------- 
if args.Model=='GeneExp': 
    model = GeneExp(**vars(args))   

#Run model-----------------------------------------        
if __name__ == '__main__':
    Test(model)    


## 4. Optional: Run Gillespie simulation
To evaluate the accuracy of the learnt distribution by the VAN, we can compare the resultant marginal distribution of one species with those from Gillespie algorithm.

In [None]:
import numpy as np
import biocircuits
import matplotlib.pyplot as plt

def GeneExp_propensity(
    propensities, population, t, kr, kp, yr, yp
):
    #species
    r, p = population
    propensities[0] = kr
    propensities[1] = kp*r
    propensities[2] = yr*r
    propensities[3] = yp*p

#the stoichiometric matrix       
GeneExp_update = np.array(
    [
        [ 1, 0 ], 
        [ 0, 1 ],
        [-1, 0 ],
        [ 0,-1 ], 
    ],dtype=int)

#the reaction rates 
kr = 0.1 #kr
kp = 0.1 #kp
yr = 0.1 #yr
yp = 0.002 #yp 

GeneExp_args = (kr, kp, yr, yp)

#initial number of species 
r0=0
p0=0
GeneExp_pop_0 = np.array([r0,p0], dtype=float) # follow VAN's learnt initial number

#simulation time length
T=3600
time_points = np.linspace(0, T, int(T/60))

In [None]:
Run=1 #run Gillespie or not
times=1000 #Gillespie simulation times

out_filename = 'GeneExp_times'+str(times)+'_T'+str(T)+'_dis'+str(r0)+'_'+str(p0) #filename to save
if Run==1:
    
    r_total=np.empty(shape=(0,len(time_points)))#to save the time evolution of mRNA number in each simulation (dimension: times*time_points)
    p_total=np.empty(shape=(0,len(time_points)))#to save the time evolution of protein number in each simulation (dimension: times*time_points)
    
    for i in range(times):
        
        # Perform the Gillespie simulation
        pop = biocircuits.gillespie_ssa(
            GeneExp_propensity,
            GeneExp_update,
            GeneExp_pop_0,
            time_points,
            args=GeneExp_args,
        )
        
        r_total=np.row_stack((r_total,pop[0,:,0]))
        p_total=np.row_stack((p_total,pop[0,:,1]))
        
    np.savez('{}'.format(out_filename),np.array(times),np.array(time_points),r_total,p_total) #sava data    
else: #load existing data file
    data=np.load(out_filename+'.npz')
    print(list(data))    
    time_points = data['arr_1']
    r_total = data['arr_2']
    p_total = data['arr_3']

## 5. Plot Results
Plot the result of the gene expression. Relevant details can be referred in the Supplement Information of the manuscript.

In [2]:
# load data of the VAN and Gillespie
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

jet = cm.get_cmap('jet')
jet_12_colors = jet(np.linspace(0, 1, 15))
plt.rc('font', size=48)

path1="GeneExpression\GeneExp1_times10000_T3600_dis0_0.npz" # Gillespie data path
path2="GeneExpression\Data_GeneExp1_M100_T36001_dt0.1_batch1000.npz" # VAN data path

data1=np.load(path1)
times = data1['arr_0']
time_points = data1['arr_1']
tfinal = time_points[-1]
step = time_points[1]-time_points[0]

rna_total = data1['arr_2']
prot_total = data1['arr_3']
    
rna_total_mean=np.mean(rna_total,0)
rna_total_std=np.std(rna_total,0)
prot_total_mean=np.mean(prot_total,0)
prot_total_std=np.std(prot_total,0)

data2=np.load(path2, allow_pickle=True)
argsSave = data2['arr_1']
delta_t=argsSave[1]
print_step= argsSave[7]
SampleSum=data2['arr_5']
delta_T= data2['arr_6']
TimePoins=np.cumsum(delta_T)[np.arange(SampleSum.shape[0])*print_step]*delta_t

Here only show how to plot the time evolution of the average counts of species. More plots can be referred to in other jupyter notebook.

In [None]:
# The time evolution of the average counts for the genes and proteins,  from the VAN (dots) and the Gillespie simulation (lines). 
#####curve-----------------------------------

markersize0=12

plt.figure(num=None,  dpi=400, edgecolor='k', linewidth=8)
fig, axes = plt.subplots(1,1)
plt.plot(time_points/3600,rna_total_mean,linewidth=5,color=jet_12_colors[3,:])
plt.plot(TimePoins[0:-1:13]/3600,np.mean(SampleSum[:,:,0],axis=1)[0:-1:13],
          marker='o',linestyle = 'None',color=jet_12_colors[3,:],markersize=markersize0)
plt.xlabel("Time (hr)")
plt.ylabel("mRNA")
plt.ylim(top=1.2)
fig.set_size_inches(9, 8)
plt.title('Average Count',fontsize=56)
plt.savefig('GE_panela1.svg', bbox_inches="tight", dpi=400)

plt.figure(num=None,  dpi=400, edgecolor='k', linewidth=8)
fig, axes = plt.subplots(1,1)
plt.plot(time_points/3600,prot_total_mean,linewidth=5,color=jet_12_colors[12,:])
plt.plot(TimePoins[0:-1:13]/3600,np.mean(SampleSum[:,:,1],axis=1)[0:-1:13],
          marker='o',linestyle = 'None',color=jet_12_colors[12,:],markersize=markersize0)
plt.xlabel("Time (hr)")
plt.ylabel("Protein")
plt.title('Average Count',fontsize=56)
plt.ylim(top=55)
fig.set_size_inches(9, 8)
plt.savefig('GE_panela2.svg', bbox_inches="tight", dpi=400)

<table>
<td> 
<img src="https://github.com/jiadeyu0602/CheatSheet/raw/master/GE_panela1.svg" width="200"/> <br>
</td> 
<td> 
<img src="https://github.com/jiadeyu0602/CheatSheet/raw/master/GE_panela2.svg" width="200"/> <br>
</td> 
</table>