In [None]:
using MomentClosure, Catalyst, Distributions, JumpProcesses, DiffEqBase, OrdinaryDiffEq, DiffEqBase.EnsembleAnalysis, Plots, Latexify

In [None]:
# Proteins are produced in bursts of size m,
# where m is a geometric random variable with mean b.
# Note that if b is the mean burst size, then p = 1/(1+b).
# Implemented by first registering the distribution with Symbolics
@register_symbolic Distributions.Geometric(b)
@parameters b
m = rand(Distributions.Geometric(1/(1+b)))

rn = @reaction_network begin
      @parameters k_on k_off k_p γ_p
      k_on*(1-g), 0 --> g  # G* -> G
      k_off*P^2, g --> 0   # G -> G*
      k_p, g --> g + $m*P  # G -> G + mP, m ~ Geometric(p)
      γ_p, P --> 0         # P -> ∅
end

In [None]:
eqs = generate_raw_moment_eqs(rn, 3)
latexify(eqs)

In [None]:
println(latexify(eqs))

In [None]:
@parameters p
m = rand(Distributions.Geometric(p))

rn = @reaction_network begin
    @parameters k_on k_off k_p γ_p
    k_on*(1-g), 0 --> g
    k_off*P^2, g --> 0
    k_p, g --> g + $m*P
    γ_p, P --> 0
end

eqs = generate_raw_moment_eqs(rn, 3)
latexify(eqs)

In [None]:
println(latexify(eqs))

In [None]:
# specify the indices of species which molecule numbers are binary
binary_vars = [1]
# simplify the moment equations using properties of Bernoulli variables
clean_eqs = bernoulli_moment_eqs(eqs, binary_vars)
latexify(clean_eqs)

In [None]:
println(latexify(clean_eqs))

In [None]:
normal_eqs = moment_closure(eqs, "normal", binary_vars)
latexify(normal_eqs, :closure)

In [None]:
println(latexify(normal_eqs, :closure))

In [None]:
dm_eqs = moment_closure(eqs, "derivative matching", binary_vars)
latexify(dm_eqs, :closure)

In [None]:
println(latexify(dm_eqs, :closure))

In [None]:
cond_gaussian_eqs = moment_closure(eqs, "conditional gaussian", binary_vars)
latexify(cond_gaussian_eqs, :closure)

In [None]:
println(latexify(cond_gaussian_eqs, :closure))

In [None]:
cond_dm_eqs = moment_closure(eqs, "conditional derivative matching", binary_vars)
latexify(cond_dm_eqs, :closure)

In [None]:
println(latexify(dm_eqs, :closure))

In [None]:
# PARAMETER INITIALISATION
mean_p = 200
mean_b = 70
p_val = 1/(1+mean_b)
γ_p_val = 1
k_off_val = 0.001
k_on_val = 0.05

k_p_val = mean_p * γ_p_val * (k_off_val * mean_p^2 + k_on_val) / (k_on_val * mean_b)

pmap = [:k_on => k_on_val,
          :k_off => k_off_val,
          :k_p => k_p_val,
          :γ_p => γ_p_val,
          :p => p_val]

u₀ = [1, 1]

tspan = (0., 6.0);

In [None]:
# convert the reaction network into a system of jump processes
jsys = convert(JumpSystem, rn; combinatoric_ratelaws=false)
jsys = complete(jsys)

# create a discrete problem setting the simulation parameters
dprob = DiscreteProblem(jsys, u₀, tspan, pmap)

# create a JumpProblem compatible with ReactionSystemMod
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))

# simulate 2×10⁴ SSA trajectories
ensembleprob  = EnsembleProblem(jprob)
@time sol_SSA = solve(ensembleprob, SSAStepper(), saveat=0.1, trajectories=20000)
# compute the means and variances
means_ssa, vars_ssa = timeseries_steps_meanvar(sol_SSA);

In [None]:
plt_m = plot()   # plot mean protein number
plt_std = plot() # plot ssd of protein number

# construct the initial molecule number mapping
u₀map = deterministic_IC(u₀, dm_eqs)

# solve moment ODEs for each closure and plot the results
for closure in ["normal", "derivative matching", 
                "conditional gaussian", "conditional derivative matching"]
    
    # it is very quick so we just apply all closures again
    closed_eqs = moment_closure(eqs, closure, binary_vars)
    
    # solve the system of moment ODEs
    oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
    sol = solve(oprob, AutoTsit5(Rosenbrock23()), saveat=0.01)
    
    # μ₀₁ is 2nd and μ₀₂ is 4th element in sol
    plt_m = plot!(plt_m, sol, idxs=[2], label=closure)
    plt_std = plot!(plt_std, sol.t, sqrt.(sol[4, :] .- sol[2, :].^2), label=closure)
    
end

plt_m = plot!(plt_m, xlabel="Time [hr]", ylabel="Protein mean level")
plt_m = plot!(plt_m, means_ssa.t, means_ssa[2,:], label="SSA", linestyle=:dash, color="gray")
plt_std = plot!(plt_std, xlabel="Time [hr]", ylabel="Protein standard deviation")
plt_std = plot!(plt_std, vars_ssa.t, sqrt.(vars_ssa[2,:]), label="SSA", linestyle=:dash, color="gray");

In [None]:
plot(plt_m, lw=2)

In [None]:
#savefig("../docs/src/assets/gene_1_means.svg")

In [None]:
plot(plt_std, lw=2)

In [None]:
#savefig("../docs/src/assets/gene_1_stds.svg")

In [None]:
@register_symbolic Distributions.Geometric(b)
@parameters b_x b_y
m = rand(Distributions.Geometric(1/(1+b_x)))
l = rand(Distributions.Geometric(1/(1+b_y)))

rn = @reaction_network begin
    @parameters kx_on kx_off ky_on ky_off k_x γ_x k_y γ_y
    kx_on*(1-g_x)*y, 0 --> g_x  # 0   -> g_x
    kx_off,          g_x --> 0  # g_x -> 0
    ky_on*(1-g_y),   0 --> g_y  # 0 -> g_y
    ky_off*x,        g_y --> 0  # g_y -> 0
    k_x*g_x,         0 --> $m*x # 0 -> mx, m ~ Geometric(mean=b_x)
    γ_x,             x --> 0    # x -> 0
    k_y*g_y,         0 --> $l*y # 0 -> ly, l ~ Geometric(mean_b_y)
    γ_y,             y --> 0    # y -> 0
end

# both g_x and g_y are Bernoulli random variables
binary_vars = [1, 2];

In [None]:
# Parameter initialisation

mean_x = 100
mean_y = 100
mean_b_x = 5
mean_b_y = 5
γ_x_val = 1
γ_y_val = 1
kx_off_val = 4
ky_on_val = 0.3
kx_on_val = 0.05
ky_off_val = 0.05

k_x_val = mean_x * γ_x_val * (kx_off_val * mean_y^2 + kx_on_val) / (kx_on_val * mean_b_x)
k_y_val = mean_y * γ_y_val * (ky_off_val * mean_x^2 + ky_on_val) / (ky_on_val * mean_b_y)

# unclear if Soltani et al. (2015) actually used this parameter set as X numbers jump to millions
# making SSA extremely slow...

# introduce additional rescaling (otherwise rate coefficients are too high)
k_x_val *= 0.00003
k_y_val *= 0.01

# parameter mapping
pmap = [:kx_on => kx_on_val,
          :kx_off => kx_off_val,
          :ky_on => ky_on_val,
          :ky_off => ky_off_val,
          :k_x => k_x_val,
          :k_y => k_y_val,
          :γ_x => γ_x_val,
          :γ_y => γ_y_val,
          :b_x => mean_b_x,
          :b_y => mean_b_y]

# initial gene state and protein number, order [g_x, g_y, x, y]
u₀ = [1, 1, 1, 1]

# time interval to solve on
tspan = (0., 12.0);

In [None]:
eqs = generate_raw_moment_eqs(rn, 4);

In [None]:
# can compare to results in Soltani et al. (2015)
closed_eqs = moment_closure(eqs, "derivative matching", binary_vars)
latexify(closed_eqs, :closure)

In [None]:
closed_eqs = moment_closure(eqs, "conditional derivative matching", binary_vars)
latexify(closed_eqs, :closure)

In [None]:
closed_eqs = moment_closure(eqs, "normal", binary_vars)
latexify(closed_eqs, :closure); # very long

In [None]:
closed_eqs = moment_closure(eqs, "conditional gaussian", binary_vars)
latexify(closed_eqs, :closure)

In [None]:
using Sundials # for CVODE_BDF

closed_eqs = moment_closure(eqs, "normal", binary_vars)
u₀map = deterministic_IC(u₀, closed_eqs)

oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
sol = solve(oprob, CVODE_BDF(), saveat=0.1);

In [None]:
closed_eqs = moment_closure(eqs, "conditional gaussian", binary_vars)
u₀map = deterministic_IC(u₀, closed_eqs)

oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
sol = solve(oprob, CVODE_BDF(), saveat=0.1);

In [None]:
jsys = convert(JumpSystem, rn, combinatoric_ratelaws=false)
jsys = complete(jsys)
dprob = DiscreteProblem(jsys, u₀, tspan, pmap)
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))

ensembleprob  = EnsembleProblem(jprob)
@time sol_SSA = solve(ensembleprob, SSAStepper(), saveat=0.1, trajectories=10000)
means_ssa, vars_ssa = timeseries_steps_meanvar(sol_SSA);

In [None]:
plt_m = plot()   # plot mean activator protein number
plt_std = plot() # plot ssd of activator protein number

for closure in ["derivative matching", "conditional derivative matching"]

    closed_eqs = moment_closure(eqs, closure, binary_vars)

    u₀map = deterministic_IC(u₀, closed_eqs)
    oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
    sol = solve(oprob, Tsit5(), saveat=0.1)

    # μ₀₀₀₁ is the 4th and μ₀₀₀₂ is the 12th element in sol (can check with closed_eqs.odes.states)
    plt_m = plot!(plt_m, sol, vars=(0, 4), label=closure)
    plt_std = plot!(plt_std, sol.t, sqrt.(sol[12, :] .- sol[4, :].^2), label=closure)
end

plt_m = plot!(plt_m, xlabel="Time [hr]", ylabel="Activator mean level")
plt_m = plot!(plt_m, means_ssa.t, means_ssa[4,:], label="SSA", linestyle=:dash, color="gray")
plt_std = plot!(plt_std, xlabel="Time [hr]", ylabel="Activator standard deviation")
plt_std = plot!(plt_std, vars_ssa.t, sqrt.(vars_ssa[4,:]), label="SSA", linestyle=:dash, color="gray");

In [None]:
plot(plt_m, lw=2)

In [None]:
#savefig("../docs/src/assets/gene_2_means.svg")

In [None]:
plot(plt_std, lw=2, xlims=(0., 12.))

In [None]:
#savefig("../docs/src/assets/gene_2_stds.svg")