From 8f33b7e1db9bd28dafdaad52dde42947ae9d340e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 19 Nov 2025 08:29:33 +0000 Subject: [PATCH 1/4] Initial plan From 1fb01affc4a14fb64308d9d29bcb86eb0f801df7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 19 Nov 2025 08:45:37 +0000 Subject: [PATCH 2/4] Implement generalized batched_vec for N-D batches Co-authored-by: CarloLucibello <7014210+CarloLucibello@users.noreply.github.com> --- src/batched/batchedmul.jl | 26 +++++++++++++++++++++++++- test/batchedmul.jl | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index ccd9b0e8..37fd9c49 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -164,11 +164,16 @@ _semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:An """ batched_vec(A::Array{T,3}, B::Matrix) batched_vec(A::Array{T,3}, b::Vector) + batched_vec(A::AbstractArray, B::AbstractArray) -Batched matrix-vector multiplication: +Batched matrix-vector multiplication. For the 3D case: the result has `C[:,:,k] == A[:,:,k] * B[:,k]` for all `k`, or else `C[:,:,k] == A[:,:,k] * b` for `b::Vector`. +For the general N-D case where `ndims(A) == ndims(B) + 1`: +the result has `C[:,k...] == A[:,:,k...] * B[:,k...]` for all batch indices `k...`. +The batch dimensions must match: `size(A)[3:end] == size(B)[2:end]`. + With the same argument types, `batched_mul(A, B)` would regard `B` as a fixed matrix, not a batch of vectors. Both reshape and then call `batched_mul(::Array{T,3}, ::Array{T,3})`. @@ -181,8 +186,27 @@ julia> batched_vec(A,B) |> size julia> batched_vec(A,b) |> size (16, 32) + +julia> A4d, B3d = randn(16,8,10,32), randn(8,10,32); # 4D and 3D arrays + +julia> batched_vec(A4d, B3d) |> size +(16, 10, 32) ``` """ +function batched_vec(A::AbstractArray, B::AbstractArray) + ndims(A) == ndims(B) + 1 || throw(DimensionMismatch( + "batched_vec requires ndims(A) == ndims(B) + 1, got ndims(A)=$(ndims(A)) and ndims(B)=$(ndims(B))")) + size(A)[3:end] == size(B)[2:end] || throw(DimensionMismatch( + "batch dimensions must match: size(A)[3:end]=$(size(A)[3:end]) != size(B)[2:end]=$(size(B)[2:end])")) + + # Reshape B to add a singleton dimension for matrix multiplication + B_reshaped = reshape(B, size(B, 1), 1, size(B)[2:end]...) + # Perform batched multiplication + C = batched_mul(A, B_reshaped) + # Remove the singleton dimension + return dropdims(C, dims=2) +end + batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) = reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3)) diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 1b8b08e1..042f4c1e 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -303,3 +303,41 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P)) end + +@testset "batched_vec: N-D batches" begin + # Test 4D case: A is 4D, B is 3D + A4d = randn(4, 5, 3, 2) # (matrix_rows, matrix_cols, batch_dim1, batch_dim2) + B3d = randn(5, 3, 2) # (vector_length, batch_dim1, batch_dim2) + + C = batched_vec(A4d, B3d) + @test size(C) == (4, 3, 2) + + # Manual verification + for i in 1:3, j in 1:2 + @test C[:, i, j] ≈ A4d[:, :, i, j] * B3d[:, i, j] + end + + # Test 5D case: A is 5D, B is 4D + A5d = randn(3, 4, 2, 3, 2) # (matrix_rows, matrix_cols, batch1, batch2, batch3) + B4d = randn(4, 2, 3, 2) # (vector_length, batch1, batch2, batch3) + + C5 = batched_vec(A5d, B4d) + @test size(C5) == (3, 2, 3, 2) + + # Manual verification for a few cases + @test C5[:, 1, 1, 1] ≈ A5d[:, :, 1, 1, 1] * B4d[:, 1, 1, 1] + @test C5[:, 2, 3, 2] ≈ A5d[:, :, 2, 3, 2] * B4d[:, 2, 3, 2] + + # Test dimension mismatch errors + @test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3)) # ndims mismatch + @test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2)) # batch size mismatch + + # Test that 3D case still works (backward compatibility) + A3d_test = randn(4, 5, 3) + B2d_test = randn(5, 3) + C3_test = batched_vec(A3d_test, B2d_test) + @test size(C3_test) == (4, 3) + for k in 1:3 + @test C3_test[:, k] ≈ A3d_test[:, :, k] * B2d_test[:, k] + end +end From 4cc6de3beea10c70b3ce7110d05dc6728c782918 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 19 Nov 2025 12:07:45 +0100 Subject: [PATCH 3/4] Update src/batched/batchedmul.jl --- src/batched/batchedmul.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index 37fd9c49..b5a451f9 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -162,8 +162,8 @@ _semi_batched_mul(A::Transpose{<:Number,<:AbstractMatrix}, B::AbstractArray{<:An batched_mul(batched_transpose(reshape(parent(A), size(parent(A))..., 1)), B) """ - batched_vec(A::Array{T,3}, B::Matrix) - batched_vec(A::Array{T,3}, b::Vector) + batched_vec(A::AbstractArray{T,3}, B::AbstractMatrix) + batched_vec(A::AbstractArray{T,3}, b::AbstractVector) batched_vec(A::AbstractArray, B::AbstractArray) Batched matrix-vector multiplication. For the 3D case: From 69f61e6cbc61010282733e47882df47b74c688ab Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 19 Nov 2025 12:08:12 +0100 Subject: [PATCH 4/4] Update test/batchedmul.jl --- test/batchedmul.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 042f4c1e..b27f7ce0 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -332,12 +332,4 @@ end @test_throws DimensionMismatch batched_vec(randn(3, 4, 2), randn(4, 3)) # ndims mismatch @test_throws DimensionMismatch batched_vec(randn(3, 4, 2, 3), randn(4, 2, 2)) # batch size mismatch - # Test that 3D case still works (backward compatibility) - A3d_test = randn(4, 5, 3) - B2d_test = randn(5, 3) - C3_test = batched_vec(A3d_test, B2d_test) - @test size(C3_test) == (4, 3) - for k in 1:3 - @test C3_test[:, k] ≈ A3d_test[:, :, k] * B2d_test[:, k] - end end