Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function NNlib.conv!(
#! format: on

padding = Reactant.MLIR.IR.DenseElementsAttribute(
reshape(collect(padding), (num_spatial_dims, 2))
reshape(collect(padding), (2, num_spatial_dims))'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't call '/adjoint because it will conjugate complex matrices

Suggested change
reshape(collect(padding), (2, num_spatial_dims))'
transpose(reshape(collect(padding), (2, num_spatial_dims)))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know about that. As you said, it does not apply here but I will be cautious in the future 👍

)
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))

Expand Down Expand Up @@ -163,7 +163,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
end

padding = Reactant.MLIR.IR.DenseElementsAttribute(
reshape([padding..., 0, 0, 0, 0], (N, 2))
reshape([padding..., 0, 0, 0, 0], (2, N))'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reshape([padding..., 0, 0, 0, 0], (2, N))'
transpose(reshape([padding..., 0, 0, 0, 0], (2, N)))

)

output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
Expand Down Expand Up @@ -306,7 +306,7 @@ function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
len = size(x, dims)
# directly generating booleans were causing an incorrect constant attribute generation
# but the optimized IR removes the type case so we are probably ok
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))'))
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))))
return Reactant.promote_to(
TracedRArray{Bool,2},
TracedRArray{Int,2}(
Expand Down
75 changes: 53 additions & 22 deletions src/mlir/IR/Attribute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ function Base.fill(::Core.Type{Attribute}, value, shape)
return Base.fill(value, shaped_type)
end

to_row_major(x) = permutedims(x, ndims(x):-1:1)
to_row_major(x::AbstractVector) = x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the error logs it looks like this also needs a 0-dim specialization

to_row_major(x::AbstractArray{T,0}) where {T} = x

"""
DenseElementsAttribute(array::AbstractArray)

Expand All @@ -501,66 +505,86 @@ function DenseElementsAttribute(values::AbstractArray{Bool})
shaped_type = TensorType(size(values), Type(Bool))
return Attribute(
API.mlirDenseElementsAttrBoolGet(
shaped_type, length(values), AbstractArray{Cint}(values)
shaped_type, length(values), AbstractArray{Cint}(to_row_major(values))
),
)
end

function DenseElementsAttribute(values::AbstractArray{UInt8})
shaped_type = TensorType(size(values), Type(UInt8))
return Attribute(API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{Int8})
shaped_type = TensorType(size(values), Type(Int8))
return Attribute(API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{UInt16})
shaped_type = TensorType(size(values), Type(UInt16))
return Attribute(
API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), values)
API.mlirDenseElementsAttrUInt16Get(
shaped_type, length(values), to_row_major(values)
),
)
end

function DenseElementsAttribute(values::AbstractArray{Int16})
shaped_type = TensorType(size(values), Type(Int16))
return Attribute(API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{UInt32})
shaped_type = TensorType(size(values), Type(UInt32))
return Attribute(
API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), values)
API.mlirDenseElementsAttrUInt32Get(
shaped_type, length(values), to_row_major(values)
),
)
end

function DenseElementsAttribute(values::AbstractArray{Int32})
shaped_type = TensorType(size(values), Type(Int32))
return Attribute(API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{UInt64})
shaped_type = TensorType(size(values), Type(UInt64))
return Attribute(
API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), values)
API.mlirDenseElementsAttrUInt64Get(
shaped_type, length(values), to_row_major(values)
),
)
end

function DenseElementsAttribute(values::AbstractArray{Int64})
shaped_type = TensorType(size(values), Type(Int64))
return Attribute(API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{Float32})
shaped_type = TensorType(size(values), Type(Float32))
return Attribute(API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), values))
return Attribute(
API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), to_row_major(values))
)
end

function DenseElementsAttribute(values::AbstractArray{Float64})
shaped_type = TensorType(size(values), Type(Float64))
return Attribute(
API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), values)
API.mlirDenseElementsAttrDoubleGet(
shaped_type, length(values), to_row_major(values)
),
)
end

Expand All @@ -569,16 +593,17 @@ end
function DenseElementsAttribute(values::AbstractArray{Float16})
shaped_type = TensorType(size(values), Type(Float16))
return Attribute(
API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), values)
API.mlirDenseElementsAttrFloat16Get(
shaped_type, length(values), to_row_major(values)
),
)
end

function DenseElementsAttribute(values::AbstractArray{<:Complex})
shaped_type = TensorType(size(values), Type(eltype(values)))
# TODO: row major
return Attribute(
API.mlirDenseElementsAttrRawBufferGet(
shaped_type, length(values) * Base.elsize(values), values
shaped_type, length(values) * Base.elsize(values), to_row_major(values)
),
)
end
Expand All @@ -592,7 +617,9 @@ function DenseElementsAttribute(values::AbstractArray{String})
# TODO may fail because `Type(String)` is not defined
shaped_type = TensorType(size(values), Type(String))
return Attribute(
API.mlirDenseElementsAttrStringGet(shaped_type, length(values), values)
API.mlirDenseElementsAttrStringGet(
shaped_type, length(values), to_row_major(values)
),
)
end

Expand Down Expand Up @@ -663,25 +690,29 @@ function DenseArrayAttribute end

@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Bool}; context::Context=context()
) = Attribute(API.mlirDenseBoolArrayGet(context, length(values), values))
) = Attribute(
API.mlirDenseBoolArrayGet(
context, length(values), AbstractArray{Cint}(to_row_major(values))
),
)
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int8}; context::Context=context()
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values)))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int16}; context::Context=context()
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values)))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int32}; context::Context=context()
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values)))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Int64}; context::Context=context()
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values)))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Float32}; context::Context=context()
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values)))
@llvmversioned min = v"16" DenseArrayAttribute(
values::AbstractArray{Float64}; context::Context=context()
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), values))
) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values)))

@llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values)

Expand Down
15 changes: 15 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,18 @@ end
@test y == x
end
end

function f_row_major(x)
y = [1 2; 3 4; 5 6]
if x isa Reactant.TracedRArray
y = Reactant.promote_to(Reactant.TracedRArray{eltype(x),2}, y)
end
return x .+ y
end

@testset "array attributes: row major" begin
x = zeros(Int, 3, 2)
x_ra = Reactant.to_rarray(x)

@test @jit(f_row_major(x_ra)) ≈ f_row_major(x)
end
Loading