diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index ee0b4b6f74..b90370aa1a 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -38,7 +38,7 @@ jobs: - name: "Run CompatHelper" run: | import CompatHelper - CompatHelper.main() + CompatHelper.main(; subdirs=[".", "test", "lib/ReactantCore"]) shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Project.toml b/Project.toml index f0762b756a..db90c4658b 100644 --- a/Project.toml +++ b/Project.toml @@ -39,10 +39,10 @@ ArrayInterface = "7.10" CEnum = "0.4, 0.5" Downloads = "1.6" Enzyme = "0.13.21" -EnzymeCore = "0.8.6, 0.8.7, 0.8.8" +EnzymeCore = "0.8.8" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" -NNlib = "0.9.24" +NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" ReactantCore = "0.1.2" diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 46efa14a3f..42c454b31a 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -215,7 +215,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} end XLA.await(a.data) - if XLA.BufferOnCPU(a.data.buffer) + if buffer_on_cpu(a) buf = a.data.buffer GC.@preserve buf begin ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) @@ -246,7 +246,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N end XLA.await(a.data) - if XLA.BufferOnCPU(a.data.buffer) + if buffer_on_cpu(a) buf = a.data.buffer GC.@preserve buf begin ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) @@ -289,15 +289,52 @@ end # TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`) function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}}) - ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) - if !Base.isconcretetype(ElType) - throw( - ErrorException( - "`copy` on `ConcreteRArray` for non-concrete eltype is not implemented" - ), - ) + for x in bc.args + x isa ConcreteRArray && XLA.await(x.data) end - aux = copyto!(similar(Array{ElType}, axes(bc)), bc) - return ConcreteRArray(aux) + all_on_cpu = all(buffer_on_cpu, bc.args) + if all_on_cpu + ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) + if !Base.isconcretetype(ElType) + throw( + ErrorException( + "`copy` on `ConcreteRArray` for non-concrete eltype is not implemented" + ), + ) + end + aux = copyto!(similar(Array{ElType}, axes(bc)), bc) + return ConcreteRArray(aux) + end + + fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) + return fn(bc.args...) +end + +function Base.copyto!(dest::ConcreteRArray, src::ConcreteRArray) + dest.data = src.data + return dest +end + +function Base.mapreduce( + @nospecialize(f), + @nospecialize(op), + @nospecialize(A::ConcreteRArray{T,N}); + dims=:, + init=nothing, +) where {T,N} + fn = Reactant.compile(CallMapReduce(f, op, dims, init), (A,)) + return fn(A) +end + +struct CallMapReduce{Fn,Op,Dims,Init} + f::Fn + op::Op + dims::Dims + init::Init end + +(f::CallMapReduce)(A) = Base.mapreduce(f.f, f.op, A; f.dims, f.init) + +buffer_on_cpu(::Any) = true +buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 12c5123b75..d3a6f3b990 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -240,6 +240,8 @@ for (jlop, hloop) in ( (:(Base.FastMath.exp_fast), :exponential), (:(Base.log), :log), (:(Base.sqrt), :sqrt), + (:(Base.ceil), :ceil), + (:(Base.floor), :floor), ) @eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} OutTy = $(hloop === :abs) ? real(T) : T diff --git a/test/Project.toml b/test/Project.toml index c9526e3624..4b50a487fc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,3 +19,24 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ArrayInterface = "7.10" +BenchmarkTools = "1.5" +Enzyme = "0.13.21" +FFTW = "1.8" +Flux = "0.15" +Functors = "0.5" +InteractiveUtils = "1.10" +LinearAlgebra = "1.10" +Lux = "1.4.1" +LuxLib = "1.3" +MLUtils = "0.4.4" +NNlib = "0.9.26" +OneHotArrays = "0.2.6" +Optimisers = "0.4" +Random = "1.10" +SafeTestsets = "0.1" +SpecialFunctions = "2.4" +Statistics = "1.10" +Test = "1.10" diff --git a/test/ops.jl b/test/ops.jl index eb4981b80e..0437b2723f 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -82,14 +82,18 @@ end end @testset "cholesky" begin - g(x) = Ops.cholesky(x; lower=true) + # cholesky in stablehlo for the other triangle is implementation defined. + # See https://github.com/EnzymeAD/Reactant.jl/issues/338 for more details. + g1(x) = triu(Ops.cholesky(x)) + g2(x) = tril(Ops.cholesky(x; lower=true)) + x = ConcreteRArray([ 10.0 2.0 3.0 2.0 5.0 6.0 3.0 6.0 9.0 ]) - @test cholesky(Array(x)).U ≈ @jit Ops.cholesky(x) - @test transpose(cholesky(Array(x)).U) ≈ @jit g(x) + @test cholesky(Array(x)).U ≈ @jit g1(x) + @test transpose(cholesky(Array(x)).U) ≈ @jit g2(x) x = ConcreteRArray( [ @@ -98,8 +102,9 @@ end 3.0+4.0im 3.0+2.0im 9.0+0.0im ], ) - @test cholesky(Array(x)).U ≈ @jit Ops.cholesky(x) - @test adjoint(cholesky(Array(x)).U) ≈ @jit g(x) + + @test cholesky(Array(x)).U ≈ @jit g1(x) + @test adjoint(cholesky(Array(x)).U) ≈ @jit g2(x) end @testset "clamp" begin @@ -210,13 +215,14 @@ end ] # NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation @test sum(a .* b) ≈ @jit f1(a, b) - @test kron(reshape(a, length(a), 1), reshape(b, 1, length(b))) ≈ @jit fouter(a, b) + @test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b))) ≈ + @jit fouter(a, b) @test a .* b ≈ @jit fouter_batch1(a, b) end a = ConcreteRArray([1 2; 3 4]) b = ConcreteRArray([5 6; -7 -8]) - @test a' * b == @jit f1(a, b) + @test Array(a)' * Array(b) == @jit f1(a, b) end @testset "einsum" begin @@ -239,7 +245,7 @@ end x = reshape(a, (2, 2)) y = reshape(b, (2, 2)) @test x .* y ≈ @jit f3(x, y) - @test x * y ≈ @jit f4(x, y) + @test Array(x) * Array(y) ≈ @jit f4(x, y) end end