Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3130,6 +3130,7 @@ end
[size(input, i) for i in (length(batch_shape) + 1):ndims(input)]...,
) for input in inputs
]
argprefix = gensym("batcharg")
mlir_fn_res = Reactant.TracedUtils.make_mlir_fn(
f,
(sample_inputs...,),
Expand All @@ -3138,11 +3139,35 @@ end
false;
args_in_result=:none,
do_transpose=false,
argprefix,
)

func = mlir_fn_res.f
@assert MLIR.IR.nregions(func) == 1

if mlir_fn_res.fnwrapped
# In the long-term we should be able to do per-argument batching.
# Rn we simply broadcast_in_dim the arguments to the correct shape.
final_inputs = TracedRArray[]
seenargs = Reactant.OrderedIdDict()
Reactant.make_tracer(
seenargs, f, (argprefix, 1), Reactant.TracedSetPath; toscalar=false
)
for (k, v) in seenargs
v isa Reactant.TracedType || continue
bcasted_arg = broadcast_in_dim(
v,
collect(Int64, (length(batch_shape) + 1):(ndims(v) + length(batch_shape))),
vcat(batch_shape, collect(Int64, size(v)));
location,
)
push!(final_inputs, bcasted_arg)
end
append!(final_inputs, inputs)
else
final_inputs = inputs
end

output_types = MLIR.IR.Type[]
for result in mlir_fn_res.linear_results
push!(
Expand All @@ -3154,7 +3179,7 @@ end
)
end

return batch(inputs, output_types, batch_shape; fn=func, location)
return batch(final_inputs, output_types, batch_shape; fn=func, location)
end

@noinline function batch(
Expand Down
15 changes: 15 additions & 0 deletions test/batching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,18 @@ end

run_auto_batching_tests(naive_batched_matmul, x, y)
end

function batch_with_closure(x, y)
_fn(x) = x .+ y
return mapslices(_fn, x; dims=2)
end

@testset "Batching with closure" begin
x = Reactant.TestUtils.construct_test_array(Float32, 3, 256, 8)
y = Reactant.TestUtils.construct_test_array(Float32, 256)

x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(batch_with_closure(x_ra, y_ra)) ≈ batch_with_closure(x, y)
end
Loading