diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 2141d08368..9590b08ab9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -202,6 +202,7 @@ for (jlop, hloop) in ( (:(Base.:-), :negate), (:(Base.sin), :sine), (:(Base.cos), :cosine), + (:(Base.tan), :tan), (:(Base.tanh), :tanh), (:(Base.FastMath.tanh_fast), :tanh), (:(Base.exp), :exponential), @@ -214,6 +215,13 @@ for (jlop, hloop) in ( @eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs) end +for (jlop, hloop) in + ((:(Base.sinpi), :sine), (:(Base.cospi), :cosine), (:(Base.tanpi), :tan)) + @eval $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} = Ops.$(hloop)(T(π) * lhs) +end + +Base.sincospi(x::TracedRNumber{T}) where {T} = Ops.sine(T(π) * x), Ops.cosine(T(π) * x) + Base.conj(x::TracedRNumber) = x Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x) diff --git a/test/basic.jl b/test/basic.jl index ca5ce7729b..1783a44c36 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -885,3 +885,28 @@ end res = @jit fn(x_ra, Array(idxs_ra)) @test res ≈ fn(Array(x_ra), Array(idxs_ra)) end + +@testset "Common Trig Functions" begin + x = rand(Float32, 4, 16) + x_ra = Reactant.to_rarray(x) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2} + end + + x = 0.235f0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + @testset for fn in (sinpi, cospi, tanpi, sin, cos, tan) + @test @jit(fn.(x_ra)) ≈ fn.(x) + @test @jit(fn.(x_ra)) isa ConcreteRNumber{Float32} + end + @testset for fn in (sincospi, sincos) + res = @jit fn(x_ra) + @test res[1] ≈ fn(x)[1] + @test res[2] ≈ fn(x)[2] + @test res[1] isa ConcreteRNumber{Float32} + @test res[2] isa ConcreteRNumber{Float32} + end +end