diff --git a/Project.toml b/Project.toml index 3e04e2f..83f1256 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.11.0" +version = "1.11.1" [deps] IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" diff --git a/src/rules.jl b/src/rules.jl index 45104aa..6e4a1fa 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -248,11 +248,14 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x) @define_diffrule LogExpFunctions.logmxp1(x) = :((1 - $x) / $x) # binary -@define_diffrule LogExpFunctions.xlogy(x, y) = :(log($y)), :($x / $y) +@define_diffrule LogExpFunctions.xlogy(x, y) = + :(log($y)), + :(z = $x / $y; iszero($x) && !isnan($y) ? zero(z) : z) @define_diffrule LogExpFunctions.logaddexp(x, y) = :(exp($x - LogExpFunctions.logaddexp($x, $y))), :(exp($y - LogExpFunctions.logaddexp($x, $y))) @define_diffrule LogExpFunctions.logsubexp(x, y) = :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? exp($x - z) : -exp($x - z)), :(z = LogExpFunctions.logsubexp($x, $y); $x > $y ? -exp($y - z) : exp($y - z)) - -@define_diffrule LogExpFunctions.xlog1py(x, y) = :(log1p($y)), :($x / (1 + $y)) +@define_diffrule LogExpFunctions.xlog1py(x, y) = + :(log1p($y)), + :(z = $x / (1 + $y); iszero($x) && !isnan($y) ? zero(z) : z) diff --git a/test/runtests.jl b/test/runtests.jl index fa75616..3312f7f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -141,6 +141,40 @@ for xtype in [:Float64, :BigFloat] end end +# Test `iszero(x)` branch of `xlogy` +derivs = DiffRules.diffrule(:LogExpFunctions, :xlogy, :x, :y) +for xytype in [:Float32, :Float64, :BigFloat] + @eval begin + let + x = zero($xytype) + y = rand($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + + y = one($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + end + end +end + +# Test `iszero(x)` branch of `xlog1py` +derivs = DiffRules.diffrule(:LogExpFunctions, :xlog1py, :x, :y) +for xytype in [:Float32, :Float64, :BigFloat] + @eval begin + let + x = zero($xytype) + y = rand($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + + y = -one($xytype) + dx, dy = $(derivs[1]), $(derivs[2]) + @test iszero(dy) + end + end +end + end @testset "diffrules" begin