Skip to content

Commit

Permalink
Merge pull request #51 from devmotion/fix_nlsolve
Browse files Browse the repository at this point in the history
Update to NLsolve changes in OrdinaryDiffEq
  • Loading branch information
ChrisRackauckas committed Nov 4, 2017
2 parents 57402b7 + 2be623c commit 9c5a0f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 40 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
julia 0.6
DiffEqBase 2.0.0
OrdinaryDiffEq 2.21.0
OrdinaryDiffEq 2.22.0
DataStructures 0.4.6
RecursiveArrayTools 0.2.0
Reexport
Expand Down
51 changes: 12 additions & 39 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ Hereby u, uprev, uprev2, and function f are updated, if required.
assign_expr(::Val{name}, ::Type, ::Type) where {name} =
:($name = getfield(cache, $(Meta.quot(name))))

# update uhold
assign_expr(::Val{:uhold}, ::Type,
::Type{<:Union{OrdinaryDiffEq.GenericImplicitEulerCache,
OrdinaryDiffEq.GenericTrapezoidCache,
OrdinaryDiffEq.GenericIIF1Cache,
OrdinaryDiffEq.GenericIIF2Cache}}) =
:(uhold = vec(u))

# update matrix exponential
assign_expr(::Val{:expA}, ::Type, ::Type) =
:(A = f.f1; expA = expm(A*dt))
Expand Down Expand Up @@ -75,39 +67,20 @@ assign_expr(::Val{name}, ::Type{ForwardDiff.JacobianConfig{T,V,N,D}},
ForwardDiff.Chunk{$N}()))

# update implicit RHS
function assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.ImplicitRHS}, ::Type) where name
nameq = Meta.quot(name)
:($name = OrdinaryDiffEq.ImplicitRHS(
f,
getfield(cache, $nameq).tmp,
t, t, t,
getfield(cache, $nameq).dual_cache))
end
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.ImplicitRHS}, ::Type) where name =
:($name = OrdinaryDiffEq.ImplicitRHS(f, cache.tmp, t, t, t, cache.dual_cache))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.ImplicitRHS_Scalar}, ::Type) where name =
:($name = OrdinaryDiffEq.ImplicitRHS_Scalar(
f,
getfield(cache, $(Meta.quot(name))).tmp,
t, t, t))
function assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF}, ::Type) where name
nameq = Meta.quot(name)
:($name = OrdinaryDiffEq.RHS_IIF(
f,
getfield(cache, $nameq).tmp,
t, t,
getfield(cache, $nameq).dual_cache,
getfield(cache, $nameq).a))
end
function assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF_Scalar},
::Type) where name
nameq = Meta.quot(name)
:($name = OrdinaryDiffEq.RHS_IIF_Scalar(
f,
t, t,
getfield(cache, $nameq).tmp,
getfield(cache, $nameq).a))
end
:($name = OrdinaryDiffEq.ImplicitRHS_Scalar(f, zero(u), t, t, t))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF}, ::Type) where name =
:($name = OrdinaryDiffEq.RHS_IIF(f, cache.tmp, t, t, cache.tmp, cache.dual_cache))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF_Scalar}, ::Type) where name =
:($name = OrdinaryDiffEq.RHS_IIF_Scalar(f, zero(u), t, t,
getfield(cache, $(Meta.quot(name))).a))

# create new NLsolve differentiable function
assign_expr(::Val{name}, ::Type{<:NLsolve.DifferentiableMultivariateFunction},
::Type) where name =
::Type{<:OrdinaryDiffEq.OrdinaryDiffEqMutableCache}) where name =
:($name = alg.nlsolve(Val{:init},rhs,u))
assign_expr(::Val{name}, ::Type{<:NLsolve.DifferentiableMultivariateFunction},
::Type{<:OrdinaryDiffEq.OrdinaryDiffEqConstantCache}) where name =
:($name = alg.nlsolve(Val{:init},rhs,uhold))

0 comments on commit 9c5a0f2

Please sign in to comment.