diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 20f4d6a83c..55fa038396 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -28,6 +28,8 @@ function Base.rtoldefault(::Type{ConcreteRNumber{T}}) where {T} return ConcreteRNumber(Base.rtoldefault(T)) end +Base.strides(x::ConcreteRArray) = Base.size_to_strides(1, size(x)...) + # Ensure the device and client are the same as the input function Base.float(x::ConcreteRNumber{T}) where {T} client = XLA.client(x.data) diff --git a/src/Ops.jl b/src/Ops.jl index 0d94cb75d4..8a2c15798f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -956,19 +956,26 @@ function broadcast_in_dim( end @noinline function sort( - x::TracedRArray{T,N}; + xs::TracedRArray...; comparator, dimension=1, is_stable=false, location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -) where {T,N} +) #C4: - @assert 0 < dimension <= ndims(x) "$x invalid dimension" + for x in xs + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + end - (a, b) = (Reactant.ConcreteRNumber(T(0)), Reactant.ConcreteRNumber(T(0))) + sample_inputs = Vector{Reactant.ConcreteRNumber}(undef, length(xs) * 2) + for i in eachindex(xs) + T = Reactant.unwrapped_eltype(xs[i]) + sample_inputs[2i - 1] = Reactant.ConcreteRNumber(T(0)) + sample_inputs[2i] = Reactant.ConcreteRNumber(T(0)) + end func = Reactant.TracedUtils.make_mlir_fn( comparator, - (a, b), + (sample_inputs...,), (), "comparator"; no_args_in_result=true, @@ -993,30 +1000,52 @@ end dimension = MLIR.IR.Attribute(dimension - 1) is_stable = MLIR.IR.Attribute(is_stable) - res = MLIR.IR.result( - stablehlo.sort( - [x.mlir_data]; - result_0=[mlir_type(TracedRArray{T,N}, size(x))], - dimension, - is_stable, - comparator, - location, - ), + op = stablehlo.sort( + [x.mlir_data for x in xs]; + result_0=[mlir_type(typeof(x), size(x)) for x in xs], + dimension, + is_stable, + comparator, + location, ) - return TracedRArray{T,N}((), res, size(x)) + return [ + TracedRArray{Reactant.unwrapped_eltype(xs[i]),ndims(xs[i])}( + (), MLIR.IR.result(op, i), size(xs[i]) + ) for i in eachindex(xs) + ] end @noinline function top_k( - x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) + x::TracedRArray{T,N}, + k; + dimension::Integer=N, + location=mlir_stacktrace("top_k", @__FILE__, @__LINE__), ) where {T,N} + @assert 1 <= dimension <= N + if dimension != N # chlo.top_k performs the operation along the last dimension + pdims = collect(Int64, 1:N) + pdims[dimension] = N + pdims[N] = dimension + x = permutedims(x, pdims) + end + rsize = [size(x)[1:(end - 1)]..., k] values = mlir_type(TracedRArray{T,N}, rsize) indices = mlir_type(TracedRArray{Int32,N}, rsize) op = chlo.top_k(x.mlir_data; values, indices, k, location) - return (; - values=TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize), - indices=TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), - ) + indices = add( + TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize), + constant(fill(Int32(1), Tuple(rsize))), + ) # return the 1-indexed index + indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally + values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize) + + if dimension != N + values = permutedims(values, invperm(pdims)) + indices = permutedims(indices, invperm(pdims)) + end + + return (; values, indices) end @noinline function iota( diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8b7835bcd5..2c135c4e38 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -10,6 +10,7 @@ using ..Reactant: ReactantPrimitive, WrappedTracedRArray, AnyTracedRArray, + AnyTracedRVector, Ops, MLIR, ancestor, @@ -19,10 +20,12 @@ using ..Reactant: using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array using ReactantCore: ReactantCore -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: GPUArraysCore, @allowscalar ReactantCore.is_traced(::TracedRArray) = true +Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...) + function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} @assert ndims(x) == N if x isa TracedRArray @@ -86,6 +89,17 @@ function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) wh return idxs end +function scalar_index_to_cartesian(idx::T, sz::NTuple{N,Int}) where {T<:Number,N} + idx = idx - 1 + idxs = (idx % T(sz[1]),) + idx = idx ÷ T(sz[1]) + for i in 2:N + idxs = (idxs..., idx % T(sz[i])) + idx = idx ÷ T(sz[i]) + end + return idxs +end + function Base.getindex( a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}} ) where {T,N} @@ -509,7 +523,10 @@ function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = TracedUtils.elem_apply(bc.f, args...) + res = TracedUtils.promote_to( + TracedRArray{unwrapped_eltype(dest),ndims(dest)}, + TracedUtils.elem_apply(bc.f, args...), + ) TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest end @@ -687,4 +704,234 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs) return cat(res...; dims) end +# sort +function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, kwargs...) +end +function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sort!(copy(x); alg, order, dims=1, kwargs...) +end + +function Base.sort!( + x::AnyTracedRArray; + dims::Union{Integer,Nothing}=nothing, + lt=isless, + by=identity, + rev::Bool=false, + alg=missing, + order=missing, +) + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sort!`" + + comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b)) + res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator)) + set_mlir_data!(x, get_mlir_data(res)) + return x +end + +function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, kwargs...) +end +function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...) + return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...) +end + +function Base.sortperm!( + ix::AnyTracedRArray{Int,N}, + x::AnyTracedRArray{<:Any,N}; + dims::Union{Integer,Nothing}=nothing, + lt=isless, + by=identity, + rev::Bool=false, + alg=missing, + order=missing, +) where {N} + if dims === nothing + @assert ndims(x) == 1 + dims = 1 + end + + @assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`" + @assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`" + + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + _, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator) + set_mlir_data!(ix, get_mlir_data(res)) + return ix +end + +function Base.partialsort(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(values[k]) + return view(values, k) +end + +function Base.partialsort!(x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs...) + values, _ = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(values[kget]) + @allowscalar setindex!(x, val, k) + k isa Integer && return val + return view(x, k) +end + +function Base.partialsortperm( + x::AnyTracedRVector, k::Union{Integer,OrdinalRange}; kwargs... +) + idxs = overloaded_partialsort(x, k; kwargs...)[2] + k = k .- minimum(k) .+ 1 + k isa Integer && return @allowscalar(idxs[k]) + return view(idxs, k) +end + +function Base.partialsortperm!( + ix::AnyTracedRVector{Int}, + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + kwargs..., +) + _, idxs = overloaded_partialsort(x, k; kwargs...) + kget = k .- minimum(k) .+ 1 + val = @allowscalar(idxs[kget]) + @allowscalar setindex!(ix, val, k) + k isa Integer && return val + return view(ix, k) +end + +function overloaded_partialsort( + x::AnyTracedRVector, + k::Union{Integer,OrdinalRange}; + by=identity, + rev::Bool=false, + lt=isless, +) + if lt !== isless || by !== identity + comparator = + rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b)) + idxs = Ops.constant(collect(LinearIndices(x))) + sorted_x, sorted_idxs = Ops.sort( + materialize_traced_array(x), idxs; dimension=1, comparator + ) + return sorted_x[1:maximum(k)], sorted_idxs[1:maximum(k)] + end + + # XXX: If `maxk` is beyond a threshold should we emit a sort directly? + !rev && (k = length(x) .- k .+ 1) + !(k isa Integer) && (k = maximum(k)) + (; values, indices) = Ops.top_k(materialize_traced_array(x), k) + if !rev + values = Ops.reverse(values; dimensions=[1]) + indices = Ops.reverse(indices; dimensions=[1]) + end + return values, indices +end + +# arg* functions +function Base.argmin(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmin(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + +function Base.argmax(f::F, x::AnyTracedRArray) where {F} + idx = scalar_index_to_cartesian(argmax(f.(x)), size(x)) .+ 1 + return @allowscalar x[idx...] +end + +Base.argmin(x::AnyTracedRArray; kwargs...) = findmin(identity, x; kwargs...)[2] +Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2] + +# find* functions +Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x) +Base.findlast(x::AnyTracedRArray) = findlast(identity, x) + +function Base.findfirst(f::Function, x::AnyTracedRArray) + fA = materialize_traced_array(vec(f.(x))) + (; indices) = Ops.top_k(fA, 1) + return @allowscalar indices[1] +end + +function Base.findlast(f::Function, x::AnyTracedRArray) + fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1]) + (; indices) = Ops.top_k(fA, 1) + return length(x) - @allowscalar(indices[1]) + 1 +end + +Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1) +function Base.findmin(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmin(identity, x; dims) +end + +Base.findmax(x::AnyTracedRVector) = findmax(identity, x; dims=1) +function Base.findmax(x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + return findmax(identity, x; dims) +end + +## To avoid scalar indexing and constructing an array of tuples, we return the linear index +## instead of the cartesian index +function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmin(f, vec(x); dims=1) + end + end + + fx = Ops.negate(materialize_traced_array(f.(x))) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) + + # Compute linear indices + strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] + iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices)))) + linear_indices = Ops.constant(fill(Int64(1), size(indices))) + for d in eachindex(iotas) + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))), + ) + end + + values = Ops.negate(values) + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) + return (values, linear_indices) +end + +function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothing) + if dims === nothing + if ndims(x) == 1 + dims = 1 + else + return findmax(f, vec(x); dims=1) + end + end + + fx = materialize_traced_array(f.(x)) + (; values, indices) = Ops.top_k(fx, 1; dimension=dims) + + # Compute linear indices + strds = strides(x) + iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)] + iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices)))) + linear_indices = Ops.constant(fill(Int64(1), size(indices))) + for d in eachindex(iotas) + linear_indices = Ops.add( + linear_indices, + Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))), + ) + end + + ndims(x) == 1 && return @allowscalar (values[1], linear_indices[1]) + return (values, linear_indices) +end + end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9dcc6dd39a..8f01d4bdb9 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -129,6 +129,7 @@ for (jlop, hloop, hlocomp) in ( (:(Base.:(>)), :compare, "GT"), (:(Base.:(<=)), :compare, "LE"), (:(Base.:(<)), :compare, "LT"), + (:(Base.isless), :compare, "LT"), ) @eval begin function $(jlop)( @@ -246,9 +247,22 @@ Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x) Base.real(x::TracedRNumber) = x Base.real(x::TracedRNumber{<:Complex}) = Ops.real(x) +Base.isreal(::TracedRNumber) = false +Base.isreal(::TracedRNumber{<:Real}) = true + Base.imag(x::TracedRNumber) = zero(x) Base.imag(x::TracedRNumber{<:Complex}) = Ops.imag(x) +Base.iseven(x::TracedRNumber) = iseven(real(x)) +function Base.iseven(x::TracedRNumber{<:Real}) + return iszero( + rem( + TracedUtils.promote_to(TracedRNumber{Int}, x), + TracedUtils.promote_to(TracedRNumber{Int}, 2), + ), + ) +end + for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end diff --git a/test/ops.jl b/test/ops.jl index 0223eeed97..45ba719165 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -689,7 +689,7 @@ end end @testset "sort" begin - basic_sort(x, dimension) = Ops.sort(x; comparator=(a, b) -> a < b, dimension) + basic_sort(x, dimension) = only(Ops.sort(x; comparator=(a, b) -> a < b, dimension)) @testset for i in 1:3 t_size = tuple(fill(10, (i,))...) x = randn(t_size) @@ -916,7 +916,14 @@ end @testset "top_k" begin x = ConcreteRArray([1, 2, 3, 4]) - @test (; values=[4, 3], indices=[3, 2]) == @jit Ops.top_k(x, 2) + @test (; values=[4, 3], indices=[4, 3]) == @jit Ops.top_k(x, 2) + + x = ConcreteRArray([NaN, 123, 456, 789, 121]) + res = @jit Ops.top_k(x, 2) + true_res = (; values=[NaN, 789], indices=[1, 4]) + @test res.indices == true_res.indices + @test @allowscalar isnan(res.values[1]) + @test @allowscalar res.values[2] == 789 end @testset "zeta" begin diff --git a/test/runtests.jl b/test/runtests.jl index 8cc161fef9..7d188fe3dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Shortcuts to MLIR ops" include("ops.jl") @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") @safetestset "Control Flow" include("control_flow.jl") + @safetestset "Sorting" include("sorting.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/sorting.jl b/test/sorting.jl new file mode 100644 index 0000000000..d54699fa57 --- /dev/null +++ b/test/sorting.jl @@ -0,0 +1,190 @@ +using Reactant, Test + +@testset "sort & sortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) + + srt_rev(x) = sort(x; rev=true) + srtperm_rev(x) = sortperm(x; rev=true) + srt_by(x) = sort(x; by=abs2) + srtperm_by(x) = sortperm(x; by=abs2) + srt_lt(x) = sort(x; lt=(a, b) -> a > b) + srtperm_lt(x) = sortperm(x; lt=(a, b) -> a > b) + + @test @jit(sort(x_ra)) == sort(x) + @test @jit(srt_rev(x_ra)) == srt_rev(x) + @test @jit(srt_lt(x_ra)) == srt_lt(x) + @test @jit(srt_by(x_ra)) == srt_by(x) + @test @jit(sortperm(x_ra)) == sortperm(x) + @test @jit(srtperm_rev(x_ra)) == srtperm_rev(x) + @test @jit(srtperm_lt(x_ra)) == srtperm_lt(x) + @test @jit(srtperm_by(x_ra)) == srtperm_by(x) + + x = rand(10) + x_ra = Reactant.to_rarray(x) + @jit sort!(x_ra) + @test x_ra == sort(x) + + x = rand(10) + x_ra = Reactant.to_rarray(x) + ix = similar(x_ra, Int) + @jit sortperm!(ix, x_ra) + @test ix == sortperm(x) + + x = rand(10, 4, 3) + x_ra = Reactant.to_rarray(x) + + srt(x, d) = sort(x; dims=d) + srt_rev(x, d) = sort(x; dims=d, rev=true) + srt_by(x, d) = sort(x; dims=d, by=abs2) + srt_lt(x, d) = sort(x; dims=d, lt=(a, b) -> a > b) + srtperm(x, d) = sortperm(x; dims=d) + srtperm_rev(x, d) = sortperm(x; dims=d, rev=true) + srtperm_by(x, d) = sortperm(x; dims=d, by=abs2) + srtperm_lt(x, d) = sortperm(x; dims=d, lt=(a, b) -> a > b) + + @testset for d in 1:ndims(x) + @test @jit(srt(x_ra, d)) == srt(x, d) + @test @jit(srtperm(x_ra, d)) == srtperm(x, d) + @test @jit(srt_rev(x_ra, d)) == srt_rev(x, d) + @test @jit(srtperm_rev(x_ra, d)) == srtperm_rev(x, d) + @test @jit(srt_by(x_ra, d)) == srt_by(x, d) + @test @jit(srtperm_by(x_ra, d)) == srtperm_by(x, d) + @test @jit(srt_lt(x_ra, d)) == srt_lt(x, d) + @test @jit(srtperm_lt(x_ra, d)) == srtperm_lt(x, d) + end +end + +@testset "partialsort & partialsortperm" begin + x = randn(10) + x_ra = Reactant.to_rarray(x) + + @test @jit(partialsort(x_ra, 1:5)) == partialsort(x, 1:5) + @test @jit(partialsortperm(x_ra, 1:5)) == partialsortperm(x, 1:5) + @test @jit(partialsort(x_ra, 4)) == partialsort(x, 4) + @test @jit(partialsortperm(x_ra, 4)) == partialsortperm(x, 4) + + psrt_rev(x, k) = partialsort(x, k; rev=true) + psrtperm_rev(x, k) = partialsortperm(x, k; rev=true) + psrt_by(x, k) = partialsort(x, k; by=abs2) + psrtperm_by(x, k) = partialsortperm(x, k; by=abs2) + psrt_lt(x, k) = partialsort(x, k; lt=(a, b) -> a > b) + psrtperm_lt(x, k) = partialsortperm(x, k; lt=(a, b) -> a > b) + + @test @jit(psrt_rev(x_ra, 1:5)) == psrt_rev(x, 1:5) + @test @jit(psrtperm_rev(x_ra, 1:5)) == psrtperm_rev(x, 1:5) + @test @jit(psrt_by(x_ra, 1:5)) == psrt_by(x, 1:5) + @test @jit(psrtperm_by(x_ra, 1:5)) == psrtperm_by(x, 1:5) + @test @jit(psrt_lt(x_ra, 1:5)) == psrt_lt(x, 1:5) + @test @jit(psrtperm_lt(x_ra, 1:5)) == psrtperm_lt(x, 1:5) + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 1:5) + partialsort!(x, 1:5) + @test Array(x_ra)[1:5] == x[1:5] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + @jit partialsort!(x_ra, 3) + partialsort!(x, 3) + @test @allowscalar(x_ra[3]) == x[3] + + x = randn(10) + x_ra = Reactant.to_rarray(x) + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 1:5) + partialsortperm!(ix, x, 1:5) + @test Array(ix_ra)[1:5] == ix[1:5] + + ix = similar(x, Int) + ix_ra = Reactant.to_rarray(ix) + @jit partialsortperm!(ix_ra, x_ra, 3) + partialsortperm!(ix, x, 3) + @test @allowscalar(ix_ra[3]) == ix[3] +end + +@testset "argmin / argmax" begin + x = rand(2, 3) + x_ra = Reactant.to_rarray(x) + + linargmin(x) = LinearIndices(x)[argmin(x)] + linargmax(x) = LinearIndices(x)[argmax(x)] + + @test linargmin(x) == @jit(argmin(x_ra)) + @test linargmax(x) == @jit(argmax(x_ra)) + + x = rand(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + linargmin(x, d) = LinearIndices(x)[argmin(x; dims=d)] + linargmax(x, d) = LinearIndices(x)[argmax(x; dims=d)] + argmindims(x, d) = argmin(x; dims=d) + argmaxdims(x, d) = argmax(x; dims=d) + + @test linargmin(x, 1) == @jit(argmindims(x_ra, 1)) + @test linargmax(x, 1) == @jit(argmaxdims(x_ra, 1)) + @test linargmin(x, 2) == @jit(argmindims(x_ra, 2)) + @test linargmax(x, 2) == @jit(argmaxdims(x_ra, 2)) + @test linargmin(x, 3) == @jit(argmindims(x_ra, 3)) + @test linargmax(x, 3) == @jit(argmaxdims(x_ra, 3)) + + x = randn(2, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test argmin(abs2, x) == @jit(argmin(abs2, x_ra)) + @test argmax(abs2, x) == @jit(argmax(abs2, x_ra)) +end + +@testset "findmin / findmax" begin + xvec = randn(10) + xvec_ra = Reactant.to_rarray(xvec) + + x = randn(2, 3) + x_ra = Reactant.to_rarray(x) + + function fwithlinindices(g, f, x; kwargs...) + values, indices = g(f, x; kwargs...) + return values, LinearIndices(x)[indices] + end + + @test fwithlinindices(findmin, identity, x) == @jit(findmin(x_ra)) + @test fwithlinindices(findmax, identity, x) == @jit(findmax(x_ra)) + @test fwithlinindices(findmin, identity, xvec) == @jit(findmin(xvec_ra)) + @test fwithlinindices(findmax, identity, xvec) == @jit(findmax(xvec_ra)) + + fmindims(x, d) = findmin(x; dims=d) + fmindims(f, x, d) = findmin(f, x; dims=d) + fmaxdims(x, d) = findmax(x; dims=d) + fmaxdims(f, x, d) = findmax(f, x; dims=d) + + @test fwithlinindices(findmin, identity, x; dims=1) == @jit(fmindims(x_ra, 1)) + @test fwithlinindices(findmax, identity, x; dims=1) == @jit(fmaxdims(x_ra, 1)) + @test fwithlinindices(findmin, identity, x; dims=2) == @jit(fmindims(x_ra, 2)) + @test fwithlinindices(findmax, identity, x; dims=2) == @jit(fmaxdims(x_ra, 2)) + @test fwithlinindices(findmin, abs2, x; dims=1) == @jit(fmindims(abs2, x_ra, 1)) + @test fwithlinindices(findmax, abs2, x; dims=1) == @jit(fmaxdims(abs2, x_ra, 1)) + @test fwithlinindices(findmin, abs2, x; dims=2) == @jit(fmindims(abs2, x_ra, 2)) + @test fwithlinindices(findmax, abs2, x; dims=2) == @jit(fmaxdims(abs2, x_ra, 2)) +end + +@testset "findfirst / findlast" begin + x = rand(Bool, 3, 4) + x_ra = Reactant.to_rarray(x) + + ffirstlinindices(x) = LinearIndices(x)[findfirst(x)] + ffirstlinindices(f, x) = LinearIndices(x)[findfirst(f, x)] + flastlinindices(x) = LinearIndices(x)[findlast(x)] + flastlinindices(f, x) = LinearIndices(x)[findlast(f, x)] + + @test ffirstlinindices(x) == @jit(findfirst(x_ra)) + @test flastlinindices(x) == @jit(findlast(x_ra)) + + x = rand(1:256, 3, 4) + x_ra = Reactant.to_rarray(x) + + @test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra)) + @test flastlinindices(iseven, x) == @jit(findlast(iseven, x_ra)) +end