diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index bb11cc7c8..7026c0065 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -294,7 +294,7 @@ macro non_differentiable(sig_expr) primal_name, orig_args = Iterators.peel(sig_expr.args) constrained_args = _constrain_and_name.(orig_args, :Any) - primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] + primal_sig_parts = [:(::Core.Typeof($primal_name)), constrained_args...] unconstrained_args = _unconstrain.(constrained_args) diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 59a2429ad..4a77c237e 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -11,7 +11,7 @@ using Test # Define the AD # Note that we never directly define Dual Number Arithmetic on Dual numbers -# instead it is automatically defined from the `frules` +# instead it is automatically defined from the `frules` struct Dual <: Real primal::Float64 partial::Float64 @@ -30,7 +30,8 @@ Base.to_power_type(x::Dual) = x function define_dual_overload(sig) sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls opT, argTs = Iterators.peel(sig.parameters) - fieldcount(opT) == 0 || return # not handling functors + opT isa Type{<:Type} && return # not handling constructors + fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. N = length(sig.parameters) - 1 # skip the op @@ -65,7 +66,7 @@ function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call -refresh_rules(); +refresh_rules(); @testset "ForwardDiffZero" begin foo(x) = x + x diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index adeebef9a..2f1a5e3e1 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -59,7 +59,8 @@ Base.to_power_type(x::Tracked) = x function define_tracked_overload(sig) sig = Base.unwrap_unionall(sig) # not really handling most UnionAll opT, argTs = Iterators.peel(sig.parameters) - fieldcount(opT) == 0 || return # not handling functors + opT isa Type{<:Type} && return # not handling constructors + fieldcount(opT) == 0 || return # not handling functors all(Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. N = length(sig.parameters) - 1 # skip the op @@ -116,7 +117,7 @@ function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) end # Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call -refresh_rules(); +refresh_rules(); @testset "ReversedDiffZero" begin foo(x) = x + x diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index d02c6fad1..6ca7c769f 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -21,6 +21,14 @@ macro test_macro_throws(err_expr, expr) end end +# struct need to be defined outside of tests for julia 1.0 compat +struct NonDiffExample + x +end + +struct NonDiffCounterExample + x +end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @@ -98,6 +106,25 @@ end end end + @testset "Constructors" begin + @non_differentiable NonDiffExample(::Any) + + @test isequal( + frule((Zero(), 1.2), NonDiffExample, 2.0), + (NonDiffExample(2.0), DoesNotExist()) + ) + + res, pullback = rrule(NonDiffExample, 2.0) + @test res == NonDiffExample(2.0) + @test pullback(1.2) == (NO_FIELDS, DoesNotExist()) + + # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/213 + # problem was that `@nondiff Foo(x)` was also defining rules for other types. + # make sure that isn't happenning + @test frule((Zero(), 1.2), NonDiffCounterExample, 2.0) === nothing + @test rrule(NonDiffCounterExample, 2.0) === nothing + end + @testset "Not supported (Yet)" begin # Varargs are not supported @test_macro_throws ErrorException @non_differentiable vararg1(xs...) @@ -115,7 +142,7 @@ end @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) @scalar_rule(simo(x), 1f0, 2f0) - + y, simo_pb = rrule(simo, π) @test simo_pb((10f0, 20f0)) == (NO_FIELDS, 50f0)