diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 834a12c77b..287be37d42 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -394,32 +394,20 @@ function NNlib.batched_mul!( ) end + x = @opcall convert(TracedRArray{T2,3}, materialize_traced_array(x)) + y = @opcall convert(TracedRArray{T3,3}, materialize_traced_array(y)) + if size(x, 3) != size(y, 3) B = max(size(x, 3), size(y, 3)) if size(x, 3) == 1 - x = TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + x = @opcall broadcast_in_dim(x, [1, 2, 3], [size(x, 1), size(x, 2), B]) elseif size(y, 3) == 1 - y = TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) - end - end - - x = permutedims(x, (3, 1, 2)) - y = permutedims(y, (3, 1, 2)) - - if size(x, 1) != size(y, 1) - B = max(size(x, 1), size(y, 1)) - if size(x, 1) == 1 - x = TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) - elseif size(y, 1) == 1 - y = TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) + y = @opcall broadcast_in_dim(y, [1, 2, 3], [size(y, 1), size(y, 2), B]) end end tmp = @opcall dot_general( - T1.(materialize_traced_array(x)), - T1.(materialize_traced_array(y)); - contracting_dimensions=([3], [2]), - batching_dimensions=([1], [1]), + x, y; contracting_dimensions=([2], [1]), batching_dimensions=([3], [3]) ) set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1))))