From e0906bc6a10cfab2ddc71fd62cb8fa038cd5cd8a Mon Sep 17 00:00:00 2001 From: qfl3x Date: Tue, 16 Sep 2025 13:16:00 +0200 Subject: [PATCH 01/14] WIP FixedSizeArrays support --- Project.toml | 3 +++ ext/ReactantFixedSizeArraysExt.jl | 34 +++++++++++++++++++++++++++++++ src/Tracing.jl | 4 ++++ test/Project.toml | 2 ++ 4 files changed, 43 insertions(+) create mode 100644 ext/ReactantFixedSizeArraysExt.jl diff --git a/Project.toml b/Project.toml index 51446928e5..d03494fb52 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2" [sources] ReactantCore = {path = "lib/ReactantCore"} @@ -57,6 +58,7 @@ ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"] ReactantDLFP8TypesExt = "DLFP8Types" ReactantFillArraysExt = "FillArrays" +ReactantFixedSizeArraysExt = "FixedSizeArrays" ReactantFloat8sExt = "Float8s" ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" @@ -82,6 +84,7 @@ EnumX = "1" Enzyme = "0.13.74" EnzymeCore = "0.8.13" FillArrays = "1.13" +FixedSizeArrays = "1.2.0" Float8s = "0.1" Functors = "0.5" GPUArraysCore = "0.2" diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl new file mode 100644 index 0000000000..ee110df86e --- /dev/null +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -0,0 +1,34 @@ +module ReactantFixedSizeArraysExt + +using FixedSizeArrays +using Reactant +using Reactant: TracedRArray, TracedRNumber, Ops +using ReactantCore: ReactantCore + +function Reactant.traced_type_inner( + @nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}), + seen, + @nospecialize(mode::Reactant.TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {T, N, I} + T2 = Reactant.TracedRNumber{T} + I2 = Reactant.TracedRNumber{I} + return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}} +end + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, + @nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}), + @nospecialize(path), + mode; kwargs... +) where {T, N, I} + return FixedSizeArrays.FixedSizeArray( + Reactant.make_tracer( + seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number + ) + ) +end + +end diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..fb223714a2 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1148,6 +1148,10 @@ function make_tracer( @nospecialize(runtime = nothing), kwargs..., ) + @show prev + @show path + @show seen + @show typeof(prev) return make_tracer_unknown( seen, prev, path, mode; track_numbers, sharding, runtime, kwargs... ) diff --git a/test/Project.toml b/test/Project.toml index ba3ce8fc73..d7096ece04 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -28,6 +29,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" From 460dd4c06db41f6899d39242f19b3b2ae303f461 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Tue, 16 Sep 2025 13:17:27 +0200 Subject: [PATCH 02/14] Removed debug messages --- src/Tracing.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index fb223714a2..015a2ccff7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1148,10 +1148,6 @@ function make_tracer( @nospecialize(runtime = nothing), kwargs..., ) - @show prev - @show path - @show seen - @show typeof(prev) return make_tracer_unknown( seen, prev, path, mode; track_numbers, sharding, runtime, kwargs... ) From 73205e2cf3047d740d6a7681d07179de3f81b45a Mon Sep 17 00:00:00 2001 From: qfl3x Date: Wed, 17 Sep 2025 11:39:21 +0200 Subject: [PATCH 03/14] FixedSizeArray -> ConcreteArray --- ext/ReactantFixedSizeArraysExt.jl | 17 +++--- src/Tracing.jl | 89 +++++++++++++++++++++++++++++++ src/Types.jl | 27 ++++++++++ src/xla/PJRT/Buffer.jl | 16 ++++++ 4 files changed, 139 insertions(+), 10 deletions(-) diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl index ee110df86e..f72073882a 100644 --- a/ext/ReactantFixedSizeArraysExt.jl +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -6,28 +6,25 @@ using Reactant: TracedRArray, TracedRNumber, Ops using ReactantCore: ReactantCore function Reactant.traced_type_inner( - @nospecialize(_::Type{FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}}), + @nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T, N}}), seen, @nospecialize(mode::Reactant.TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), @nospecialize(runtime) -) where {T, N, I} +) where {T, N} T2 = Reactant.TracedRNumber{T} - I2 = Reactant.TracedRNumber{I} - return FixedSizeArrays.FixedSizeArray{T2, N, Memory{I2}} + return FixedSizeArrays.FixedSizeArrayDefault{T2, N} end Base.@nospecializeinfer function Reactant.make_tracer( seen, - @nospecialize(prev::FixedSizeArrays.FixedSizeArray{T, N, Memory{I}}), + @nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T, N}), @nospecialize(path), mode; kwargs... -) where {T, N, I} - return FixedSizeArrays.FixedSizeArray( - Reactant.make_tracer( - seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number - ) +) where {T, N} + return Reactant.make_tracer( + seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number ) end diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..efd3868b4e 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1812,6 +1812,95 @@ Base.@nospecializeinfer function make_tracer( return res end + +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Memory), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + @nospecialize(device = nothing), + @nospecialize(client = nothing), + kwargs..., +) + RT = Core.Typeof(prev) + # XXX: If someone wants to shard the same array with different shardings, we need to + # somehow handle this correctly... Right now we just use the first sharding. + if mode != NoStopTracedTrack && haskey(seen, prev) + if mode == TracedToTypes + visited = seen[prev] + push!(path, visited) + return nothing + end + return seen[prev] + end + if eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete + runtime isa Val{:PJRT} && + (return seen[prev] = ConcretePJRTArray(prev; sharding, device, client)) + runtime isa Val{:IFRT} && + (return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client)) + error("Unsupported runtime $runtime") + elseif mode == TracedToTypes + # Original array can get mutated so we store a copy: + push!(path, copy(prev)) + seen[prev] = VisitedObject(length(seen) + 1) + return nothing + end + elseif mode == TracedToTypes + push!(path, RT) + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + make_tracer( + seen, + pv, + path, + mode; + track_numbers, + sharding, + runtime, + device, + client, + kwargs..., + ) + end + end + return nothing + end + TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) + newa = Array{TT,ndims(RT)}(undef, size(prev)) + seen[prev] = newa + same = true + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + nv = make_tracer( + seen, + pv, + append_path(path, I), + mode; + track_numbers, + sharding=Base.getproperty(sharding, I), + runtime, + device, + client, + kwargs..., + ) + if pv !== nv + same = false + end + @inbounds newa[I] = nv + end + end + if same + seen[prev] = prev + return prev + end + return newa +end Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Sharding.Mesh), diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..72201864a2 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -228,6 +228,20 @@ function ConcretePJRTArray( return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) end +function ConcretePJRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.PJRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.PJRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T} + theclient, thedevice = _select_client_and_device(client, idx, device, sharding) + sharded_data, shardinfo = sharding(theclient, thedevice, data) + shape = size(data) + nsharded = length(sharded_data) + return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) +end + Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) @@ -356,6 +370,19 @@ function ConcreteIFRTArray( return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) end +function ConcreteIFRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.IFRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.IFRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T} + theclient, thedevice = _select_client_and_device(client, idx, device, sharding) + sharded_data, shardinfo, padding = sharding(theclient, nothing, data) + shape = size(data) + return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo) +end + # Assemble data from multiple arrays. Needed in distributed setting where each process wont # have enough host memory to hold all the arrays. We assume that the data is only provided # for all of the addressable devices. diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 2b36292c93..174ca15c02 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -21,6 +21,22 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} return Buffer(buffer) end + +function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} + sizear = collect(Int64, reverse(size(memory))) + buffer = GC.@preserve memory sizear begin + @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( + client.client::Ptr{Cvoid}, + pointer(memory)::Ptr{T}, + XLA.primitive_type(T)::UInt64, + 1::Csize_t, + pointer(sizear)::Ptr{Int64}, + device.device::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + return Buffer(buffer) +end + function Base.similar(a::Buffer) buffer = GC.@preserve a begin @ccall MLIR.API.mlir_c.UninitPJRTBuffer( From 15eca0beefefc69ff2829cb7d6848b0aa411b3a5 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Wed, 17 Sep 2025 11:46:44 +0200 Subject: [PATCH 04/14] Project cleanup --- test/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index d7096ece04..243590158a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -29,7 +29,6 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" -Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" From cbd5e7533c75c0a3ae689ee5befdce163927695c Mon Sep 17 00:00:00 2001 From: qfl3x Date: Wed, 17 Sep 2025 12:20:37 +0200 Subject: [PATCH 05/14] Safeguard Memory for v1.10 --- src/Tracing.jl | 138 +++++++++++++++++++++-------------------- src/Types.jl | 50 ++++++++------- src/xla/PJRT/Buffer.jl | 27 ++++---- 3 files changed, 111 insertions(+), 104 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index efd3868b4e..ec63f241bf 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1812,95 +1812,97 @@ Base.@nospecializeinfer function make_tracer( return res end - -Base.@nospecializeinfer function make_tracer( - seen, - @nospecialize(prev::Memory), - @nospecialize(path), - mode; - @nospecialize(track_numbers::Type = Union{}), - @nospecialize(sharding = Sharding.NoSharding()), - @nospecialize(runtime = nothing), - @nospecialize(device = nothing), - @nospecialize(client = nothing), - kwargs..., -) - RT = Core.Typeof(prev) - # XXX: If someone wants to shard the same array with different shardings, we need to - # somehow handle this correctly... Right now we just use the first sharding. - if mode != NoStopTracedTrack && haskey(seen, prev) - if mode == TracedToTypes - visited = seen[prev] - push!(path, visited) - return nothing +if isdefined(Base, :Memory) + Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Memory), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + @nospecialize(device = nothing), + @nospecialize(client = nothing), + kwargs..., + ) + RT = Core.Typeof(prev) + # XXX: If someone wants to shard the same array with different shardings, we need to + # somehow handle this correctly... Right now we just use the first sharding. + if mode != NoStopTracedTrack && haskey(seen, prev) + if mode == TracedToTypes + visited = seen[prev] + push!(path, visited) + return nothing + end + return seen[prev] end - return seen[prev] - end - if eltype(RT) <: ReactantPrimitive - if mode == ArrayToConcrete - runtime isa Val{:PJRT} && - (return seen[prev] = ConcretePJRTArray(prev; sharding, device, client)) - runtime isa Val{:IFRT} && - (return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client)) - error("Unsupported runtime $runtime") + if eltype(RT) <: ReactantPrimitive + if mode == ArrayToConcrete + runtime isa Val{:PJRT} && + (return seen[prev] = ConcretePJRTArray(prev; sharding, device, client)) + runtime isa Val{:IFRT} && + (return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client)) + error("Unsupported runtime $runtime") + elseif mode == TracedToTypes + # Original array can get mutated so we store a copy: + push!(path, copy(prev)) + seen[prev] = VisitedObject(length(seen) + 1) + return nothing + end elseif mode == TracedToTypes - # Original array can get mutated so we store a copy: - push!(path, copy(prev)) - seen[prev] = VisitedObject(length(seen) + 1) + push!(path, RT) + for I in eachindex(prev) + if isassigned(prev, I) + pv = prev[I] + make_tracer( + seen, + pv, + path, + mode; + track_numbers, + sharding, + runtime, + device, + client, + kwargs..., + ) + end + end return nothing end - elseif mode == TracedToTypes - push!(path, RT) + TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) + newa = Array{TT,ndims(RT)}(undef, size(prev)) + seen[prev] = newa + same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - make_tracer( + nv = make_tracer( seen, pv, - path, + append_path(path, I), mode; track_numbers, - sharding, + sharding=Base.getproperty(sharding, I), runtime, device, client, kwargs..., ) + if pv !== nv + same = false + end + @inbounds newa[I] = nv end end - return nothing - end - TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) - newa = Array{TT,ndims(RT)}(undef, size(prev)) - seen[prev] = newa - same = true - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - nv = make_tracer( - seen, - pv, - append_path(path, I), - mode; - track_numbers, - sharding=Base.getproperty(sharding, I), - runtime, - device, - client, - kwargs..., - ) - if pv !== nv - same = false - end - @inbounds newa[I] = nv + if same + seen[prev] = prev + return prev end + return newa end - if same - seen[prev] = prev - return prev - end - return newa end + Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Sharding.Mesh), diff --git a/src/Types.jl b/src/Types.jl index 72201864a2..23d0ed5626 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -228,18 +228,20 @@ function ConcretePJRTArray( return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) end -function ConcretePJRTArray( - data::Memory{T}; - client::Union{Nothing,XLA.PJRT.Client}=nothing, - idx::Union{Int,Nothing}=nothing, - device::Union{Nothing,XLA.PJRT.Device}=nothing, - sharding::Sharding.AbstractSharding=Sharding.NoSharding(), -) where {T} - theclient, thedevice = _select_client_and_device(client, idx, device, sharding) - sharded_data, shardinfo = sharding(theclient, thedevice, data) - shape = size(data) - nsharded = length(sharded_data) - return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) +if isdefined(Base, :Memory) + function ConcretePJRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.PJRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.PJRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), + ) where {T} + theclient, thedevice = _select_client_and_device(client, idx, device, sharding) + sharded_data, shardinfo = sharding(theclient, thedevice, data) + shape = size(data) + nsharded = length(sharded_data) + return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) + end end Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) @@ -370,17 +372,19 @@ function ConcreteIFRTArray( return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) end -function ConcreteIFRTArray( - data::Memory{T}; - client::Union{Nothing,XLA.IFRT.Client}=nothing, - idx::Union{Int,Nothing}=nothing, - device::Union{Nothing,XLA.IFRT.Device}=nothing, - sharding::Sharding.AbstractSharding=Sharding.NoSharding(), -) where {T} - theclient, thedevice = _select_client_and_device(client, idx, device, sharding) - sharded_data, shardinfo, padding = sharding(theclient, nothing, data) - shape = size(data) - return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo) +if isdefined(Base, :Memory) + function ConcreteIFRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.IFRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.IFRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), + ) where {T} + theclient, thedevice = _select_client_and_device(client, idx, device, sharding) + sharded_data, shardinfo, padding = sharding(theclient, nothing, data) + shape = size(data) + return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo) + end end # Assemble data from multiple arrays. Needed in distributed setting where each process wont diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 174ca15c02..7e4749bd8e 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -21,20 +21,21 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} return Buffer(buffer) end - -function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} - sizear = collect(Int64, reverse(size(memory))) - buffer = GC.@preserve memory sizear begin - @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( - client.client::Ptr{Cvoid}, - pointer(memory)::Ptr{T}, - XLA.primitive_type(T)::UInt64, - 1::Csize_t, - pointer(sizear)::Ptr{Int64}, - device.device::Ptr{Cvoid}, - )::Ptr{Cvoid} +if isdefined(Base, :Memory) + function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} + sizear = collect(Int64, reverse(size(memory))) + buffer = GC.@preserve memory sizear begin + @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( + client.client::Ptr{Cvoid}, + pointer(memory)::Ptr{T}, + XLA.primitive_type(T)::UInt64, + 1::Csize_t, + pointer(sizear)::Ptr{Int64}, + device.device::Ptr{Cvoid}, + )::Ptr{Cvoid} + end + return Buffer(buffer) end - return Buffer(buffer) end function Base.similar(a::Buffer) From 702ac1d369edd6ae3440d9193d13ff3d04d9bebd Mon Sep 17 00:00:00 2001 From: qfl3x Date: Wed, 17 Sep 2025 12:48:52 +0200 Subject: [PATCH 06/14] Formatter --- ext/ReactantFixedSizeArraysExt.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl index f72073882a..ffb1bed2c5 100644 --- a/ext/ReactantFixedSizeArraysExt.jl +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -6,26 +6,27 @@ using Reactant: TracedRArray, TracedRNumber, Ops using ReactantCore: ReactantCore function Reactant.traced_type_inner( - @nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T, N}}), + @nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}), seen, @nospecialize(mode::Reactant.TraceMode), @nospecialize(track_numbers::Type), @nospecialize(sharding), @nospecialize(runtime) -) where {T, N} +) where {T,N} T2 = Reactant.TracedRNumber{T} - return FixedSizeArrays.FixedSizeArrayDefault{T2, N} + return FixedSizeArrays.FixedSizeArrayDefault{T2,N} end Base.@nospecializeinfer function Reactant.make_tracer( seen, - @nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T, N}), + @nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}), @nospecialize(path), - mode; kwargs... -) where {T, N} + mode; + kwargs..., +) where {T,N} return Reactant.make_tracer( seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number ) end - + end From 658a3d7b02ab102cb898de94f1696ab71be713a6 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Thu, 18 Sep 2025 11:29:39 +0200 Subject: [PATCH 07/14] support for N dim arrays, tests --- ext/ReactantFixedSizeArraysExt.jl | 5 +++-- test/integration/fixedsizearrays.jl | 17 +++++++++++++++++ test/memory.jl | 11 +++++++++++ test/runtests.jl | 4 ++++ 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 test/integration/fixedsizearrays.jl create mode 100644 test/memory.jl diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl index ffb1bed2c5..012a984eb7 100644 --- a/ext/ReactantFixedSizeArraysExt.jl +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -24,9 +24,10 @@ Base.@nospecializeinfer function Reactant.make_tracer( mode; kwargs..., ) where {T,N} - return Reactant.make_tracer( + shape = size(prev) + return reshape(Reactant.make_tracer( seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number - ) + ), shape) end end diff --git a/test/integration/fixedsizearrays.jl b/test/integration/fixedsizearrays.jl new file mode 100644 index 0000000000..b77e06d1c3 --- /dev/null +++ b/test/integration/fixedsizearrays.jl @@ -0,0 +1,17 @@ + +using Reactant, Test, FixedSizeArrays + +fn(x, y) = (2 .* x .- 3) * y' + +@testset "FixedSizeArrays" begin + @testset "1D" begin + x = FixedSizeArray(fill(3.0f0, 100)) + rx = Reactant.to_rarray(x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) + end + @testset "2D" begin + x = FixedSizeArray(fill(3.0f0, (4,5))) + rx = Reactant.to_rarray(x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) + end +end diff --git a/test/memory.jl b/test/memory.jl new file mode 100644 index 0000000000..e86305bf07 --- /dev/null +++ b/test/memory.jl @@ -0,0 +1,11 @@ +using Reactant, Test + +fn(x,y) = sin.(x) .+ cos.(y) + +@testset "Memory test" begin + x = Memory{Float32}(fill(2.0f0, 10)) + x_ra = Reactant.to_rarray(x) + + @test @jit(fn(x_ra,x_ra)) ≈ fn(x,x) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 11ccf5d862..8159581b69 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") @safetestset "QA" include("qa.jl") + if isdefined(Base, :Memory) + @safetestset "Memory" include("memory.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @@ -52,6 +55,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + @safetestset "FixedSizeArrays" include("integration/fixedsizearrays.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" From 7e79d150b5f3f676e58f90f7b64c81ac569634f6 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Thu, 18 Sep 2025 11:53:28 +0200 Subject: [PATCH 08/14] formatting --- ext/ReactantFixedSizeArraysExt.jl | 9 ++++++--- src/Types.jl | 11 ++++++++--- test/integration/fixedsizearrays.jl | 2 +- test/memory.jl | 5 ++--- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl index 012a984eb7..ba5045e07a 100644 --- a/ext/ReactantFixedSizeArraysExt.jl +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -25,9 +25,12 @@ Base.@nospecializeinfer function Reactant.make_tracer( kwargs..., ) where {T,N} shape = size(prev) - return reshape(Reactant.make_tracer( - seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number - ), shape) + return reshape( + Reactant.make_tracer( + seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number + ), + shape, + ) end end diff --git a/src/Types.jl b/src/Types.jl index 23d0ed5626..b55fbfc7d7 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -240,7 +240,9 @@ if isdefined(Base, :Memory) sharded_data, shardinfo = sharding(theclient, thedevice, data) shape = size(data) nsharded = length(sharded_data) - return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) + return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}( + sharded_data, shape, shardinfo + ) end end @@ -503,8 +505,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT" ConcreteIFRTArray end -@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = - ConcreteRArray{T}(undef, Dims(shape); kwargs...) +@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{ + T +}( + undef, Dims(shape); kwargs... +) """ ConcreteRNumber( diff --git a/test/integration/fixedsizearrays.jl b/test/integration/fixedsizearrays.jl index b77e06d1c3..3200df2e57 100644 --- a/test/integration/fixedsizearrays.jl +++ b/test/integration/fixedsizearrays.jl @@ -10,7 +10,7 @@ fn(x, y) = (2 .* x .- 3) * y' @test @jit(fn(rx, rx)) ≈ fn(x, x) end @testset "2D" begin - x = FixedSizeArray(fill(3.0f0, (4,5))) + x = FixedSizeArray(fill(3.0f0, (4, 5))) rx = Reactant.to_rarray(x) @test @jit(fn(rx, rx)) ≈ fn(x, x) end diff --git a/test/memory.jl b/test/memory.jl index e86305bf07..d2a9e13558 100644 --- a/test/memory.jl +++ b/test/memory.jl @@ -1,11 +1,10 @@ using Reactant, Test -fn(x,y) = sin.(x) .+ cos.(y) +fn(x, y) = sin.(x) .+ cos.(y) @testset "Memory test" begin x = Memory{Float32}(fill(2.0f0, 10)) x_ra = Reactant.to_rarray(x) - @test @jit(fn(x_ra,x_ra)) ≈ fn(x,x) + @test @jit(fn(x_ra, x_ra)) ≈ fn(x, x) end - From 08a425ed1dddd5d2a01f2586df2fdbf4be2aee8b Mon Sep 17 00:00:00 2001 From: qfl3x Date: Thu, 18 Sep 2025 16:01:50 +0200 Subject: [PATCH 09/14] IFRT support for Memory and FixedSizeArrays --- src/xla/IFRT/Array.jl | 80 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 5a618dd131..7520de5aa6 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -38,6 +38,30 @@ function Array( return Array(buffer) end +if isdefined(Base, :Memory) + function Array( + client::Client, + memory::Base.Memory{T}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), + ) where {T<:Reactant.ReactantPrimitive} + sizear = collect(Int64, reverse(size(memory))) + buffer = GC.@preserve memory sizear begin + @ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer( + client.client::Ptr{Cvoid}, + pointer(memory)::Ptr{T}, + XLA.primitive_type(T)::UInt64, + 1::Csize_t, + sizear::Ptr{Int64}, + 0::Cint, # kAlwaysCopy + device.device::Ptr{Cvoid}, + string(memory_kind)::Cstring, + )::Ptr{Cvoid} + end + return Array(buffer) + end +end + function Array( client::Client, array::Base.Array{T,N}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive,N} @@ -143,6 +167,45 @@ function Array( return Array(buffer) end +if isdefined(Base, :Memory) + function Array( + client::Client, memory::Base.Memory{T}, sharding::Sharding + ) where {T<:Reactant.ReactantPrimitive} + all_devices = XLA.devices(sharding) + all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1)) + hlo_sharding = convert(XLA.HloSharding, sharding) + + slices, _ = XLA.sharding_to_concrete_array_indices( + hlo_sharding, size(memory), all_logical_device_ids + ) + + seen_slice = Dict{NTuple{N,UnitRange{Int64}},Int}() + host_buffers = Base.Array{T,1}[] + addressable_shard_indices = Vector{Int64}[] + + cur_shard = 0 + for (slice, device) in zip(slices, all_devices) + XLA.is_addressable(device) || continue + + if haskey(seen_slice, slice) + idx = seen_slice[slice] + push!(addressable_shard_indices[idx], cur_shard) + else + host_buffer = let slice = memory[slice...] + slice isa Number ? collect(slice) : slice + end + push!(host_buffers, host_buffer) + push!(addressable_shard_indices, Int64[cur_shard]) + seen_slice[slice] = length(host_buffers) + end + + cur_shard += 1 + end + + return Array(client, host_buffers, addressable_shard_indices, size(memory), sharding) + end +end + function Array( client::Client, array::Base.Array{T,N}, sharding ) where {T<:Reactant.ReactantPrimitive,N} @@ -158,6 +221,23 @@ function Array( return Array(client, array, ifrt_sharding) end +if isdefined(Base, :Memory) + function Array( + client::Client, memory::Base.Memory{T}, sharding + ) where {T<:Reactant.ReactantPrimitive} + @assert sharding isa Reactant.Sharding.AbstractSharding + if !(sharding isa Reactant.Sharding.HloSharding) + sharding = Reactant.Sharding.HloSharding(sharding, size(memory)) + end + + (; hlo_sharding, mesh) = sharding + devices = XLA.get_device.((client,), mesh.device_ids) + ifrt_sharding = Sharding([devices...], hlo_sharding) + + return Array(client, memory, ifrt_sharding) + end +end + @inline function XLA.free_buffer(buffer::Array) if buffer.buffer != C_NULL @ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid From 8b9f24ddc3bf2cd7c0db76b36979981649722e75 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Thu, 18 Sep 2025 16:03:43 +0200 Subject: [PATCH 10/14] Formatter pass --- src/xla/IFRT/Array.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 7520de5aa6..adcc1cea40 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -202,7 +202,9 @@ if isdefined(Base, :Memory) cur_shard += 1 end - return Array(client, host_buffers, addressable_shard_indices, size(memory), sharding) + return Array( + client, host_buffers, addressable_shard_indices, size(memory), sharding + ) end end From dc00c6a94d0233f142b462c12da95fbe2c00fe30 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Fri, 19 Sep 2025 15:29:20 +0200 Subject: [PATCH 11/14] Refactored Array and Memory common code --- src/Tracing.jl | 97 ++++++++------------------------------- src/Types.jl | 41 ++++++++++------- src/xla/IFRT/Array.jl | 100 ++++++++++++++--------------------------- src/xla/PJRT/Buffer.jl | 20 +++------ 4 files changed, 85 insertions(+), 173 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index ec63f241bf..fe014bca5f 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1516,11 +1516,11 @@ Base.@nospecializeinfer function make_tracer( ) end -Base.@nospecializeinfer function make_tracer( +Base.@nospecializeinfer function make_tracer_array( seen, - @nospecialize(prev::Array), + @nospecialize(prev::AbstractArray), @nospecialize(path), - mode; + mode, @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), @@ -1605,6 +1605,21 @@ Base.@nospecializeinfer function make_tracer( return newa end +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Array), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + @nospecialize(device = nothing), + @nospecialize(client = nothing), + kwargs..., +) + return make_tracer_array(seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs...) +end + Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Dict{Key,Value}), @@ -1825,81 +1840,7 @@ if isdefined(Base, :Memory) @nospecialize(client = nothing), kwargs..., ) - RT = Core.Typeof(prev) - # XXX: If someone wants to shard the same array with different shardings, we need to - # somehow handle this correctly... Right now we just use the first sharding. - if mode != NoStopTracedTrack && haskey(seen, prev) - if mode == TracedToTypes - visited = seen[prev] - push!(path, visited) - return nothing - end - return seen[prev] - end - if eltype(RT) <: ReactantPrimitive - if mode == ArrayToConcrete - runtime isa Val{:PJRT} && - (return seen[prev] = ConcretePJRTArray(prev; sharding, device, client)) - runtime isa Val{:IFRT} && - (return seen[prev] = ConcreteIFRTArray(prev; sharding, device, client)) - error("Unsupported runtime $runtime") - elseif mode == TracedToTypes - # Original array can get mutated so we store a copy: - push!(path, copy(prev)) - seen[prev] = VisitedObject(length(seen) + 1) - return nothing - end - elseif mode == TracedToTypes - push!(path, RT) - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - make_tracer( - seen, - pv, - path, - mode; - track_numbers, - sharding, - runtime, - device, - client, - kwargs..., - ) - end - end - return nothing - end - TT = traced_type(eltype(RT), Val(mode), track_numbers, sharding, runtime) - newa = Array{TT,ndims(RT)}(undef, size(prev)) - seen[prev] = newa - same = true - for I in eachindex(prev) - if isassigned(prev, I) - pv = prev[I] - nv = make_tracer( - seen, - pv, - append_path(path, I), - mode; - track_numbers, - sharding=Base.getproperty(sharding, I), - runtime, - device, - client, - kwargs..., - ) - if pv !== nv - same = false - end - @inbounds newa[I] = nv - end - end - if same - seen[prev] = prev - return prev - end - return newa + return make_tracer_array(seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs...) end end diff --git a/src/Types.jl b/src/Types.jl index b55fbfc7d7..09d47d5fb9 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -214,8 +214,8 @@ function ConcretePJRTArray{T,N}( return ConcretePJRTArray{T,N,D,typeof(sharding)}(data, shape, sharding) end -function ConcretePJRTArray( - data::Array{T,N}; +function make_concrete_PJRT_array( + data::AbstractArray{T,N}, client::Union{Nothing,XLA.PJRT.Client}=nothing, idx::Union{Int,Nothing}=nothing, device::Union{Nothing,XLA.PJRT.Device}=nothing, @@ -228,6 +228,16 @@ function ConcretePJRTArray( return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) end +function ConcretePJRTArray( + data::Array{T,N}; + client::Union{Nothing,XLA.PJRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.PJRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T,N} + return make_concrete_PJRT_array(data, client, idx, device, sharding) +end + if isdefined(Base, :Memory) function ConcretePJRTArray( data::Memory{T}; @@ -236,13 +246,7 @@ if isdefined(Base, :Memory) device::Union{Nothing,XLA.PJRT.Device}=nothing, sharding::Sharding.AbstractSharding=Sharding.NoSharding(), ) where {T} - theclient, thedevice = _select_client_and_device(client, idx, device, sharding) - sharded_data, shardinfo = sharding(theclient, thedevice, data) - shape = size(data) - nsharded = length(sharded_data) - return ConcretePJRTArray{T,1,nsharded,typeof(shardinfo)}( - sharded_data, shape, shardinfo - ) + return make_concrete_PJRT_array(data, client, idx, device, sharding) end end @@ -360,8 +364,8 @@ function ConcreteIFRTArray{T,N}( return ConcreteIFRTArray{T,N,typeof(sharding)}(data, shape, sharding) end -function ConcreteIFRTArray( - data::Array{T,N}; +function make_concrete_IFRT_array( + data::AbstractArray{T,N}, client::Union{Nothing,XLA.IFRT.Client}=nothing, idx::Union{Int,Nothing}=nothing, device::Union{Nothing,XLA.IFRT.Device}=nothing, @@ -374,6 +378,16 @@ function ConcreteIFRTArray( return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) end +function ConcreteIFRTArray( + data::Array{T,N}; + client::Union{Nothing,XLA.IFRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.IFRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T,N} + return make_concrete_IFRT_array(data, client, idx, device, sharding) +end + if isdefined(Base, :Memory) function ConcreteIFRTArray( data::Memory{T}; @@ -382,10 +396,7 @@ if isdefined(Base, :Memory) device::Union{Nothing,XLA.IFRT.Device}=nothing, sharding::Sharding.AbstractSharding=Sharding.NoSharding(), ) where {T} - theclient, thedevice = _select_client_and_device(client, idx, device, sharding) - sharded_data, shardinfo, padding = sharding(theclient, nothing, data) - shape = size(data) - return ConcreteIFRTArray{T,1,typeof(shardinfo)}(sharded_data, shape, shardinfo) + return make_concrete_IFRT_array(data, client, idx, device, sharding) end end diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index adcc1cea40..a2fe87faae 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -16,11 +16,11 @@ function Array( return Array(client, fill(array), device, memory_kind) end -function Array( +function make_array_singleshard( client::Client, - array::Base.Array{T,N}, - device::Device=XLA.default_device(client), - memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), + array::AbstractArray{T,N}, + device::Device, + memory_kind::AbstractString, ) where {T<:Reactant.ReactantPrimitive,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @@ -38,6 +38,15 @@ function Array( return Array(buffer) end +function Array( + client::Client, + array::Base.Array{T,N}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +) where {T<:Reactant.ReactantPrimitive,N} + return make_array_singleshard(client, array, device, memory_kind) +end + if isdefined(Base, :Memory) function Array( client::Client, @@ -45,25 +54,12 @@ if isdefined(Base, :Memory) device::Device=XLA.default_device(client), memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), ) where {T<:Reactant.ReactantPrimitive} - sizear = collect(Int64, reverse(size(memory))) - buffer = GC.@preserve memory sizear begin - @ccall MLIR.API.mlir_c.ifrt_client_make_single_shard_array_from_host_buffer( - client.client::Ptr{Cvoid}, - pointer(memory)::Ptr{T}, - XLA.primitive_type(T)::UInt64, - 1::Csize_t, - sizear::Ptr{Int64}, - 0::Cint, # kAlwaysCopy - device.device::Ptr{Cvoid}, - string(memory_kind)::Cstring, - )::Ptr{Cvoid} - end - return Array(buffer) + return make_array_singleshard(client, memory, device, memory_kind) end end -function Array( - client::Client, array::Base.Array{T,N}, sharding::Sharding +function make_array_sharding( + client::Client, array::AbstractArray{T,N}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive,N} all_devices = XLA.devices(sharding) all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1)) @@ -99,6 +95,12 @@ function Array( return Array(client, host_buffers, addressable_shard_indices, size(array), sharding) end +function Array( + client::Client, array::Base.Array{T,N}, sharding::Sharding +) where {T<:Reactant.ReactantPrimitive,N} + make_array_sharding(client, array, sharding) +end + function Array( client::Client, host_buffers::Vector{Base.Array{T,N}}, @@ -106,7 +108,7 @@ function Array( array_shape, sharding::Sharding, ) where {T<:Reactant.ReactantPrimitive,N} - # Construct using the slower path, the faster path is only implemented for IFRT-Proxy + # make using the slower path, the faster path is only implemented for IFRT-Proxy # and seems to cause issues with IFRT-PJRT all_addressable_devices = filter(XLA.is_addressable, XLA.devices(sharding)) @@ -171,45 +173,12 @@ if isdefined(Base, :Memory) function Array( client::Client, memory::Base.Memory{T}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive} - all_devices = XLA.devices(sharding) - all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1)) - hlo_sharding = convert(XLA.HloSharding, sharding) - - slices, _ = XLA.sharding_to_concrete_array_indices( - hlo_sharding, size(memory), all_logical_device_ids - ) - - seen_slice = Dict{NTuple{N,UnitRange{Int64}},Int}() - host_buffers = Base.Array{T,1}[] - addressable_shard_indices = Vector{Int64}[] - - cur_shard = 0 - for (slice, device) in zip(slices, all_devices) - XLA.is_addressable(device) || continue - - if haskey(seen_slice, slice) - idx = seen_slice[slice] - push!(addressable_shard_indices[idx], cur_shard) - else - host_buffer = let slice = memory[slice...] - slice isa Number ? collect(slice) : slice - end - push!(host_buffers, host_buffer) - push!(addressable_shard_indices, Int64[cur_shard]) - seen_slice[slice] = length(host_buffers) - end - - cur_shard += 1 - end - - return Array( - client, host_buffers, addressable_shard_indices, size(memory), sharding - ) + make_array_sharding(client, memory, sharding) end end -function Array( - client::Client, array::Base.Array{T,N}, sharding +function make_array_ifrt_sharding( + client::Client, array::Base.AbstractArray{T,N}, sharding ) where {T<:Reactant.ReactantPrimitive,N} @assert sharding isa Reactant.Sharding.AbstractSharding if !(sharding isa Reactant.Sharding.HloSharding) @@ -223,20 +192,17 @@ function Array( return Array(client, array, ifrt_sharding) end +function Array( + client::Client, array::Base.Array{T,N}, sharding +) where {T<:Reactant.ReactantPrimitive,N} + return make_array_ifrt_sharding(client, array, sharding) +end + if isdefined(Base, :Memory) function Array( client::Client, memory::Base.Memory{T}, sharding ) where {T<:Reactant.ReactantPrimitive} - @assert sharding isa Reactant.Sharding.AbstractSharding - if !(sharding isa Reactant.Sharding.HloSharding) - sharding = Reactant.Sharding.HloSharding(sharding, size(memory)) - end - - (; hlo_sharding, mesh) = sharding - devices = XLA.get_device.((client,), mesh.device_ids) - ifrt_sharding = Sharding([devices...], hlo_sharding) - - return Array(client, memory, ifrt_sharding) + return make_array_ifrt_sharding(client, memory, sharding) end end diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 7e4749bd8e..2979350cd6 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -6,7 +6,8 @@ mutable struct Buffer <: XLA.AbstractBuffer end end -function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} +function make_buffer_array( + client::Client, array::Array{T,N}, device::Device) where {T,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( @@ -21,20 +22,13 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} return Buffer(buffer) end +function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} + return make_buffer_array(client, array, device) +end + if isdefined(Base, :Memory) function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} - sizear = collect(Int64, reverse(size(memory))) - buffer = GC.@preserve memory sizear begin - @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( - client.client::Ptr{Cvoid}, - pointer(memory)::Ptr{T}, - XLA.primitive_type(T)::UInt64, - 1::Csize_t, - pointer(sizear)::Ptr{Int64}, - device.device::Ptr{Cvoid}, - )::Ptr{Cvoid} - end - return Buffer(buffer) + return make_buffer_array(client, array, device) end end From ff2d39b0dc0fe6604e70c3680829c2781b4ccd61 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 19 Sep 2025 10:23:25 -0400 Subject: [PATCH 12/14] Update src/xla/PJRT/Buffer.jl --- src/xla/PJRT/Buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 2979350cd6..781ec325c1 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -28,7 +28,7 @@ end if isdefined(Base, :Memory) function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} - return make_buffer_array(client, array, device) + return make_buffer_array(client, memory, device) end end From 995c8a3868f569980a4d7c778e68482a6656e1d8 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Mon, 22 Sep 2025 07:54:47 +0200 Subject: [PATCH 13/14] Fixed make_buffer_array --- src/xla/PJRT/Buffer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 781ec325c1..acfe514a4a 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -7,7 +7,7 @@ mutable struct Buffer <: XLA.AbstractBuffer end function make_buffer_array( - client::Client, array::Array{T,N}, device::Device) where {T,N} + client::Client, array::AbstractArray{T,N}, device::Device) where {T,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( From c1842c320bcc1f1df75e1f8bc362d461cd79b9e0 Mon Sep 17 00:00:00 2001 From: qfl3x Date: Tue, 23 Sep 2025 11:06:23 +0200 Subject: [PATCH 14/14] formatting --- src/Tracing.jl | 17 +++++++++++++++-- src/xla/IFRT/Array.jl | 9 +++------ src/xla/PJRT/Buffer.jl | 3 ++- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index fe014bca5f..41ecd8e601 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1617,7 +1617,9 @@ Base.@nospecializeinfer function make_tracer( @nospecialize(client = nothing), kwargs..., ) - return make_tracer_array(seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs...) + return make_tracer_array( + seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs... + ) end Base.@nospecializeinfer function make_tracer( @@ -1840,7 +1842,18 @@ if isdefined(Base, :Memory) @nospecialize(client = nothing), kwargs..., ) - return make_tracer_array(seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs...) + return make_tracer_array( + seen, + prev, + path, + mode, + track_numbers, + sharding, + runtime, + device, + client, + kwargs..., + ) end end diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index a2fe87faae..c45dfc34c4 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -17,10 +17,7 @@ function Array( end function make_array_singleshard( - client::Client, - array::AbstractArray{T,N}, - device::Device, - memory_kind::AbstractString, + client::Client, array::AbstractArray{T,N}, device::Device, memory_kind::AbstractString ) where {T<:Reactant.ReactantPrimitive,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @@ -98,7 +95,7 @@ end function Array( client::Client, array::Base.Array{T,N}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive,N} - make_array_sharding(client, array, sharding) + return make_array_sharding(client, array, sharding) end function Array( @@ -173,7 +170,7 @@ if isdefined(Base, :Memory) function Array( client::Client, memory::Base.Memory{T}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive} - make_array_sharding(client, memory, sharding) + return make_array_sharding(client, memory, sharding) end end diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index acfe514a4a..a8a6725717 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -7,7 +7,8 @@ mutable struct Buffer <: XLA.AbstractBuffer end function make_buffer_array( - client::Client, array::AbstractArray{T,N}, device::Device) where {T,N} + client::Client, array::AbstractArray{T,N}, device::Device +) where {T,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @ccall MLIR.API.mlir_c.ArrayFromHostBuffer(