diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index e9d9c02d7f..d42ca08b6f 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -242,7 +242,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N} return convert(Array, a)[args...] end -function mysetindex!(a, v, args::Vararg{Int,N}) where {N} +function mysetindex!(a, v, args::Vararg{Any,N}) where {N} setindex!(a, v, args...) return nothing end @@ -353,3 +353,28 @@ end function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T} return Ops.constant(Base.convert(T, x); kwargs...) end + +Base.zero(x::ConcreteRArray{T,N}) where {T,N} = ConcreteRArray(zeros(T, size(x)...)) + +function Base.fill!(a::ConcreteRArray{T,N}, val) where {T,N} + if a.data == XLA.AsyncEmptyBuffer + throw("Cannot setindex! to empty buffer") + end + + XLA.await(a.data) + if buffer_on_cpu(a) + buf = a.data.buffer + GC.@preserve buf begin + ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf)) + for start in 1:length(a) + unsafe_store!(ptr, val, start) + end + end + return a + end + + idxs = ntuple(Returns(Colon()), N) + fn = compile(mysetindex!, (a, val, idxs...)) + fn(a, val, idxs...) + return a +end diff --git a/test/basic.jl b/test/basic.jl index 5eff286eda..31fd841b42 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -664,3 +664,15 @@ end ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0)) ) isa ConcreteRNumber{Float32} end + +@testset "fill! and zero on ConcreteRArray" begin + x_ra = Reactant.to_rarray(rand(3, 4)) + + z = zero(x_ra) + @test z isa ConcreteRArray + @test size(z) == size(x_ra) + @test all(iszero, Array(z)) + + fill!(z, 1.0) + @test all(==(1.0), Array(z)) +end