Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with DynamicQuantities.jl – use oneunit(::T) instead of oneunit(::Type{T}) #993

Open
MilesCranmer opened this issue Oct 25, 2023 · 3 comments

Comments

@MilesCranmer
Copy link

MilesCranmer commented Oct 25, 2023

Trying out a DynamicQuantities.jl example with DifferentialEquations.jl but running into some issues with the use of oneunit(::Type{T}) rather than oneunit(::T). I think changing to the latter will make things compatible with both DynamicQuantities and Unitful.

julia> using DynamicQuantities, DifferentialEquations

julia> f(u, p, t) = u * t;

julia> problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s"));

julia> sol = solve(problem)
ERROR: Cannot create a dimensionful 1 for a `AbstractUnionQuantity` type without knowing the dimensions. Please use `oneunit(::AbstractUnionQuantity)` instead.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] oneunit(::Type{Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}})
    @ DynamicQuantities ~/Documents/DynamicQuantities.jl/src/utils.jl:140
  [3] __init(prob::ODEProblem{…}, alg::CompositeAlgorithm{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Quantity{…}, dtmin::Nothing, dtmax::Quantity{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{…})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:174
  [4] __solve(::ODEProblem{…}, ::CompositeAlgorithm{…}; kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/yppG9/src/solve.jl:5
  [5] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:571
  [6] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::CompositeAlgorithm{…}; kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1033
  [7] solve(prob::ODEProblem{…}, args::CompositeAlgorithm{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:943
  [8] __solve(::ODEProblem{…}, ::Nothing; default_set::Bool, kwargs::@Kwargs{})
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:14
  [9] __solve
    @ DifferentialEquations ~/.julia/packages/DifferentialEquations/Tu7HS/src/default_solve.jl:1 [inlined]
 [10] #__solve#63
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1314 [inlined]
 [11] __solve
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1307 [inlined]
 [12] solve_call(::ODEProblem{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:571
 [13] solve_call(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:537
 [14] solve_up(prob::SciMLBase.AbstractDEProblem, sensealg::Any, u0::Any, p::Any, args::Vararg{Any}; kwargs...)
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:1037 [inlined]
 [15] solve(::ODEProblem{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:943
 [16] solve(::ODEProblem{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/xSmHR/src/solve.jl:933
 [17] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
@ChrisRackauckas
Copy link
Member

I pushed it along and got pretty far:

using DynamicQuantities, OrdinaryDiffEq, RecursiveArrayTools

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
                                                                      <:DynamicQuantities.Quantity{T}
                                                                      }) where T
    T
end

function RecursiveArrayTools.recursive_unitless_eltype(a::Type{<:DynamicQuantities.Quantity{T}}) where T
    T
end

DiffEqBase.value(x::DynamicQuantities.Quantity) = x.value
@inline function DiffEqBase.UNITLESS_ABS2(x::AbstractArray)
    mapreduce(DiffEqBase.UNITLESS_ABS2, DiffEqBase.abs2_and_sum, x, init = zero(real(first(DiffEqBase.value(x)))))
end
@inline function DiffEqBase.UNITLESS_ABS2(x::DynamicQuantities.Quantity)
    abs(DiffEqBase.value(x))
end

function DiffEqBase.abs2_and_sum(x::DynamicQuantities.Quantity, y::Float64)
    reduce(Base.add_sum, DiffEqBase.value(x), init = zero(real(DiffEqBase.value(x)))) +
    reduce(Base.add_sum, y, init = zero(real(DiffEqBase.value(eltype(y)))))
end

DiffEqBase.recursive_length(u::Array) = length(u)
Base.sign(x::DynamicQuantities.Quantity) = Base.sign(DiffEqBase.value(x))

function DiffEqBase.prob2dtmin(prob; use_end_time = true)
    DiffEqBase.prob2dtmin(prob.tspan, oneunit(first(prob.tspan)), use_end_time)
end

DiffEqBase.NAN_CHECK(x::DynamicQuantities.Quantity) = isnan(x)
Base.zero(x::Array{T}) where {T<:DynamicQuantities.Quantity} = zero.(x)

@inline function DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t)
    @. DiffEqBase.calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, t)
end

f(u, p, t) = u / t;
problem = ODEProblem(f, [1.0u"km/s"], (0.0u"s", 1.0u"s"));
sol = solve(problem, Tsit5(), dt = 0.1u"s")

with just one internal modification. Two interface breaks are weird though:

First one:

julia> typeof(one(0.0u"s"))
Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}

that should just be Float64?

Second there's something odd in brodcasting I haven't isolated yet.

@MilesCranmer
Copy link
Author

Thanks, nice work!

Regarding one, see the discussion here: SymbolicML/DynamicQuantities.jl#40. This resulted in the package BaseType.jl for specifically getting the base numeric type. But maybe an interim is to allow Float64 return value, I’m not sure.

Also one alternative to this sort of modification is some of the ideas in SymbolicML/DynamicQuantities.jl#76

@MilesCranmer
Copy link
Author

Here is the PR to implement these changes: SymbolicML/DynamicQuantities.jl#74

So I think the missing part is switching to oneunit(::T) and one(::T) in OrdinaryDiffEq.jl?

julia> sol = solve(problem, Tsit5(), dt = 0.1u"s")
ERROR: Cannot create a dimensionful 1 for a `UnionAbstractQuantity` type without knowing the dimensions. Please use `oneunit(::UnionAbstractQuantity)` instead.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] oneunit(::Type{Quantity{Float64, Dimensions{DynamicQuantities.FixedRational{Int32, 25200}}}})
   @ DynamicQuantities ~/Documents/DynamicQuantities.jl/src/utils.jl:191
 [3] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Quantity{…}, dtmin::Nothing, dtmax::Quantity{…}, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
   @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/qxpST/src/solve.jl:220
 [4] __solve(::ODEProblem{…}, ::Tsit5{…}; kwargs::@Kwargs{})
   @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/qxpST/src/solve.jl:5
 [5] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:557
 [6] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::Tsit5{…}; kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:1006
 [7] solve(prob::ODEProblem{…}, args::Tsit5{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/NYLhl/src/solve.jl:929
 [8] top-level scope
   @ REPL[21]:1
Some type information was truncated. Use `show(err)` to see complete types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants