diff --git a/src/Ops.jl b/src/Ops.jl index 15a15ee393..22bf67679b 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3130,6 +3130,7 @@ end [size(input, i) for i in (length(batch_shape) + 1):ndims(input)]..., ) for input in inputs ] + argprefix = gensym("batcharg") mlir_fn_res = Reactant.TracedUtils.make_mlir_fn( f, (sample_inputs...,), @@ -3138,11 +3139,35 @@ end false; args_in_result=:none, do_transpose=false, + argprefix, ) func = mlir_fn_res.f @assert MLIR.IR.nregions(func) == 1 + if mlir_fn_res.fnwrapped + # In the long-term we should be able to do per-argument batching. + # Rn we simply broadcast_in_dim the arguments to the correct shape. + final_inputs = TracedRArray[] + seenargs = Reactant.OrderedIdDict() + Reactant.make_tracer( + seenargs, f, (argprefix, 1), Reactant.TracedSetPath; toscalar=false + ) + for (k, v) in seenargs + v isa Reactant.TracedType || continue + bcasted_arg = broadcast_in_dim( + v, + collect(Int64, (length(batch_shape) + 1):(ndims(v) + length(batch_shape))), + vcat(batch_shape, collect(Int64, size(v))); + location, + ) + push!(final_inputs, bcasted_arg) + end + append!(final_inputs, inputs) + else + final_inputs = inputs + end + output_types = MLIR.IR.Type[] for result in mlir_fn_res.linear_results push!( @@ -3154,7 +3179,7 @@ end ) end - return batch(inputs, output_types, batch_shape; fn=func, location) + return batch(final_inputs, output_types, batch_shape; fn=func, location) end @noinline function batch( diff --git a/test/batching.jl b/test/batching.jl index 13298a5394..b799373471 100644 --- a/test/batching.jl +++ b/test/batching.jl @@ -85,3 +85,18 @@ end run_auto_batching_tests(naive_batched_matmul, x, y) end + +function batch_with_closure(x, y) + _fn(x) = x .+ y + return mapslices(_fn, x; dims=2) +end + +@testset "Batching with closure" begin + x = Reactant.TestUtils.construct_test_array(Float32, 3, 256, 8) + y = Reactant.TestUtils.construct_test_array(Float32, 256) + + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(batch_with_closure(x_ra, y_ra)) ≈ batch_with_closure(x, y) +end