Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/scripts/7_Others/7.3_Batched_neighbor_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
cutoff = torch.tensor(4.0, dtype=pos.dtype)
self_interaction = False

# Fix: Ensure pbc has the correct shape [n_systems, 3]
pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool)
# Ensure pbc has the correct shape [n_systems, 3]
pbc_tensor = torch.tensor(pbc).repeat(state.n_systems, 1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this repeat is needed since torch_nl_linked_cell accepts (num_systems, 3) for the pbc. it doens't just accept a tensor of shape (3,)


mapping, mapping_system, shifts_idx = torch_nl_linked_cell(
pos, cell, pbc_tensor, cutoff, system_idx, self_interaction
Expand Down
25 changes: 17 additions & 8 deletions examples/tutorials/diff_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class BaseState:

positions: torch.Tensor
cell: torch.Tensor
pbc: bool
pbc: torch.Tensor
species: torch.Tensor


Expand All @@ -133,14 +133,18 @@ def __init__(
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
*, # Force keyword-only arguments
pbc: bool = True,
pbc: torch.Tensor | bool = True,
cutoff: float | None = None,
) -> None:
"""Initialize a soft sphere model for multi-component systems."""
super().__init__()
self.device = device or torch.device("cpu")
self.dtype = dtype
self.pbc = pbc
self.pbc = (
pbc
if isinstance(pbc, torch.Tensor)
else torch.tensor([pbc] * 3, dtype=torch.bool)
)

# Store species list and determine number of unique species
self.species = species
Expand Down Expand Up @@ -369,20 +373,25 @@ def simulation(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Create the simulation environment.
box_size = box_size_at_packing_fraction(diameter, packing_fraction)
cell = torch.eye(2) * box_size
cell = torch.eye(3) * box_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not against diff_sim being in 3D, but was there any reason to be in 2D in the first place?

Copy link
Collaborator Author

@curtischong curtischong Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

¯\(ツ)

# Create the energy function.
sigma = species_sigma(diameter)
model = SoftSphereMultiModel(sigma_matrix=sigma, species=species)
model = torch.compile(model)
# Randomly initialize the system.
# Fix seed for reproducible random positions
torch.manual_seed(seed)
R = torch.rand(N, 2) * box_size
R = torch.rand(N, 3) * box_size

# Minimize to the nearest minimum.
init_fn, apply_fn = gradient_descent(model, lr=0.1)

custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True)
custom_state = BaseState(
positions=R,
cell=cell,
species=species,
pbc=torch.tensor([True] * 3, dtype=torch.bool),
)
state = init_fn(custom_state)
for _ in range(simulation_steps):
state = apply_fn(state)
Expand Down Expand Up @@ -415,7 +424,7 @@ def simulation(
seeds = torch.arange(1, 6)
box_size_tensor = torch.zeros(len(diameters), len(seeds))
raft_energy_tensor = torch.zeros(len(diameters), len(seeds))
bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 2)
bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 3)
for i, d in enumerate(diameters):
for j, s in enumerate(seeds):
box_size, raft_energy, bubble_positions = simulation(d, s)
Expand Down Expand Up @@ -468,7 +477,7 @@ def short_simulation(
) -> tuple[torch.Tensor, torch.Tensor]:
diameter = diameter.requires_grad_(True)
box_size = box_size_at_packing_fraction(diameter, packing_fraction)
cell = torch.eye(2) * box_size
cell = torch.eye(3) * box_size
# Create the energy function.
sigma = species_sigma(diameter)
model = SoftSphereMultiModel(sigma_matrix=sigma, species=species)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ dependencies = [
"tables>=3.10.2",
"torch>=2",
"tqdm>=4.67",
"vesin-torch>=0.3.7, <0.4.0",
"vesin>=0.3.7, <0.4.0",
"vesin-torch>=0.4.0, <0.5.0",
"vesin>=0.4.0, <0.5.0",
]

[project.optional-dependencies]
Expand All @@ -48,7 +48,7 @@ test = [
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
mace = ["mace-torch>=0.3.12"]
mattersim = ["mattersim>=0.1.2"]
metatomic = ["metatomic-torch>=0.1.1", "metatrain[pet]>=2025.7"]
metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.7"]
orb = ["orb-models>=0.5.2"]
sevenn = ["sevenn>=0.11.0"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_soft_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def test_multispecies_cutoff_default() -> None:
@pytest.mark.parametrize(
("flag_name", "flag_value"),
[
("pbc", True),
("pbc", False),
("pbc", torch.tensor([True, True, True])),
("pbc", torch.tensor([False, False, False])),
("compute_forces", False),
("compute_stress", True),
("per_atom_energies", True),
Expand Down
10 changes: 5 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_multiple_structures_to_state(si_structure: Structure) -> None:
assert state.positions.shape == (16, 3)
assert state.masses.shape == (16,)
assert state.cell.shape == (2, 3, 3)
assert state.pbc
assert torch.all(state.pbc)
assert state.atomic_numbers.shape == (16,)
assert state.system_idx.shape == (16,)
assert torch.all(
Expand All @@ -64,7 +64,7 @@ def test_single_atoms_to_state(si_atoms: Atoms) -> None:
assert state.positions.shape == (8, 3)
assert state.masses.shape == (8,)
assert state.cell.shape == (1, 3, 3)
assert state.pbc
assert torch.all(state.pbc)
assert state.atomic_numbers.shape == (8,)
assert state.system_idx.shape == (8,)
assert torch.all(state.system_idx == 0)
Expand All @@ -79,7 +79,7 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None:
assert state.positions.shape == (16, 3)
assert state.masses.shape == (16,)
assert state.cell.shape == (2, 3, 3)
assert state.pbc
assert torch.all(state.pbc)
assert state.atomic_numbers.shape == (16,)
assert state.system_idx.shape == (16,)
assert torch.all(
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None:
assert state.positions.shape == (16, 3)
assert state.masses.shape == (16,)
assert state.cell.shape == (2, 3, 3)
assert state.pbc
assert torch.all(state.pbc)
assert state.atomic_numbers.shape == (16,)
assert state.system_idx.shape == (16,)
assert torch.all(
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_state_round_trip(
assert torch.allclose(sim_state.cell, round_trip_state.cell)
assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers)
assert torch.all(sim_state.system_idx == round_trip_state.system_idx)
assert sim_state.pbc == round_trip_state.pbc
assert torch.equal(sim_state.pbc, round_trip_state.pbc)

if isinstance(intermediate_format[0], Atoms):
# TODO: masses round trip for pmg and phonopy masses is not exact
Expand Down
45 changes: 33 additions & 12 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ def test_primitive_neighbor_list(
pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE)
row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE)

pbc = atoms.pbc.any()
pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE)

# Get the neighbor list using the appropriate function (jitted or non-jitted)
# Note: No self-interaction
idx_i, idx_j, shifts_tensor = neighbor_list_fn(
quantities="ijS",
positions=pos,
cell=row_vector_cell,
pbc=(pbc, pbc, pbc),
pbc=pbc,
cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE),
device=DEVICE,
dtype=DTYPE,
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_neighbor_list_implementations(
# Convert to torch tensors
pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE)
row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE)
pbc = atoms.pbc.any()
pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE)

# Get the neighbor list from the implementation being tested
mapping, shifts = nl_implementation(
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_primitive_neighbor_list_edge_cases() -> None:
quantities="ijS",
positions=pos,
cell=cell,
pbc=pbc,
pbc=torch.tensor(pbc, device=DEVICE, dtype=DTYPE),
cutoff=cutoff,
device=DEVICE,
dtype=DTYPE,
Expand All @@ -383,7 +383,7 @@ def test_primitive_neighbor_list_edge_cases() -> None:
quantities="ijS",
positions=pos,
cell=cell,
pbc=(True, True, True),
pbc=torch.Tensor([True, True, True]),
cutoff=cutoff,
device=DEVICE,
dtype=DTYPE,
Expand All @@ -404,7 +404,7 @@ def test_standard_nl_edge_cases() -> None:
mapping, _shifts = neighbors.standard_nl(
positions=pos,
cell=cell,
pbc=pbc,
pbc=torch.tensor([pbc] * 3, device=DEVICE, dtype=DTYPE),
cutoff=cutoff,
)
assert len(mapping[0]) > 0 # Should find neighbors
Expand All @@ -413,7 +413,7 @@ def test_standard_nl_edge_cases() -> None:
mapping, _shifts = neighbors.standard_nl(
positions=pos,
cell=cell,
pbc=True,
pbc=torch.Tensor([True, True, True]),
cutoff=cutoff,
sort_id=True,
)
Expand All @@ -430,13 +430,20 @@ def test_vesin_nl_edge_cases() -> None:
# Test both implementations
for nl_fn in (neighbors.vesin_nl, neighbors.vesin_nl_ts):
# Test different PBC combinations
for pbc in (True, False):
for pbc in (
torch.Tensor([True, True, True]),
torch.Tensor([False, False, False]),
):
mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=pbc, cutoff=cutoff)
assert len(mapping[0]) > 0 # Should find neighbors

# Test sort_id
mapping, _shifts = nl_fn(
positions=pos, cell=cell, pbc=True, cutoff=cutoff, sort_id=True
positions=pos,
cell=cell,
pbc=torch.Tensor([True, True, True]),
cutoff=cutoff,
sort_id=True,
)
# Check if indices are sorted
assert torch.all(mapping[0][1:] >= mapping[0][:-1])
Expand All @@ -446,7 +453,10 @@ def test_vesin_nl_edge_cases() -> None:
pos_f32 = pos.to(dtype=torch.float32)
cell_f32 = cell.to(dtype=torch.float32)
mapping, _shifts = nl_fn(
positions=pos_f32, cell=cell_f32, pbc=True, cutoff=cutoff
positions=pos_f32,
cell=cell_f32,
pbc=torch.Tensor([True, True, True]),
cutoff=cutoff,
)
assert len(mapping[0]) > 0 # Should find neighbors

Expand Down Expand Up @@ -528,7 +538,12 @@ def test_neighbor_lists_time_and_memory() -> None:
self_interaction=False,
)
else:
_mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff)
_mapping, _shifts = nl_fn(
positions=pos,
cell=cell,
pbc=torch.Tensor([True, True, True]),
cutoff=cutoff,
)

end_time = time.perf_counter()
execution_time = end_time - start_time
Expand All @@ -551,4 +566,10 @@ def test_neighbor_lists_time_and_memory() -> None:
assert cpu_memory_used < 5e8, (
f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB"
)
assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s"
if nl_fn == neighbors.standard_nl:
# this function is just quite slow. So we have a higher tolerance.
# I tried removing "@jit.script" and it was still slow.
# (This nl function is just slow)
assert execution_time < 3, f"{fn_name} took too long: {execution_time}s"
else:
assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s"
13 changes: 7 additions & 6 deletions tests/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def random_state() -> MDState:
cell=torch.unsqueeze(torch.eye(3) * 10.0, 0),
atomic_numbers=torch.ones(10, dtype=torch.int32),
system_idx=torch.zeros(10, dtype=torch.int32),
pbc=True,
pbc=[True, True, False],
)


Expand Down Expand Up @@ -93,7 +93,7 @@ def test_write_state_single(
assert trajectory.get_array("positions").shape == (1, 10, 3)
assert trajectory.get_array("atomic_numbers").shape == (1, 10)
assert trajectory.get_array("cell").shape == (1, 3, 3)
assert trajectory.get_array("pbc").shape == (1,)
assert trajectory.get_array("pbc").shape == (3,)


def test_write_state_multiple(
Expand All @@ -106,7 +106,7 @@ def test_write_state_multiple(
assert trajectory.get_array("positions").shape == (2, 10, 3)
assert trajectory.get_array("atomic_numbers").shape == (1, 10)
assert trajectory.get_array("cell").shape == (2, 3, 3)
assert trajectory.get_array("pbc").shape == (1,)
assert trajectory.get_array("pbc").shape == (3,)


def test_optional_arrays(trajectory: TorchSimTrajectory, random_state: MDState) -> None:
Expand Down Expand Up @@ -439,7 +439,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non
np.testing.assert_allclose(
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
)
assert atoms.pbc.all() == random_state.pbc
np.testing.assert_array_equal(atoms.pbc, random_state.pbc.detach().cpu().numpy())


def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> None:
Expand Down Expand Up @@ -473,12 +473,13 @@ def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> Non
assert state.positions.dtype == expected_dtype
assert state.cell.dtype == expected_dtype
assert state.atomic_numbers.dtype == torch.int # Should always be int
assert state.pbc.dtype == torch.bool # Should always be bool

# Test values (convert to CPU for comparison)
np.testing.assert_allclose(state.positions, random_state.positions)
np.testing.assert_allclose(state.cell, random_state.cell)
np.testing.assert_allclose(state.atomic_numbers, random_state.atomic_numbers)
assert state.pbc == random_state.pbc
assert torch.equal(state.pbc, random_state.pbc)


def test_write_ase_trajectory(
Expand Down Expand Up @@ -509,7 +510,7 @@ def test_write_ase_trajectory(
np.testing.assert_allclose(
atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy()
)
assert atoms.pbc.all() == random_state.pbc
np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you should transfer first in cpu() in case we ever run this on gpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine because the random_state is automatically made on cpu by default. Maybe we can do it in a separate PR if people run these tests and see it fail? (trying to keep this one slim)


# Clean up
ase_traj.close()
Expand Down
16 changes: 14 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# ruff: noqa: PT011
import itertools

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -195,7 +197,7 @@ def test_pbc_wrap_general_batch() -> None:


@pytest.mark.parametrize(
"pbc", [[True, True, True], [True, True, False], [False, False, False], True, False]
"pbc", [*list(itertools.product([False, True], repeat=3)), True, False]
)
@pytest.mark.parametrize("pretty_translation", [True, False])
def test_wrap_positions_matches_ase(
Expand Down Expand Up @@ -892,10 +894,20 @@ def test_get_fractional_coordinates_batched() -> None:
True,
[[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]],
),
(
[[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]],
torch.eye(3, dtype=DTYPE) * 2.0,
torch.tensor([True, False, True], dtype=torch.bool),
[[0.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 0.2]],
),
],
)
def test_minimum_image_displacement(
*, dr: list[list[float]], cell: torch.Tensor, pbc: bool, expected: list[list[float]]
*,
dr: list[list[float]],
cell: torch.Tensor,
pbc: bool | torch.Tensor,
expected: list[list[float]],
) -> None:
"""Test minimum_image_displacement with various inputs.

Expand Down
Loading