Skip to content

Commit

Permalink
Merge pull request #999 from SciML/ChrisRackauckas-patch-7
Browse files Browse the repository at this point in the history
Check dualness by fields instead of properties
  • Loading branch information
ChrisRackauckas committed Feb 9, 2024
2 parents c05806f + ec80140 commit 1facd29
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 50 deletions.
92 changes: 46 additions & 46 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,21 @@ end
reduce_tup(op, map(f, x))
end
# For other container types, we probably just want to call `mapreduce`
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x)
@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init=Any)

struct DualEltypeChecker{T}
struct DualEltypeChecker{T, T2}
x::T
counter::Int
DualEltypeChecker(x::T, counter::Int) where {T} = new{T}(x, counter + 1)
counter::T2
end

getval(::Val{I}) where I = I
getval(::Type{Val{I}}) where I = I
getval(I::Int) = I

function (dec::DualEltypeChecker)(::Val{Y}) where {Y}
isdefined(dec.x, Y) || return Any
dec.counter >= DUALCHECK_RECURSION_MAX && return Any
anyeltypedual(getproperty(dec.x, Y), dec.counter)
end

# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
if VERSION >= v"1.7"
function (dec::DualEltypeChecker{<:Base.Pairs})(::Val{Y}) where {Y}
isdefined(dec.x, Y) || return Any
dec.counter >= DUALCHECK_RECURSION_MAX && return Any
anyeltypedual(getfield(dec.x, Y), dec.counter)
end
getval(dec.counter) >= DUALCHECK_RECURSION_MAX && return Any
anyeltypedual(getfield(dec.x, Y), Val{getval(dec.counter)})
end

# Untyped dispatch: catch composite types, check all of their fields
Expand All @@ -98,37 +92,43 @@ upconversion is not done automatically, the user is required to upconvert all in
themselves, for an example of how this can be confusing to a user see
https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937
"""
function anyeltypedual(x, counter = 0)
if propertynames(x) === ()
Any
@generated function anyeltypedual(x, ::Type{Val{counter}} = Val{0}) where counter
x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x
if x <: ForwardDiff.Dual
:($x)
elseif fieldnames(x) === ()
:(Any)
elseif counter < DUALCHECK_RECURSION_MAX
diffeqmapreduce(DualEltypeChecker(x, counter), promote_dual,
map(Val, propertynames(x)))
T = diffeqmapreduce(x->anyeltypedual(x, Val{counter+1}), promote_dual,
x.parameters)
if T === Any || isconcretetype(T)
:($T)
else
:(diffeqmapreduce(DualEltypeChecker($x, $counter+1), promote_dual,
map(Val, fieldnames($(typeof(x))))))
end
else
Any
:(Any)
end
end

# Opt out since these are using for preallocation, not differentiation
anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, counter = 0) = Any
anyeltypedual(x::Type{T}, counter = 0) where {T <: ForwardDiff.AbstractConfig} = Any
anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, counter = 0) = Any

if VERSION >= v"1.7"
anyeltypedual(x::Returns, counter = 0) = anyeltypedual(x.value, counter)
end
anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, ::Type{Val{counter}} = Val{0}) where {counter} = Any
anyeltypedual(x::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.AbstractConfig} = Any
anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, ::Type{Val{counter}} = Val{0}) where {counter} = Any
anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} = anyeltypedual(x.value, Val{counter})

if isdefined(PreallocationTools, :FixedSizeDiffCache)
anyeltypedual(x::PreallocationTools.FixedSizeDiffCache, counter = 0) = Any
anyeltypedual(x::PreallocationTools.FixedSizeDiffCache, ::Type{Val{counter}} = Val{0}) where {counter} = Any
end

Base.@pure function __anyeltypedual(::Type{T}) where {T}
hasproperty(T, :parameters) ?
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T
end
anyeltypedual(::Type{T}, counter = 0) where {T} = __anyeltypedual(T)
anyeltypedual(::Type{T}, counter = 0) where {T <: ForwardDiff.Dual} = T
function anyeltypedual(::Type{T}, counter = 0) where {T <: Union{AbstractArray, Set}}
anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} = __anyeltypedual(T)
anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} = T
function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: Union{AbstractArray, Set}}
anyeltypedual(eltype(T))
end
Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
Expand All @@ -141,22 +141,22 @@ Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple}
mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any)
end
end
anyeltypedual(::Type{T}, counter = 0) where {T <: NTuple} = __anyeltypedual_ntuple(T)
anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} = __anyeltypedual_ntuple(T)

# Any in this context just means not Dual
anyeltypedual(x::SciMLBase.NullParameters, counter = 0) = Any
anyeltypedual(x::Number, counter = 0) = anyeltypedual(typeof(x))
anyeltypedual(x::Union{String, Symbol}, counter = 0) = typeof(x)
anyeltypedual(x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} = Any
anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} = anyeltypedual(typeof(x))
anyeltypedual(x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} = typeof(x)
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {
::Type{Val{counter}} = Val{0}) where {counter} where {
T <:
Union{Number,
Symbol,
String}}
anyeltypedual(T)
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {
::Type{Val{counter}} = Val{0}) where {counter} where {
T <: Union{
AbstractArray{
<:Number,
Expand All @@ -167,26 +167,26 @@ function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
anyeltypedual(eltype(x))
end
function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}},
counter = 0) where {N, T <: NTuple{N, <:Number}}
::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}}
anyeltypedual(eltype(x))
end

# Try to avoid this dispatch because it can lead to type inference issues when !isconcrete(eltype(x))
function anyeltypedual(x::AbstractArray, counter = 0)
function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {counter}
if isconcretetype(eltype(x))
anyeltypedual(eltype(x))
elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) &&
counter < DUALCHECK_RECURSION_MAX
counter += 1
mapreduce(y -> anyeltypedual(y, counter), promote_dual, x)
_counter = Val{counter+1}
mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x)
else
# This fallback to Any is required since otherwise we cannot handle `undef` in all cases
# misses cases of
Any
end
end

function anyeltypedual(x::Set, counter = 0)
function anyeltypedual(x::Set, ::Type{Val{counter}} = Val{0}) where {counter}
if isconcretetype(eltype(x))
anyeltypedual(eltype(x))
else
Expand All @@ -195,18 +195,18 @@ function anyeltypedual(x::Set, counter = 0)
end
end

function anyeltypedual(x::Tuple, counter = 0)
function anyeltypedual(x::Tuple, ::Type{Val{counter}} = Val{0}) where {counter}
# Handle the empty tuple case separately for inference and to avoid mapreduce error
if x === ()
Any
else
diffeqmapreduce(anyeltypedual, promote_dual, x)
end
end
function anyeltypedual(x::Dict, counter = 0)
function anyeltypedual(x::Dict, ::Type{Val{counter}} = Val{0}) where {counter}
isempty(x) ? eltype(values(x)) : mapreduce(anyeltypedual, promote_dual, values(x))
end
function anyeltypedual(x::NamedTuple, counter = 0)
function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {counter}
isempty(x) ? Any : diffeqmapreduce(anyeltypedual, promote_dual, values(x))
end
@inline function promote_u0(u0, p, t0)
Expand Down
37 changes: 33 additions & 4 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ p_possibilities17 = [
(Mod, ForwardDiff.Dual(2.0)), (() -> 2.0, ForwardDiff.Dual(2.0)),
(Base.pointer([2.0]), ForwardDiff.Dual(2.0)),
]
VERSION >= v"1.7" &&
push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0))))
push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0))))

for p in p_possibilities17
@show p
Expand Down Expand Up @@ -262,5 +261,35 @@ for p in p_possibilities_configs_not_inferred
end

# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
VERSION >= v"1.7" &&
@test_nowarn DiffEqBase.DualEltypeChecker(pairs((;)), 0)(Val(:data))
@test_nowarn DiffEqBase.DualEltypeChecker(pairs((;)), 0)(Val(:data))

# https://discourse.julialang.org/t/type-instability-with-differentialequations-jl-when-using-nested-structs/109764/5
struct Fit
m₁::Float64
c₁::Float64
m₂::Float64
c₂::Float64

function Fit()
m₁ = 1.595
c₁ = 3.438
m₂ = 1.075
c₂ = 3.484

new(m₁, c₁, m₂, c₂)
end
end

struct EOS
fit::Fit

function EOS()
fit = Fit()

new(fit)
end
end

p = EOS()
@test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual)
@inferred DiffEqBase.anyeltypedual(p)

0 comments on commit 1facd29

Please sign in to comment.