From 5e727d69a0f7af645b3a566b033a70f0576cadc5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 12 Mar 2025 15:02:29 -0400 Subject: [PATCH 1/3] fix: dispatches for GB [skip ci] --- src/ConcreteRArray.jl | 14 ++++++-------- src/Ops.jl | 14 ++++++++++++++ src/TracedRArray.jl | 7 +++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 15ab869c6c..411542aec7 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -393,14 +393,6 @@ buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data) buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data) -function Ops.constant(x::AbstractConcreteArray; kwargs...) - return Ops.constant(Base.convert(Array, x); kwargs...) -end - -function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T} - return Ops.constant(Base.convert(T, x); kwargs...) -end - function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N} return ConcretePJRTArray( zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding @@ -464,3 +456,9 @@ function Base.mapreducedim!( fn(f, op, R, A) return R end + +function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray) + fn = compile(Base.map!, (f, R, A)) + fn(f, R, A) + return R +end diff --git a/src/Ops.jl b/src/Ops.jl index 39a38f9505..bade8937e2 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -122,6 +122,16 @@ end end end +@noinline function constant( + x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) +) where {T,N} + return constant(collect(x); location) +end + +@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...) + return constant(Base.convert(Array, x); kwargs...) +end + @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} @@ -130,6 +140,10 @@ end return TracedRNumber{T}((), res.mlir_data) end +@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T} + return constant(Base.convert(T, x); kwargs...) +end + function fill( v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) ) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 662b9d1e9f..0c543a5c4b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1084,4 +1084,11 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin return (values, linear_indices) end +Base.map(f, x::AnyTracedRArray) = f.(x) + +function Base.map!(f, y::AnyTracedRArray, x::AbstractArray) + y .= f.(x) + return y +end + end From d58af5eae774c2c63beed6e50af6e9a9fbfef94d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 12 Mar 2025 21:01:46 -0500 Subject: [PATCH 2/3] fix: map compile --- src/ConcreteRArray.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 411542aec7..f2c7906d8a 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -457,8 +457,13 @@ function Base.mapreducedim!( return R end +function mymap!(f, R, A) + map!(f, R, A) + return nothing +end + function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray) - fn = compile(Base.map!, (f, R, A)) + fn = compile(mymap!, (f, R, A)) fn(f, R, A) return R end From c1fae620b55b46f2f5991ff360c141680c707d58 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 12 Mar 2025 21:05:07 -0500 Subject: [PATCH 3/3] test: map! --- test/basic.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index d94d16e4ad..1b6203ce85 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -938,3 +938,14 @@ end rv ) end + +@testset "map!" begin + x = randn(Float32, 2, 3) + y = zeros(Float32, 2, 3) + + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test Array(@jit(map!(abs2, y_ra, x_ra))) ≈ map!(abs2, y, x) + @test Array(y_ra) ≈ y +end