From a279324de2bb569ccac041205dcc3312d95f6fad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Oct 2025 09:23:25 -0500 Subject: [PATCH 1/2] feat: enable auto-batching passes --- src/Compiler.jl | 1 + test/batching.jl | 56 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 4ce10365a4..f608a798e4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -937,6 +937,7 @@ function optimization_passes( "broadcastindim_slice_to_batch", "reducewindow_slice_to_batch", "elementwise_slice_to_batch", + "greedy_while_loop_batch_fission", ], ) end diff --git a/test/batching.jl b/test/batching.jl index 231a8aa8e0..e967cab168 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -29,3 +29,59 @@ end @test @jit(f3(A_ra, 1)) ≈ A .+ 1 end + +# Auto-Batching +function run_auto_batching_tests(f::F, args...) where {F} + @testset "$(nameof(F))" begin + @testset "Correctness" begin + res1 = @jit f(args...) + res2 = @jit compile_options = CompileOptions(; + disable_auto_batching_passes=true + ) f(args...) + @test res1 ≈ res2 + end + + @testset "No while loops" begin + hlo = repr( + @code_hlo compile_options = CompileOptions(; + disable_auto_batching_passes=true + ) f(args...) + ) + @test occursin("stablehlo.while", hlo) + + hlo = repr(@code_hlo f(args...)) + @test !occursin("stablehlo.while", hlo) + end + end +end + +function looped_reduction(y, x) + z = copy(y) + @trace for i in 1:size(x, 2) + z[:, i, :] = dropdims(sum(abs2, x[:, i, :, :]; dims=3); dims=3) + end + return z +end + +@testset "Loop of Reduces => Single Reduction" begin + x = Reactant.to_rarray(rand(Float32, 3, 256, 5, 7)) + y = Reactant.to_rarray(rand(Float32, 3, 260, 5)) + + run_auto_batching_tests(looped_reduction, y, x) +end + +function naive_batched_matmul(x, y) + @assert size(x, 3) == size(y, 3) + z = similar(x, size(x, 1), size(y, 2), size(x, 3)) + @trace for i in 1:size(x, 3) + z[:, :, i] = x[:, :, i] * y[:, :, i] + end + return z +end + +@testset "Naive Batched Matmul => Single Dot General" begin + x = Reactant.to_rarray(rand(Float32, 3, 256, 5)) + y = Reactant.to_rarray(rand(Float32, 256, 7, 5)) + + run_auto_batching_tests(naive_batched_matmul, x, y) +end From 2ace961b723f8df50cda1b78f8d05fd920a37944 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Oct 2025 11:40:41 -0400 Subject: [PATCH 2/2] chore: bump reactant version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6010bbbed5..733e1c2797 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.172" +version = "0.2.173" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"