From 1036d45b5d3b782a3f695d050e8fc226572e5289 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Mar 2022 13:23:47 +0100 Subject: [PATCH] Implement `complex` for dual arguments (#90) * Implement complex wrappers * Test complex implementations * Increment version number * Overload complex for types --- Project.toml | 2 +- src/dual.jl | 7 +++++++ test/automatic_differentiation_test.jl | 8 ++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 284a596..ce89417 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DualNumbers" uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.6" +version = "0.6.7" [deps] Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" diff --git a/src/dual.jl b/src/dual.jl index 0812821..ec00ca4 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -39,6 +39,13 @@ dual(x::ReComp, y::ReComp) = Dual(x, y) dual(x::ReComp) = Dual(x) dual(z::Dual) = z +function Base.complex(x::Dual, y::Dual) + dual(complex(value(x), value(y)), complex(epsilon(x), epsilon(y))) +end +Base.complex(x::Real, y::Dual) = complex(dual(x), y) +Base.complex(x::Dual, y::Real) = complex(x, dual(y)) +Base.complex(::Type{Dual{T}}) where {T} = Dual{complex(T)} + const realpart = value const dualpart = epsilon diff --git a/test/automatic_differentiation_test.jl b/test/automatic_differentiation_test.jl index 93ad8ba..00baa0d 100644 --- a/test/automatic_differentiation_test.jl +++ b/test/automatic_differentiation_test.jl @@ -200,6 +200,14 @@ test(x, y) = x^2 + y @test epsilon(Dual(-2.0,1.0)^Dual(2.0,0.0)) == -4 +# test complex and dual mixing +@test complex(dual(1, 2), dual(3, 4)) == dual(complex(1, 3), complex(2, 4)) +@test complex(1, dual(2, 3)) == dual(complex(1, 2), complex(0, 3)) +@test complex(dual(1, 2), 3) == dual(complex(1, 3), complex(2, 0)) +@test complex(dual(1, 2)) == dual(complex(1, 0), complex(2, 0)) +@test complex(Dual128) === DualComplex256 +@test complex(Dual64) === DualComplex128 + # test for flipsign flipsign(Dual(1.0,1.0),2.0) == Dual(1.0,1.0) flipsign(Dual(1.0,1.0),-2.0) == Dual(-1.0,-1.0)