From e0e7bac5421f61e8f964f810c3bbea97e8b6bee1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 19 Aug 2025 18:03:16 -0400 Subject: [PATCH 1/2] feat: FillArrays support --- Project.toml | 5 ++- ext/ReactantFillArraysExt.jl | 87 ++++++++++++++++++++++++++++++++++++ src/Compiler.jl | 2 +- src/ConcreteRArray.jl | 1 + src/Reactant.jl | 9 ++++ src/xla/PJRT/Buffer.jl | 4 +- 6 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 ext/ReactantFillArraysExt.jl diff --git a/Project.toml b/Project.toml index 5d450beaa2..29297148c3 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" @@ -54,6 +55,7 @@ ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"] ReactantDLFP8TypesExt = "DLFP8Types" +ReactantFillArraysExt = "FillArrays" ReactantFloat8sExt = "Float8s" ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" @@ -77,6 +79,7 @@ Downloads = "1.6" EnumX = "1" Enzyme = "0.13.49" EnzymeCore = "0.8.11" +FillArrays = "1.13" Float8s = "0.1" Functors = "0.5" GPUArraysCore = "0.2" @@ -103,9 +106,9 @@ Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" Statistics = "1.10" -unzip_jll = "6" YaoBlocks = "0.13, 0.14" julia = "1.10" +unzip_jll = "6" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/ReactantFillArraysExt.jl b/ext/ReactantFillArraysExt.jl new file mode 100644 index 0000000000..017fd8ca82 --- /dev/null +++ b/ext/ReactantFillArraysExt.jl @@ -0,0 +1,87 @@ +module ReactantFillArraysExt + +using Reactant: Reactant, TracedUtils, TracedRNumber, Ops, Sharding, unwrapped_eltype +using ReactantCore: ReactantCore +using FillArrays: FillArrays, AbstractFill, Fill, Ones, Zeros, OneElement +using GPUArraysCore: @allowscalar + +# Tracing +Reactant._parent_type(T::Type{<:AbstractFill}) = T +Reactant._parent_type(T::Type{<:OneElement}) = T + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(FA::Type{Fill{T,N,Axes}}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {T,N,Axes} + # T will be a number so we need to trace it + return Fill{Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes} +end + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, @nospecialize(prev::Fill{T,N,Axes}), @nospecialize(path), mode; kwargs... +) where {T,N,Axes} + return Fill( + Reactant.make_tracer( + seen, prev.value, (path..., 1), mode; kwargs..., track_numbers=Number + ), + prev.axes, + ) +end + +Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(FA::Type{OneElement{T,N,I,A}}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {T,N,I,A} + # T will be a number so we need to trace it + return OneElement{ + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,I,A + } +end + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, @nospecialize(prev::OneElement{T,N,I,A}), @nospecialize(path), mode; kwargs... +) where {T,N,I,A} + return OneElement( + Reactant.make_tracer( + seen, prev.val, (path..., 1), mode; kwargs..., track_numbers=Number + ), + prev.ind, + prev.axes, + ) +end + +# Materialize into a dense array +function ReactantCore.materialize_traced_array(x::Fill{T}) where {T} + return TracedUtils.broadcast_to_size( + TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(T)}, x.value), size(x) + ) +end + +function ReactantCore.materialize_traced_array(x::Ones{T}) where {T} + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(1), size(x)) +end + +function ReactantCore.materialize_traced_array(x::Zeros{T}) where {T} + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x)) +end + +function ReactantCore.materialize_traced_array(x::OneElement{T}) where {T} + y = TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), size(x)) + @allowscalar setindex!(y, x.val, x.ind...) + return y +end + +# some functions to avoid bad performance +function Base.similar(::OneElement{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T} + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims) +end + +end diff --git a/src/Compiler.jl b/src/Compiler.jl index 2dce2dac1b..7fcf817af6 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -288,7 +288,7 @@ function create_result( sym = Symbol("result", var_idx[]) var_idx[] += 1 - @assert haskey(result_stores, path) + @assert haskey(result_stores, path) "Expected $(path) in $(keys(result_stores))" restore = result_stores[path] delete!(result_stores, path) if path_to_shard_info !== nothing && haskey(path_to_shard_info, path) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index e4d7656c79..60d36aad70 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -83,6 +83,7 @@ for T in Base.uniontypes(ReactantPrimitive) end function Base.convert(::Type{T}, x::AbstractConcreteNumber) where {T<:Number} + T == typeof(x) && return x return convert(T, to_number(x)) end diff --git a/src/Reactant.jl b/src/Reactant.jl index 6b739eaccb..b9cd2ef129 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -42,11 +42,20 @@ function ancestor(T::Type{<:AbstractArray}) p_T == T && return T return ancestor(p_T) end + if applicable(_parent_type, T) + p_T = _parent_type(T) + p_T == T && return T + return ancestor(p_T) + end @warn "`Adapt.parent_type` is not implemented for $(T). Assuming $T isn't a wrapped \ array." maxlog = 1 return T end +# A lot of packages don't define `Adapt.parent_type`. We use `_parent_type` as a way to +# define the parent type of an array without type-piracy. +function _parent_type end + include("accelerators/Accelerators.jl") using .Accelerators.TPU: has_tpu diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index d1e035482d..5fec870c44 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -167,7 +167,9 @@ function XLA.buffer_on_cpu(buffer::Buffer) end function XLA.to_host(buffer::Buffer, data, sharding) - GC.@preserve buffer begin + @assert data !== C_NULL + @assert buffer.buffer !== C_NULL + GC.@preserve buffer data begin @ccall MLIR.API.mlir_c.BufferToHost( buffer.buffer::Ptr{Cvoid}, data::Ptr{Cvoid} )::Cvoid From 93dd2e8b61634c790fe24821c30b9936aa11aaae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 19 Aug 2025 20:37:04 -0400 Subject: [PATCH 2/2] fix: Ones/Zeros --- ext/ReactantFillArraysExt.jl | 58 +++++++++++++++++++++++++++------- test/Project.toml | 1 + test/integration/fillarrays.jl | 29 +++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 77 insertions(+), 12 deletions(-) create mode 100644 test/integration/fillarrays.jl diff --git a/ext/ReactantFillArraysExt.jl b/ext/ReactantFillArraysExt.jl index 017fd8ca82..7f6d26be2d 100644 --- a/ext/ReactantFillArraysExt.jl +++ b/ext/ReactantFillArraysExt.jl @@ -9,16 +9,20 @@ using GPUArraysCore: @allowscalar Reactant._parent_type(T::Type{<:AbstractFill}) = T Reactant._parent_type(T::Type{<:OneElement}) = T -Base.@nospecializeinfer function Reactant.traced_type_inner( - @nospecialize(FA::Type{Fill{T,N,Axes}}), - seen, - mode::Reactant.TraceMode, - @nospecialize(track_numbers::Type), - @nospecialize(sharding), - @nospecialize(runtime) -) where {T,N,Axes} - # T will be a number so we need to trace it - return Fill{Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes} +for AT in (Fill, Ones, Zeros) + @eval Base.@nospecializeinfer function Reactant.traced_type_inner( + @nospecialize(FA::Type{$(AT){T,N,Axes}}), + seen, + mode::Reactant.TraceMode, + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) + ) where {T,N,Axes} + # T will be a number so we need to trace it + return $(AT){ + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime),N,Axes + } + end end Base.@nospecializeinfer function Reactant.make_tracer( @@ -32,6 +36,34 @@ Base.@nospecializeinfer function Reactant.make_tracer( ) end +Base.@nospecializeinfer function Reactant.make_tracer( + seen, + @nospecialize(prev::Ones{T,N,Axes}), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + kwargs..., +) where {T,N,Axes} + return Ones( + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes + ) +end + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, + @nospecialize(prev::Zeros{T,N,Axes}), + @nospecialize(path), + mode; + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + kwargs..., +) where {T,N,Axes} + return Zeros( + Reactant.traced_type_inner(T, seen, mode, Number, sharding, runtime), prev.axes + ) +end + Base.@nospecializeinfer function Reactant.traced_type_inner( @nospecialize(FA::Type{OneElement{T,N,I,A}}), seen, @@ -80,8 +112,10 @@ function ReactantCore.materialize_traced_array(x::OneElement{T}) where {T} end # some functions to avoid bad performance -function Base.similar(::OneElement{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T} - return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims) +for AT in (Fill, Ones, Zeros, OneElement) + @eval function Base.similar(x::$AT{<:TracedRNumber}, ::Type{T}, dims::Dims) where {T} + return TracedUtils.broadcast_to_size(unwrapped_eltype(T)(0), dims) + end end end diff --git a/test/Project.toml b/test/Project.toml index 8754fc59fe..35bfe7b725 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/test/integration/fillarrays.jl b/test/integration/fillarrays.jl new file mode 100644 index 0000000000..f7d7359fe6 --- /dev/null +++ b/test/integration/fillarrays.jl @@ -0,0 +1,29 @@ +using Reactant, Test, FillArrays + +fn(x, y) = (2 .* x .- 3) * y' + +@testset "Fill" begin + x = Fill(2.0f0, 4, 5) + rx = Reactant.to_rarray(x) + + @test @jit(fn(rx, rx)) ≈ fn(x, x) + + @testset "Ones" begin + y = Ones(Float32, 4, 5) + ry = Reactant.to_rarray(y) + @test @jit(fn(rx, ry)) ≈ fn(x, y) + end + + @testset "Zeros" begin + y = Zeros(Float32, 4, 5) + ry = Reactant.to_rarray(y) + @test @jit(fn(rx, ry)) ≈ fn(x, y) + end +end + +@testset "OneElement" begin + x = OneElement(3.4f0, (3, 4), (32, 32)) + rx = Reactant.to_rarray(x) + + @test @jit(fn(rx, rx)) ≈ fn(x, x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..ae1ed0200e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Random" include("integration/random.jl") @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") + @safetestset "FillArrays" include("integration/fillarrays.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"