Skip to content

Commit 6099bd3

Browse files
committed
fix: dispatches for GB
1 parent 3155cba commit 6099bd3

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

src/ConcreteRArray.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,6 @@ buffer_on_cpu(::Any) = true
393393
buffer_on_cpu(x::ConcretePJRTArray) = all(XLA.buffer_on_cpu, x.data)
394394
buffer_on_cpu(x::ConcreteIFRTArray) = XLA.buffer_on_cpu(x.data)
395395

396-
function Ops.constant(x::AbstractConcreteArray; kwargs...)
397-
return Ops.constant(Base.convert(Array, x); kwargs...)
398-
end
399-
400-
function Ops.constant(x::AbstractConcreteNumber{T}; kwargs...) where {T}
401-
return Ops.constant(Base.convert(T, x); kwargs...)
402-
end
403-
404396
function Base.zero(x::ConcretePJRTArray{T,N}) where {T,N}
405397
return ConcretePJRTArray(
406398
zeros(T, size(x)...); client=XLA.client(x), device=XLA.device(x), x.sharding
@@ -464,3 +456,9 @@ function Base.mapreducedim!(
464456
fn(f, op, R, A)
465457
return R
466458
end
459+
460+
function Base.map!(f, R::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, A::AbstractArray)
461+
fn = compile(Base.map!, (f, R, A))
462+
fn(f, R, A)
463+
return R
464+
end

src/Ops.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ end
122122
end
123123
end
124124

125+
@noinline function constant(
126+
x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
127+
) where {T,N}
128+
return constant(collect(x); location)
129+
end
130+
131+
@noinline function constant(x::Reactant.AbstractConcreteArray; kwargs...)
132+
return constant(Base.convert(Array, x); kwargs...)
133+
end
134+
125135
@noinline function constant(
126136
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
127137
) where {T<:Number}
@@ -130,6 +140,10 @@ end
130140
return TracedRNumber{T}((), res.mlir_data)
131141
end
132142

143+
@noinline function constant(x::Reactant.AbstractConcreteNumber{T}; kwargs...) where {T}
144+
return constant(Base.convert(T, x); kwargs...)
145+
end
146+
133147
function fill(
134148
v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
135149
)

src/TracedRArray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,4 +1084,11 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
10841084
return (values, linear_indices)
10851085
end
10861086

1087+
Base.map(f, x::AnyTracedRArray) = f.(x)
1088+
1089+
function Base.map!(f, y::AnyTracedRArray, x::AbstractArray)
1090+
y .= f.(x)
1091+
return y
1092+
end
1093+
10871094
end

0 commit comments

Comments
 (0)