From 304a9f7cf156b3daaccf5de34d0696c5e388533e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 07:31:49 -0500 Subject: [PATCH 1/7] fix: try always downloading libtpu on ci --- .github/workflows/CommonCI.yml | 1 + src/accelerators/TPU.jl | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/CommonCI.yml b/.github/workflows/CommonCI.yml index 6a888f9e4b..8bdfc25228 100644 --- a/.github/workflows/CommonCI.yml +++ b/.github/workflows/CommonCI.yml @@ -212,3 +212,4 @@ env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager XLA_FLAGS: "--xla_force_host_platform_device_count=12" JULIA_DEBUG: "Reactant,Reactant_jll" + REACTANT_FORCE_DOWNLOAD_LIBTPU: "true" diff --git a/src/accelerators/TPU.jl b/src/accelerators/TPU.jl index a0b5564914..8b696ca53c 100644 --- a/src/accelerators/TPU.jl +++ b/src/accelerators/TPU.jl @@ -9,6 +9,7 @@ using unzip_jll: unzip const libtpu_dir = Ref{Union{Nothing,String}}(nothing) const RUNNING_IN_CLOUD_TPU_VM = Ref(false) +const FORCE_DOWNLOAD_LIBTPU = Ref(false) const LIBTPU_VERSION = "0.0.28.dev20251027" const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" @@ -19,6 +20,12 @@ function __init__() setup_libtpu!() cloud_tpu_init!() end + + # TODO: we should have a way to checking that the downloaded libtpu doesn't match + # the expected version. + FORCE_DOWNLOAD_LIBTPU[] = parse( + Bool, get(ENV, "REACTANT_FORCE_DOWNLOAD_LIBTPU", "false") + ) end end From 5b241ad7a96f96e8cb14c084885cfd1e304625b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 07:47:37 -0500 Subject: [PATCH 2/7] feat: version check for libtpu --- .github/workflows/CommonCI.yml | 1 - src/accelerators/TPU.jl | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/CommonCI.yml b/.github/workflows/CommonCI.yml index 8bdfc25228..6a888f9e4b 100644 --- a/.github/workflows/CommonCI.yml +++ b/.github/workflows/CommonCI.yml @@ -212,4 +212,3 @@ env: JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager XLA_FLAGS: "--xla_force_host_platform_device_count=12" JULIA_DEBUG: "Reactant,Reactant_jll" - REACTANT_FORCE_DOWNLOAD_LIBTPU: "true" diff --git a/src/accelerators/TPU.jl b/src/accelerators/TPU.jl index 8b696ca53c..0b08de4488 100644 --- a/src/accelerators/TPU.jl +++ b/src/accelerators/TPU.jl @@ -9,7 +9,9 @@ using unzip_jll: unzip const libtpu_dir = Ref{Union{Nothing,String}}(nothing) const RUNNING_IN_CLOUD_TPU_VM = Ref(false) -const FORCE_DOWNLOAD_LIBTPU = Ref(false) + +const LIBTPU_VERSION = "0.0.28.dev20251027" +const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" const LIBTPU_VERSION = "0.0.28.dev20251027" const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" @@ -20,12 +22,6 @@ function __init__() setup_libtpu!() cloud_tpu_init!() end - - # TODO: we should have a way to checking that the downloaded libtpu doesn't match - # the expected version. - FORCE_DOWNLOAD_LIBTPU[] = parse( - Bool, get(ENV, "REACTANT_FORCE_DOWNLOAD_LIBTPU", "false") - ) end end From 7ba923a57b9e00014c7eed6de8b80770ca9f399a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 27 Oct 2025 19:23:49 -0500 Subject: [PATCH 3/7] feat: support onehotbatch/onecold --- ext/ReactantOneHotArraysExt.jl | 69 ++++++++++++++++++++++++++++---- test/integration/onehotarrays.jl | 19 +++++++++ 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/ext/ReactantOneHotArraysExt.jl b/ext/ReactantOneHotArraysExt.jl index 18d31dd8d1..1f2d08258a 100644 --- a/ext/ReactantOneHotArraysExt.jl +++ b/ext/ReactantOneHotArraysExt.jl @@ -1,10 +1,18 @@ module ReactantOneHotArraysExt -using OneHotArrays: OneHotArray -using Reactant: Reactant, TracedRArray, TracedRNumber, Ops +using GPUArraysCore: @allowscalar +using OneHotArrays: OneHotArrays, OneHotArray +using Reactant: Reactant, AnyTracedRArray, TracedRArray, TracedRNumber using ReactantCore: ReactantCore using Reactant.Ops: @opcall +__compatible_eltype(::Type{T}, ::Type{U}) where {T,U} = T +function __compatible_eltype(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{U}}) where {T,U} + return TracedRNumber{T} +end +__compatible_eltype(::Type{TracedRNumber{T}}, ::Type{U}) where {T,U} = T +__compatible_eltype(::Type{T}, ::Type{TracedRNumber{U}}) where {T,U} = TracedRNumber{T} + function Reactant.traced_type_inner( @nospecialize(_::Type{OneHotArray{T,N,Np1,I}}), seen, @@ -14,12 +22,7 @@ function Reactant.traced_type_inner( @nospecialize(runtime) ) where {T,N,Np1,I} I2 = Reactant.traced_type_inner(I, seen, mode, track_numbers, sharding, runtime) - T2 = if eltype(I2) <: Reactant.TracedRNumber && !(T <: Reactant.TracedRNumber) - Reactant.TracedRNumber{T} - else - T - end - return OneHotArray{T2,N,Np1,I2} + return OneHotArray{__compatible_eltype(T, eltype(I2)),N,Np1,I2} end function ReactantCore.materialize_traced_array(r::OneHotArray) @@ -45,4 +48,54 @@ function Base.Array( return Array(reshape(Array(r.indices), 1, size(r.indices)...) .== 1:(r.nlabels)) end +function OneHotArrays.onehotbatch( + data::AnyTracedRArray{<:Integer,N}, labels::AbstractVector{<:Integer} +) where {N} + # TODO: add checkbounds once we support that with TracedRNumber + indices = + UInt32.( + map( + Base.Fix2(+, 1 - first(labels)), ReactantCore.materialize_traced_array(data) + ) + ) + return indices + return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}( + indices, length(labels) + ) +end + +function OneHotArrays.onecold(y::AnyTracedRArray{T,1}, labels=1:length(y)) where {T} + nl = length(labels) + ny = length(y) + nl == ny || throw( + DimensionMismatch( + "onecold got $nl labels for a vector of length $ny, these must agree" + ), + ) + imax = argmax(y) + # TODO: error if ymax is nan + labels_arr = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels + ) + return @allowscalar labels_arr[imax] +end +function OneHotArrays.onecold(y::AnyTracedRArray{T}, labels=1:size(y, 1)) where {T} + nl = length(labels) + ny = size(y, 1) + nl == ny || throw( + DimensionMismatch( + "onecold got $nl labels for an array with first dimension of size $ny, these must agree", + ), + ) + labels_arr = Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, labels + ) + labels_expanded = @opcall broadcast_in_dim( + labels_arr, Int64[1], Int64[nl, size(y)[2:end]...] + ) + return ReactantCore.materialize_traced_array( + vec(getindex(labels_expanded, argmax(y; dims=1))) + ) +end + end diff --git a/test/integration/onehotarrays.jl b/test/integration/onehotarrays.jl index f860c424bb..0f1380206f 100644 --- a/test/integration/onehotarrays.jl +++ b/test/integration/onehotarrays.jl @@ -31,3 +31,22 @@ end @test res_ra ≈ res end end + +using Reactant, Test, OneHotArrays, Random + +@testset "onehotbatch/onecold" begin + x = Int32[10, 20, 30, 10, 10] + x_ra = Reactant.to_rarray(x) + labels = Int32.(10:10:40) + + res_ra = @jit onehotbatch(x_ra, labels) # XXX: broken?? + res = onehotbatch(x, labels) + @test Array(res_ra) ≈ res + + vec_ra = Reactant.to_rarray(Float32[0.3, 0.2, 0.5]) + @test @jit(onecold(vec_ra)) == 3 + + dense_ra = Reactant.to_rarray(Array(res)) + oc_res = onecold(res) + @test @jit(onecold(dense_ra)) == oc_res +end From 96a462b486b7c43daf2e2ef1fa4cb623fe544d01 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 09:16:56 -0500 Subject: [PATCH 4/7] fix: onehotbatch --- ext/ReactantOneHotArraysExt.jl | 28 ++++++++++++++++++++-------- test/integration/onehotarrays.jl | 19 +++++++++++++++---- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/ext/ReactantOneHotArraysExt.jl b/ext/ReactantOneHotArraysExt.jl index 1f2d08258a..10f028c66d 100644 --- a/ext/ReactantOneHotArraysExt.jl +++ b/ext/ReactantOneHotArraysExt.jl @@ -48,17 +48,28 @@ function Base.Array( return Array(reshape(Array(r.indices), 1, size(r.indices)...) .== 1:(r.nlabels)) end +function OneHotArrays.onehotbatch(data::AnyTracedRArray{<:Any,N}, labels) where {N} + # TODO: add checkbounds once we support that with TracedRNumber + labels_expanded = @opcall broadcast_in_dim( + Reactant.promote_to( + TracedRArray{Reactant.unwrapped_eltype(labels),1}, + ReactantCore.materialize_traced_array(vec(labels)) + ), + Int64[1], + [length(labels), size(data)...], + ) + data = ReactantCore.materialize_traced_array(reshape(data, 1, size(data)...)) + return mapslices(findfirst, data .== labels_expanded; dims=Tuple(2:(N + 1))) +end + function OneHotArrays.onehotbatch( - data::AnyTracedRArray{<:Integer,N}, labels::AbstractVector{<:Integer} + data::AnyTracedRArray{<:Integer,N}, labels::AbstractUnitRange{<:Integer} ) where {N} # TODO: add checkbounds once we support that with TracedRNumber - indices = - UInt32.( - map( - Base.Fix2(+, 1 - first(labels)), ReactantCore.materialize_traced_array(data) - ) - ) - return indices + indices = map( + TracedRNumber{UInt32} ∘ Base.Fix2(+, 1 - first(labels)), + ReactantCore.materialize_traced_array(data), + ) return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}( indices, length(labels) ) @@ -79,6 +90,7 @@ function OneHotArrays.onecold(y::AnyTracedRArray{T,1}, labels=1:length(y)) where ) return @allowscalar labels_arr[imax] end + function OneHotArrays.onecold(y::AnyTracedRArray{T}, labels=1:size(y, 1)) where {T} nl = length(labels) ny = size(y, 1) diff --git a/test/integration/onehotarrays.jl b/test/integration/onehotarrays.jl index 0f1380206f..98e27c28c7 100644 --- a/test/integration/onehotarrays.jl +++ b/test/integration/onehotarrays.jl @@ -32,14 +32,25 @@ end end end -using Reactant, Test, OneHotArrays, Random - @testset "onehotbatch/onecold" begin x = Int32[10, 20, 30, 10, 10] x_ra = Reactant.to_rarray(x) - labels = Int32.(10:10:40) + labels = Int32(10):Int32(10):Int32(40) + res_ra = @jit onehotbatch(x_ra, labels) + res = onehotbatch(x, labels) + @test Array(res_ra) ≈ res - res_ra = @jit onehotbatch(x_ra, labels) # XXX: broken?? + x = rand(10:10:40, 2, 3, 5) + x_ra = Reactant.to_rarray(x) + labels = reshape([10, 20, 30, 40], 2, 2) + res = onehotbatch(x, labels) + res_ra = @jit onehotbatch(x_ra, labels) + @test Array(res_ra) ≈ res + + x = Int32[1, 2, 3, 1, 1] + x_ra = Reactant.to_rarray(x) + labels = Int32(1):Int32(4) + res_ra = @jit onehotbatch(x_ra, labels) res = onehotbatch(x, labels) @test Array(res_ra) ≈ res From 378cf6376dc2b5847633f3ac0ff83be75793d5bd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 12:44:26 -0400 Subject: [PATCH 5/7] Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantOneHotArraysExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantOneHotArraysExt.jl b/ext/ReactantOneHotArraysExt.jl index 10f028c66d..e3781cb584 100644 --- a/ext/ReactantOneHotArraysExt.jl +++ b/ext/ReactantOneHotArraysExt.jl @@ -53,7 +53,7 @@ function OneHotArrays.onehotbatch(data::AnyTracedRArray{<:Any,N}, labels) where labels_expanded = @opcall broadcast_in_dim( Reactant.promote_to( TracedRArray{Reactant.unwrapped_eltype(labels),1}, - ReactantCore.materialize_traced_array(vec(labels)) + ReactantCore.materialize_traced_array(vec(labels)), ), Int64[1], [length(labels), size(data)...], From 0280ef4790d5debee75b06f2ece61932e3940dc8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 28 Oct 2025 19:33:17 -0500 Subject: [PATCH 6/7] fix: use batched findfirst --- ext/ReactantOneHotArraysExt.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ext/ReactantOneHotArraysExt.jl b/ext/ReactantOneHotArraysExt.jl index e3781cb584..86d7099db4 100644 --- a/ext/ReactantOneHotArraysExt.jl +++ b/ext/ReactantOneHotArraysExt.jl @@ -59,7 +59,10 @@ function OneHotArrays.onehotbatch(data::AnyTracedRArray{<:Any,N}, labels) where [length(labels), size(data)...], ) data = ReactantCore.materialize_traced_array(reshape(data, 1, size(data)...)) - return mapslices(findfirst, data .== labels_expanded; dims=Tuple(2:(N + 1))) + indices = UInt32.(@opcall(findfirst(data .== labels_expanded; dimension=1))) + return OneHotArray{TracedRNumber{UInt32},N,N + 1,typeof(indices)}( + indices, length(labels) + ) end function OneHotArrays.onehotbatch( From f1069721b62d1005d98f5a432a5a1b7407a59f8d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Oct 2025 08:32:02 -0400 Subject: [PATCH 7/7] Apply suggestion from @avik-pal --- src/accelerators/TPU.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/accelerators/TPU.jl b/src/accelerators/TPU.jl index 0b08de4488..a0b5564914 100644 --- a/src/accelerators/TPU.jl +++ b/src/accelerators/TPU.jl @@ -13,9 +13,6 @@ const RUNNING_IN_CLOUD_TPU_VM = Ref(false) const LIBTPU_VERSION = "0.0.28.dev20251027" const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" -const LIBTPU_VERSION = "0.0.28.dev20251027" -const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" - function __init__() @static if !Sys.isapple() if !Reactant.precompiling() && has_tpu()