From 3f715d1bc1bd6e4ed8c2c00bc38b0027beede30e Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 21 Sep 2025 23:19:14 +0200 Subject: [PATCH 1/4] Make `RArray` a subtype of `DenseArray` --- src/Types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..26bb998247 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -2,7 +2,7 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end abstract type AbstractConcreteNumber{T} <: RNumber{T} end -abstract type RArray{T,N} <: AbstractArray{T,N} end +abstract type RArray{T,N} <: DenseArray{T,N} end abstract type AbstractConcreteArray{T,N} <: RArray{T,N} end From 8988f3f70af7e976ffef798c4a6bbf5eeabeddb2 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Mon, 22 Sep 2025 11:18:35 +0200 Subject: [PATCH 2/4] Make `AnyTracedRArray` a `DenseArray` subtype --- src/Types.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Types.jl b/src/Types.jl index 26bb998247..a5c50a9d9d 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -118,7 +118,7 @@ end @leaf TracedUnitRange Adapt.parent_type(::Type{TracedUnitRange{T}}) where {T} = TracedUnitRange{T} -const AnyTracedRArray{T,N} = AbstractArray{TracedRNumber{T},N} +const AnyTracedRArray{T,N} = DenseArray{TracedRNumber{T},N} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} From 04e3d42e2c259d37b269b61ac50336a091a141fc Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Mon, 22 Sep 2025 21:36:20 +0200 Subject: [PATCH 3/4] Fix permutedims errors --- src/TracedRArray.jl | 11 +++++++++-- src/Types.jl | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 916d0d13ee..3a82e5b66e 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -265,8 +265,15 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC return print(io, "TracedRArray{", T, ",", N, "N}(", X.paths, ", size=", size(X), ")") end -function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} - return @opcall transpose(materialize_traced_array(A), Int64[perm...]) +for ArrayType in ( + :(AnyTracedRArray{T,N}), + :(TracedRArray{T,N}), + :(SubArray{<:TracedRNumber{T},N,<:TracedRArray}), + :(Base.ReshapedArray{<:TracedRNumber{T},N,<:TracedRArray}) + ) + @eval function Base.permutedims(A::$ArrayType, perm) where {T,N} + return @opcall transpose(materialize_traced_array(A), Int64[perm...]) + end end for (jlop, hloop, hlocomp, merge) in diff --git a/src/Types.jl b/src/Types.jl index a5c50a9d9d..26bb998247 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -118,7 +118,7 @@ end @leaf TracedUnitRange Adapt.parent_type(::Type{TracedUnitRange{T}}) where {T} = TracedUnitRange{T} -const AnyTracedRArray{T,N} = DenseArray{TracedRNumber{T},N} +const AnyTracedRArray{T,N} = AbstractArray{TracedRNumber{T},N} const AnyTracedRVector{T} = AnyTracedRArray{T,1} const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} From 68a6d9620172aeecd3f0931f8398c6a83ad2e97f Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Mon, 22 Sep 2025 23:45:19 +0200 Subject: [PATCH 4/4] Fix errors --- src/ConcreteRArray.jl | 2 +- src/TracedRArray.jl | 27 +++++++++++++++++++++++++++ src/Types.jl | 5 +++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 662889e49e..908bf429ef 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -200,7 +200,7 @@ for jlop in ( :(Base.:^), :(Base.:(==)), ), - T in (AbstractConcreteNumber, AbstractConcreteArray{<:Any,0}) + T in (AbstractConcreteNumber, AbstractConcreteArray{<:Number,0}) @eval begin $(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y)) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 3a82e5b66e..565d158122 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -23,6 +23,33 @@ Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) Base.IndexStyle(::Type{<:TracedRArray}) = Base.IndexLinear() +Base.elsize(::Type{TracedRArray{T,N}}) where {T,N} = sizeof(T) + +const ArrayTypesAlias = ( + :(TracedRArray), + :(SubArray{<:TracedRNumber,<:Any,<:TracedRArray}), + :(Base.ReshapedArray{<:TracedRNumber,<:Any,<:TracedRArray}), + :(SubArray{<:TracedRNumber,<:Any,<:Base.ReshapedArray{<:TracedRNumber,<:Any,<:TracedRArray}}), + :(Base.ReshapedArray{<:TracedRNumber,<:Any,<:SubArray{<:TracedRNumber,<:Any,<:TracedRArray}}), +) +for ArrayType1 in ArrayTypesAlias + for ArrayType2 in ArrayTypesAlias + @eval Base.mightalias(::$ArrayType1, ::$ArrayType2) = false + end +end + +# Base.mightalias(::TracedRArray, ::TracedRArray) = false +# Base.mightalias( +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, +# ) = false +# Base.mightalias( +# ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray}, ::TracedRArray +# ) = false +# Base.mightalias( +# ::TracedRArray, ::SubArray{<:TracedRNumber,<:Any,<:TracedRArray} +# ) = false + # This is required otherwise we will copy a tracedrarray each time # we use it Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T, x) diff --git a/src/Types.jl b/src/Types.jl index 26bb998247..ab92991786 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -52,6 +52,11 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end +Base.elsize(::Type{TracedRNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{RNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{<:AbstractConcreteNumber{T}}) where {T} = sizeof(T) +Base.elsize(::Type{<:AbstractConcreteArray{T}}) where {T} = sizeof(T) + function repath(x::TracedRNumber{T}, paths) where {T} return TracedRNumber{T}(paths, x.mlir_data) end