From 41d4127661c892aa270ce1339f9b9e6ba8de2366 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 29 Jan 2021 03:30:21 +0530 Subject: [PATCH 1/2] Add helper functions for AbstractSciMLFunction --- src/scimlfunctions.jl | 121 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 475d31aa3..9e2b44a7c 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1110,6 +1110,18 @@ has_paramjac(f::AbstractDiffEqFunction) = __has_paramjac(f) && f.paramjac !== no has_syms(f::AbstractDiffEqFunction) = __has_syms(f) && f.syms !== nothing has_colorvec(f::AbstractDiffEqFunction) = __has_colorvec(f) && f.colorvec !== nothing +has_invW(f::AbstractSciMLFunction) = false +has_analytic(f::AbstractSciMLFunction) = __has_analytic(f) && f.analytic !== nothing +has_jac(f::AbstractSciMLFunction) = __has_jac(f) && f.jac !== nothing +has_jvp(f::AbstractSciMLFunction) = __has_jvp(f) && f.jvp !== nothing +has_vjp(f::AbstractSciMLFunction) = __has_vjp(f) && f.vjp !== nothing +has_tgrad(f::AbstractSciMLFunction) = __has_tgrad(f) && f.tgrad !== nothing +has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing +has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing +has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing +has_syms(f::AbstractSciMLFunction) = __has_syms(f) && f.syms !== nothing +has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing + # TODO: find an appropriate way to check `has_*` has_jac(f::Union{SplitFunction,SplitSDEFunction}) = has_jac(f.f1) has_jvp(f::Union{SplitFunction,SplitSDEFunction}) = has_jvp(f.f1) @@ -1770,6 +1782,115 @@ function Base.convert(::Type{SDDEFunction{iip}},f,g) where iip Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) end +function Base.convert(::Type{NonlinearFunction}, f) + if __has_analytic(f) + analytic = f.analytic + else + analytic = nothing + end + if __has_jac(f) + jac = f.jac + else + jac = nothing + end + if __has_jvp(f) + jvp = f.jvp + else + jvp = nothing + end + if __has_vjp(f) + vjp = f.vjp + else + vjp = nothing + end + if __has_tgrad(f) + tgrad = f.tgrad + else + tgrad = nothing + end + if __has_Wfact(f) + Wfact = f.Wfact + else + Wfact = nothing + end + if __has_Wfact_t(f) + Wfact_t = f.Wfact_t + else + Wfact_t = nothing + end + if __has_paramjac(f) + paramjac = f.paramjac + else + paramjac = nothing + end + if __has_syms(f) + syms = f.syms + else + syms = nothing + end + if __has_colorvec(f) + colorvec = f.colorvec + else + colorvec = nothing + end + NonlinearFunction(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact, + Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) +end +function Base.convert(::Type{NonlinearFunction{iip}},f) where iip + if __has_analytic(f) + analytic = f.analytic + else + analytic = nothing + end + if __has_jac(f) + jac = f.jac + else + jac = nothing + end + if __has_jvp(f) + jvp = f.jvp + else + jvp = nothing + end + if __has_vjp(f) + vjp = f.vjp + else + vjp = nothing + end + if __has_tgrad(f) + tgrad = f.tgrad + else + tgrad = nothing + end + if __has_Wfact(f) + Wfact = f.Wfact + else + Wfact = nothing + end + if __has_Wfact_t(f) + Wfact_t = f.Wfact_t + else + Wfact_t = nothing + end + if __has_paramjac(f) + paramjac = f.paramjac + else + paramjac = nothing + end + if __has_syms(f) + syms = f.syms + else + syms = nothing + end + if __has_colorvec(f) + colorvec = f.colorvec + else + colorvec = nothing + end + NonlinearFunction{iip,RECOMPILE_BY_DEFAULT}(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact, + Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) +end + struct IncrementingODEFunction{iip,F} <: AbstractODEFunction{iip} f::F end From 772f7ee5f6a73b0e2216c6441afa2248ad276d65 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Fri, 29 Jan 2021 03:38:57 +0530 Subject: [PATCH 2/2] Remove redundant AbstractDiffEqFunction functions --- src/scimlfunctions.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 9e2b44a7c..f9a0f659e 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1098,18 +1098,6 @@ __has_analytic(f) = isdefined(f, :analytic) __has_colorvec(f) = isdefined(f, :colorvec) # compatibility -has_invW(f::AbstractDiffEqFunction) = false -has_analytic(f::AbstractDiffEqFunction) = __has_analytic(f) && f.analytic !== nothing -has_jac(f::AbstractDiffEqFunction) = __has_jac(f) && f.jac !== nothing -has_jvp(f::AbstractDiffEqFunction) = __has_jvp(f) && f.jvp !== nothing -has_vjp(f::AbstractDiffEqFunction) = __has_vjp(f) && f.vjp !== nothing -has_tgrad(f::AbstractDiffEqFunction) = __has_tgrad(f) && f.tgrad !== nothing -has_Wfact(f::AbstractDiffEqFunction) = __has_Wfact(f) && f.Wfact !== nothing -has_Wfact_t(f::AbstractDiffEqFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing -has_paramjac(f::AbstractDiffEqFunction) = __has_paramjac(f) && f.paramjac !== nothing -has_syms(f::AbstractDiffEqFunction) = __has_syms(f) && f.syms !== nothing -has_colorvec(f::AbstractDiffEqFunction) = __has_colorvec(f) && f.colorvec !== nothing - has_invW(f::AbstractSciMLFunction) = false has_analytic(f::AbstractSciMLFunction) = __has_analytic(f) && f.analytic !== nothing has_jac(f::AbstractSciMLFunction) = __has_jac(f) && f.jac !== nothing