From a18e0756930442ed461b93478a414ea78d4b49b0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Jan 2025 16:31:03 -0500 Subject: [PATCH 01/13] feat: implement sort --- src/TracedRArray.jl | 17 +++++++++++++++++ src/TracedRNumber.jl | 1 + 2 files changed, 18 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8b7835bcd5..9421bb85e0 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -687,4 +687,21 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs) return cat(res...; dims) end +# sort +Base.sort(x::AnyTracedRArray; kwargs...) = sort!(copy(x); kwargs...) + +function Base.sort!( + x::AnyTracedRArray; + dims::Integer, + lt=isless, + by=identity, + rev::Bool=false, + kwargs..., # TODO: implement `order` and `alg` kwargs +) + comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) + res = Ops.sort(materialize_traced_array(x); dimension=dims, comparator) + set_mlir_data!(x, get_mlir_data(res)) + return x +end + end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9dcc6dd39a..365728f5de 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -129,6 +129,7 @@ for (jlop, hloop, hlocomp) in ( (:(Base.:(>)), :compare, "GT"), (:(Base.:(<=)), :compare, "LE"), (:(Base.:(<)), :compare, "LT"), + (:(Base.isless), :compare, "LT"), ) @eval begin function $(jlop)( From de3aacd63113ce5d8d8e8c6d73a692bc4816aa25 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 02:04:47 -0500 Subject: [PATCH 02/13] feat: generalize Ops.sort to take in multiple args --- src/Ops.jl | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 0d94cb75d4..3f2eb8f5a6 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -956,19 +956,26 @@ function broadcast_in_dim( end @noinline function sort( - x::TracedRArray{T,N}; + xs::TracedRArray...; comparator, dimension=1, is_stable=false, location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -) where {T,N} +) #C4: - @assert 0 < dimension <= ndims(x) "$x invalid dimension" + for x in xs + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + end - (a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0))) + sample_inputs = Vector{Reactant.ConcreteRNumber}(undef, length(xs) * 2) + for i in eachindex(xs) + T = Reactant.unwrapped_eltype(xs[i]) + sample_inputs[2i - 1] = Reactant.ConcreteRNumber(T(0)) + sample_inputs[2i] = Reactant.ConcreteRNumber(T(0)) + end func = Reactant.TracedUtils.make_mlir_fn( comparator, - (a, b), + (sample_inputs...,), (), "comparator"; no_args_in_result=true, @@ -993,17 +1000,21 @@ end dimension = MLIR.IR.Attribute(dimension - 1) is_stable = MLIR.IR.Attribute(is_stable) - res = MLIR.IR.result( - stablehlo.sort( - [x.mlir_data]; - result_0=[mlir_type(TracedRArray{T,N}, size(x))], - dimension, - is_stable, - comparator, - location, - ), + op = stablehlo.sort( + [x.mlir_data for x in xs]; + result_0=[mlir_type(typeof(x), size(x)) for x in xs], + dimension, + is_stable, + comparator, + location, ) - return TracedRArray{T,N}((), res, size(x)) + res = [ + TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}( + (), MLIR.IR.result(op, i), size(xs[i]) + ) for i in eachindex(xs) + ] + length(res) == 1 && return only(res) # Kept for backwards compatibility + return res end @noinline function top_k( From 99f3975248447604f9265501f36a06e0d2240f0e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 02:05:57 -0500 Subject: [PATCH 03/13] feat: implement perm related functions --- src/TracedRArray.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9421bb85e0..2847023432 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -704,4 +704,23 @@ function Base.sort!( return x end +Base.sortperm(x::AnyTracedRArray; kwargs...) = sortperm!(similar(x, Int), x; kwargs...) + +function Base.sortperm!( + ix::AnyTracedRArray{Int,N}, + x::AnyTracedRArray{<:Any,N}; + dims::Integer, + lt=isless, + by=identity, + rev::Bool=false, + kwargs..., # TODO: implement `order` and `alg` kwargs +) where {N} + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + _, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator) + set_mlir_data!(ix, get_mlir_data(res)) + return ix +end + end From 7155b79d9a7eabc5d80675f215edbf9928d330f9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 02:35:35 -0500 Subject: [PATCH 04/13] feat: implement partialsort --- src/TracedRArray.jl | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2847023432..1029a4f8b5 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -10,6 +10,7 @@ using ..Reactant: ReactantPrimitive, WrappedTracedRArray, AnyTracedRArray, + AnyTracedRVector, Ops, MLIR, ancestor, @@ -19,7 +20,7 @@ using ..Reactant: using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: GPUArraysCore, @allowscalar ReactantCore.is_traced(::TracedRArray) = true @@ -723,4 +724,34 @@ function Base.sortperm!( return ix end +function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + return partialsort!(copy(x), k; kwargs...) +end + +function Base.partialsort!( + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + by=identity, + rev::Bool=false, + lt=isless, +) + # TODO: general `lt` support + @assert lt === isless "Only `isless` is supported for now in `partialsort!`" + + by_x = by.(x) + if k isa Integer + !rev && (k = length(x) - k + 1) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) + by === identity && return @allowscalar values[k] + return @allowscalar x[indices[k] + 1] + else + klist = collect(Int64, k) + !rev && (klist = length(x) .- klist .+ 1) + maxk = maximum(klist) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) + by === identity && return values[klist] + return x[indices[klist] .+ 1] + end +end + end From 3c4931af961fdfabe95fb1c9d6a38268123fa0e8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 14:42:21 -0500 Subject: [PATCH 05/13] feat: implement argmin and argmax --- src/ConcreteRArray.jl | 2 ++ src/Ops.jl | 6 +++++- src/TracedRArray.jl | 49 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 20f4d6a83c..55fa038396 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -28,6 +28,8 @@ function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T} return ConcreteRNumber(Base.rtoldefault(T)) end +Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...) + # Ensure the device and client are the same as the input function Base.float(x::ConcreteRNumber{T}) where {T} client = XLA.client(x.data) diff --git a/src/Ops.jl b/src/Ops.jl index 3f2eb8f5a6..2873501791 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1024,9 +1024,13 @@ end values = mlir_type(TracedRArray{T,N}, rsize) indices = mlir_type(TracedRArray{Int32,N}, rsize) op = chlo.top_k(x.mlir_data; values, indices, k, location) + indices = add( + TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), + constant(fill(Int32(1), Tuple(rsize))), + ) # return the 1-indexed index return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), + indices, ) end diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 1029a4f8b5..478cbde439 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -24,6 +24,8 @@ using GPUArraysCore: GPUArraysCore, @allowscalar ReactantCore.is_traced(::TracedRArray) = true +Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) + function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} @assert ndims(x) == N if x isa TracedRArray @@ -510,7 +512,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) + res = TracedUtils.promote_to( + TracedRArray{unwrapped_eltype(dest),ndims(dest)}, + TracedUtils.elem_apply(bc.f, args...), + ) TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest end @@ -743,15 +748,53 @@ function Base.partialsort!( !rev && (k = length(x) - k + 1) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) by === identity && return @allowscalar values[k] - return @allowscalar x[indices[k] + 1] + return @allowscalar x[indices[k]] else klist = collect(Int64, k) !rev && (klist = length(x) .- klist .+ 1) maxk = maximum(klist) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) by === identity && return values[klist] - return x[indices[klist] .+ 1] + return x[indices[klist]] + end +end + +function Base.argmin(x::AnyTracedRArray; kwargs...) + return argmax(Ops.negate(materialize_traced_array(x)); kwargs...) +end + +function Base.argmax(x::AnyTracedRVector) + (; indices) = Ops.top_k(materialize_traced_array(x), 1) + return @allowscalar indices[1] +end + +# To avoid scalar indexing and constructing an array of tuples, we return the linear index +# instead of the cartesian index +function Base.argmax(x::AnyTracedRArray{T,N}; dims::Integer) where {T,N} + strds = strides(x) + + if dims != N # chlo.top_k performs the operation along the last dimension + pdims = collect(Int64, 1:N) + pdims[dims] = N + pdims[N] = dims + pdims = Tuple(pdims) + x = permutedims(x, pdims) + end + (; indices) = Ops.top_k(materialize_traced_array(x), 1) + indices = Ops.convert(TracedRArray{Int64,N}, indices) + dims != N && (indices = permutedims(indices, invperm(pdims))) + + # Compute linear indices + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:N] + iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices)))) + linear_indices = Ops.constant(fill(Int64(1), size(indices))) + for d in 1:N + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))), + ) end + return linear_indices end end From 3bd1e8c0d7e3392d2ecb4b55592dda1b4d97c61b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 15:01:17 -0500 Subject: [PATCH 06/13] fix: general support for other kwargs --- src/Ops.jl | 5 +--- src/TracedRArray.jl | 68 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 2873501791..06108b15f3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1028,10 +1028,7 @@ end TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), constant(fill(Int32(1), Tuple(rsize))), ) # return the 1-indexed index - return (; - values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices, - ) + return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), indices) end @noinline function iota( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 478cbde439..cd224c2282 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -694,7 +694,9 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs) end # sort -Base.sort(x::AnyTracedRArray; kwargs...) = sort!(copy(x); kwargs...) +function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, kwargs...) +end function Base.sort!( x::AnyTracedRArray; @@ -702,15 +704,21 @@ function Base.sort!( lt=isless, by=identity, rev::Bool=false, - kwargs..., # TODO: implement `order` and `alg` kwargs + alg=missing, + order=missing, ) + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" + comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) res = Ops.sort(materialize_traced_array(x); dimension=dims, comparator) set_mlir_data!(x, get_mlir_data(res)) return x end -Base.sortperm(x::AnyTracedRArray; kwargs...) = sortperm!(similar(x, Int), x; kwargs...) +function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, kwargs...) +end function Base.sortperm!( ix::AnyTracedRArray{Int,N}, @@ -719,8 +727,12 @@ function Base.sortperm!( lt=isless, by=identity, rev::Bool=false, - kwargs..., # TODO: implement `order` and `alg` kwargs + alg=missing, + order=missing, ) where {N} + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`" + comparator = rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) idxs = Ops.constant(collect(LinearIndices(x))) @@ -743,19 +755,59 @@ function Base.partialsort!( # TODO: general `lt` support @assert lt === isless "Only `isless` is supported for now in `partialsort!`" + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? + by_x = by.(x) + if k isa Integer + !rev && (k = length(x) - k + 1) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) + res = by === identity ? @allowscalar(values[k]) : @allowscalar(x[indices[k]]) + @allowscalar setindex!(ix, res, k) + return res + else + klist = collect(Int64, k) + !rev && (klist = length(x) .- klist .+ 1) + maxk = maximum(klist) + (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) + res = by === identity ? values[klist] : x[indices[klist]] + setindex!(ix, res, klist) + return res + end +end + +function Base.partialsortperm( + x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... +) + return partialsortperm!(similar(x, Int), x, k; kwargs...) +end + +function Base.partialsortperm!( + ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + by=identity, + rev::Bool=false, + lt=isless, +) + # TODO: general `lt` support + @assert lt === isless "Only `isless` is supported for now in `partialsortperm!`" + by_x = by.(x) + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? if k isa Integer !rev && (k = length(x) - k + 1) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) - by === identity && return @allowscalar values[k] - return @allowscalar x[indices[k]] + indices = Ops.convert(TracedRArray{Int64,1}, indices) + idx = @allowscalar indices[k] + @allowscalar setindex!(ix, idx, k) + return idx else klist = collect(Int64, k) !rev && (klist = length(x) .- klist .+ 1) maxk = maximum(klist) (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) - by === identity && return values[klist] - return x[indices[klist]] + indices = Ops.convert(TracedRArray{Int64,1}, indices) + setindex!(ix, indices[klist], klist) + return indices[klist] end end From 2af9a040e4665bf413c28955624dad117edbd1f8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 15:16:49 -0500 Subject: [PATCH 07/13] feat: keep lazy indexing --- src/TracedRArray.jl | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index cd224c2282..a87bf7a7b7 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -777,11 +777,29 @@ end function Base.partialsortperm( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... ) - return partialsortperm!(similar(x, Int), x, k; kwargs...) + idxs = overloaded_partialsortperm(x, k; kwargs...) + k isa Integer && return @allowscalar idxs[k] + return view(idxs, k) end function Base.partialsortperm!( ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + kwargs..., +) + idxs = overloaded_partialsortperm(x, k; kwargs...) + + if k isa Integer + @allowscalar setindex!(ix, idxs[k], k) + return idxs + else + setindex!(ix, idxs[k], k) + return view(ix, k) + end +end + +function overloaded_partialsortperm( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; by=identity, @@ -791,24 +809,11 @@ function Base.partialsortperm!( # TODO: general `lt` support @assert lt === isless "Only `isless` is supported for now in `partialsortperm!`" - by_x = by.(x) - # XXX: If `maxk` is beyond a threshold should we emit a sort directly? - if k isa Integer - !rev && (k = length(x) - k + 1) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) - indices = Ops.convert(TracedRArray{Int64,1}, indices) - idx = @allowscalar indices[k] - @allowscalar setindex!(ix, idx, k) - return idx - else - klist = collect(Int64, k) - !rev && (klist = length(x) .- klist .+ 1) - maxk = maximum(klist) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) - indices = Ops.convert(TracedRArray{Int64,1}, indices) - setindex!(ix, indices[klist], klist) - return indices[klist] - end + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg + !rev && (k = length(x) .- k .+ 1) + !(k isa Integer) && (k = maximum(k)) + (; indices) = Ops.top_k(materialize_traced_array(by.(x)), k) + return Ops.convert(TracedRArray{Int64,1}, indices) end function Base.argmin(x::AnyTracedRArray; kwargs...) From 02a1bfd1729da6fa780413e90b973fb11639c177 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 20:50:41 -0500 Subject: [PATCH 08/13] feat: support lt and by by directly emitting sort --- src/TracedRArray.jl | 87 ++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a87bf7a7b7..a6d297bf27 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -697,6 +697,9 @@ end function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) return sort!(copy(x); alg, order, kwargs...) end +function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, dims=1, kwargs...) +end function Base.sort!( x::AnyTracedRArray; @@ -719,6 +722,9 @@ end function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) return sortperm!(similar(x, Int), x; alg, order, kwargs...) end +function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...) +end function Base.sortperm!( ix::AnyTracedRArray{Int,N}, @@ -742,44 +748,52 @@ function Base.sortperm!( end function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) - return partialsort!(copy(x), k; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + k = k .- minimum(k) .+ 1 + return view(values, k) +end + +function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(values[kget]) + @allowscalar setindex!(x, val, k) + return val end -function Base.partialsort!( +function overloaded_partialsort( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; by=identity, rev::Bool=false, lt=isless, ) - # TODO: general `lt` support - @assert lt === isless "Only `isless` is supported for now in `partialsort!`" + if lt !== isless || by !== identity + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + sorted_x, sorted_idxs = Ops.sort( + materialize_traced_array(x), idxs; dimension=1, comparator + ) + return sorted_x[1:maximum(k)], sorted_idxs[1:maximum(k)] + end # XXX: If `maxk` is beyond a threshold should we emit a sort directly? - by_x = by.(x) - if k isa Integer - !rev && (k = length(x) - k + 1) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), k) - res = by === identity ? @allowscalar(values[k]) : @allowscalar(x[indices[k]]) - @allowscalar setindex!(ix, res, k) - return res - else - klist = collect(Int64, k) - !rev && (klist = length(x) .- klist .+ 1) - maxk = maximum(klist) - (; values, indices) = Ops.top_k(materialize_traced_array(by_x), maxk) - res = by === identity ? values[klist] : x[indices[klist]] - setindex!(ix, res, klist) - return res + !rev && (k = length(x) .- k .+ 1) + !(k isa Integer) && (k = maximum(k)) + (; values, indices) = Ops.top_k(materialize_traced_array(x), k) + indices = Ops.convert(TracedRArray{Int64,1}, indices) + if !rev + values = Ops.reverse(values; dimensions=[1]) + indices = Ops.reverse(indices; dimensions=[1]) end + return values, indices end function Base.partialsortperm( x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... ) - idxs = overloaded_partialsortperm(x, k; kwargs...) - k isa Integer && return @allowscalar idxs[k] - return view(idxs, k) + return view(overloaded_partialsort(x, k; kwargs...)[2], k) end function Base.partialsortperm!( @@ -788,32 +802,9 @@ function Base.partialsortperm!( k::Union{Integer,OrdinalRange}; kwargs..., ) - idxs = overloaded_partialsortperm(x, k; kwargs...) - - if k isa Integer - @allowscalar setindex!(ix, idxs[k], k) - return idxs - else - setindex!(ix, idxs[k], k) - return view(ix, k) - end -end - -function overloaded_partialsortperm( - x::AnyTracedRVector, - k::Union{Integer,OrdinalRange}; - by=identity, - rev::Bool=false, - lt=isless, -) - # TODO: general `lt` support - @assert lt === isless "Only `isless` is supported for now in `partialsortperm!`" - - # XXX: If `maxk` is beyond a threshold should we emit a sort directly? Or do a neg - !rev && (k = length(x) .- k .+ 1) - !(k isa Integer) && (k = maximum(k)) - (; indices) = Ops.top_k(materialize_traced_array(by.(x)), k) - return Ops.convert(TracedRArray{Int64,1}, indices) + _, idxs = overloaded_partialsort(x, k; kwargs...) + @allowscalar setindex!(ix, idxs[k], k) + return view(ix, k) end function Base.argmin(x::AnyTracedRArray; kwargs...) From 3ae42c85a66786797a20ee535abb21bc31dc1878 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 21:44:48 -0500 Subject: [PATCH 09/13] feat: more argmin/argmax support and testing --- src/TracedRArray.jl | 29 +++++++++++++++++++++++++++-- test/ops.jl | 9 ++++++++- test/runtests.jl | 1 + test/sorting.jl | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 test/sorting.jl diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a6d297bf27..d9e1bf4c10 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -89,6 +89,17 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh return idxs end +function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T <: Number, N} + idx = idx - 1 + idxs = (idx % T(sz[1]),) + idx = idx ÷ T(sz[1]) + for i in 2:N + idxs = (idxs..., idx % T(sz[i])) + idx = idx ÷ T(sz[i]) + end + return idxs +end + function Base.getindex( a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}} ) where {T,N} @@ -807,18 +818,32 @@ function Base.partialsortperm!( return view(ix, k) end +function Base.argmin(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + +function Base.argmax(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmax(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + function Base.argmin(x::AnyTracedRArray; kwargs...) return argmax(Ops.negate(materialize_traced_array(x)); kwargs...) end function Base.argmax(x::AnyTracedRVector) (; indices) = Ops.top_k(materialize_traced_array(x), 1) - return @allowscalar indices[1] + return @allowscalar TracedRNumber{Int64}(indices[1]) end # To avoid scalar indexing and constructing an array of tuples, we return the linear index # instead of the cartesian index -function Base.argmax(x::AnyTracedRArray{T,N}; dims::Integer) where {T,N} +function Base.argmax( + x::AnyTracedRArray{T,N}; dims::Union{Integer,Nothing}=nothing +) where {T,N} + dims === nothing && return argmax(vec(x)) + strds = strides(x) if dims != N # chlo.top_k performs the operation along the last dimension diff --git a/test/ops.jl b/test/ops.jl index 0223eeed97..4f735ee6fe 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -916,7 +916,14 @@ end @testset "top_k" begin x = ConcreteRArray([1, 2, 3, 4]) - @test (; values=[4, 3], indices=[3, 2]) == @jit Ops.top_k(x, 2) + @test (; values=[4, 3], indices=[4, 3]) == @jit Ops.top_k(x, 2) + + x = ConcreteRArray([NaN, 123, 456, 789, 121]) + res = @jit Ops.top_k(x, 2) + true_res = (; values=[NaN, 789], indices=[1, 4]) + @test res.indices == true_res.indices + @test @allowscalar isnan(res.values[1]) + @test @allowscalar res.values[2] == 789 end @testset "zeta" begin diff --git a/test/runtests.jl b/test/runtests.jl index 8cc161fef9..7d188fe3dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") + @safetestset "Sorting" include("sorting.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/sorting.jl b/test/sorting.jl new file mode 100644 index 0000000000..151c73beb1 --- /dev/null +++ b/test/sorting.jl @@ -0,0 +1,41 @@ +using Reactant, Test + +@testset "sort" begin end + +@testset "sortperm" begin end + +@testset "partialsort" begin end + +@testset "partialsortperm" begin end + +@testset "argmin / argmax" begin + x = rand(2, 3) + x_ra = Reactant.to_rarray(x) + + linargmin(x) = LinearIndices(x)[argmin(x)] + linargmax(x) = LinearIndices(x)[argmax(x)] + + @test linargmin(x) == @jit(argmin(x_ra)) + @test linargmax(x) == @jit(argmax(x_ra)) + + x = rand(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + linargmin(x, d) = LinearIndices(x)[argmin(x; dims=d)] + linargmax(x, d) = LinearIndices(x)[argmax(x; dims=d)] + argmindims(x, d) = argmin(x; dims=d) + argmaxdims(x, d) = argmax(x; dims=d) + + @test linargmin(x, 1) == @jit(argmindims(x_ra, 1)) + @test linargmax(x, 1) == @jit(argmaxdims(x_ra, 1)) + @test linargmin(x, 2) == @jit(argmindims(x_ra, 2)) + @test linargmax(x, 2) == @jit(argmaxdims(x_ra, 2)) + @test linargmin(x, 3) == @jit(argmindims(x_ra, 3)) + @test linargmax(x, 3) == @jit(argmaxdims(x_ra, 3)) + + x = randn(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test argmin(abs2, x) == @jit(argmin(abs2, x_ra)) + @test argmax(abs2, x) == @jit(argmax(abs2, x_ra)) +end From 9031e4ac61eb2506de088dfca31b0512c82811c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 22:45:50 -0500 Subject: [PATCH 10/13] fix: always return tuple from sort --- src/Ops.jl | 4 +--- src/TracedRArray.jl | 2 +- test/ops.jl | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 06108b15f3..04f393152c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1008,13 +1008,11 @@ end comparator, location, ) - res = [ + return [ TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}( (), MLIR.IR.result(op, i), size(xs[i]) ) for i in eachindex(xs) ] - length(res) == 1 && return only(res) # Kept for backwards compatibility - return res end @noinline function top_k( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index d9e1bf4c10..c33a96f28a 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -725,7 +725,7 @@ function Base.sort!( @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) - res = Ops.sort(materialize_traced_array(x); dimension=dims, comparator) + res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator)) set_mlir_data!(x, get_mlir_data(res)) return x end diff --git a/test/ops.jl b/test/ops.jl index 4f735ee6fe..45ba719165 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -689,7 +689,7 @@ end end @testset "sort" begin - basic_sort(x, dimension) = Ops.sort(x; comparator=(a, b) -> a < b, dimension) + basic_sort(x, dimension) = only(Ops.sort(x; comparator=(a, b) -> a < b, dimension)) @testset for i in 1:3 t_size = tuple(fill(10, (i,))...) x = randn(t_size) From b92d65ba6b8e49574eed79691a4065059a12af27 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 14 Jan 2025 23:44:02 -0500 Subject: [PATCH 11/13] feat: findmin/findmax/findlast/findfirst --- src/Ops.jl | 23 ++++++++++- src/TracedRArray.jl | 93 ++++++++++++++++++++++++++++++++------------ src/TracedRNumber.jl | 13 +++++++ 3 files changed, 102 insertions(+), 27 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 04f393152c..8a2c15798f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1016,8 +1016,19 @@ end end @noinline function top_k( - x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) + x::TracedRArray{T,N}, + k; + dimension::Integer=N, + location=mlir_stacktrace("top_k", @__FILE__, @__LINE__), ) where {T,N} + @assert 1 <= dimension <= N + if dimension != N # chlo.top_k performs the operation along the last dimension + pdims = collect(Int64, 1:N) + pdims[dimension] = N + pdims[N] = dimension + x = permutedims(x, pdims) + end + rsize = [size(x)[1:(end - 1)]..., k] values = mlir_type(TracedRArray{T,N}, rsize) indices = mlir_type(TracedRArray{Int32,N}, rsize) @@ -1026,7 +1037,15 @@ end TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), constant(fill(Int32(1), Tuple(rsize))), ) # return the 1-indexed index - return (; values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), indices) + indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally + values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize) + + if dimension != N + values = permutedims(values, invperm(pdims)) + indices = permutedims(indices, invperm(pdims)) + end + + return (; values, indices) end @noinline function iota( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index c33a96f28a..b1d0f72e3b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -89,7 +89,7 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh return idxs end -function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T <: Number, N} +function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T<:Number,N} idx = idx - 1 idxs = (idx % T(sz[1]),) idx = idx ÷ T(sz[1]) @@ -793,7 +793,6 @@ function overloaded_partialsort( !rev && (k = length(x) .- k .+ 1) !(k isa Integer) && (k = maximum(k)) (; values, indices) = Ops.top_k(materialize_traced_array(x), k) - indices = Ops.convert(TracedRArray{Int64,1}, indices) if !rev values = Ops.reverse(values; dimensions=[1]) indices = Ops.reverse(indices; dimensions=[1]) @@ -818,6 +817,7 @@ function Base.partialsortperm!( return view(ix, k) end +# arg* functions function Base.argmin(f::F, x::AnyTracedRArray) where {F} idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1 return @allowscalar x[idx...] @@ -828,46 +828,89 @@ function Base.argmax(f::F, x::AnyTracedRArray) where {F} return @allowscalar x[idx...] end -function Base.argmin(x::AnyTracedRArray; kwargs...) - return argmax(Ops.negate(materialize_traced_array(x)); kwargs...) +Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2] +Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2] + +# find* functions +Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x) +Base.findlast(x::AnyTracedRArray) = findlast(identity, x) + +function Base.findfirst(f::Function, x::AnyTracedRArray) + fA = f.(x) + (; indices) = Ops.top_k(materialize_traced_array(fA), 1) + return @allowscalar indices[1] end -function Base.argmax(x::AnyTracedRVector) - (; indices) = Ops.top_k(materialize_traced_array(x), 1) - return @allowscalar TracedRNumber{Int64}(indices[1]) +function Base.findlast(f::Function, x::AnyTracedRArray) + fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1]) + (; indices) = Ops.top_k(materialize_traced_array(fA), 1) + return length(x) - @allowscalar(indices[1]) + 1 end -# To avoid scalar indexing and constructing an array of tuples, we return the linear index -# instead of the cartesian index -function Base.argmax( - x::AnyTracedRArray{T,N}; dims::Union{Integer,Nothing}=nothing -) where {T,N} - dims === nothing && return argmax(vec(x)) +Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1) +function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmin(identity, x; dims) +end + +Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1) +function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmax(identity, x; dims) +end + +## To avoid scalar indexing and constructing an array of tuples, we return the linear index +## instead of the cartesian index +function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmin(f, vec(x); dims=1) + end + end + + fx = Ops.negate(materialize_traced_array(f.(x))) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) + # Compute linear indices strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] + iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices)))) + linear_indices = Ops.constant(fill(Int64(1), size(indices))) + for d in eachindex(iotas) + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))), + ) + end + + return (Ops.negate(values), linear_indices) +end - if dims != N # chlo.top_k performs the operation along the last dimension - pdims = collect(Int64, 1:N) - pdims[dims] = N - pdims[N] = dims - pdims = Tuple(pdims) - x = permutedims(x, pdims) +function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmax(f, vec(x); dims=1) + end end - (; indices) = Ops.top_k(materialize_traced_array(x), 1) - indices = Ops.convert(TracedRArray{Int64,N}, indices) - dims != N && (indices = permutedims(indices, invperm(pdims))) + + fx = materialize_traced_array(f.(x)) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) # Compute linear indices - iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:N] + strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices)))) linear_indices = Ops.constant(fill(Int64(1), size(indices))) - for d in 1:N + for d in eachindex(iotas) linear_indices = Ops.add( linear_indices, Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))), ) end - return linear_indices + + return (values, linear_indices) end end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 365728f5de..8f01d4bdb9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -247,9 +247,22 @@ Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x) Base.real(x::TracedRNumber) = x Base.real(x::TracedRNumber{<:Complex}) = Ops.real(x) +Base.isreal(::TracedRNumber) = false +Base.isreal(::TracedRNumber{<:Real}) = true + Base.imag(x::TracedRNumber) = zero(x) Base.imag(x::TracedRNumber{<:Complex}) = Ops.imag(x) +Base.iseven(x::TracedRNumber) = iseven(real(x)) +function Base.iseven(x::TracedRNumber{<:Real}) + return iszero( + rem( + TracedUtils.promote_to(TracedRNumber{Int}, x), + TracedUtils.promote_to(TracedRNumber{Int}, 2), + ), + ) +end + for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end From c98e2c3b5e66456d13c28c56cc473b9e73ef60eb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Jan 2025 11:57:51 -0500 Subject: [PATCH 12/13] fix: more tests and fixes for find functions --- src/TracedRArray.jl | 11 ++++++---- test/sorting.jl | 51 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index b1d0f72e3b..e6bea2be8b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -836,14 +836,14 @@ Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x) Base.findlast(x::AnyTracedRArray) = findlast(identity, x) function Base.findfirst(f::Function, x::AnyTracedRArray) - fA = f.(x) - (; indices) = Ops.top_k(materialize_traced_array(fA), 1) + fA = materialize_traced_array(vec(f.(x))) + (; indices) = Ops.top_k(fA, 1) return @allowscalar indices[1] end function Base.findlast(f::Function, x::AnyTracedRArray) fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1]) - (; indices) = Ops.top_k(materialize_traced_array(fA), 1) + (; indices) = Ops.top_k(fA, 1) return length(x) - @allowscalar(indices[1]) + 1 end @@ -883,7 +883,9 @@ function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin ) end - return (Ops.negate(values), linear_indices) + values = Ops.negate(values) + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) + return (values, linear_indices) end function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) @@ -910,6 +912,7 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin ) end + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) return (values, linear_indices) end diff --git a/test/sorting.jl b/test/sorting.jl index 151c73beb1..f1881ba78a 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -39,3 +39,54 @@ using Reactant, Test @test argmin(abs2, x) == @jit(argmin(abs2, x_ra)) @test argmax(abs2, x) == @jit(argmax(abs2, x_ra)) end + +@testset "findmin / findmax" begin + xvec = randn(10) + xvec_ra = Reactant.to_rarray(xvec) + + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) + + function fwithlinindices(g, f, x; kwargs...) + values, indices = g(f, x; kwargs...) + return values, LinearIndices(x)[indices] + end + + @test fwithlinindices(findmin, identity, x) == @jit(findmin(x_ra)) + @test fwithlinindices(findmax, identity, x) == @jit(findmax(x_ra)) + @test fwithlinindices(findmin, identity, xvec) == @jit(findmin(xvec_ra)) + @test fwithlinindices(findmax, identity, xvec) == @jit(findmax(xvec_ra)) + + fmindims(x, d) = findmin(x; dims=d) + fmindims(f, x, d) = findmin(f, x; dims=d) + fmaxdims(x, d) = findmax(x; dims=d) + fmaxdims(f, x, d) = findmax(f, x; dims=d) + + @test fwithlinindices(findmin, identity, x; dims=1) == @jit(fmindims(x_ra, 1)) + @test fwithlinindices(findmax, identity, x; dims=1) == @jit(fmaxdims(x_ra, 1)) + @test fwithlinindices(findmin, identity, x; dims=2) == @jit(fmindims(x_ra, 2)) + @test fwithlinindices(findmax, identity, x; dims=2) == @jit(fmaxdims(x_ra, 2)) + @test fwithlinindices(findmin, abs2, x; dims=1) == @jit(fmindims(abs2, x_ra, 1)) + @test fwithlinindices(findmax, abs2, x; dims=1) == @jit(fmaxdims(abs2, x_ra, 1)) + @test fwithlinindices(findmin, abs2, x; dims=2) == @jit(fmindims(abs2, x_ra, 2)) + @test fwithlinindices(findmax, abs2, x; dims=2) == @jit(fmaxdims(abs2, x_ra, 2)) +end + +@testset "findfirst / findlast" begin + x = rand(Bool, 3, 4) + x_ra = Reactant.to_rarray(x) + + ffirstlinindices(x) = LinearIndices(x)[findfirst(x)] + ffirstlinindices(f, x) = LinearIndices(x)[findfirst(f, x)] + flastlinindices(x) = LinearIndices(x)[findlast(x)] + flastlinindices(f, x) = LinearIndices(x)[findlast(f, x)] + + @test ffirstlinindices(x) == @jit(findfirst(x_ra)) + @test flastlinindices(x) == @jit(findlast(x_ra)) + + x = rand(1:256, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra)) + @test flastlinindices(iseven, x) == @jit(findlast(iseven, x_ra)) +end From dafa186c5cfbe81d11fce37e1f5e6829cd292477 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Jan 2025 14:35:16 -0500 Subject: [PATCH 13/13] test: sort and partial sort functions --- src/TracedRArray.jl | 58 +++++++++++++++--------- test/sorting.jl | 106 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 24 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e6bea2be8b..2c135c4e38 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -714,13 +714,18 @@ end function Base.sort!( x::AnyTracedRArray; - dims::Integer, + dims::Union{Integer,Nothing}=nothing, lt=isless, by=identity, rev::Bool=false, alg=missing, order=missing, ) + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`" @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" @@ -740,13 +745,18 @@ end function Base.sortperm!( ix::AnyTracedRArray{Int,N}, x::AnyTracedRArray{<:Any,N}; - dims::Integer, + dims::Union{Integer,Nothing}=nothing, lt=isless, by=identity, rev::Bool=false, alg=missing, order=missing, ) where {N} + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`" @assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`" @@ -761,6 +771,7 @@ end function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) values, _ = overloaded_partialsort(x, k; kwargs...) k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(values[k]) return view(values, k) end @@ -769,7 +780,31 @@ function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kget = k .- minimum(k) .+ 1 val = @allowscalar(values[kget]) @allowscalar setindex!(x, val, k) - return val + k isa Integer && return val + return view(x, k) +end + +function Base.partialsortperm( + x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... +) + idxs = overloaded_partialsort(x, k; kwargs...)[2] + k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(idxs[k]) + return view(idxs, k) +end + +function Base.partialsortperm!( + ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + kwargs..., +) + _, idxs = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(idxs[kget]) + @allowscalar setindex!(ix, val, k) + k isa Integer && return val + return view(ix, k) end function overloaded_partialsort( @@ -800,23 +835,6 @@ function overloaded_partialsort( return values, indices end -function Base.partialsortperm( - x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... -) - return view(overloaded_partialsort(x, k; kwargs...)[2], k) -end - -function Base.partialsortperm!( - ix::AnyTracedRVector{Int}, - x::AnyTracedRVector, - k::Union{Integer,OrdinalRange}; - kwargs..., -) - _, idxs = overloaded_partialsort(x, k; kwargs...) - @allowscalar setindex!(ix, idxs[k], k) - return view(ix, k) -end - # arg* functions function Base.argmin(f::F, x::AnyTracedRArray) where {F} idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1 diff --git a/test/sorting.jl b/test/sorting.jl index f1881ba78a..d54699fa57 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -1,12 +1,110 @@ using Reactant, Test -@testset "sort" begin end +@testset "sort & sortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) -@testset "sortperm" begin end + srt_rev(x) = sort(x; rev=true) + srtperm_rev(x) = sortperm(x; rev=true) + srt_by(x) = sort(x; by=abs2) + srtperm_by(x) = sortperm(x; by=abs2) + srt_lt(x) = sort(x; lt=(a, b) -> a > b) + srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b) + + @test @jit(sort(x_ra)) == sort(x) + @test @jit(srt_rev(x_ra)) == srt_rev(x) + @test @jit(srt_lt(x_ra)) == srt_lt(x) + @test @jit(srt_by(x_ra)) == srt_by(x) + @test @jit(sortperm(x_ra)) == sortperm(x) + @test @jit(srtperm_rev(x_ra)) == srtperm_rev(x) + @test @jit(srtperm_lt(x_ra)) == srtperm_lt(x) + @test @jit(srtperm_by(x_ra)) == srtperm_by(x) + + x = rand(10) + x_ra = Reactant.to_rarray(x) + @jit sort!(x_ra) + @test x_ra == sort(x) -@testset "partialsort" begin end + x = rand(10) + x_ra = Reactant.to_rarray(x) + ix = similar(x_ra, Int) + @jit sortperm!(ix, x_ra) + @test ix == sortperm(x) + + x = rand(10, 4, 3) + x_ra = Reactant.to_rarray(x) -@testset "partialsortperm" begin end + srt(x, d) = sort(x; dims=d) + srt_rev(x, d) = sort(x; dims=d, rev=true) + srt_by(x, d) = sort(x; dims=d, by=abs2) + srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b) + srtperm(x, d) = sortperm(x; dims=d) + srtperm_rev(x, d) = sortperm(x; dims=d, rev=true) + srtperm_by(x, d) = sortperm(x; dims=d, by=abs2) + srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b) + + @testset for d in 1:ndims(x) + @test @jit(srt(x_ra, d)) == srt(x, d) + @test @jit(srtperm(x_ra, d)) == srtperm(x, d) + @test @jit(srt_rev(x_ra, d)) == srt_rev(x, d) + @test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d) + @test @jit(srt_by(x_ra, d)) == srt_by(x, d) + @test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d) + @test @jit(srt_lt(x_ra, d)) == srt_lt(x, d) + @test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d) + end +end + +@testset "partialsort & partialsortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) + + @test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5) + @test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5) + @test @jit(partialsort(x_ra, 4)) == partialsort(x, 4) + @test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4) + + psrt_rev(x, k) = partialsort(x, k; rev=true) + psrtperm_rev(x, k) = partialsortperm(x, k; rev=true) + psrt_by(x, k) = partialsort(x, k; by=abs2) + psrtperm_by(x, k) = partialsortperm(x, k; by=abs2) + psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b) + psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b) + + @test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5) + @test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5) + @test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5) + @test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5) + @test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5) + @test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5) + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 1:5) + partialsort!(x, 1:5) + @test Array(x_ra)[1:5] == x[1:5] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 3) + partialsort!(x, 3) + @test @allowscalar(x_ra[3]) == x[3] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 1:5) + partialsortperm!(ix, x, 1:5) + @test Array(ix_ra)[1:5] == ix[1:5] + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 3) + partialsortperm!(ix, x, 3) + @test @allowscalar(ix_ra[3]) == ix[3] +end @testset "argmin / argmax" begin x = rand(2, 3)