Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code cleanup and restructuring #27

Merged
merged 2 commits into from Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 3 additions & 10 deletions README.md
Expand Up @@ -16,8 +16,6 @@ using Petri
using LabelledArrays
using OrdinaryDiffEq
using Plots
using Catlab.Graphics.Graphiz
import Catlab.Graphics.Graphviz: Graph
```

The SIR model represents the epidemiological dynamics of an infectious disease that causes immunity in its victims. There are three *states:* `Suceptible ,Infected, Recovered`. These states interact through two *transitions*. Infection has the form `S+I -> 2I` where a susceptible person meets an infected person and results in two infected people. The second transition is recovery `I -> R` where an infected person recovers spontaneously.
Expand All @@ -37,11 +35,8 @@ u0 = LVector(S=100.0, I=1, R=0)
# define the parameters of the model, each rate corresponds to a transition
p = LVector(inf=0.05, rec=0.35)

# evaluate the expression to create a runnable function
f = toODE(sir)

# this is regular OrdinaryDiffEq problem setup
prob = ODEProblem(f,u0,(0.0,365.0),p)
prob = ODEProblem(sir,u0,(0.0,365.0),p)
sol = OrdinaryDiffEq.solve(prob,Tsit5())

# generate a graphviz visualization of the model
Expand All @@ -65,8 +60,7 @@ seir = Petri.Model([:S,:E,:I,:R],LVector(
rec=(LVector(I=1), LVector(R=1))))
u0 = LVector(S=100.0, E=1, I=0, R=0)
p = (exp=0.35, inf=0.05, rec=0.05)
f = toODE(seir)
prob = ODEProblem(f,u0,(0.0,365.0),p)
prob = ODEProblem(seir,u0,(0.0,365.0),p)
sol = OrdinaryDiffEq.solve(prob,Tsit5())
plt = plot(sol)
```
Expand All @@ -85,8 +79,7 @@ seirs = Petri.Model([:S,:E,:I,:R],LVector(
deg=(LVector(R=1), LVector(S=1))))
u0 = LVector(S=100.0, E=1, I=0, R=0)
p = LVector(exp=0.35, inf=0.05, rec=0.07, deg=0.3)
f = toODE(seirs)
prob = ODEProblem(f,u0,(0.0,365.0),p)
prob = ODEProblem(seirs,u0,(0.0,365.0),p)
sol = OrdinaryDiffEq.solve(prob,Tsit5())
plt = plot(sol)
```
Expand Down
Binary file added docs/src/assets/logo.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 1 addition & 3 deletions examples/epidemiologyModels.jl
Expand Up @@ -4,8 +4,6 @@ using LabelledArrays
using StochasticDiffEq
using OrdinaryDiffEq
using Plots
using Catlab.Graphics.Graphviz
import Catlab.Graphics.Graphviz: Graph

@show "SIR"

Expand Down Expand Up @@ -75,4 +73,4 @@ sol = StochasticDiffEq.solve(prob,SRA1(),callback=cb)

plot(sol)

Graph(seird)
Graph(seird)
48 changes: 0 additions & 48 deletions examples/invokelatest.jl

This file was deleted.

60 changes: 3 additions & 57 deletions src/Petri.jl
Expand Up @@ -6,64 +6,10 @@ Provides a modeling framework for representing and solving stochastic petri nets
"""
module Petri

export Model, Problem, NullPetri, solve, vectorfields, NullPetri

function funcindex!(list, key, f, vals...)
setindex!(list, f(getindex(list, key),vals...), key)
end
export Model, Problem, NullPetri, vectorfields, Graph

include("types.jl")

"""
NullPetri(n::Int)

create a Petri net of ``n`` states with no transitions
"""
NullPetri(n::Int) = Model(collect(1:n), Vector{Tuple{Dict{Int, Int},Dict{Int, Int}}}())

"""
solve(p::Problem)

Evaluate petri net problem and return the final state
"""
function solve(p::AbstractProblem)
state = p.initial
for i in 1:p.steps
state = step(p, state)
end
return state
end

function validTransition(state, δ)
ins = first(δ)
all(s->getindex(state,s) >= getindex(ins,s), keys(ins))
end

function step(p::Problem, state)
ks = keys(p.model.Δ)
i = rand(1:length(ks))
δ = getindex(p.model.Δ, getindex(ks, i))
if validTransition(state, δ)
return apply(state, δ)
else
return state
end
end

function apply(state, δ)
ins = first(δ)
outs = last(δ)
out = deepcopy(state)
for k in keys(ins)
funcindex!(out, k, -, getindex(ins, k))
end
for k in keys(outs)
funcindex!(out, k, +, getindex(outs, k))
end
return out
end

include("vectorfields.jl")
include("solvers.jl")
include("visualization.jl")

end #Module
end
47 changes: 26 additions & 21 deletions src/vectorfields.jl → src/solvers.jl
Expand Up @@ -4,11 +4,11 @@ using StochasticDiffEq
import OrdinaryDiffEq: ODEProblem
import StochasticDiffEq: SDEProblem

funcindex!(list, key, f, vals...) = list[key] = f(list[key],vals...)
valueat(x::Number, t) = x
valueat(f::Function, t) = f(t)

"""
vectorfields(m::Model)
""" vectorfields(m::Model)

Convert a petri model into a differential equation function that can
be passed into DifferentialEquation.jl or OrdinaryDiffEq.jl solvers
Expand All @@ -19,46 +19,52 @@ function vectorfields(m::Model)
ϕ = Dict()
f(du, u, p, t) = begin
for k in keys(T)
ins = first(getindex(T, k))
setindex!(ϕ, reduce((x,y)->x*getindex(u,y)/getindex(ins,y), keys(ins); init=valueat(getindex(p, k),t)), k)
ins = first(T[k])
ϕ[k] = reduce((x,y)->x*u[y]/ins[y], keys(ins); init=valueat(p[k],t))
end
for s in S
setindex!(du, 0, s)
du[s] = 0
end
for k in keys(T)
ins = first(getindex(T, k))
outs = last(getindex(T, k))
ins = first(T[k])
outs = last(T[k])
for s in keys(ins)
funcindex!(du, s, -, getindex(ϕ, k) * getindex(ins, s))
funcindex!(du, s, -, ϕ[k] * ins[s])
end
for s in keys(outs)
funcindex!(du, s, +, getindex(ϕ, k) * getindex(outs, s))
funcindex!(du, s, +, ϕ[k] * outs[s])
end
end
return du
end
return f
end

function ODEProblem(m::Model,u0,tspan,β)
return ODEProblem(vectorfields(m), u0, tspan, β)
end
""" ODEProblem(m::Model, u0, tspan, β)

Generate an OrdinaryDiffEq ODEProblem
"""
ODEProblem(m::Model, u0, tspan, β) = ODEProblem(vectorfields(m), u0, tspan, β)

function statecb(s)
cond = (u,t,integrator) -> u[s]
aff = (integrator) -> integrator.u[s] = 0.0
return ContinuousCallback(cond, aff)
end

function SDEProblem(m::Model,u0,tspan,β)
""" SDEProblem(m::Model, u0, tspan, β)

Generate an StochasticDiffEq SDEProblem and an appropriate CallbackSet
"""
function SDEProblem(m::Model, u0, tspan, β)
S = m.S
T = m.Δ
ϕ = Dict()
Spos = Dict(S[k]=>k for k in keys(S))
Tpos = Dict(keys(T)[k]=>k for k in keys(keys(T)))
nu = zeros(Float64, length(S), length(T))
for k in keys(T)
l,r = getindex(T, k)
l,r = T[k]
for i in keys(l)
nu[Spos[i],Tpos[k]] -= l[i]
end
Expand All @@ -69,13 +75,13 @@ function SDEProblem(m::Model,u0,tspan,β)
noise(du, u, p, t) = begin
sum_u = sum(u)
for k in keys(T)
ins = first(getindex(T, k))
ϕ[k] = reduce((x,y)->x*getindex(u,y)/(sum_u*getindex(ins,y)), keys(ins); init=valueat(getindex(p, k),t))
ins = first(T[k])
ϕ[k] = reduce((x,y)->x*u[y]/(sum_u*ins[y]), keys(ins); init=valueat(p[k],t))
end

for k in keys(T)
l,r = getindex(T, k)
rate = sqrt(abs(getindex(ϕ, k)))
l,r = T[k]
rate = sqrt(abs(ϕ[k]))
for i in keys(l)
du[Spos[i],Tpos[k]] = -rate
end
Expand All @@ -85,7 +91,6 @@ function SDEProblem(m::Model,u0,tspan,β)
end
return du
end
prob_sde = SDEProblem(vectorfields(m),noise,u0,tspan,β,noise_rate_prototype=nu)
cb = CallbackSet([statecb(s) for s in S]...)
return prob_sde, cb
return SDEProblem(vectorfields(m),noise,u0,tspan,β,noise_rate_prototype=nu),
CallbackSet([statecb(s) for s in S]...)
end
18 changes: 3 additions & 15 deletions src/types.jl
Expand Up @@ -14,21 +14,9 @@ end

Model(s::S, Δ) where S<:UnitRange = Model(collect(s), Δ)

function ==(x::Petri.Model,y::Petri.Model)
return x.S == y.S && x.Δ == y.Δ
end

abstract type AbstractProblem end

"""
Problem{M<:Model, S, N}
NullPetri(n::Int)

Structure for representing a petri net problem

represented by a petri net model, initial state, and number of steps
create a Petri net of ``n`` states with no transitions
"""
struct Problem{M<:Model, S, N} <: AbstractProblem
model::M
initial::S
steps::N
end
NullPetri(n::Int) = Model(collect(1:n), Vector{Tuple{Dict{Int, Number},Dict{Int, Number}}}())
6 changes: 3 additions & 3 deletions src/visualization.jl
Expand Up @@ -9,7 +9,7 @@ edge_attrs = Attributes(:splines=>"splines")
function edgify(δ, transition, reverse::Bool)
attr = Attributes()
return map(collect(keys(δ))) do k
weight = "$(getindex(δ, k))"
weight = "$(δ[k])"
state = "$k"
attr = Attributes(:label=>weight, :labelfontsize=>"6")
return Edge(reverse ? ["T_$transition", "S_$state"] :
Expand All @@ -29,8 +29,8 @@ function Graph(model::Model)

stmts = vcat(statenodes, transnodes)
edges = map(ks) do k
vcat(edgify(first(getindex(model.Δ, k)), k, false),
edgify(last(getindex(model.Δ, k)), k, true))
vcat(edgify(first(model.Δ[k]), k, false),
edgify(last(model.Δ[k]), k, true))
end |> flatten |> collect
stmts = vcat(stmts, edges)
g = Graphviz.Graph("G", true, stmts, graph_attrs, node_attrs,edge_attrs)
Expand Down
4 changes: 1 addition & 3 deletions test/runtests.jl
Expand Up @@ -20,6 +20,4 @@ using LabelledArrays
@test seir != y
end

include("stochastic.jl")

include("vectorfields.jl")
include("solvers.jl")