From 1af92593324d95b78a9af7665af915e10c9b8364 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Mar 2025 19:20:25 -0400 Subject: [PATCH] feat: more dispatches for GB --- src/ConcreteRArray.jl | 14 ++++++-------- src/Ops.jl | 14 ++++++++++++++ src/TracedRArray.jl | 7 +++++++ 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 15ab869c6c..411542aec7 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -393,14 +393,6 @@ buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data) buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data) -function Ops.constant(x::AbstractConcreteArray; kwargs...) - return Ops.constant(Base.convert(Array, x); kwargs...) -end - -function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T} - return Ops.constant(Base.convert(T, x); kwargs...) -end - function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N} return ConcretePJRTArray( zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding @@ -464,3 +456,9 @@ function Base.mapreducedim!( fn(f, op, R, A) return R end + +function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray) + fn = compile(Base.map!, (f, R, A)) + fn(f, R, A) + return R +end diff --git a/src/Ops.jl b/src/Ops.jl index 8311f96fd8..df1e42f27e 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -119,6 +119,16 @@ end end end +@noinline function constant( + x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) +) where {T,N} + return constant(collect(x); location) +end + +@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...) + return constant(Base.convert(Array, x); kwargs...) +end + @noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} @@ -127,6 +137,10 @@ end return TracedRNumber{T}((), res.mlir_data) end +@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T} + return constant(Base.convert(T, x); kwargs...) +end + function fill( v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) ) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 066bea0680..507b44a8a7 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1015,4 +1015,11 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin return (values, linear_indices) end +Base.map(f, x::AnyTracedRArray) = f.(x) + +function Base.map!(f, y::AnyTracedRArray, x::AbstractArray) + y .= f.(x) + return y +end + end