In [None]:
using MomentClosure, Latexify, OrdinaryDiffEq, Catalyst

$$ G \stackrel{c_1}{\rightarrow} G+P, \\
   G^* \stackrel{c_2}{\rightarrow} G^*+P, \\
   P \stackrel{c_3}{\rightarrow} 0 \\
   G+P \underset{c_5}{\stackrel{c_4}{\rightleftharpoons}}  G^* $$
   
On/off gene states are merged into a Bernoulli variable $g(t)$ which can be either $1$ ($G$) or $0$ ($G^*$). The number of proteins in the system is given by $p(t)$. 

### Using Catalyst.jl `ReactionSystem`

* $\rightarrow$ indicates a reaction that follows the law of mass action (need to indicate only the reaction coefficient, full propensity function is constructed automatically)
* $\Rightarrow$ indicates a reaction that does not follow the law of mass action (need to define the full propensity function)

In [None]:
@parameters c₁, c₂, c₃, c₄, c₅
rn = @reaction_network begin
    (c₁), g → g+p
    (c₂*(1-g)), 0 ⇒ p
    (c₃), p → 0
    (c₄), g+p → 0
    (c₅*(1-g)), 0 ⇒ g+p
end

Check the stoichiometry matrix and propensity functions:

In [None]:
propensities(rn, combinatoric_ratelaws=false)

In [None]:
netstoichmat(rn)

### Moment equations

Generate raw moment equations up to 3rd order.

The argument `combinatoric_ratelaws = false` indicates whether binomial coefficients are included when constructing the propensity functions for the reactions that follow the law of mass action (does not play a role in this specific scenarion)

Equivalently, central moment equations can be generated using `generate_central_moment_eqs(rn, 3, 5, combinatoric_ratelaws=false)`

In [None]:
raw_eqs = generate_raw_moment_eqs(rn, 3, combinatoric_ratelaws=false)
latexify(raw_eqs)

We are solving for moments up to `m_order = 3`, and in the equations encounter moments up to `exp_order = 5`. 

Use the Bernoulli variable properties to eliminate redundant equations to see how they simplify:

In [None]:
binary_vars = [1]
bernoulli_eqs = bernoulli_moment_eqs(raw_eqs, binary_vars)
latexify(bernoulli_eqs)

### Closing the moment equations

Finally, we can apply the selected moment closure method on the system of raw moment equations:

In [None]:
closed_raw_eqs = moment_closure(raw_eqs, "conditional derivative matching", binary_vars)
latexify(closed_raw_eqs)

We can also print out the closure functions for each higher order moment:

In [None]:
latexify(closed_raw_eqs, :closure)

### Numerical solution

The closed moment equations can be solved using DifferentialEquations.jl (or just OrdinaryDiffEq.jl which is more lightweight and sufficient for this particular case. 

In [None]:
# PARAMETER INITIALISATION
pmap = [c₁ => 0.01,
        c₂ => 40,
        c₃ => 1,
        c₄ => 1,
        c₅ => 1]

# DETERMINISTIC INITIAL CONDITIONS
u0map = [:g => 1., :p => 0.001]

# time interval to solve on
tspan = (0., 1000.0)
dt = 1

@time oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap);
@time sol_CDM = solve(oprob, Tsit5(), saveat=dt);

In [None]:
using Plots

plot(sol_CDM.t, sol_CDM[1,:], 
    label  = "CDM", 
    legend = true,
    xlabel = "Time [s]",
    ylabel = "Mean gene number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)

In [None]:
plot(sol_CDM.t, sol_CDM[2,:], 
    label  = "CDM",
    legend = :bottomright,
    xlabel = "Time [s]",
    ylabel = "Mean protein number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)

In [None]:
std_CDM = sqrt.(sol_CDM[4,2:end] .- sol_CDM[2,2:end].^2)
plot(sol_CDM.t[2:end], std_CDM, 
    label  = "CDM", 
    legend = true,
    xlabel = "Time [s]",
    ylabel = "standard deviation of the protein number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)

### SSA

In [None]:
using JumpProcesses

# initial conditions [g, p]
u0map = [:g => 1, :p => 0]

# time interval to solve on
tspan = (0., 1000.)

# create a discrete problem to encode that our species are integer valued
dprob = DiscreteProblem(rn, u0map, tspan, pmap)

# create a JumpProblem and specify Gillespie's Direct Method as the solver:
jprob = JumpProblem(rn, dprob, Direct(), save_positions=(false, false))
# SET save_positions to (false, false) as otherwise time of each reaction occurence is saved

dt = 1 # time resolution at which numerical solution is saved

# solve and plot
ensembleprob  = EnsembleProblem(jprob)
@time sol_SSA = solve(ensembleprob, SSAStepper(), saveat=dt, trajectories=1000);

Can compute all sample moments up to chosen order

In [None]:
@time SSA_μ = get_raw_moments(sol_SSA, 2);
@time SSA_M = get_central_moments(sol_SSA, 2);

In [None]:
plot(sol_CDM.t, [sol_CDM[1,:], SSA_μ[1,0]], 
    label  = ["CDM" "SSA"], 
    legend = true,
    xlabel = "Time [s]",
    ylabel = "Mean gene number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)

In [None]:
plot(sol_CDM.t, [sol_CDM[2,:], SSA_μ[0,1]], 
    label  = ["CDM" "SSA"],
    legend = :bottomright,
    xlabel = "Time [s]",
    ylabel = "Mean protein number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)

In [None]:
std_CDM = sqrt.(sol_CDM[4,2:end] .- sol_CDM[2,2:end].^2)
std_p_SSA = sqrt.(SSA_M[0,2][2:end])
plot(sol_CDM.t[2:end], [std_CDM, std_p_SSA], 
    label  = ["CDM" "SSA"], 
    legend = true,
    xlabel = "Time [s]",
    ylabel = "standard deviation of the protein number",
    lw=2,
    legendfontsize=8,
    xtickfontsize=10,
    ytickfontsize=10,
    dpi=100)