From b0a442bef5812d40462fd0febe7335de13c606ec Mon Sep 17 00:00:00 2001 From: Simone Surace Date: Thu, 7 Jul 2022 17:15:24 +0200 Subject: [PATCH 1/7] Add `iszero(x)` branch to `xlogy` and `xlog1py` --- src/rules.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 45104aa..f8fe89c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -248,11 +248,15 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x) # binary -@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y) +@define_diffrule LogExpFunctions.xlogy(x, y) = + :(log($y)), + :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) @define_diffrule LogExpFunctions.logsubexp(x, y) = :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) -@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y)) +@define_diffrule LogExpFunctions.xlog1py(x, y) = + :(log1p($y)), + :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) From fd1902da9f7e1488da73eb6f690515ffb8d55d30 Mon Sep 17 00:00:00 2001 From: Simone Surace Date: Fri, 8 Jul 2022 00:42:57 +0200 Subject: [PATCH 2/7] Treat `y` singularity --- src/rules.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index f8fe89c..38ab883 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -250,7 +250,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # binary @define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), - :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) + :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : iszero($y) ? oftype(z, NaN) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) @define_diffrule LogExpFunctions.logsubexp(x, y) = @@ -259,4 +259,4 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), - :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) + :(yp1 = 1 + $y; z = $x / yp1; iszero($x) && !isnan($y) ? zero(z) : iszero(yp1) ? oftype(z, NaN) : z) From d3ae34099e79efa3b061296122d17d0c27abf167 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Fri, 8 Jul 2022 08:28:25 +0200 Subject: [PATCH 3/7] Update src/rules.jl Co-authored-by: David Widmann --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index 38ab883..f83cb19 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -250,7 +250,7 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) # binary @define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), - :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : iszero($y) ? oftype(z, NaN) : z) + :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) @define_diffrule LogExpFunctions.logsubexp(x, y) = From 09a4e0e342cf38ae74f733075358d1eb57ec463b Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace <51025924+simsurace@users.noreply.github.com> Date: Fri, 8 Jul 2022 08:28:40 +0200 Subject: [PATCH 4/7] Update src/rules.jl Co-authored-by: David Widmann --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index f83cb19..f8fe89c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -259,4 +259,4 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), - :(yp1 = 1 + $y; z = $x / yp1; iszero($x) && !isnan($y) ? zero(z) : iszero(yp1) ? oftype(z, NaN) : z) + :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) From 88422d0ff4550b322b6a99424b1b6c531a3cb758 Mon Sep 17 00:00:00 2001 From: Simone Surace Date: Fri, 19 Aug 2022 11:21:48 +0200 Subject: [PATCH 5/7] Remove whitespace --- src/rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index f8fe89c..6e4a1fa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -256,7 +256,6 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logsubexp(x, y) = :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) - @define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) From f557e82a7aa4b2bf2bbf5cd9610636b7b44e4dbf Mon Sep 17 00:00:00 2001 From: Simone Surace Date: Fri, 19 Aug 2022 11:22:24 +0200 Subject: [PATCH 6/7] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3e04e2f..83f1256 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.11.0" +version = "1.11.1" [deps] IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" From d324416f65a7c2af1ecbe4b365b1ce3fde393ff2 Mon Sep 17 00:00:00 2001 From: Simone Surace Date: Wed, 24 Aug 2022 11:19:59 +0200 Subject: [PATCH 7/7] Add tests --- test/runtests.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index fa75616..3312f7f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -141,6 +141,40 @@ for xtype in [:Float64, :BigFloat] end end +# Test `iszero(x)` branch of `xlogy` +derivs = DiffRules.diffrule(:LogExpFunctions, :xlogy, :x, :y) +for xytype in [:Float32, :Float64, :BigFloat] + @eval begin + let + x = zero($xytype) + y = rand($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + + y = one($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + end + end +end + +# Test `iszero(x)` branch of `xlog1py` +derivs = DiffRules.diffrule(:LogExpFunctions, :xlog1py, :x, :y) +for xytype in [:Float32, :Float64, :BigFloat] + @eval begin + let + x = zero($xytype) + y = rand($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + + y = -one($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + end + end +end + end @testset "diffrules" begin