diff --git a/ext/ReactantOneHotArraysExt.jl b/ext/ReactantOneHotArraysExt.jl index 18d31dd8d1..86d7099db4 100644 --- a/ext/ReactantOneHotArraysExt.jl +++ b/ext/ReactantOneHotArraysExt.jl @@ -1,10 +1,18 @@ module ReactantOneHotArraysExt -using OneHotArrays: OneHotArray -using Reactant: Reactant, TracedRArray, TracedRNumber, Ops +using GPUArraysCore: @allowscalar +using OneHotArrays: OneHotArrays, OneHotArray +using Reactant: Reactant, AnyTracedRArray, TracedRArray, TracedRNumber using ReactantCore: ReactantCore using Reactant.Ops: @opcall +__compatible_eltype(::Type{T}, ::Type{U}) where {T,U} = T +function __compatible_eltype(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{U}}) where {T,U} + return TracedRNumber{T} +end +__compatible_eltype(::Type{TracedRNumber{T}}, ::Type{U}) where {T,U} = T +__compatible_eltype(::Type{T}, ::Type{TracedRNumber{U}}) where {T,U} = TracedRNumber{T} + function Reactant.traced_type_inner( @nospecialize(_::Type{OneHotArray{T,N,Np1,I}}), seen, @@ -14,12 +22,7 @@ function Reactant.traced_type_inner( @nospecialize(runtime) ) where {T,N,Np1,I} I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime) - T2 = if eltype(I2) <: Reactant.TracedRNumber && !(T <: Reactant.TracedRNumber) - Reactant.TracedRNumber{T} - else - T - end - return OneHotArray{T2,N,Np1,I2} + return OneHotArray{__compatible_eltype(T, eltype(I2)),N,Np1,I2} end function ReactantCore.materialize_traced_array(r::OneHotArray) @@ -45,4 +48,69 @@ function Base.Array( return Array(reshape(Array(r.indices), 1, size(r.indices)...) .== 1:(r.nlabels)) end +function OneHotArrays.onehotbatch(data::AnyTracedRArray{<:Any,N}, labels) where {N} + # TODO: add checkbounds once we support that with TracedRNumber + labels_expanded = @opcall broadcast_in_dim( + Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, + ReactantCore.materialize_traced_array(vec(labels)), + ), + Int64[1], + [length(labels), size(data)...], + ) + data = ReactantCore.materialize_traced_array(reshape(data, 1, size(data)...)) + indices = UInt32.(@opcall(findfirst(data .== labels_expanded; dimension=1))) + return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}( + indices, length(labels) + ) +end + +function OneHotArrays.onehotbatch( + data::AnyTracedRArray{<:Integer,N}, labels::AbstractUnitRange{<:Integer} +) where {N} + # TODO: add checkbounds once we support that with TracedRNumber + indices = map( + TracedRNumber{UInt32} ∘ Base.Fix2(+, 1 - first(labels)), + ReactantCore.materialize_traced_array(data), + ) + return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}( + indices, length(labels) + ) +end + +function OneHotArrays.onecold(y::AnyTracedRArray{T,1}, labels=1:length(y)) where {T} + nl = length(labels) + ny = length(y) + nl == ny || throw( + DimensionMismatch( + "onecold got $nl labels for a vector of length $ny, these must agree" + ), + ) + imax = argmax(y) + # TODO: error if ymax is nan + labels_arr = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels + ) + return @allowscalar labels_arr[imax] +end + +function OneHotArrays.onecold(y::AnyTracedRArray{T}, labels=1:size(y, 1)) where {T} + nl = length(labels) + ny = size(y, 1) + nl == ny || throw( + DimensionMismatch( + "onecold got $nl labels for an array with first dimension of size $ny, these must agree", + ), + ) + labels_arr = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels + ) + labels_expanded = @opcall broadcast_in_dim( + labels_arr, Int64[1], Int64[nl, size(y)[2:end]...] + ) + return ReactantCore.materialize_traced_array( + vec(getindex(labels_expanded, argmax(y; dims=1))) + ) +end + end diff --git a/test/integration/onehotarrays.jl b/test/integration/onehotarrays.jl index f860c424bb..98e27c28c7 100644 --- a/test/integration/onehotarrays.jl +++ b/test/integration/onehotarrays.jl @@ -31,3 +31,33 @@ end @test res_ra ≈ res end end + +@testset "onehotbatch/onecold" begin + x = Int32[10, 20, 30, 10, 10] + x_ra = Reactant.to_rarray(x) + labels = Int32(10):Int32(10):Int32(40) + res_ra = @jit onehotbatch(x_ra, labels) + res = onehotbatch(x, labels) + @test Array(res_ra) ≈ res + + x = rand(10:10:40, 2, 3, 5) + x_ra = Reactant.to_rarray(x) + labels = reshape([10, 20, 30, 40], 2, 2) + res = onehotbatch(x, labels) + res_ra = @jit onehotbatch(x_ra, labels) + @test Array(res_ra) ≈ res + + x = Int32[1, 2, 3, 1, 1] + x_ra = Reactant.to_rarray(x) + labels = Int32(1):Int32(4) + res_ra = @jit onehotbatch(x_ra, labels) + res = onehotbatch(x, labels) + @test Array(res_ra) ≈ res + + vec_ra = Reactant.to_rarray(Float32[0.3, 0.2, 0.5]) + @test @jit(onecold(vec_ra)) == 3 + + dense_ra = Reactant.to_rarray(Array(res)) + oc_res = onecold(res) + @test @jit(onecold(dense_ra)) == oc_res +end