diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 04f9ad4d4..1492b405e 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,14 +10,13 @@ export @non_differentiable, @scalar_rule, @thunk, @not_implemented # definition export canonicalize, extern, unthunk # differential operations export add!! # gradient accumulation operations # differentials -export Tangent, NoTangent, InplaceableThunk, One, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("compat.jl") include("debug_mode.jl") include("differentials/abstract_differential.jl") include("differentials/abstract_zero.jl") -include("differentials/one.jl") include("differentials/thunks.jl") include("differentials/composite.jl") include("differentials/notimplemented.jl") diff --git a/src/deprecated.jl b/src/deprecated.jl index a9152f5a1..e69de29bb 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,5 +0,0 @@ -Base.@deprecate_binding AbstractDifferential AbstractTangent -Base.@deprecate_binding Composite Tangent -Base.@deprecate_binding Zero ZeroTangent -Base.@deprecate_binding DoesNotExist NoTangent -Base.@deprecate_binding NO_FIELDS NoTangent() diff --git a/src/differential_arithmetic.jl b/src/differential_arithmetic.jl index 863015574..0046707ae 100644 --- a/src/differential_arithmetic.jl +++ b/src/differential_arithmetic.jl @@ -8,7 +8,7 @@ Thus we can avoid any ambiguities. Notice: The precedence goes: - `NotImplemented, NoTangent, ZeroTangent, One, AbstractThunk, Tangent, Any` + `NotImplemented, NoTangent, ZeroTangent, AbstractThunk, Tangent, Any` Thus each of the @eval loops create most definitions of + and * defines the combination this type with all types of lower precidence. This means each eval loops is 1 item smaller than the previous. @@ -21,7 +21,7 @@ Base.:+(::ZeroTangent, x::NotImplemented) = x Base.:+(x::NotImplemented, ::NotImplemented) = x Base.:*(::NotImplemented, ::ZeroTangent) = ZeroTangent() Base.:*(::ZeroTangent, ::NotImplemented) = ZeroTangent() -for T in (:NoTangent, :One, :AbstractThunk, :Tangent, :Any) +for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x @eval Base.:*(x::NotImplemented, ::$T) = x @@ -58,7 +58,7 @@ Base.:*(x::NotImplemented, ::NotImplemented) = throw(NotImplementedException(x)) function LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) return throw(NotImplementedException(x)) end -for T in (:NoTangent, :One, :AbstractThunk, :Tangent, :Any) +for T in (:NoTangent, :AbstractThunk, :Tangent, :Any) @eval Base.:-(x::NotImplemented, ::$T) = throw(NotImplementedException(x)) @eval Base.:-(::$T, x::NotImplemented) = throw(NotImplementedException(x)) @eval Base.:*(::$T, x::NotImplemented) = throw(NotImplementedException(x)) @@ -71,7 +71,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent() Base.:-(::NoTangent) = NoTangent() Base.:*(::NoTangent, ::NoTangent) = NoTangent() LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent() -for T in (:One, :AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :Tangent, :Any) @eval Base.:+(::NoTangent, b::$T) = b @eval Base.:+(a::$T, ::NoTangent) = a @eval Base.:-(::NoTangent, b::$T) = -b @@ -111,7 +111,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.:-(::ZeroTangent) = ZeroTangent() Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent() LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent() -for T in (:One, :AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :Tangent, :Any) @eval Base.:+(::ZeroTangent, b::$T) = b @eval Base.:+(a::$T, ::ZeroTangent) = a @eval Base.:-(::ZeroTangent, b::$T) = -b @@ -127,33 +127,11 @@ end Base.real(::ZeroTangent) = ZeroTangent() Base.imag(::ZeroTangent) = ZeroTangent() -Base.real(::One) = One() -Base.imag(::One) = ZeroTangent() - Base.complex(::ZeroTangent) = ZeroTangent() Base.complex(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.complex(::ZeroTangent, i::Real) = complex(oftype(i, 0), i) Base.complex(r::Real, ::ZeroTangent) = complex(r) -Base.complex(::One) = One() -Base.complex(::ZeroTangent, ::One) = im -Base.complex(::One, ::ZeroTangent) = One() - -Base.:+(a::One, b::One) = extern(a) + extern(b) -Base.:*(::One, ::One) = One() -for T in (:AbstractThunk, :Tangent, :Any) - if T != :Tangent - @eval Base.:+(a::One, b::$T) = extern(a) + b - @eval Base.:+(a::$T, b::One) = a + extern(b) - end - - @eval Base.:*(::One, b::$T) = b - @eval Base.:*(a::$T, ::One) = a -end - -LinearAlgebra.dot(::One, x::Number) = x -LinearAlgebra.dot(x::Number, ::One) = conj(x) # see definition of Frobenius inner product - Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b) Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b) for T in (:Tangent, :Any) diff --git a/src/differentials/one.jl b/src/differentials/one.jl deleted file mode 100644 index 6cbbeeeb7..000000000 --- a/src/differentials/one.jl +++ /dev/null @@ -1,16 +0,0 @@ -""" - One() -The Differential which is the multiplicative identity. -Basically, this represents `1`. -""" -struct One <: AbstractTangent - One() = (Base.depwarn("`One()` is deprecated; use `true` instead", :One); return new()) -end - -extern(x::One) = true # true is a strong 1. - -Base.Broadcast.broadcastable(::One) = Ref(One()) -Base.Broadcast.broadcasted(::Type{One}) = One() - -Base.iterate(x::One) = (x, nothing) -Base.iterate(::One, ::Any) = nothing diff --git a/test/deprecated.jl b/test/deprecated.jl index da8b17748..e69de29bb 100644 --- a/test/deprecated.jl +++ b/test/deprecated.jl @@ -1,74 +0,0 @@ -# Define some rules to test One on -dummy_identity(x) = x -@scalar_rule(dummy_identity(x), One()) - -very_nice(x, y) = x + y -@scalar_rule(very_nice(x, y), (One(), One())) - -@testset "deprecations" begin - @test ChainRulesCore.AbstractDifferential === ChainRulesCore.AbstractTangent - @test Zero === ZeroTangent - @test DoesNotExist === NoTangent - @test Composite === Tangent - @test_deprecated NO_FIELDS - @test_deprecated One() -end - -@testset "One()" begin - - o = One() - @test extern(o) === true - @test o + o == 2 - @test o + 1 == 2 - @test 1 + o == 2 - @test o * o == o - @test o * 17 == 17 - @test 6 * o == 6 - @test dot(2 + im, o) == 2 - im - @test dot(o, 2 + im) == 2 + im - for x in o - @test x === o - end - @test broadcastable(o) isa Ref{One} - @test conj(o) == o - - @test reim(o) === (One(), ZeroTangent()) - @test real(o) === One() - @test imag(o) === ZeroTangent() - - @test complex(o) === o - @test complex(o, ZeroTangent()) === o - @test complex(ZeroTangent(), o) === im - - @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) - - @testset "broadcasting One" begin - sx = @SVector [1, 2] - sy = @SVector [3, 4] - - # Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting - @inferred frule((ZeroTangent(), sx, sy), very_nice, 1, 2) - end - - @testset "interaction with other types" begin - c = Tangent{Foo}(y=1.5, x=2.5) - @test One() * c === c - @test c * One() === c - - z = ZeroTangent() - @test zero(One()) === z - @test zero(One) === z - - ni = ChainRulesCore.NotImplemented( - @__MODULE__, LineNumberNode(@__LINE__, @__FILE__), "error" - ) - @test ni + One() === ni - @test One() + ni === ni - E = ChainRulesCore.NotImplementedException - @test_throws E ni - One() - @test_throws E One() - ni - @test_throws E One() * ni - @test_throws E dot(ni, One()) - @test_throws E dot(One(), ni) - end -end