Skip to content

Commit

Permalink
Refactor nlsys equation storage
Browse files Browse the repository at this point in the history
Disallow intermediate equations.
  • Loading branch information
HarrisonGrodin committed Feb 2, 2019
1 parent 61f4914 commit 40b0370
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 27 deletions.
5 changes: 2 additions & 3 deletions src/systems/diffeqs/diffeqsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function DiffEqSystem(eqs, iv)
end


function calculate_jacobian(sys::DiffEqSystem, simplify=true)
function calculate_jacobian(sys::DiffEqSystem)
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
rhs = [eq.rhs for eq in sys.eqs]

Expand All @@ -56,15 +56,14 @@ function calculate_jacobian(sys::DiffEqSystem, simplify=true)
end

system_eqs(sys::DiffEqSystem) = collect(Equation, sys.eqs)
system_extras(::DiffEqSystem) = Equation[]
system_vars(sys::DiffEqSystem) = sys.dvs
system_params(sys::DiffEqSystem) = sys.ps


function generate_ode_iW(sys::DiffEqSystem, simplify=true)
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in eachindex(sys.dvs)]
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in eachindex(sys.ps)]
jac = calculate_jacobian(sys, simplify)
jac = calculate_jacobian(sys)

gam = Parameter(:gam)

Expand Down
31 changes: 15 additions & 16 deletions src/systems/nonlinear/nonlinear_system.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
export NonlinearSystem


struct NLEq
rhs::Expression
end
function Base.convert(::Type{NLEq}, eq::Equation)
isequal(eq.lhs, Constant(0)) || throw(ArgumentError("nonzero lhs received"))
return NLEq(eq.rhs)
end
Base.convert(::Type{Equation}, eq::NLEq) = Equation(0, eq.rhs)

struct NonlinearSystem <: AbstractSystem
eqs::Vector{Equation}
eqs::Vector{NLEq}
vs::Vector{Variable}
ps::Vector{Variable}
end
Expand All @@ -13,22 +22,12 @@ function NonlinearSystem(eqs)
end


function calculate_jacobian(sys::NonlinearSystem, simplify=true)
sys_eqs, calc_eqs = system_eqs(sys), filter(iscalc, sys.eqs)
rhs = [eq.rhs for eq in sys_eqs]

for calc_eq calc_eqs
find_replace!.(rhs, calc_eq.lhs, calc_eq.rhs)
end

sys_exprs = calculate_jacobian(rhs,sys.vs)
sys_exprs = Expression[expand_derivatives(expr) for expr in sys_exprs]
sys_exprs
function calculate_jacobian(sys::NonlinearSystem)
rhs = [eq.rhs for eq in sys.eqs]
jac = expand_derivatives.(calculate_jacobian(rhs, sys.vs))
return jac
end

iscalc(eq) = !isequal(eq.lhs, Constant(0))

system_eqs(sys::NonlinearSystem) = filter(!iscalc, sys.eqs)
system_extras(sys::NonlinearSystem) = filter(eq -> isa(eq.lhs, Variable), sys.eqs)
system_eqs(sys::NonlinearSystem) = collect(Equation, sys.eqs)
system_vars(sys::NonlinearSystem) = sys.vs
system_params(sys::NonlinearSystem) = sys.ps
10 changes: 4 additions & 6 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@ export generate_jacobian, generate_function
abstract type AbstractSystem end

function system_eqs end
function system_extras end
function system_vars end
function system_params end

function generate_jacobian(sys::AbstractSystem, simplify = true)
function generate_jacobian(sys::AbstractSystem)
vs, ps = system_vars(sys), system_params(sys)
var_exprs = [:($(vs[i].name) = u[$i]) for i in eachindex(vs)]
param_exprs = [:($(ps[i].name) = p[$i]) for i in eachindex(ps)]
jac = calculate_jacobian(sys, simplify)
jac = calculate_jacobian(sys)
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
exprs = vcat(var_exprs, param_exprs, vec(jac_exprs))
block = expr_arr_to_block(exprs)
:((J,u,p,t) -> $(block))
end

function generate_function(sys::AbstractSystem; version::FunctionVersion = ArrayFunction)
sys_eqs, calc_eqs = system_eqs(sys), system_extras(sys)
sys_eqs = system_eqs(sys)
vs, ps = system_vars(sys), system_params(sys)

var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(vs)]
param_pairs = [(p.name, :(p[$i])) for (i, p) enumerate(ps)]
calc_pairs = [(eq.lhs.name, convert(Expr, eq.rhs)) for eq calc_eqs]
(ls, rs) = collect(zip(var_pairs..., param_pairs..., calc_pairs...))
(ls, rs) = collect(zip(var_pairs..., param_pairs...))

var_eqs = Expr(:(=), build_expr(:tuple, ls), build_expr(:tuple, rs))
sys_exprs = build_expr(:tuple, [convert(Expr, eq.rhs) for eq sys_eqs])
Expand Down
3 changes: 1 addition & 2 deletions test/system_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ f = @eval eval(nlsys_func)

# Intermediate calculations
# Define a nonlinear system
eqs = [a ~ y-x,
0 ~ σ*a,
eqs = [0 ~ σ*a,
0 ~ x*-z)-y,
0 ~ x*y - β*z]
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
Expand Down

0 comments on commit 40b0370

Please sign in to comment.