From 6aec75d3517942cc8ba13be10f762c02685561bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 12:44:22 +0530 Subject: [PATCH 01/18] fix: manually zero out the lower triangular and upper triangular values --- src/Ops.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 748d06f662..ceb6d5d962 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -13,6 +13,8 @@ using ..Reactant: mlir_type, mlir_stacktrace +using LinearAlgebra: triu!, tril! + struct Token mlir_data::MLIR.IR.Value end @@ -485,13 +487,20 @@ function cholesky( lower::Bool=false, location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), ) where {T,N} - lower = MLIR.IR.Attribute(lower) res = MLIR.IR.result( stablehlo.cholesky( - x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), lower, location + x.mlir_data; + result=mlir_type(TracedRArray{T,N}, size(x)), + lower=MLIR.IR.Attribute(lower), + location, ), ) - return TracedRArray{T,N}((), res, size(x)) + res = TracedRArray{T,N}((), res, size(x)) + + # See https://github.com/EnzymeAD/Reactant.jl/issues/338 for why we need to do this + lower ? tril!(res) : triu!(res) + + return res end function clamp( From ea32177a7abd54727c763f36fd94400e2dce6dd0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 12:59:50 +0530 Subject: [PATCH 02/18] fix: only do it in tests --- src/ConcreteRArray.jl | 2 ++ src/Ops.jl | 9 +-------- src/Reactant.jl | 12 ++++++++++++ src/TracedRArray.jl | 8 -------- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 42c454b31a..3febb4bcfa 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -8,6 +8,8 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N} shape::NTuple{N,Int} end +const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} + mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer end diff --git a/src/Ops.jl b/src/Ops.jl index ceb6d5d962..4bd92ceea9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -13,8 +13,6 @@ using ..Reactant: mlir_type, mlir_stacktrace -using LinearAlgebra: triu!, tril! - struct Token mlir_data::MLIR.IR.Value end @@ -495,12 +493,7 @@ function cholesky( location, ), ) - res = TracedRArray{T,N}((), res, size(x)) - - # See https://github.com/EnzymeAD/Reactant.jl/issues/338 for why we need to do this - lower ? tril!(res) : triu!(res) - - return res + return TracedRArray{T,N}((), res, size(x)) end function clamp( diff --git a/src/Reactant.jl b/src/Reactant.jl index 0b73d3d96e..348a541f5d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -89,6 +89,18 @@ function Enzyme.make_zero( return res end +function ancestor(x::AbstractArray) + p_x = parent(x) + p_x === x && return x + return ancestor(p_x) +end + +function get_ancestor_indices(x::AbstractArray, indices...) + p_x = parent(x) + p_x === x && return indices + return get_ancestor_indices(p_x, Base.reindex(parentindices(x), indices)...) +end + include("mlir/MLIR.jl") include("XLA.jl") include("Interpreter.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 97a29e56f9..0dd4b7261c 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -72,14 +72,6 @@ function set_mlir_data!(x::AnyTracedRArray, data) return x end -ancestor(x::TracedRArray) = x -ancestor(x::WrappedTracedRArray) = ancestor(parent(x)) - -get_ancestor_indices(::TracedRArray, indices...) = indices -function get_ancestor_indices(x::WrappedTracedRArray, indices...) - return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) -end - function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} From 575dc75ded227497ff25aac435b850b68eb18d24 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 13:13:53 +0530 Subject: [PATCH 03/18] revert: change in Ops.cholesky --- src/Ops.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 4bd92ceea9..748d06f662 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -485,12 +485,10 @@ function cholesky( lower::Bool=false, location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), ) where {T,N} + lower = MLIR.IR.Attribute(lower) res = MLIR.IR.result( stablehlo.cholesky( - x.mlir_data; - result=mlir_type(TracedRArray{T,N}, size(x)), - lower=MLIR.IR.Attribute(lower), - location, + x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), lower, location ), ) return TracedRArray{T,N}((), res, size(x)) From fde9d335574e2d8d25a00a642352c8834bea5742 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 7 Dec 2024 13:15:15 +0530 Subject: [PATCH 04/18] revert: remove unnecessary changes --- src/ConcreteRArray.jl | 2 -- src/Reactant.jl | 12 ------------ src/TracedRArray.jl | 8 ++++++++ 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 3febb4bcfa..42c454b31a 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -8,8 +8,6 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N} shape::NTuple{N,Int} end -const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} - mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer end diff --git a/src/Reactant.jl b/src/Reactant.jl index 348a541f5d..0b73d3d96e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -89,18 +89,6 @@ function Enzyme.make_zero( return res end -function ancestor(x::AbstractArray) - p_x = parent(x) - p_x === x && return x - return ancestor(p_x) -end - -function get_ancestor_indices(x::AbstractArray, indices...) - p_x = parent(x) - p_x === x && return indices - return get_ancestor_indices(p_x, Base.reindex(parentindices(x), indices)...) -end - include("mlir/MLIR.jl") include("XLA.jl") include("Interpreter.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0dd4b7261c..97a29e56f9 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -72,6 +72,14 @@ function set_mlir_data!(x::AnyTracedRArray, data) return x end +ancestor(x::TracedRArray) = x +ancestor(x::WrappedTracedRArray) = ancestor(parent(x)) + +get_ancestor_indices(::TracedRArray, indices...) = indices +function get_ancestor_indices(x::WrappedTracedRArray, indices...) + return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) +end + function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} From fb1459a942707470b1a6fb58f6ef93291f4fc1ed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Dec 2024 07:58:22 +0530 Subject: [PATCH 05/18] fix: preserve parent array tracking for reshape --- src/Ops.jl | 8 +++--- src/Reactant.jl | 4 --- src/TracedRArray.jl | 64 +++++++++++++-------------------------------- 3 files changed, 23 insertions(+), 53 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 748d06f662..f1e7ea073b 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -253,9 +253,11 @@ function reshape( dims::Vector{Int}; location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), ) where {T,N} - restype = mlir_type(TracedRArray{T,length(dims)}, dims) - res = MLIR.IR.result(stablehlo.reshape(x.mlir_data; result_0=restype, location)) - result = TracedRArray{T,length(dims)}((), res, dims) + # HLO reshape semantics collapse the opposite way + res1 = transpose(x, [N:-1:1...]) + restype = mlir_type(TracedRArray{T,length(dims)}, collect(Base.reverse(dims))) + res = MLIR.IR.result(stablehlo.reshape(res1.mlir_data; result_0=restype, location)) + result = TracedRArray{T,length(dims)}((), res, collect(Base.reverse(dims))) # NOTE this last `transpose` is required for consistency with Julia's column-major order # do not remove, as it will be optimized away by the compiler return transpose(result, [length(dims):-1:1...]) diff --git a/src/Reactant.jl b/src/Reactant.jl index 0b73d3d96e..06fd59affe 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,10 +57,6 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end Base.collect(A::RArray) = copy(A) -function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}}) - return reshape(A, Base._reshape_uncolon(A, dims)) -end - function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray} diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 97a29e56f9..b0841dda6f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -17,6 +17,13 @@ mutable struct TracedRArray{T,N} <: RArray{T,N} end end +function TracedRArray(data::MLIR.IR.Value) + data_type = MLIR.IR.type(data) + return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( + (), data, size(data_type) + ) +end + ReactantCore.is_traced(::TracedRArray) = true new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) @@ -55,6 +62,9 @@ const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] +function materialize_traced_array(x::Base.ReshapedArray{T,N,<:TracedRArray}) where {T,N} + return Ops.reshape(parent(x), size(x)...) +end get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) @@ -63,12 +73,14 @@ function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end +function set_mlir_data!(x::Base.ReshapedArray{T,N,<:TracedRArray}, data) where {T,N} + tdata = TracedRArray(data) + parent(x).mlir_data = Ops.reshape(tdata, size(parent(x))...).mlir_data + return x +end function set_mlir_data!(x::AnyTracedRArray, data) - data_type = MLIR.IR.type(data) - data = TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) - setindex!(x, data, axes(x)...) + tdata = TracedRArray(data) + setindex!(x, tdata, axes(x)...) return x end @@ -198,46 +210,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC # return print(io, X.mlir_data, ")") end -function Base.reshape(A::AnyTracedRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} - if prod(dims) != prod(size(A)) - throw( - DimensionMismatch( - "new shape $(dims) is incompatible with array size $(size(A))" - ), - ) - end - - # HLO reshape semantics collapse the opposite way - res1 = MLIR.IR.result( - MLIR.Dialects.stablehlo.transpose( - get_mlir_data(A); - permutation=MLIR.IR.DenseArrayAttribute([Int64(N - 1 - i) for i in 0:(N - 1)]), - ), - 1, - ) - - res2 = MLIR.IR.result( - MLIR.Dialects.stablehlo.reshape( - res1; - result_0=MLIR.IR.TensorType( - [Int64(i) for i in reverse(dims)], eltype(MLIR.IR.type(res1)) - ), - ), - ) - - res3 = MLIR.IR.result( - MLIR.Dialects.stablehlo.transpose( - res2; - permutation=MLIR.IR.DenseArrayAttribute([ - Int64(NT - 1 - i) for i in 0:(NT - 1) - ]), - ), - 1, - ) - - return TracedRArray{T,NT}((), res3, dims) -end - function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N} return TracedRArray{T,N}( (), @@ -648,7 +620,7 @@ end function broadcast_to_size(arg::AnyTracedRArray, rsize) arg = materialize_traced_array(arg) - size(arg) == rsize && return arg + size(arg) == Tuple(rsize) && return arg return broadcast_to_size_internal(arg, rsize) end From 9df23d897fec1a2f854f7de4f766f33ad941e4eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Dec 2024 22:08:05 +0530 Subject: [PATCH 06/18] test: writing to a reshaped array --- test/wrapped_arrays.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f2aafbe2cd..35c20f9e62 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -100,3 +100,16 @@ end y_ra = @jit(bypass_permutedims(x_ra)) @test @allowscalar(Array(y_ra)) ≈ bypass_permutedims(x) end + +function writeto_reshaped_array!(x) + z1 = similar(x) + z2 = reshape(z1, 1, 2, 3, 1) + @. z2 = 1.0 + return z1 +end + +@testset "writeto_reshaped_array!" begin + x = ConcreteRArray(rand(3, 2)) + y = @jit writeto_reshaped_array!(x) + @test all(isone, Array(y)) +end From b1504e99146c9994fc550557cb1ccb69d765ff8e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Dec 2024 22:10:38 +0530 Subject: [PATCH 07/18] test: upsample_nearest --- test/nn/nnlib.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 5f9c92ef74..295d256ef8 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -393,3 +393,10 @@ end NNlib.∇conv_filter(x, dy, conv_dims) end end + +@testset "Upsampling" begin + x = randn(Float32, 4, 4, 3, 2) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(NNlib.upsample_nearest(x_ra, (2, 2))) ≈ NNlib.upsample_nearest(x, (2, 2)) +end From 6d46c72177d24953a034b1925db817993a408dd8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Dec 2024 22:59:24 +0530 Subject: [PATCH 08/18] fix: test failures due to wrappers --- ext/ReactantArrayInterfaceExt.jl | 4 ++-- src/Ops.jl | 4 ++-- src/TracedRArray.jl | 16 ++++++++++------ src/linear_algebra.jl | 2 +- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/ext/ReactantArrayInterfaceExt.jl b/ext/ReactantArrayInterfaceExt.jl index f1d1659b93..ffd7ddc60e 100644 --- a/ext/ReactantArrayInterfaceExt.jl +++ b/ext/ReactantArrayInterfaceExt.jl @@ -2,7 +2,7 @@ module ReactantArrayInterfaceExt using ArrayInterface: ArrayInterface using Reactant: - Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray + Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops ArrayInterface.can_setindex(::Type{<:RArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false @@ -14,7 +14,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where end function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T} - return reshape(vcat(x...), size(x)) + return Ops.reshape(vcat(x...), size(x)...) end end diff --git a/src/Ops.jl b/src/Ops.jl index f1e7ea073b..2f1c7f6eb9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -254,13 +254,13 @@ function reshape( location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), ) where {T,N} # HLO reshape semantics collapse the opposite way - res1 = transpose(x, [N:-1:1...]) + res1 = transpose(x, Int64[N:-1:1...]) restype = mlir_type(TracedRArray{T,length(dims)}, collect(Base.reverse(dims))) res = MLIR.IR.result(stablehlo.reshape(res1.mlir_data; result_0=restype, location)) result = TracedRArray{T,length(dims)}((), res, collect(Base.reverse(dims))) # NOTE this last `transpose` is required for consistency with Julia's column-major order # do not remove, as it will be optimized away by the compiler - return transpose(result, [length(dims):-1:1...]) + return transpose(result, Int64[length(dims):-1:1...]) end function get_dimension_size( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index b0841dda6f..ec918223eb 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -17,6 +17,12 @@ mutable struct TracedRArray{T,N} <: RArray{T,N} end end +const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} +const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} +const AnyTracedRVector{T} = AnyTracedRArray{T,1} +const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} +const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} + function TracedRArray(data::MLIR.IR.Value) data_type = MLIR.IR.type(data) return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( @@ -45,6 +51,10 @@ function TracedRArray{T,N}(rhs::TracedRArray{T0,N}) where {T,T0,N} end end +function TracedRArray{T,N}(rhs::WrappedTracedRArray{T0,N}) where {T0,T,N} + return TracedRArray{T,N}(materialize_traced_array(rhs)) +end + function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} attr = MLIR.IR.DenseElementsAttribute(collect(rhs)) return TracedRArray{T,N}( @@ -54,12 +64,6 @@ function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} ) end -const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} -const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} -const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} - materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] function materialize_traced_array(x::Base.ReshapedArray{T,N,<:TracedRArray}) where {T,N} diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index c7e72651d7..90dbd323ca 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -6,7 +6,7 @@ function LinearAlgebra.mul!( β::Number=false, ) where {T1,T2,T3} # TODO: The reshape operations are not getting optimized, we should directly call dot_general - rC = reshape(C, :, 1) + rC = Ops.reshape(C, length(C), 1) LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) C.mlir_data = get_mlir_data(vec(rC)) return C From f0bb779ad53d8f823769a3736a31b4d189e45311 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Dec 2024 23:19:08 +0530 Subject: [PATCH 09/18] fix: handle lazy transpose/adjoint correctly --- src/TracedRArray.jl | 40 ++++++++++++++++++++++++++++------------ src/linear_algebra.jl | 6 ++++++ test/wrapped_arrays.jl | 26 ++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index ec918223eb..38005daf04 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -69,6 +69,14 @@ materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] function materialize_traced_array(x::Base.ReshapedArray{T,N,<:TracedRArray}) where {T,N} return Ops.reshape(parent(x), size(x)...) end +function materialize_traced_array(x::LinearAlgebra.Transpose{T,<:TracedRArray{T}}) where {T} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end +function materialize_traced_array(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T}}) where {T} + return conj(transpose(parent(x))) +end get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) @@ -82,6 +90,26 @@ function set_mlir_data!(x::Base.ReshapedArray{T,N,<:TracedRArray}, data) where { parent(x).mlir_data = Ops.reshape(tdata, size(parent(x))...).mlir_data return x end +function set_mlir_data!(x::LinearAlgebra.Transpose{T,<:TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end function set_mlir_data!(x::AnyTracedRArray, data) tdata = TracedRArray(data) setindex!(x, tdata, axes(x)...) @@ -279,12 +307,6 @@ function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N} ) end -function Base.transpose(A::AnyTracedRVecOrMat) - A = ndims(A) == 1 ? reshape(A, :, 1) : A - return permutedims(A, (2, 1)) -end -Base.adjoint(A::AnyTracedRVecOrMat) = conj(transpose(A)) - promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) @@ -770,9 +792,3 @@ end Base.all(f::Function, x::TracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::TracedRArray) = mapreduce(f, |, x) - -# LinearAlgebra defines norm with some conditionals which cannot be traced directly -function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} - isinf(p) && return maximum(abs, x) - return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) -end diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 90dbd323ca..85ed86930e 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -106,3 +106,9 @@ function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) wh X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data return X end + +# LinearAlgebra defines norm with some conditionals which cannot be traced directly +function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} + isinf(p) && return maximum(abs, x) + return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) +end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 35c20f9e62..100406084f 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -113,3 +113,29 @@ end y = @jit writeto_reshaped_array!(x) @test all(isone, Array(y)) end + +function write_to_transposed_array!(x) + z1 = similar(x) + z2 = transpose(z1) + @. z2 = 1.0 + return z1 +end + +@testset "write_to_transposed_array!" begin + x = ConcreteRArray(rand(3, 2)) + y = @jit write_to_transposed_array!(x) + @test all(isone, Array(y)) +end + +function write_to_adjoint_array!(x) + z1 = similar(x) + z2 = adjoint(z1) + @. z2 = 1.0 + return z1 +end + +@testset "write_to_adjoint_array!" begin + x = ConcreteRArray(rand(3, 2)) + y = @jit write_to_adjoint_array!(x) + @test all(isone, Array(y)) +end From 813cee92d9f62193919f5e5d1eaa97713d9fcf18 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 10:51:53 +0530 Subject: [PATCH 10/18] fix: handle wrappers in NNlibExt correctly --- ext/ReactantNNlibExt.jl | 108 +++++++++++++++++----------------------- src/TracedRNumber.jl | 3 ++ 2 files changed, 50 insertions(+), 61 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index b90fa60fb2..e7559d583c 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,7 +3,15 @@ module ReactantNNlibExt using NNlib using GPUArraysCore: @allowscalar using Reactant: - Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber + Reactant, + Ops, + TracedRArray, + AnyTracedRArray, + materialize_traced_array, + MLIR, + TracedRNumber, + get_mlir_data, + set_mlir_data! using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -12,14 +20,7 @@ for (jlop, hloop) in ( (:(NNlib.sigmoid_fast), :logistic), (:(NNlib.sigmoid), :logistic), ) - @eval function $(jlop)(x::TracedRNumber{T}) where {T} - return TracedRNumber{T}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 - ), - ) - end + @eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x) end function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N} @@ -82,13 +83,6 @@ function NNlib.conv!( kernel_input_dim = N - 1 kernel_output_dim = N - output_spatial_shapes = map(input_spatial_dims) do i - K = kernel_size[i] - pl, pr = padding[2i - 1], padding[2i] - d = dilation[i] - s = stride[i] - return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1 - end output_batch_dim = input_batch_dim output_feature_dim = input_feature_dim output_spatial_dims = input_spatial_dims @@ -119,8 +113,8 @@ function NNlib.conv!( end conv = Reactant.MLIR.Dialects.stablehlo.convolution( - x.mlir_data, - weight.mlir_data; + get_mlir_data(x), + get_mlir_data(weight); result_0=result_type, window_strides=collect(stride), padding, @@ -130,7 +124,7 @@ function NNlib.conv!( feature_group_count, batch_group_count=1, ) - y.mlir_data = Reactant.MLIR.IR.result(conv) + set_mlir_data!(y, Reactant.MLIR.IR.result(conv)) return y end @@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N)) result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T)) - unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data))) + unranked = Reactant.MLIR.IR.TensorType( + (), eltype(Reactant.MLIR.IR.type(get_mlir_data(x))) + ) body = let body = Reactant.MLIR.IR.Region(), loc = Reactant.MLIR.IR.Location(), @@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} Reactant.MLIR.Dialects.stablehlo.constant(; value=attr) ) reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window( - [x.mlir_data], + [get_mlir_data(x)], [init_value]; result_0=[result_type], window_dimensions, @@ -205,10 +201,10 @@ end function NNlib.maxpool!( y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims ) where {T} - y.mlir_data = - reduce_window( - Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T) - ).mlir_data + res = reduce_window( + Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T) + ) + set_mlir_data!(y, get_mlir_data(res)) return y end @@ -216,7 +212,7 @@ function NNlib.meanpool!( y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims ) where {T} res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T)) - y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data + set_mlir_data!(y, get_mlir_data(res ./ T(prod(NNlib.kernel_size(pdims))))) return y end @@ -243,7 +239,7 @@ function NNlib.batched_mul!( B = max(size(x, 1), size(y, 1)) out_shape = (B, size(x, 2), size(y, 3)) - resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data))) + resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(get_mlir_data(res)))) if size(x, 1) != size(y, 1) if size(x, 1) == 1 @@ -264,8 +260,8 @@ function NNlib.batched_mul!( (), MLIR.IR.result( MLIR.Dialects.stablehlo.dot_general( - x.mlir_data, - y.mlir_data; + get_mlir_data(x), + get_mlir_data(y); result_0=resty, dot_dimension_numbers=dot_dimension_numbers, precision_config=prec, @@ -274,28 +270,18 @@ function NNlib.batched_mul!( ), size(resty), ) - res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data + set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1)))) return res end function NNlib.pad_constant( - x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value + x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} value = Reactant.promote_to(TracedRNumber{T}, value) - edge_padding_low = [i[1] for i in pad] - edge_padding_high = [i[2] for i in pad] - interior_padding = [0 for i in pad] - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.pad( - x.mlir_data, - value.mlir_data; - edge_padding_low, - edge_padding_high, - interior_padding, - ), - 1, - ) - return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) + low = [i[1] for i in pad] + high = [i[2] for i in pad] + interior = [0 for i in pad] + return Ops.pad(materialize_traced_array(x), value; low, high, interior) end # XXX: reevaluate this manual optimization once @@ -305,7 +291,7 @@ function NNlib.gather!( src::AnyTracedRArray{T2,2}, idxs::Union{AbstractUnitRange{<:Number}}, ) where {T1,T2} - dst.mlir_data = src[:, idxs].mlir_data + set_mlir_data!(dst, get_mlir_data(src[:, idxs])) return dst end @@ -314,8 +300,8 @@ function NNlib.gather!( ) where {T1,T2} dims = NNlib.scatter_dims(src, dst, idxs) @assert dims == 1 # scatter_dims lets us do some size checks so we call that function - idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data - slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data + idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1) + slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( @@ -331,11 +317,11 @@ function NNlib.gather!( res = MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.dynamic_gather( - src.mlir_data, idxs, slice_sizes; dimension_numbers + get_mlir_data(src), idxs, slice_sizes; dimension_numbers ), 1, ) - dst.mlir_data = res + set_mlir_data!(dst, res) return dst end @@ -354,7 +340,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) - dst.mlir_data = res.mlir_data + set_mlir_data!(dst, get_mlir_data(res)) return dst end @@ -363,7 +349,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1)) # see lax._conv_general_dilated_transpose_rhs # https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495 function NNlib.∇conv_filter!( - dw::Reactant.TracedRArray{T,N}, + dw::TracedRArray{T,N}, x::AnyTracedRArray, dy::AnyTracedRArray, cdims::NNlib.DenseConvDims, @@ -437,8 +423,8 @@ function NNlib.∇conv_filter!( result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T)) conv = MLIR.Dialects.stablehlo.convolution( - x.mlir_data, - dy.mlir_data; + get_mlir_data(x), + get_mlir_data(dy); result_0=result_type, window_strides=collect(dilation), padding, @@ -447,11 +433,12 @@ function NNlib.∇conv_filter!( feature_group_count, batch_group_count, ) - - dw.mlir_data = MLIR.IR.result(conv) + set_mlir_data!(dw, MLIR.IR.result(conv)) if !NNlib.flipkernel(cdims) - dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data + set_mlir_data!( + dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims)) + ) end return dw @@ -553,8 +540,8 @@ function NNlib.∇conv_data!( end conv = MLIR.Dialects.stablehlo.convolution( - dy.mlir_data, - w.mlir_data; + get_mlir_data(dy), + get_mlir_data(w); result_0=result_type, window_strides=1, padding, @@ -564,8 +551,7 @@ function NNlib.∇conv_data!( feature_group_count, batch_group_count=1, ) - - dx.mlir_data = MLIR.IR.result(conv) + set_mlir_data!(dx, MLIR.IR.result(conv)) return dx end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index d3a6f3b990..a87ccb3340 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -12,6 +12,9 @@ mutable struct TracedRNumber{T} <: RNumber{T} end end +get_mlir_data(x::TracedRNumber) = x.mlir_data +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data) + ReactantCore.is_traced(::TracedRNumber) = true new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) From 790e00e2ffdaefd03bcef13470cbb985aeed3b2a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 11:49:22 +0530 Subject: [PATCH 11/18] fix: more reshaped wrappers handling --- ext/ReactantNNlibExt.jl | 45 +++++++++++++++------------------- src/TracedRArray.jl | 24 +++++++++++++++---- src/linear_algebra.jl | 53 ++++++++++------------------------------- 3 files changed, 50 insertions(+), 72 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index e7559d583c..8bfa5de02a 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -216,9 +216,9 @@ function NNlib.meanpool!( return y end -NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3)) +NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = PermutedDimsArray(x, (2, 1, 3)) function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T} - y = permutedims(x, (2, 1, 3)) + y = NNlib.batched_transpose(x) conj!(y) return y end @@ -234,14 +234,21 @@ function NNlib.batched_mul!( ), ) end + + if size(x, 3) != size(y, 3) + B = max(size(x, 3), size(y, 3)) + if size(x, 3) == 1 + x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + elseif size(y, 3) == 1 + y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + end + end + x = permutedims(x, (3, 1, 2)) y = permutedims(y, (3, 1, 2)) - B = max(size(x, 1), size(y, 1)) - out_shape = (B, size(x, 2), size(y, 3)) - resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(get_mlir_data(res)))) - if size(x, 1) != size(y, 1) + B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 @@ -249,28 +256,14 @@ function NNlib.batched_mul!( end end - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1] - ) - - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - tmp = TracedRArray{T1,3}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - get_mlir_data(x), - get_mlir_data(y); - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=prec, - ), - 1, - ), - size(resty), + tmp = Ops.dot_general( + T1.(materialize_traced_array(x)), + T1.(materialize_traced_array(y)); + contracting_dimensions=([3], [2]), + batching_dimensions=([1], [1]), ) set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1)))) + return res end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 38005daf04..f91ba0c77d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -66,8 +66,10 @@ end materialize_traced_array(x::TracedRArray) = x materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] -function materialize_traced_array(x::Base.ReshapedArray{T,N,<:TracedRArray}) where {T,N} - return Ops.reshape(parent(x), size(x)...) +function materialize_traced_array( + x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} +) where {T,N} + return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) end function materialize_traced_array(x::LinearAlgebra.Transpose{T,<:TracedRArray{T}}) where {T} px = parent(x) @@ -77,6 +79,11 @@ end function materialize_traced_array(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T}}) where {T} return conj(transpose(parent(x))) end +function materialize_traced_array( + x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} +) where {T,N,perm,iperm} + return permutedims(parent(x), perm) +end get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) @@ -85,9 +92,9 @@ function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data return x end -function set_mlir_data!(x::Base.ReshapedArray{T,N,<:TracedRArray}, data) where {T,N} - tdata = TracedRArray(data) - parent(x).mlir_data = Ops.reshape(tdata, size(parent(x))...).mlir_data +function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} + res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + set_mlir_data!(parent(x), res_mlir_data) return x end function set_mlir_data!(x::LinearAlgebra.Transpose{T,<:TracedRArray{T,N}}, data) where {T,N} @@ -110,6 +117,13 @@ function set_mlir_data!(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T,N}}, data) w px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data return x end +function set_mlir_data!( + x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}}, data +) where {T,N,perm,iperm} + tdata = TracedRArray(data) + parent(x).mlir_data = permutedims(tdata, iperm) + return x +end function set_mlir_data!(x::AnyTracedRArray, data) tdata = TracedRArray(data) setindex!(x, tdata, axes(x)...) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 85ed86930e..7ef16279b6 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -40,50 +40,21 @@ function LinearAlgebra.mul!( if size(A, 2) != size(B, 1) throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))")) end - resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1)) - dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet( - MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0] - ) - prec = MLIR.IR.Attribute( - MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT") - ) - precar = MLIR.IR.Attribute([prec, prec]) - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.dot_general( - get_mlir_data(A), - get_mlir_data(B); - result_0=resty, - dot_dimension_numbers=dot_dimension_numbers, - precision_config=precar, - ), - 1, + + tmp = Ops.dot_general( + T1.(materialize_traced_array(A)), + T1.(materialize_traced_array(B)); + contracting_dimensions=([2], [1]), ) - if iszero(β) - if isone(α) - C.mlir_data = res - else - C.mlir_data = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - end + + res = if iszero(β) + isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T1(α), size(C))) else - α_res = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - res, broadcast_to_size(T1(α), size(C)).mlir_data - ), - 1, - ) - β_C = MLIR.IR.result( - MLIR.Dialects.stablehlo.multiply( - C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data - ), - 1, - ) - C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1) + α_res = Ops.multiply(tmp, broadcast_to_size(T1(α), size(C))) + β_C = Ops.multiply(C, broadcast_to_size(T1(β), size(C))) + Ops.add(α_res, β_C) end + set_mlir_data!(C, get_mlir_data(res)) return C end From 3e3d234ad4df60f55a812a327f19e7f0781fc2eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 12:21:47 +0530 Subject: [PATCH 12/18] fix: dispatches to avoid ambiguity --- src/TracedRArray.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index f91ba0c77d..e1c9e893c0 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -71,13 +71,15 @@ function materialize_traced_array( ) where {T,N} return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) end -function materialize_traced_array(x::LinearAlgebra.Transpose{T,<:TracedRArray{T}}) where {T} +function materialize_traced_array( + x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} +) where {T,N} px = parent(x) A = ndims(px) == 1 ? reshape(px, :, 1) : px return permutedims(A, (2, 1)) end -function materialize_traced_array(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T}}) where {T} - return conj(transpose(parent(x))) +function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) end function materialize_traced_array( x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} @@ -97,7 +99,7 @@ function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) set_mlir_data!(parent(x), res_mlir_data) return x end -function set_mlir_data!(x::LinearAlgebra.Transpose{T,<:TracedRArray{T,N}}, data) where {T,N} +function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} tdata = TracedRArray(data) px = parent(x) px.mlir_data = ( @@ -109,7 +111,7 @@ function set_mlir_data!(x::LinearAlgebra.Transpose{T,<:TracedRArray{T,N}}, data) ).mlir_data return x end -function set_mlir_data!(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T,N}}, data) where {T,N} +function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} tdata = TracedRArray(data) px = parent(x) transposed_data = @@ -118,7 +120,7 @@ function set_mlir_data!(x::LinearAlgebra.Adjoint{T,<:TracedRArray{T,N}}, data) w return x end function set_mlir_data!( - x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}}, data + x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} tdata = TracedRArray(data) parent(x).mlir_data = permutedims(tdata, iperm) @@ -562,10 +564,10 @@ function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N} end function Base.similar( - bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims + ::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims ) where {T<:ReactantPrimitive,N} @assert N isa Int - return TracedRArray{T,N}((), nothing, map(length, dims)) + return TracedRArray{T,length(dims)}((), nothing, map(length, dims)) end function Base.similar( From 284248d6d2889ebab6a892627692f022235e595e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 13:46:41 +0530 Subject: [PATCH 13/18] fix: handle diagonal wrapper gracefully --- src/TracedRArray.jl | 17 ++++-- src/linear_algebra.jl | 93 ++++++++++++++++++++++++------ test/integration/linear_algebra.jl | 29 ++++++++++ test/wrapped_arrays.jl | 44 +++++++++++++- 4 files changed, 160 insertions(+), 23 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e1c9e893c0..0e9bf6f779 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -20,7 +20,9 @@ end const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = AnyTracedRArray{T,2} +const AnyTracedRMatrix{T} = Union{ + AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} +} const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} function TracedRArray(data::MLIR.IR.Value) @@ -86,6 +88,9 @@ function materialize_traced_array( ) where {T,N,perm,iperm} return permutedims(parent(x), perm) end +function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} + return LinearAlgebra.diagm(parent(x)) +end get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) @@ -122,13 +127,15 @@ end function set_mlir_data!( x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data ) where {T,N,perm,iperm} - tdata = TracedRArray(data) - parent(x).mlir_data = permutedims(tdata, iperm) + parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} + parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data return x end function set_mlir_data!(x::AnyTracedRArray, data) - tdata = TracedRArray(data) - setindex!(x, tdata, axes(x)...) + setindex!(x, TracedRArray(data), axes(x)...) return x end diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 7ef16279b6..7451952173 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,10 +1,10 @@ function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,1}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), + @nospecialize(C::TracedRArray{T,1}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} # TODO: The reshape operations are not getting optimized, we should directly call dot_general rC = Ops.reshape(C, length(C), 1) LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β) @@ -13,23 +13,23 @@ function LinearAlgebra.mul!( end function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,1}), + @nospecialize(C::TracedRArray{T,2}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRVector), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β) return C end function LinearAlgebra.mul!( - @nospecialize(C::TracedRArray{T1,2}), - @nospecialize(A::AnyTracedRArray{T2,2}), - @nospecialize(B::AnyTracedRArray{T3,2}), + @nospecialize(C::TracedRArray{T,2}), + @nospecialize(A::AnyTracedRMatrix), + @nospecialize(B::AnyTracedRMatrix), α::Number=true, β::Number=false, -) where {T1,T2,T3} +) where {T} if size(C) != (size(A, 1), size(B, 2)) throw( DimensionMismatch( @@ -42,16 +42,16 @@ function LinearAlgebra.mul!( end tmp = Ops.dot_general( - T1.(materialize_traced_array(A)), - T1.(materialize_traced_array(B)); + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)); contracting_dimensions=([2], [1]), ) res = if iszero(β) - isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T1(α), size(C))) + isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) else - α_res = Ops.multiply(tmp, broadcast_to_size(T1(α), size(C))) - β_C = Ops.multiply(C, broadcast_to_size(T1(β), size(C))) + α_res = Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) + β_C = Ops.multiply(C, broadcast_to_size(T(β), size(C))) Ops.add(α_res, β_C) end set_mlir_data!(C, get_mlir_data(res)) @@ -83,3 +83,62 @@ function LinearAlgebra.norm(x::TracedRArray{T,N}, p::Real=2) where {T,N} isinf(p) && return maximum(abs, x) return mapreduce(Base.Fix2(^, p), +, x)^(1 / p) end + +function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} + y = materialize_traced_array(x) + + rows, cols = size(y) + (start_row, start_col) = k ≥ 0 ? (0, k) : (-k, 0) + diag_length = min(rows - start_row, cols - start_col) + + indices = stack(( + start_row:(start_row + diag_length - 1), start_col:(start_col + diag_length - 1) + )) + + # XXX: creating an empty array causes + # terminate called after throwing an instance of 'xla::XlaRuntimeError' + # what(): UNKNOWN: :0: error: 'tensor.empty' op unsupported op for export to XLA + # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> + length(indices) ≤ 0 && return promote_to(TracedRArray{T,1}, T[]) + + idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,2}, indices)) + + #! format: off + dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( + MLIR.IR.context(), + Int64(0), Int64[], + Int64(2), Int64[0, 1], + Int64(0), Int64[], + Int64(0), Int64[], + Int64(2), Int64[0, 1], + Int64(1) + ) + #! format: on + + slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [1, 1])) + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.dynamic_gather( + get_mlir_data(y), idxs, slice_sizes; dimension_numbers + ), + 1, + ) + return TracedRArray{T,1}((), res, (diag_length,)) +end + +function LinearAlgebra.diagm(v::AnyTracedRArray{T,1}) where {T} + return LinearAlgebra.diagm(length(v), length(v), v) +end +function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) where {T} + m, n = LinearAlgebra.diagm_size((m, n), 0 => v) # size check + + v = materialize_traced_array(v) + D = length(v) + row_idxs = Ops.iota(Int, [D, D]; iota_dimension=1) + col_idxs = Ops.iota(Int, [D, D]; iota_dimension=2) + diag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="EQ") + + mat = (v .+ zero(v)') .* diag_indicator + return Ops.pad( + mat, promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] + ) +end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 22fe07c1f6..0c6efc5fdb 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -113,3 +113,32 @@ end @jit(tril!(A_ra, -1)) @test A_ra ≈ tril(A, -1) end + +@testset "diag / diagm" begin + x = rand(2, 4) + x_ra = Reactant.to_rarray(x) + + @testset for k in (-size(x, 1) + 1):(size(x, 1) - 1) + @test @jit(diag(x_ra, k)) ≈ diag(x, k) + end + + x = rand(4) + x_ra = Reactant.to_rarray(x) + + @test @jit(diagm(x_ra)) ≈ diagm(x) + @test @jit(diagm(5, 4, x_ra)) ≈ diagm(5, 4, x) + @test @jit(diagm(4, 5, x_ra)) ≈ diagm(4, 5, x) + @test @jit(diagm(6, 6, x_ra)) ≈ diagm(6, 6, x) + @test_throws DimensionMismatch @jit(diagm(3, 3, x_ra)) +end + +# TODO: Currently Diagonal(x) * x goes down the generic matmul path but it should clearly be +# optimized +mul_diagonal(x) = Diagonal(x) * x + +@testset "mul_diagonal" begin + x = rand(4) + x_ra = Reactant.to_rarray(x) + + @test @jit(mul_diagonal(x_ra)) ≈ mul_diagonal(x) +end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 100406084f..f8726f0337 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -1,4 +1,4 @@ -using Reactant, Test, Statistics, NNlib +using Reactant, Test, Statistics, NNlib, LinearAlgebra function view_getindex_1(x) x = view(x, 2:3, 1:2, :) @@ -139,3 +139,45 @@ end y = @jit write_to_adjoint_array!(x) @test all(isone, Array(y)) end + +add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1)) + +@testset "add_perm_dims" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(add_perm_dims(x_ra)) ≈ add_perm_dims(x) +end + +function write_to_permuted_dims_array!(x) + z1 = similar(x) + z2 = PermutedDimsArray(z1, (2, 1)) + @. z2 = 1.0 + return z1 +end + +@testset "write_to_permuted_dims_array!" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(write_to_permuted_dims_array!(x_ra)) ≈ write_to_permuted_dims_array!(x) +end + +function write_to_diagonal_array!(x) + z = Diagonal(x) + @. z = 1.0 + return z +end + +@testset "write_to_diagonal_array!" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = copy(x_ra) + + y = @jit(write_to_diagonal_array!(x_ra)) + y_res = @allowscalar Array(y) + @test x_ra ≈ y_ra + @test all(isone, diag(y_res)) + y_res[diagind(y_res)] .= 0 + @test all(iszero, y_res) +end From 2f22d40c75970766723eb3721ac82d00818c2ae5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 15:11:41 +0530 Subject: [PATCH 14/18] fix: compile wrapped concrete array conversion to arrays --- src/ConcreteRArray.jl | 30 +++++++++++++----------------- test/wrapped_arrays.jl | 4 ++-- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 42c454b31a..0135f663ae 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -8,6 +8,9 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N} shape::NTuple{N,Int} end +const WrappedConcreteRArray{T,N} = WrappedArray{T,N,ConcreteRArray,ConcreteRArray{T,N}} +const AnyConcreteRArray{T,N} = Union{ConcreteRArray{T,N},WrappedConcreteRArray{T,N}} + mutable struct ConcreteRNumber{T} <: RNumber{T} data::XLA.AsyncBuffer end @@ -76,19 +79,6 @@ end Base.size(x::ConcreteRArray) = x.shape -function Base.reshape(A::ConcreteRArray{T,N}, dims::NTuple{NT,Int}) where {T,N,NT} - prod(dims) == prod(size(A)) || Base._throw_dmrsa(dims, prod(size(A))) - host = convert(Array{T,N}, A) - # HLO reshape semantics collapse the opposite so enforce on Julia Side - # until we later make the transpose/reshape/transpose - host = reshape(host, dims) - client = XLA.client(A.data) - device = XLA.device(A.data) - buffer = XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, host, device), nothing) - return ConcreteRArray{T,NT}(buffer, dims) - # ConcreteRArray{T, dims, NT}(XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, XLA.to_row_major(host), device), nothing)) -end - function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,ElType,N} data = Array{ElType,N}(undef, size(X)...) # TODO replace for `similar`? XLA.await(X.data) @@ -99,7 +89,13 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El return data # XLA.from_row_major(data) end -Base.Array(x::ConcreteRArray) = convert(Array, x) +function Base.convert( + ::Type{T}, X::WrappedConcreteRArray{ElType,N} +) where {T<:Array,ElType,N} + fn = compile(materialize_traced_array, (X,)) + return convert(Array, fn(X)) +end +Base.Array(x::AnyConcreteRArray) = convert(Array, x) function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) XLA.synced_buffer(x.data) @@ -264,7 +260,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N end GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})") - fn = Reactant.compile(mysetindex!, (a, v, args...)) + fn = compile(mysetindex!, (a, v, args...)) fn(a, v, args...) return a end @@ -307,7 +303,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteR return ConcreteRArray(aux) end - fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) + fn = compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,)) return fn(bc.args...) end @@ -323,7 +319,7 @@ function Base.mapreduce( dims=:, init=nothing, ) where {T,N} - fn = Reactant.compile(CallMapReduce(f, op, dims, init), (A,)) + fn = compile(CallMapReduce(f, op, dims, init), (A,)) return fn(A) end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index f8726f0337..82390fc972 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -21,8 +21,8 @@ end x_ra = Reactant.to_rarray(x) @test @allowscalar(@jit(view_getindex_1(x_ra))) ≈ view_getindex_1(x) - @test @jit(view_getindex_2(x_ra)) ≈ view_getindex_2(x) - @test @jit(view_getindex_3(x_ra)) ≈ view_getindex_3(x) + @test Array(@jit(view_getindex_2(x_ra))) ≈ view_getindex_2(x) + @test Array(@jit(view_getindex_3(x_ra))) ≈ view_getindex_3(x) end function reshape_wrapper(x) From f47b123764f391d620dd7af9d1bbe5862497706a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 15:46:01 +0530 Subject: [PATCH 15/18] feat: more wrapped ConcreteRArray handling --- src/ConcreteRArray.jl | 21 +++++++++++++++------ test/ops.jl | 4 ++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 0135f663ae..ceb0844026 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -77,6 +77,13 @@ function ConcreteRArray( ) end +ConcreteRArray(x::AnyConcreteRArray) = ConcreteRArray{eltype(x),ndims(x)}(x) +ConcreteRArray{T}(x::AnyConcreteRArray) where {T} = ConcreteRArray{T,ndims(x)}(x) +ConcreteRArray{T,N}(x::ConcreteRArray{T,N}) where {T,N} = x +function ConcreteRArray{T,N}(x::AnyConcreteRArray) where {T,N} + return ConcreteRArray(convert(Array{T,N}, x)) +end + Base.size(x::ConcreteRArray) = x.shape function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,ElType,N} @@ -161,19 +168,21 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) end end -function Base.isapprox(x::ConcreteRArray, y::AbstractArray; kwargs...) +function Base.isapprox(x::AnyConcreteRArray, y::AbstractArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -function Base.isapprox(x::AbstractArray, y::ConcreteRArray; kwargs...) +function Base.isapprox(x::AbstractArray, y::AnyConcreteRArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -function Base.isapprox(x::ConcreteRArray, y::ConcreteRArray; kwargs...) +function Base.isapprox(x::AnyConcreteRArray, y::AnyConcreteRArray; kwargs...) return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...) end -Base.:(==)(x::ConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::AbstractArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) -Base.:(==)(x::ConcreteRArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::AnyConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y) +Base.:(==)(x::AbstractArray, y::AnyConcreteRArray) = convert(Array, x) == convert(Array, y) +function Base.:(==)(x::AnyConcreteRArray, y::AnyConcreteRArray) + return convert(Array, x) == convert(Array, y) +end function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} if X.data == XLA.AsyncEmptyBuffer diff --git a/test/ops.jl b/test/ops.jl index 0600d0b861..57d9d5d55b 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -242,8 +242,8 @@ end @test a .* b ≈ @jit f1(a, b) @test reshape(kron(Array(b), Array(a)), 4, 4) ≈ @jit f2(a, b) - x = reshape(a, (2, 2)) - y = reshape(b, (2, 2)) + x = ConcreteRArray(reshape(a, (2, 2))) + y = ConcreteRArray(reshape(b, (2, 2))) @test x .* y ≈ @jit f3(x, y) @test Array(x) * Array(y) ≈ @jit f4(x, y) end From 7c715d47caabc1219254e270669e31af9ddf6dcc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 9 Dec 2024 22:08:44 +0530 Subject: [PATCH 16/18] chore: apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/TracedRNumber.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index a87ccb3340..ebe733ce62 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -13,7 +13,7 @@ mutable struct TracedRNumber{T} <: RNumber{T} end get_mlir_data(x::TracedRNumber) = x.mlir_data -set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data) +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) ReactantCore.is_traced(::TracedRNumber) = true From 2e891ac654c5b15b23844ec4de204f1843356541 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 08:15:56 +0530 Subject: [PATCH 17/18] refactor: rearrange the tests --- test/wrapped_arrays.jl | 81 +++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 45 deletions(-) diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index 82390fc972..f5418e5c80 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -94,11 +94,18 @@ function bypass_permutedims(x) return view(x, 2:3, 1:2, :) end +add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1)) + @testset "PermutedDimsArray" begin x = rand(4, 4, 3) x_ra = Reactant.to_rarray(x) y_ra = @jit(bypass_permutedims(x_ra)) @test @allowscalar(Array(y_ra)) ≈ bypass_permutedims(x) + + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + + @test @jit(add_perm_dims(x_ra)) ≈ add_perm_dims(x) end function writeto_reshaped_array!(x) @@ -108,12 +115,6 @@ function writeto_reshaped_array!(x) return z1 end -@testset "writeto_reshaped_array!" begin - x = ConcreteRArray(rand(3, 2)) - y = @jit writeto_reshaped_array!(x) - @test all(isone, Array(y)) -end - function write_to_transposed_array!(x) z1 = similar(x) z2 = transpose(z1) @@ -121,12 +122,6 @@ function write_to_transposed_array!(x) return z1 end -@testset "write_to_transposed_array!" begin - x = ConcreteRArray(rand(3, 2)) - y = @jit write_to_transposed_array!(x) - @test all(isone, Array(y)) -end - function write_to_adjoint_array!(x) z1 = similar(x) z2 = adjoint(z1) @@ -134,21 +129,6 @@ function write_to_adjoint_array!(x) return z1 end -@testset "write_to_adjoint_array!" begin - x = ConcreteRArray(rand(3, 2)) - y = @jit write_to_adjoint_array!(x) - @test all(isone, Array(y)) -end - -add_perm_dims(x) = x .+ PermutedDimsArray(x, (2, 1)) - -@testset "add_perm_dims" begin - x = rand(4, 4) - x_ra = Reactant.to_rarray(x) - - @test @jit(add_perm_dims(x_ra)) ≈ add_perm_dims(x) -end - function write_to_permuted_dims_array!(x) z1 = similar(x) z2 = PermutedDimsArray(z1, (2, 1)) @@ -156,28 +136,39 @@ function write_to_permuted_dims_array!(x) return z1 end -@testset "write_to_permuted_dims_array!" begin - x = rand(4, 4) - x_ra = Reactant.to_rarray(x) - - @test @jit(write_to_permuted_dims_array!(x_ra)) ≈ write_to_permuted_dims_array!(x) -end - function write_to_diagonal_array!(x) z = Diagonal(x) @. z = 1.0 return z end -@testset "write_to_diagonal_array!" begin - x = rand(4, 4) - x_ra = Reactant.to_rarray(x) - y_ra = copy(x_ra) - - y = @jit(write_to_diagonal_array!(x_ra)) - y_res = @allowscalar Array(y) - @test x_ra ≈ y_ra - @test all(isone, diag(y_res)) - y_res[diagind(y_res)] .= 0 - @test all(iszero, y_res) +@testset "Preserve Aliasing with Parent" begin + @testset "$(aType)" for (aType, fn) in [ + ("ReshapedArray", writeto_reshaped_array!), + ("Transpose", write_to_transposed_array!), + ("Adjoint", write_to_adjoint_array!), + ] + x = ConcreteRArray(rand(3, 2)) + y = @jit fn(x) + @test all(isone, Array(y)) + end + + @testset "PermutedDimsArray" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + @test @jit(write_to_permuted_dims_array!(x_ra)) ≈ write_to_permuted_dims_array!(x) + end + + @testset "Diagonal" begin + x = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = copy(x_ra) + + y = @jit(write_to_diagonal_array!(x_ra)) + y_res = @allowscalar Array(y) + @test x_ra ≈ y_ra + @test all(isone, diag(y_res)) + y_res[diagind(y_res)] .= 0 + @test all(iszero, y_res) + end end From 107040d87a0d044c19cd2684aebffe7d1f941b68 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 11 Dec 2024 08:33:15 +0530 Subject: [PATCH 18/18] test: add test that fails on incorrect reshape dims ordering --- test/ops.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/ops.jl b/test/ops.jl index 57d9d5d55b..07f911e88b 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -521,6 +521,9 @@ end @testset "reshape" begin x = ConcreteRArray([1, 2, 3, 4]) @test reshape(Array(x), 2, 2) == @jit Ops.reshape(x, 2, 2) + + x = ConcreteRArray(collect(reshape(1:12, 2, 2, 3))) + @test reshape(Array(x), 3, 1, 4) == @jit Ops.reshape(x, 3, 1, 4) end @testset "reverse" begin