Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T}
return ConcreteRNumber(Base.rtoldefault(T))
end

Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...)

# Ensure the device and client are the same as the input
function Base.float(x::ConcreteRNumber{T}) where {T}
client = XLA.client(x.data)
Expand Down
69 changes: 49 additions & 20 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -956,19 +956,26 @@ function broadcast_in_dim(
end

@noinline function sort(
x::TracedRArray{T,N};
xs::TracedRArray...;
comparator,
dimension=1,
is_stable=false,
location=mlir_stacktrace("sort", @__FILE__, @__LINE__),
) where {T,N}
)
#C4:
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
for x in xs
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
end

(a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0)))
sample_inputs = Vector{Reactant.ConcreteRNumber}(undef, length(xs) * 2)
for i in eachindex(xs)
T = Reactant.unwrapped_eltype(xs[i])
sample_inputs[2i - 1] = Reactant.ConcreteRNumber(T(0))
sample_inputs[2i] = Reactant.ConcreteRNumber(T(0))
end
func = Reactant.TracedUtils.make_mlir_fn(
comparator,
(a, b),
(sample_inputs...,),
(),
"comparator";
no_args_in_result=true,
Expand All @@ -993,30 +1000,52 @@ end
dimension = MLIR.IR.Attribute(dimension - 1)
is_stable = MLIR.IR.Attribute(is_stable)

res = MLIR.IR.result(
stablehlo.sort(
[x.mlir_data];
result_0=[mlir_type(TracedRArray{T,N}, size(x))],
dimension,
is_stable,
comparator,
location,
),
op = stablehlo.sort(
[x.mlir_data for x in xs];
result_0=[mlir_type(typeof(x), size(x)) for x in xs],
dimension,
is_stable,
comparator,
location,
)
return TracedRArray{T,N}((), res, size(x))
return [
TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}(
(), MLIR.IR.result(op, i), size(xs[i])
) for i in eachindex(xs)
]
end

@noinline function top_k(
x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__)
x::TracedRArray{T,N},
k;
dimension::Integer=N,
location=mlir_stacktrace("top_k", @__FILE__, @__LINE__),
) where {T,N}
@assert 1 <= dimension <= N
if dimension != N # chlo.top_k performs the operation along the last dimension
pdims = collect(Int64, 1:N)
pdims[dimension] = N
pdims[N] = dimension
x = permutedims(x, pdims)
end

rsize = [size(x)[1:(end - 1)]..., k]
values = mlir_type(TracedRArray{T,N}, rsize)
indices = mlir_type(TracedRArray{Int32,N}, rsize)
op = chlo.top_k(x.mlir_data; values, indices, k, location)
return (;
values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize),
indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
)
indices = add(
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
constant(fill(Int32(1), Tuple(rsize))),
) # return the 1-indexed index
indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally
values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize)

if dimension != N
values = permutedims(values, invperm(pdims))
indices = permutedims(indices, invperm(pdims))
end

return (; values, indices)
end

@noinline function iota(
Expand Down
251 changes: 249 additions & 2 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using ..Reactant:
ReactantPrimitive,
WrappedTracedRArray,
AnyTracedRArray,
AnyTracedRVector,
Ops,
MLIR,
ancestor,
Expand All @@ -19,10 +20,12 @@ using ..Reactant:
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array

using ReactantCore: ReactantCore
using GPUArraysCore: GPUArraysCore
using GPUArraysCore: GPUArraysCore, @allowscalar

ReactantCore.is_traced(::TracedRArray) = true

Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)

function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N}
@assert ndims(x) == N
if x isa TracedRArray
Expand Down Expand Up @@ -86,6 +89,17 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh
return idxs
end

function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T<:Number,N}
idx = idx - 1
idxs = (idx % T(sz[1]),)
idx = idx ÷ T(sz[1])
for i in 2:N
idxs = (idxs..., idx % T(sz[i]))
idx = idx ÷ T(sz[i])
end
return idxs
end

function Base.getindex(
a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}}
) where {T,N}
Expand Down Expand Up @@ -509,7 +523,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
res = TracedUtils.promote_to(
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
)
TracedUtils.set_mlir_data!(dest, res.mlir_data)
return dest
end
Expand Down Expand Up @@ -687,4 +704,234 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
return cat(res...; dims)
end

# sort
function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
return sort!(copy(x); alg, order, kwargs...)
end
function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
return sort!(copy(x); alg, order, dims=1, kwargs...)
end

function Base.sort!(
x::AnyTracedRArray;
dims::Union{Integer,Nothing}=nothing,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=missing,
)
if dims === nothing
@assert ndims(x) == 1
dims = 1
end

@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"

comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b))
res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator))
set_mlir_data!(x, get_mlir_data(res))
return x
end

function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
return sortperm!(similar(x, Int), x; alg, order, kwargs...)
end
function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...)
end

function Base.sortperm!(
ix::AnyTracedRArray{Int,N},
x::AnyTracedRArray{<:Any,N};
dims::Union{Integer,Nothing}=nothing,
lt=isless,
by=identity,
rev::Bool=false,
alg=missing,
order=missing,
) where {N}
if dims === nothing
@assert ndims(x) == 1
dims = 1
end

@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"

comparator =
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
idxs = Ops.constant(collect(LinearIndices(x)))
_, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator)
set_mlir_data!(ix, get_mlir_data(res))
return ix
end

function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
values, _ = overloaded_partialsort(x, k; kwargs...)
k = k .- minimum(k) .+ 1
k isa Integer && return @allowscalar(values[k])
return view(values, k)
end

function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...)
values, _ = overloaded_partialsort(x, k; kwargs...)
kget = k .- minimum(k) .+ 1
val = @allowscalar(values[kget])
@allowscalar setindex!(x, val, k)
k isa Integer && return val
return view(x, k)
end

function Base.partialsortperm(
x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...
)
idxs = overloaded_partialsort(x, k; kwargs...)[2]
k = k .- minimum(k) .+ 1
k isa Integer && return @allowscalar(idxs[k])
return view(idxs, k)
end

function Base.partialsortperm!(
ix::AnyTracedRVector{Int},
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
kwargs...,
)
_, idxs = overloaded_partialsort(x, k; kwargs...)
kget = k .- minimum(k) .+ 1
val = @allowscalar(idxs[kget])
@allowscalar setindex!(ix, val, k)
k isa Integer && return val
return view(ix, k)
end

function overloaded_partialsort(
x::AnyTracedRVector,
k::Union{Integer,OrdinalRange};
by=identity,
rev::Bool=false,
lt=isless,
)
if lt !== isless || by !== identity
comparator =
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
idxs = Ops.constant(collect(LinearIndices(x)))
sorted_x, sorted_idxs = Ops.sort(
materialize_traced_array(x), idxs; dimension=1, comparator
)
return sorted_x[1:maximum(k)], sorted_idxs[1:maximum(k)]
end

# XXX: If `maxk` is beyond a threshold should we emit a sort directly?
!rev && (k = length(x) .- k .+ 1)
!(k isa Integer) && (k = maximum(k))
(; values, indices) = Ops.top_k(materialize_traced_array(x), k)
if !rev
values = Ops.reverse(values; dimensions=[1])
indices = Ops.reverse(indices; dimensions=[1])
end
return values, indices
end

# arg* functions
function Base.argmin(f::F, x::AnyTracedRArray) where {F}
idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1
return @allowscalar x[idx...]
end

function Base.argmax(f::F, x::AnyTracedRArray) where {F}
idx = scalar_index_to_cartesian(argmax(f.(x)), size(x)) .+ 1
return @allowscalar x[idx...]
end

Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2]
Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2]

# find* functions
Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x)
Base.findlast(x::AnyTracedRArray) = findlast(identity, x)

function Base.findfirst(f::Function, x::AnyTracedRArray)
fA = materialize_traced_array(vec(f.(x)))
(; indices) = Ops.top_k(fA, 1)
return @allowscalar indices[1]
end

function Base.findlast(f::Function, x::AnyTracedRArray)
fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1])
(; indices) = Ops.top_k(fA, 1)
return length(x) - @allowscalar(indices[1]) + 1
end

Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1)
function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
return findmin(identity, x; dims)
end

Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1)
function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
return findmax(identity, x; dims)
end

## To avoid scalar indexing and constructing an array of tuples, we return the linear index
## instead of the cartesian index
function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
if dims === nothing
if ndims(x) == 1
dims = 1
else
return findmin(f, vec(x); dims=1)
end
end

fx = Ops.negate(materialize_traced_array(f.(x)))
(; values, indices) = Ops.top_k(fx, 1; dimension=dims)

# Compute linear indices
strds = strides(x)
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
for d in eachindex(iotas)
linear_indices = Ops.add(
linear_indices,
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
)
end

values = Ops.negate(values)
ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
return (values, linear_indices)
end

function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing)
if dims === nothing
if ndims(x) == 1
dims = 1
else
return findmax(f, vec(x); dims=1)
end
end

fx = materialize_traced_array(f.(x))
(; values, indices) = Ops.top_k(fx, 1; dimension=dims)

# Compute linear indices
strds = strides(x)
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
for d in eachindex(iotas)
linear_indices = Ops.add(
linear_indices,
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
)
end

ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1])
return (values, linear_indices)
end

end
Loading
Loading