Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions torch_sim/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,28 @@
print(msg)
# Fall back to scipy implementation
return matrix_log_scipy(matrix).to(sim_dtype)


def batched_vdot(
x: torch.Tensor, y: torch.Tensor, batch_indices: torch.Tensor
) -> torch.Tensor:
"""Computes batched vdot (sum of element-wise product) for groups of vectors.
If is_sum_sq is True, computes sum of x_i * x_i (squared norm components).

Args:
x: Tensor of shape [N_total_entities, D] (e.g., forces, velocities).
y: Tensor of shape [N_total_entities, D].
batch_indices: Tensor of shape [N_total_entities] indicating batch membership.

Returns:
Tensor: shape [n_batches] where each element is the sum(x_i * y_i)
(or sum(x_i * x_i) if is_sum_sq) for entities belonging to that batch,
summed over all components D and all entities in the batch.
"""
if x.ndim != 2 or batch_indices.ndim != 1 or x.shape[0] != batch_indices.shape[0]:
raise ValueError(f"Invalid input shapes: {x.shape=}, {batch_indices.shape=}")

Check warning on line 1009 in torch_sim/math.py

View check run for this annotation

Codecov / codecov/patch

torch_sim/math.py#L1009

Added line #L1009 was not covered by tests

output = torch.zeros(batch_indices.max() + 1, dtype=x.dtype, device=x.device)
output.scatter_add_(dim=0, index=batch_indices, src=(x * y).sum(dim=1))

return output
Loading
Loading