diff --git a/src/ast.jl b/src/ast.jl index c4f8088..5683c87 100644 --- a/src/ast.jl +++ b/src/ast.jl @@ -259,9 +259,7 @@ function process_call!(flatAST::FlatAST, ex, new_var=nothing) # TODO: Use @match here! if op in (:+, :*) && length(ex.args) > 3 - return flatten!(flatAST, - :( ($op)($(ex.args[2]), ($op)($(ex.args[3:end]...) )) ) - ) + return flatten!(flatAST, :( ($op)($(ex.args[2]), ($op)($(ex.args[3:end]...) )) )) end top_args = [] diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index ba13f33..717a2ab 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -35,6 +35,7 @@ end minus_rev(a,b,c) = minus_rev(promote(a,b,c)...) +minus_rev(a::Interval, b::Interval) = (b = -a; return (a, b)) # a = -b function mul_rev(a::Interval, b::Interval, c::Interval) # a = b * c # a = a ∩ (b * c) @@ -48,6 +49,8 @@ mul_rev(a,b,c) = mul_rev(promote(a,b,c)...) Base.iseven(x::Interval) = isinteger(x) && iseven(round(Int, x.lo)) +Base.isodd(x::Interval) = isinteger(x) && isodd(round(Int, x.lo)) + function power_rev(a::Interval, b::Interval, c::Interval) # a = b^c, log(a) = c.log(b), b = a^(1/c) @@ -65,6 +68,12 @@ function power_rev(a::Interval, b::Interval, c::Interval) # a = b^c, log(a) = b = hull(b1, b2) + elseif isodd(c) + b1 = b ∩ ( (a ∩ (0..∞)) ^(inv(c) )) # positive part + b2 = b ∩ (- ( (-(a ∩ (-∞..0)))^(inv(c)) ) ) # negative part + + b = hull(b1, b2) + else b = b ∩ ( a^(inv(c) )) diff --git a/test/runtests.jl b/test/runtests.jl index 11e6e2f..4a40c4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -158,3 +158,9 @@ end @test C3(A, x) == IntervalBox(sqrt(A / 16)) end + +@testset "power_rev for odd power" begin + x = -∞..∞ + a = -8..27 + power_rev(a, x, 3)[2] == Interval(-2.0000000000000004, 3.0000000000000004) +end