diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 07c8a8e83c..90b1cd2ccc 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -109,7 +109,7 @@ function NNlib.conv!( #! format: on padding = Reactant.MLIR.IR.DenseElementsAttribute( - reshape(collect(padding), (num_spatial_dims, 2)) + reshape(collect(padding), (2, num_spatial_dims))' ) result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T)) @@ -163,7 +163,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N} end padding = Reactant.MLIR.IR.DenseElementsAttribute( - reshape([padding..., 0, 0, 0, 0], (N, 2)) + reshape([padding..., 0, 0, 0, 0], (2, N))' ) output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N)) @@ -306,7 +306,7 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2) len = size(x, dims) # directly generating booleans were causing an incorrect constant attribute generation # but the optimized IR removes the type case so we are probably ok - mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))')) + mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len))))) return Reactant.promote_to( TracedRArray{Bool,2}, TracedRArray{Int,2}( diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index 69bd9b99a1..d37e7c9862 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -492,6 +492,10 @@ function Base.fill(::Core.Type{Attribute}, value, shape) return Base.fill(value, shaped_type) end +to_row_major(x) = permutedims(x, ndims(x):-1:1) +to_row_major(x::AbstractVector) = x +to_row_major(x::AbstractArray{T,0}) where {T} = x + """ DenseElementsAttribute(array::AbstractArray) @@ -501,66 +505,86 @@ function DenseElementsAttribute(values::AbstractArray{Bool}) shaped_type = TensorType(size(values), Type(Bool)) return Attribute( API.mlirDenseElementsAttrBoolGet( - shaped_type, length(values), AbstractArray{Cint}(values) + shaped_type, length(values), AbstractArray{Cint}(to_row_major(values)) ), ) end function DenseElementsAttribute(values::AbstractArray{UInt8}) shaped_type = TensorType(size(values), Type(UInt8)) - return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Int8}) shaped_type = TensorType(size(values), Type(Int8)) - return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt16}) shaped_type = TensorType(size(values), Type(UInt16)) return Attribute( - API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), values) + API.mlirDenseElementsAttrUInt16Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int16}) shaped_type = TensorType(size(values), Type(Int16)) - return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt32}) shaped_type = TensorType(size(values), Type(UInt32)) return Attribute( - API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), values) + API.mlirDenseElementsAttrUInt32Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int32}) shaped_type = TensorType(size(values), Type(Int32)) - return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{UInt64}) shaped_type = TensorType(size(values), Type(UInt64)) return Attribute( - API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), values) + API.mlirDenseElementsAttrUInt64Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int64}) shaped_type = TensorType(size(values), Type(Int64)) - return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Float32}) shaped_type = TensorType(size(values), Type(Float32)) - return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), values)) + return Attribute( + API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), to_row_major(values)) + ) end function DenseElementsAttribute(values::AbstractArray{Float64}) shaped_type = TensorType(size(values), Type(Float64)) return Attribute( - API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), values) + API.mlirDenseElementsAttrDoubleGet( + shaped_type, length(values), to_row_major(values) + ), ) end @@ -569,16 +593,17 @@ end function DenseElementsAttribute(values::AbstractArray{Float16}) shaped_type = TensorType(size(values), Type(Float16)) return Attribute( - API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), values) + API.mlirDenseElementsAttrFloat16Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{<:Complex}) shaped_type = TensorType(size(values), Type(eltype(values))) - # TODO: row major return Attribute( API.mlirDenseElementsAttrRawBufferGet( - shaped_type, length(values) * Base.elsize(values), values + shaped_type, length(values) * Base.elsize(values), to_row_major(values) ), ) end @@ -592,7 +617,9 @@ function DenseElementsAttribute(values::AbstractArray{String}) # TODO may fail because `Type(String)` is not defined shaped_type = TensorType(size(values), Type(String)) return Attribute( - API.mlirDenseElementsAttrStringGet(shaped_type, length(values), values) + API.mlirDenseElementsAttrStringGet( + shaped_type, length(values), to_row_major(values) + ), ) end @@ -663,25 +690,29 @@ function DenseArrayAttribute end @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Bool}; context::Context=context() -) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), values)) +) = Attribute( + API.mlirDenseBoolArrayGet( + context, length(values), AbstractArray{Cint}(to_row_major(values)) + ), +) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int8}; context::Context=context() -) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int16}; context::Context=context() -) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int32}; context::Context=context() -) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Int64}; context::Context=context() -) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Float32}; context::Context=context() -) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" DenseArrayAttribute( values::AbstractArray{Float64}; context::Context=context() -) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), values)) +) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values))) @llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values) diff --git a/test/basic.jl b/test/basic.jl index d188c276e9..3796b7275c 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -625,3 +625,18 @@ end @test y == x end end + +function f_row_major(x) + y = [1 2; 3 4; 5 6] + if x isa Reactant.TracedRArray + y = Reactant.promote_to(Reactant.TracedRArray{eltype(x),2}, y) + end + return x .+ y +end + +@testset "array attributes: row major" begin + x = zeros(Int, 3, 2) + x_ra = Reactant.to_rarray(x) + + @test @jit(f_row_major(x_ra)) ≈ f_row_major(x) +end