In [None]:
# Loading all the packages we will need
using Catalyst, MomentClosure, OrdinaryDiffEq, JumpProcesses, 
      DiffEqBase.EnsembleAnalysis, Plots
using Plots.Measures: mm

# → for a mass-action rate
# ⇒ for a non mass-action rate 
rn = @reaction_network begin
    @parameters k₁ k₂ k₃ k₄ k₅ k₆ k₇
    (k₁), 0 → x
    (k₂), x → 0
    (k₃*x*y/(x+k₇)), x ⇒ 0
    (k₄*x), 0 ⇒ y₀
    (k₅), y₀ → y
    (k₆), y → 0
end

In [None]:
speciesmap(rn)

In [None]:
# parameter values [k₁, k₂, k₃, k₄, k₅, k₆, k₇]
pmap = [:k₁ => 90, :k₂ => 0.002, :k₃ => 1.7, :k₄ => 1.1, :k₅ => 0.93, :k₆ => 0.96, :k₇ => 0.01]

# initial molecule numbers of species [x, y₀, y]
u0map = [:x=> 70, :y₀ => 30, :y => 60]

# time interval to solve one on
tspan = (0., 200.)

In [None]:
jsys = convert(JumpSystem, rn, combinatoric_ratelaws=false)
jsys = complete(jsys)
dprob = DiscreteProblem(jsys, u0map, tspan, pmap)

jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))
ensembleprob  = EnsembleProblem(jprob)

@time sol_SSA = solve(ensembleprob, SSAStepper(), saveat=0.2, trajectories=1000);

In [None]:
#using JLD2 

#@save "sol_SSA_long_run.jld2" sol_SSA
#@load "sol_SSA_long_run.jld2" sol_SSA

In [None]:
# plot single SSA trajectory
plot(sol_SSA[666], labels=["p53" "pre-Mdm2" "Mdm2"], lw=2, tspan=(0, 100),
     linecolor=[1 3 2], xlabel="Time [h]", ylabel="Number of molecules", size=(700, 400))

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_single_SSA.svg")

In [None]:
# obtain SSA means and variances
means_SSA, vars_SSA = timeseries_steps_meanvar(sol_SSA)
plot(means_SSA, labels=["p53" "pre-Mdm2" "Mdm2"], lw=2, linecolor=[1 3 2],
     xlabel="Time [h]", ylabel="Number of molecules", size=(700, 400))

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_means_SSA.svg")

In [None]:
# plot SSA molecule number distributions at t = 25.0
data = componentwise_vectors_timepoint(sol_SSA, 25.0)
h1 = histogram(data[1], normalize=true, xlabel="x", ylabel="P(x)")
h2 = histogram(data[2], normalize=true, xlabel="y₀", ylabel="P(y₀)")
h3 = histogram(data[3], normalize=true, xlabel="y", ylabel="P(y)")

plot(h1, h2, h3, legend=false, layout=(1,3), size = (1050, 250), guidefontsize=10, left_margin = 5mm, bottom_margin = 7mm)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_distribution.svg")

In [None]:
# unload sol_SSA to save RAM
sol_SSA = nothing

In [None]:
# second-order moment expansion

closures = ["normal", "log-normal", "gamma"]

# initialise separate plot for each closure
plts = [plot() for i in 1:length(closures)]

for q in 3:6
    println(q)
    eqs = generate_central_moment_eqs(rn, 2, q, combinatoric_ratelaws=false)
    for (closure, plt) in zip(closures, plts)
        println(closure)
        closed_eqs = moment_closure(eqs, closure)
        oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
        
        sol = solve(oprob, Tsit5(), saveat=0.1)
        plt = plot!(plt, sol, idxs=[1], lw=2, label  = "q = "*string(q))
    end
end

for plt in plts
    plt = plot!(plt, xlabel = "Time [h]", ylabel = "Mean number of p53 molecules")
    plt = plot!(plt, means_SSA.t, means_SSA[1,:], lw=2, linestyle=:dash, label = "SSA", color="gray")
end

In [None]:
# normal closure
plot(plts[1], size=(750, 450), leftmargin=2mm)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_normal_2nd_order.svg")

In [None]:
# zoom-in on the initial dampening
plot(plts[1], xlims=(0., 40.))

In [None]:
# log-normal closure
plot(plts[2], leftmargin=2mm, size=(750, 450))

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_log-normal_2nd_order.svg")

In [None]:
plot(plts[2], xlims=(0., 50.), lw=3)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_log-normal_2nd_order_ZOOM.svg")

In [None]:
# gamma closure
plot(plts[3], size=(750, 450), leftmargin=2mm)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_gamma_2nd_order.svg")

In [None]:
plot(plts[3], xlims=(0., 40.), lw=3)

In [None]:
# simply rerunning the same calculations for variance as they are quite fast
plt = plot()

for q in [4,6]
    println(q)
    eqs = generate_central_moment_eqs(rn, 2, q, combinatoric_ratelaws=false)
    for closure in closures
        println(closure)
        closed_eqs = moment_closure(eqs, closure)
        oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
        sol = solve(oprob, Tsit5(), saveat=0.1)
        # index of M₂₀₀ can be checked with `unknowns(closed_eqs)`
        plt = plot!(plt, sol, idxs=[4], lw=2, label  = closure*" q = "*string(q))
    end
end

plt = plot!(plt, xlabel = "Time [h]", ylabel = "Variance of p53 molecule number", legend=:topleft)
plt = plot!(plt, means_SSA.t, vars_SSA[1,:], lw=2, linestyle=:dash, label = "SSA", color="gray")
plot(plt, size=(750, 450))

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_variances_2nd_order.svg")

In [None]:
# checking whether third-order moment expansion with odd q values is unstable (answer: yes it is)
eqs = generate_central_moment_eqs(rn, 3, 5, combinatoric_ratelaws=false)
closed_eqs = moment_closure(eqs, "log-normal")
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)

sol = solve(oprob, Tsit5(), saveat=0.1)
plot(sol, vars=(0, 1), lw=3)

In [None]:
closures = ["zero", "normal", "log-normal", "gamma"]

plt_means = [plot() for i in 1:2]
plt_vars  = [plot() for i in 1:2]

m = 3
q_vals = [4, 6]

for (q, plt_m, plt_v) in zip(q_vals, plt_means, plt_vars)

    eqs = generate_central_moment_eqs(rn, m, q, combinatoric_ratelaws=false)
    for closure in closures
        println(closure)
        closed_eqs = moment_closure(eqs, closure)
        oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)

        sol = solve(oprob, Tsit5(), saveat=0.1)
        plt_m = plot!(plt_m, sol, vars=(0, 1), label = closure)    
        plt_v = plot!(plt_v, sol, vars=(0, 4), label = closure)

    end

    plt_m = plot!(plt_m, means_SSA.t, means_SSA[1,:], title="m = "*string(m)*", q = "*string(q),
                  linestyle=:dash, label = "SSA", color="gray", legend=false)

    plt_v = plot!(plt_v, vars_SSA.t, vars_SSA[1,:], linestyle=:dash, label = "SSA", color="gray", legend=false)

end

plt_means[1] = plot(plt_means[1], ylabel = "Mean p53 molecule number")
plt_vars[1] = plot(plt_vars[1], ylabel = "Variance of p53 molecule number", legend=:topleft)
plot(plt_means..., plt_vars..., size=(1250, 750), lw=1.5, xlabel="Time [h]",
     guidefontsize=10, titlefontsize=12, legendfontsize=8, leftmargin=4mm, bottommargin=2mm)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_3rd_order_expansion.svg")

In [None]:
# checking the trajectories using the QNDF (or ode15s) solve which is the default MEANS solver
# slight differences remaning between the trajectory obtained here and the one obtained using MEANS
# indicate that the difference lies in the implementation (inclusion of higher-order moment 
# information in the closure functions)

eqs = generate_central_moment_eqs(rn, 3, 4, combinatoric_ratelaws=false)
closed_eqs = moment_closure(eqs, "log-normal")
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)

sol = solve(oprob, QNDF(), saveat=0.1)
plot(sol, vars=(0, 1), lw=3)

In [None]:
plt = plot()
closures = ["zero", "normal", "log-normal", "gamma"]

eqs = generate_central_moment_eqs(rn, 5, 6, combinatoric_ratelaws=false)
# faster to store than recompute in case we want to try different solvers/params
oprobs = Dict() 

for closure in closures
    println(closure)
    closed_eqs = moment_closure(eqs, closure)
    oprobs[closure] = ODEProblem(closed_eqs, u0map, tspan, pmap)
    sol = solve(oprobs[closure], Tsit5(), saveat=0.1)

    plt = plot!(plt, sol, vars=(0, 1), label = closure)    
end

plt = plot!(plt, xlabel = "Time [h]", ylabel = "Mean p53 molecule number")
plt = plot!(plt, means_SSA.t, means_SSA[1, :], linestyle=:dash, label = "SSA", color="gray")
plot(plt, size=(750, 450), lw=2, xlims=tspan)

In [None]:
#savefig("../docs/src/assets/p53-Mdm2_5th_order_expansion.svg")

In [None]:
# obtain both means and variances for fifth order moment expansion
#=
plt_m = plot()
plt_v = plot()

closures = ["zero", "normal", "log-normal", "gamma"]

eqs = generate_central_moment_eqs(rn, 5, 6, combinatoric_ratelaw=false)
# faster to store than recompute in case we want to try different solvers/params
oprobs = Dict() 

for closure in closures
    println(closure)
    closed_eqs = moment_closure(eqs, closure)

    u₀map = deterministic_IC(u₀, closed_eqs) 
    oprobs[closure] = ODEProblem(closed_eqs, u₀map, tspan, p)
    sol = solve(oprobs[closure], Tsit5(), saveat=0.1)

    plt_m = plot!(plt_m, sol, vars=(0, 1), label = closure)    
    plt_v = plot!(plt_v, sol, vars=(0, 4), label = closure)
end

plt_m = plot!(plt_m, xlabel = "Time [h]", ylabel = "Mean p53 molecule number", legend=false)
plt_m = plot!(plt_m, means_SSA.t, means_SSA[1, :], linestyle=:dash, label = "SSA", color="gray")

plt_v = plot!(plt_v, xlabel = "Time [h]", ylabel = "Variance of p53 molecule number", legend=:bottomleft)
plt_v = plot!(plt_v, vars_SSA.t, vars_SSA[1, :], linestyle=:dash, label = "SSA", color="gray")
plot(plt_m, plt_v, size=(1200, 400), lw=2, leftmargin=5mm, bottommargin=5mm, guidefontsize=10, legendfontsize=10)
=#

In [None]:
#=
plt_m = plot(xlabel = "Time [h]", ylabel = "Mean p53 molecule number")
for closure in closures
    oprob_long = remake(oprobs[closure], tspan=(0., 150.))
    sol = solve(oprob_long, Tsit5(), saveat=0.1)
    plt_m = plot!(plt_m, sol, vars=(0,1), label=closure)
end
plot(plt_m, lw=2)
=#