diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 475d31aa3..f9a0f659e 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -1098,17 +1098,17 @@ __has_analytic(f) = isdefined(f, :analytic) __has_colorvec(f) = isdefined(f, :colorvec) # compatibility -has_invW(f::AbstractDiffEqFunction) = false -has_analytic(f::AbstractDiffEqFunction) = __has_analytic(f) && f.analytic !== nothing -has_jac(f::AbstractDiffEqFunction) = __has_jac(f) && f.jac !== nothing -has_jvp(f::AbstractDiffEqFunction) = __has_jvp(f) && f.jvp !== nothing -has_vjp(f::AbstractDiffEqFunction) = __has_vjp(f) && f.vjp !== nothing -has_tgrad(f::AbstractDiffEqFunction) = __has_tgrad(f) && f.tgrad !== nothing -has_Wfact(f::AbstractDiffEqFunction) = __has_Wfact(f) && f.Wfact !== nothing -has_Wfact_t(f::AbstractDiffEqFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing -has_paramjac(f::AbstractDiffEqFunction) = __has_paramjac(f) && f.paramjac !== nothing -has_syms(f::AbstractDiffEqFunction) = __has_syms(f) && f.syms !== nothing -has_colorvec(f::AbstractDiffEqFunction) = __has_colorvec(f) && f.colorvec !== nothing +has_invW(f::AbstractSciMLFunction) = false +has_analytic(f::AbstractSciMLFunction) = __has_analytic(f) && f.analytic !== nothing +has_jac(f::AbstractSciMLFunction) = __has_jac(f) && f.jac !== nothing +has_jvp(f::AbstractSciMLFunction) = __has_jvp(f) && f.jvp !== nothing +has_vjp(f::AbstractSciMLFunction) = __has_vjp(f) && f.vjp !== nothing +has_tgrad(f::AbstractSciMLFunction) = __has_tgrad(f) && f.tgrad !== nothing +has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing +has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing +has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing +has_syms(f::AbstractSciMLFunction) = __has_syms(f) && f.syms !== nothing +has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing # TODO: find an appropriate way to check `has_*` has_jac(f::Union{SplitFunction,SplitSDEFunction}) = has_jac(f.f1) @@ -1770,6 +1770,115 @@ function Base.convert(::Type{SDDEFunction{iip}},f,g) where iip Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) end +function Base.convert(::Type{NonlinearFunction}, f) + if __has_analytic(f) + analytic = f.analytic + else + analytic = nothing + end + if __has_jac(f) + jac = f.jac + else + jac = nothing + end + if __has_jvp(f) + jvp = f.jvp + else + jvp = nothing + end + if __has_vjp(f) + vjp = f.vjp + else + vjp = nothing + end + if __has_tgrad(f) + tgrad = f.tgrad + else + tgrad = nothing + end + if __has_Wfact(f) + Wfact = f.Wfact + else + Wfact = nothing + end + if __has_Wfact_t(f) + Wfact_t = f.Wfact_t + else + Wfact_t = nothing + end + if __has_paramjac(f) + paramjac = f.paramjac + else + paramjac = nothing + end + if __has_syms(f) + syms = f.syms + else + syms = nothing + end + if __has_colorvec(f) + colorvec = f.colorvec + else + colorvec = nothing + end + NonlinearFunction(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact, + Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) +end +function Base.convert(::Type{NonlinearFunction{iip}},f) where iip + if __has_analytic(f) + analytic = f.analytic + else + analytic = nothing + end + if __has_jac(f) + jac = f.jac + else + jac = nothing + end + if __has_jvp(f) + jvp = f.jvp + else + jvp = nothing + end + if __has_vjp(f) + vjp = f.vjp + else + vjp = nothing + end + if __has_tgrad(f) + tgrad = f.tgrad + else + tgrad = nothing + end + if __has_Wfact(f) + Wfact = f.Wfact + else + Wfact = nothing + end + if __has_Wfact_t(f) + Wfact_t = f.Wfact_t + else + Wfact_t = nothing + end + if __has_paramjac(f) + paramjac = f.paramjac + else + paramjac = nothing + end + if __has_syms(f) + syms = f.syms + else + syms = nothing + end + if __has_colorvec(f) + colorvec = f.colorvec + else + colorvec = nothing + end + NonlinearFunction{iip,RECOMPILE_BY_DEFAULT}(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact, + Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec) +end + struct IncrementingODEFunction{iip,F} <: AbstractODEFunction{iip} f::F end