Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/scripts/7_Others/7.3_Batched_neighbor_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@
# Fix: Ensure pbc has the correct shape [n_systems, 3]
pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool)

Comment on lines 25 to 27
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Broadcast error: pbc_tensor gets wrong shape

[[pbc] * 3] * len(atoms_list) expands both the leading and 2nd dimensions, yielding shape (n_systems, 3, 3) instead of (n_systems, 3).
This will break downstream neighbour-list functions expecting (n_systems, 3) or (3,).

Suggested fix:

-# Fix: Ensure pbc has the correct shape [n_systems, 3]
-pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool)
+# Ensure pbc has shape (n_systems, 3)
+if isinstance(pbc, torch.Tensor) and pbc.dim() == 1:
+    # pbc is a (3,) tensor – broadcast over systems
+    pbc_tensor = pbc.unsqueeze(0).repeat(len(atoms_list), 1)
+else:
+    # pbc is a python bool or already (3,) list
+    pbc_tensor = torch.tensor([pbc] * len(atoms_list), dtype=torch.bool)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/scripts/7_Others/7.3_Batched_neighbor_list.py around lines 25 to 27,
the construction of pbc_tensor incorrectly expands dimensions resulting in shape
(n_systems, 3, 3) instead of the required (n_systems, 3). To fix this, create a
list of pbc repeated n_systems times without duplicating the inner list three
times, then convert it to a tensor with dtype=torch.bool to ensure the shape is
(n_systems, 3).

mapping, mapping_batch, shifts_idx = torch_nl_linked_cell(
mapping, mapping_system, shifts_idx = torch_nl_linked_cell(
cutoff, pos, cell, pbc_tensor, system_idx, self_interaction
)
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_batch)
cell_shifts = transforms.compute_cell_shifts(cell, shifts_idx, mapping_system)
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)

print(mapping.shape)
print(mapping_batch.shape)
print(mapping_system.shape)
print(shifts_idx.shape)
print(cell_shifts.shape)
print(dds.shape)

mapping_n2, mapping_batch_n2, shifts_idx_n2 = torch_nl_n2(
mapping_n2, mapping_system_n2, shifts_idx_n2 = torch_nl_n2(
cutoff, pos, cell, pbc_tensor, system_idx, self_interaction
)
cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_batch_n2)
cell_shifts_n2 = transforms.compute_cell_shifts(cell, shifts_idx_n2, mapping_system_n2)
dds_n2 = transforms.compute_distances_with_cell_shifts(pos, mapping_n2, cell_shifts_n2)

print(mapping_n2.shape)
print(mapping_batch_n2.shape)
print(mapping_system_n2.shape)
print(shifts_idx_n2.shape)
print(cell_shifts_n2.shape)
print(dds_n2.shape)
6 changes: 3 additions & 3 deletions examples/tutorials/low_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
"""
`SimState` objects can be passed directly to the model and it will compute
the properties of the systems in the batch. The properties will be returned
either batchwise, like the energy, or atomwise, like the forces.
either systemwise, like the energy, or atomwise, like the forces.

Note that the energy here refers to the potential energy of the system.
"""
Expand All @@ -116,9 +116,9 @@
model_outputs = model(state)
print(f"Model outputs: {', '.join(list(model_outputs))}")

print(f"Energy is a batchwise property with shape: {model_outputs['energy'].shape}")
print(f"Energy is a systemwise property with shape: {model_outputs['energy'].shape}")
print(f"Forces are an atomwise property with shape: {model_outputs['forces'].shape}")
print(f"Stress is a batchwise property with shape: {model_outputs['stress'].shape}")
print(f"Stress is a systemwise property with shape: {model_outputs['stress'].shape}")


# %% [markdown]
Expand Down
22 changes: 11 additions & 11 deletions examples/tutorials/state_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@

# %% [markdown]
"""
SimState attributes fall into three categories: atomwise, batchwise, and global.
SimState attributes fall into three categories: atomwise, systemwise, and global.

* Atomwise attributes are tensors with shape (n_atoms, ...), these are `positions`,
`masses`, `atomic_numbers`, and `batch`. Names are plural.
* Batchwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
`masses`, `atomic_numbers`, and `system_idx`. Names are plural.
* Systemwise attributes are tensors with shape (n_systems, ...), this is just `cell` for
the base SimState. Names are singular.
* Global attributes have any other shape or type, just `pbc` here. Names are singular.

Expand Down Expand Up @@ -112,7 +112,7 @@
f"Multi-state has {multi_state.n_atoms} total atoms across {multi_state.n_systems} systems"
)

# we can see how the shapes of batchwise, atomwise, and global properties change
# we can see how the shapes of atomwise, systemwise, and global properties change
print(f"Positions shape: {multi_state.positions.shape}")
print(f"Cell shape: {multi_state.cell.shape}")
print(f"PBC: {multi_state.pbc}")
Expand Down Expand Up @@ -142,7 +142,7 @@

SimState supports many convenience operations for manipulating batched states. Slicing
is supported through fancy indexing, e.g. `state[[0, 1, 2]]` will return a new state
containing only the first three batches. The other operations are available through the
containing only the first three systems. The other operations are available through the
`pop`, `split`, `clone`, and `to` methods.
"""

Expand Down Expand Up @@ -182,19 +182,19 @@
# %% [markdown]
"""

You can extract specific batches from a batched state using Python's slicing syntax.
You can extract specific systems from a batched state using Python's slicing syntax.
This is extremely useful for analyzing specific systems or for implementing complex
workflows where different systems need separate processing:

The slicing interface follows Python's standard indexing conventions, making it
intuitive to use. Behind the scenes, TorchSim is creating a new SimState with only the
selected batches, maintaining all the necessary properties and relationships.
selected systems, maintaining all the necessary properties and relationships.

Note the difference between these operations:
- `split()` returns all batches as separate states but doesn't modify the original
- `pop()` removes specified batches from the original state and returns them as
- `split()` returns all systems as separate states but doesn't modify the original
- `pop()` removes specified systems from the original state and returns them as
separate states
- `__getitem__` (slicing) creates a new state with specified batches without modifying
- `__getitem__` (slicing) creates a new state with specified systems without modifying
the original

This flexibility allows you to structure your simulation workflows in the most
Expand All @@ -203,7 +203,7 @@
### Splitting and Popping Batches

SimState provides methods to split a batched state into separate states or to remove
specific batches:
specific systems:
"""

# %% [markdown]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_binning_auto_batcher(
# Get batches until None is returned
batches = list(batcher)

# Check we got the expected number of batches
# Check we got the expected number of systems
assert len(batches) == len(batcher.batched_states)

# Test restore_original_order
Expand Down
48 changes: 24 additions & 24 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,66 +20,66 @@ def test_calculate_momenta_basic(device: torch.device):
seed = 42
dtype = torch.float64

# Create test inputs for 3 batches with 2 atoms each
# Create test inputs for 3 systems with 2 atoms each
n_atoms = 8
positions = torch.randn(n_atoms, 3, dtype=dtype, device=device)
masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5
batch = torch.tensor(
system_idx = torch.tensor(
[0, 0, 1, 1, 2, 2, 3, 3], device=device
) # 3 batches with 2 atoms each
) # 3 systems with 2 atoms each
kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device)

# Run the function
momenta = calculate_momenta(positions, masses, batch, kT, seed=seed)
momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed)

# Basic checks
assert momenta.shape == positions.shape
assert momenta.dtype == dtype
assert momenta.device == device

# Check that each batch has zero center of mass momentum
# Check that each system has zero center of mass momentum
for b in range(4):
batch_mask = batch == b
batch_momenta = momenta[batch_mask]
com_momentum = torch.mean(batch_momenta, dim=0)
system_mask = system_idx == b
system_momenta = momenta[system_mask]
com_momentum = torch.mean(system_momenta, dim=0)
assert torch.allclose(
com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10
)


def test_calculate_momenta_single_atoms(device: torch.device):
"""Test that calculate_momenta preserves momentum for batches with single atoms."""
"""Test that calculate_momenta preserves momentum for systems with single atoms."""
seed = 42
dtype = torch.float64

# Create test inputs with some batches having single atoms
# Create test inputs with some systems having single atoms
positions = torch.randn(5, 3, dtype=dtype, device=device)
masses = torch.rand(5, dtype=dtype, device=device) + 0.5
batch = torch.tensor(
system_idx = torch.tensor(
[0, 1, 1, 2, 3], device=device
) # Batches 0, 2, and 3 have single atoms
) # systems 0, 2, and 3 have single atoms
kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device)

# Generate momenta and save the raw values before COM correction
generator = torch.Generator(device=device).manual_seed(seed)
raw_momenta = torch.randn(
positions.shape, device=device, dtype=dtype, generator=generator
) * torch.sqrt(masses * kT[batch]).unsqueeze(-1)
) * torch.sqrt(masses * kT[system_idx]).unsqueeze(-1)

# Run the function
momenta = calculate_momenta(positions, masses, batch, kT, seed=seed)
momenta = calculate_momenta(positions, masses, system_idx, kT, seed=seed)

# Check that single-atom batches have unchanged momenta
for b in [0, 2, 3]: # Single atom batches
batch_mask = batch == b
# Check that single-atom systems have unchanged momenta
for b in [0, 2, 3]: # Single atom systems
system_mask = system_idx == b
# The momentum should be exactly the same as the raw value for single atoms
assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask])
assert torch.allclose(momenta[system_mask], raw_momenta[system_mask])

# Check that multi-atom batches have zero COM
for b in [1]: # Multi-atom batches
batch_mask = batch == b
batch_momenta = momenta[batch_mask]
com_momentum = torch.mean(batch_momenta, dim=0)
# Check that multi-atom systems have zero COM
for b in [1]: # Multi-atom systems
system_mask = system_idx == b
system_momenta = momenta[system_mask]
com_momentum = torch.mean(system_momenta, dim=0)
assert torch.allclose(
com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10
)
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_compute_cell_force_atoms_per_system():
Covers fix in https://github.com/Radical-AI/torch-sim/pull/153."""
from torch_sim.integrators.npt import _compute_cell_force

# Setup minimal state with two batches having 8:1 atom ratio
# Setup minimal state with two systems having 8:1 atom ratio
s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long)

state = NPTLangevinState(
Expand Down
14 changes: 7 additions & 7 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ def test_torch_nl_implementations(
)

# Get the neighbor list from the implementation being tested
mapping, mapping_batch, shifts_idx = nl_implementation(
mapping, mapping_system, shifts_idx = nl_implementation(
cutoff, pos, row_vector_cell, pbc, batch, self_interaction
)

# Calculate distances
cell_shifts = transforms.compute_cell_shifts(
row_vector_cell, shifts_idx, mapping_batch
row_vector_cell, shifts_idx, mapping_system
)
dds = transforms.compute_distances_with_cell_shifts(pos, mapping, cell_shifts)
dds = np.sort(dds.numpy())
Expand Down Expand Up @@ -496,30 +496,30 @@ def test_strict_nl_edge_cases(

# Test with no cell shifts
mapping = torch.tensor([[0], [1]], device=device, dtype=torch.long)
batch_mapping = torch.tensor([0], device=device, dtype=torch.long)
system_mapping = torch.tensor([0], device=device, dtype=torch.long)
shifts_idx = torch.zeros((1, 3), device=device, dtype=torch.long)

new_mapping, new_batch, new_shifts = neighbors.strict_nl(
cutoff=1.5,
positions=pos,
cell=cell,
mapping=mapping,
batch_mapping=batch_mapping,
system_mapping=system_mapping,
shifts_idx=shifts_idx,
)
assert len(new_mapping[0]) > 0 # Should find neighbors

# Test with different batch mappings
mapping = torch.tensor([[0, 1], [1, 0]], device=device, dtype=torch.long)
batch_mapping = torch.tensor([0, 1], device=device, dtype=torch.long)
system_mapping = torch.tensor([0, 1], device=device, dtype=torch.long)
shifts_idx = torch.zeros((2, 3), device=device, dtype=torch.long)

new_mapping, new_batch, new_shifts = neighbors.strict_nl(
cutoff=1.5,
positions=pos,
cell=cell,
mapping=mapping,
batch_mapping=batch_mapping,
system_mapping=system_mapping,
shifts_idx=shifts_idx,
)
assert len(new_mapping[0]) > 0 # Should find neighbors
Expand Down Expand Up @@ -559,7 +559,7 @@ def test_neighbor_lists_time_and_memory(
system_idx = torch.zeros(n_atoms, dtype=torch.long, device=device)
# Fix pbc tensor shape
pbc = torch.tensor([[True, True, True]], device=device)
mapping, mapping_batch, shifts_idx = nl_fn(
mapping, mapping_system, shifts_idx = nl_fn(
cutoff, pos, cell, pbc, system_idx, self_interaction=False
)
else:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,9 +1183,9 @@ def test_compute_cell_shifts_basic() -> None:
"""Test compute_cell_shifts function."""
cell = torch.eye(3).unsqueeze(0) * 2.0
shifts_idx = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
batch_mapping = torch.tensor([0, 0])
system_mapping = torch.tensor([0, 0])

cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, batch_mapping)
cell_shifts = tst.compute_cell_shifts(cell, shifts_idx, system_mapping)

expected = torch.tensor([[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]])
torch.testing.assert_close(cell_shifts, expected)
Expand Down Expand Up @@ -1272,16 +1272,16 @@ def test_build_linked_cell_neighborhood_basic() -> None:
cutoff = 1.5
n_atoms = torch.tensor([2, 2])

mapping, batch_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood(
mapping, system_mapping, cell_shifts_idx = tst.build_linked_cell_neighborhood(
positions, cell, pbc, cutoff, n_atoms, self_interaction=False
)

# Check that atoms in the same structure are neighbors
assert mapping.shape[1] >= 2 # At least 2 neighbor pairs

# Verify batch_mapping has correct length
assert batch_mapping.shape[0] == mapping.shape[1]
# Verify system_mapping has correct length
assert system_mapping.shape[0] == mapping.shape[1]

# Verify that there are neighbors from both batches
assert torch.any(batch_mapping == 0)
assert torch.any(batch_mapping == 1)
assert torch.any(system_mapping == 0)
assert torch.any(system_mapping == 1)
6 changes: 3 additions & 3 deletions torch_sim/models/graphpes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def state_to_atomic_graph(state: ts.SimState, cutoff: torch.Tensor) -> AtomicGra
graphs = []

for i in range(state.n_systems):
batch_mask = state.system_idx == i
R = state.positions[batch_mask]
Z = state.atomic_numbers[batch_mask]
system_mask = state.system_idx == i
R = state.positions[system_mask]
Z = state.atomic_numbers[system_mask]
cell = state.row_vector_cell[i]
nl, shifts = vesin_nl_ts(
R,
Expand Down
Loading