diff --git a/ext/FastPowerEnzymeExt.jl b/ext/FastPowerEnzymeExt.jl index c029690..8db09e5 100644 --- a/ext/FastPowerEnzymeExt.jl +++ b/ext/FastPowerEnzymeExt.jl @@ -13,25 +13,27 @@ function Enzyme.EnzymeRules.forward(config::FwdConfig, y = _y.val ret = func.val(x, y) T = typeof(ret) - if !(_x isa Const) - dxval = _x.dval * y * (fastpower(x,y - 1)) - else + if !(_x isa Const) + dxval = _x.dval * y * (fastpower(x, y - 1)) + else dxval = make_zero(_x.val) end - if !(_y isa Const) - dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : _y.dval*(fastpower(x,y))*log(x) - else + if !(_y isa Const) + dyval = x isa Real && x<=0 ? Base.oftype(float(x), NaN) : + _y.dval*(fastpower(x, y))*log(x) + else dyval = make_zero(_y.val) - end + end if RT <: DuplicatedNoNeed - return convert(T,dxval + dyval) + return convert(T, dxval + dyval) else return Duplicated(ret, convert(T, dxval + dyval)) end end -function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(fastpower)}, ::Union{Type{<:Active}, Type{<:Const}}, x::Union{Const,Active}, y::Union{Const,Active}) +function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(fastpower)}, ::Union{Type{<:Active}, Type{<:Const}}, + x::Union{Const, Active}, y::Union{Const, Active}) if EnzymeRules.needs_primal(config) primal = func.val(x.val, y.val) else @@ -40,12 +42,15 @@ function EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.RevConfigWidth{ return EnzymeRules.AugmentedReturn(primal, nothing, nothing) end -function EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(fastpower)}, dret, tape, _x::Union{Const,Active}, _y::Union{Const,Active}) +function EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(fastpower)}, dret, tape, _x::Union{Const, Active}, _y::Union{ + Const, Active}) x = _x.val y = _y.val - dxval = _x isa Const ? nothing : dret.val * y * (fastpower(x,y - 1)) - dyval = _y isa Const ? nothing : (x isa Real && x<=0 ? Base.oftype(float(x), NaN) : dret.val * (fastpower(x,y))*log(x)) + dxval = _x isa Const ? nothing : dret.val * y * (fastpower(x, y - 1)) + dyval = _y isa Const ? nothing : + (x isa Real && x<=0 ? Base.oftype(float(x), NaN) : + dret.val * (fastpower(x, y)) * log(x)) return (dxval, dyval) end diff --git a/ext/FastPowerMooncakeExt.jl b/ext/FastPowerMooncakeExt.jl index e050d73..d364faf 100644 --- a/ext/FastPowerMooncakeExt.jl +++ b/ext/FastPowerMooncakeExt.jl @@ -1,6 +1,6 @@ module FastPowerMooncakeExt using FastPower, Mooncake -Mooncake.@mooncake_overlay FastPower.fastpower(x,y) = x^y +Mooncake.@mooncake_overlay FastPower.fastpower(x, y) = x^y -end \ No newline at end of file +end diff --git a/src/FastPower.jl b/src/FastPower.jl index 5c93ebb..93bc99c 100644 --- a/src/FastPower.jl +++ b/src/FastPower.jl @@ -11,7 +11,7 @@ module FastPower c = 0.523692f0 # IEEE is sgn(1):exp(8):frac(23) representing # (1+frac)*2^(exp-127). 1+frac is called the significand - + # get exponent ux1i = reinterpret(UInt32, x) exp = Int32((ux1i & 0x7F800000) >> 23) @@ -24,7 +24,7 @@ module FastPower ux2i = (ux1i & 0x007FFFFF) | 0x3f000000 exp -= 0x7e # 126 instead of 127 compensates for division by 2 end - signif = reinterpret(Float32, ux2i) + signif = reinterpret(Float32, ux2i) quot = muladd(signif, a, b) / (signif + c) return muladd(signif - 1.0f0, quot, exp) end diff --git a/test/runtests.jl b/test/runtests.jl index 0526a2c..b391085 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,10 +27,9 @@ end @testset for RT in (Duplicated, DuplicatedNoNeed), Tx in (Const, Duplicated), Ty in (Const, Duplicated) - x = 1.0 y = 0.5 - test_forward(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol=1e-3) + test_forward(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3) end end @@ -38,11 +37,13 @@ end @testset for RT in (Active,), Tx in (Active, Const), Ty in (Active, Const) x = 1.0 y = 0.5 - test_reverse(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol=1e-3) + test_reverse(fastpower, RT, (x, Tx), (y, Ty), atol = 1e-4, rtol = 1e-3) end end -mooncake_derivative(f,x) = Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2] +function mooncake_derivative(f, x) + Mooncake.value_and_gradient!!(Mooncake.build_rrule(f, x), f, x)[2][2] +end @testset "Fast pow - Other AD Engines" begin x = 1.5123233245141 y = 0.22352354326