diff --git a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py index 724231e3..1393b908 100644 --- a/examples/scripts/7_Others/7.3_Batched_neighbor_list.py +++ b/examples/scripts/7_Others/7.3_Batched_neighbor_list.py @@ -18,8 +18,8 @@ cutoff = torch.tensor(4.0, dtype=pos.dtype) self_interaction = False -# Fix: Ensure pbc has the correct shape [n_systems, 3] -pbc_tensor = torch.tensor([[pbc] * 3] * len(atoms_list), dtype=torch.bool) +# Ensure pbc has the correct shape [n_systems, 3] +pbc_tensor = torch.tensor(pbc).repeat(state.n_systems, 1) mapping, mapping_system, shifts_idx = torch_nl_linked_cell( pos, cell, pbc_tensor, cutoff, system_idx, self_interaction diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index 3835c926..79036469 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -117,7 +117,7 @@ class BaseState: positions: torch.Tensor cell: torch.Tensor - pbc: bool + pbc: torch.Tensor species: torch.Tensor @@ -133,14 +133,18 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.float32, *, # Force keyword-only arguments - pbc: bool = True, + pbc: torch.Tensor | bool = True, cutoff: float | None = None, ) -> None: """Initialize a soft sphere model for multi-component systems.""" super().__init__() self.device = device or torch.device("cpu") self.dtype = dtype - self.pbc = pbc + self.pbc = ( + pbc + if isinstance(pbc, torch.Tensor) + else torch.tensor([pbc] * 3, dtype=torch.bool) + ) # Store species list and determine number of unique species self.species = species @@ -369,7 +373,7 @@ def simulation( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Create the simulation environment. box_size = box_size_at_packing_fraction(diameter, packing_fraction) - cell = torch.eye(2) * box_size + cell = torch.eye(3) * box_size # Create the energy function. sigma = species_sigma(diameter) model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) @@ -377,12 +381,17 @@ def simulation( # Randomly initialize the system. # Fix seed for reproducible random positions torch.manual_seed(seed) - R = torch.rand(N, 2) * box_size + R = torch.rand(N, 3) * box_size # Minimize to the nearest minimum. init_fn, apply_fn = gradient_descent(model, lr=0.1) - custom_state = BaseState(positions=R, cell=cell, species=species, pbc=True) + custom_state = BaseState( + positions=R, + cell=cell, + species=species, + pbc=torch.tensor([True] * 3, dtype=torch.bool), + ) state = init_fn(custom_state) for _ in range(simulation_steps): state = apply_fn(state) @@ -415,7 +424,7 @@ def simulation( seeds = torch.arange(1, 6) box_size_tensor = torch.zeros(len(diameters), len(seeds)) raft_energy_tensor = torch.zeros(len(diameters), len(seeds)) -bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 2) +bubble_positions_tensor = torch.zeros(len(diameters), len(seeds), N, 3) for i, d in enumerate(diameters): for j, s in enumerate(seeds): box_size, raft_energy, bubble_positions = simulation(d, s) @@ -468,7 +477,7 @@ def short_simulation( ) -> tuple[torch.Tensor, torch.Tensor]: diameter = diameter.requires_grad_(True) box_size = box_size_at_packing_fraction(diameter, packing_fraction) - cell = torch.eye(2) * box_size + cell = torch.eye(3) * box_size # Create the energy function. sigma = species_sigma(diameter) model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) diff --git a/pyproject.toml b/pyproject.toml index 850f35b6..996757d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,8 @@ dependencies = [ "tables>=3.10.2", "torch>=2", "tqdm>=4.67", - "vesin-torch>=0.3.7, <0.4.0", - "vesin>=0.3.7, <0.4.0", + "vesin-torch>=0.4.0, <0.5.0", + "vesin>=0.4.0, <0.5.0", ] [project.optional-dependencies] @@ -48,7 +48,7 @@ test = [ io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] mace = ["mace-torch>=0.3.12"] mattersim = ["mattersim>=0.1.2"] -metatomic = ["metatomic-torch>=0.1.1", "metatrain[pet]>=2025.7"] +metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.7"] orb = ["orb-models>=0.5.2"] sevenn = ["sevenn>=0.11.0"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index 99c2ed64..a07c8282 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -350,8 +350,8 @@ def test_multispecies_cutoff_default() -> None: @pytest.mark.parametrize( ("flag_name", "flag_value"), [ - ("pbc", True), - ("pbc", False), + ("pbc", torch.tensor([True, True, True])), + ("pbc", torch.tensor([False, False, False])), ("compute_forces", False), ("compute_stress", True), ("per_atom_energies", True), diff --git a/tests/test_io.py b/tests/test_io.py index 26e4a34c..a2c25ab4 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -46,7 +46,7 @@ def test_multiple_structures_to_state(si_structure: Structure) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -64,7 +64,7 @@ def test_single_atoms_to_state(si_atoms: Atoms) -> None: assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) assert state.cell.shape == (1, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (8,) assert state.system_idx.shape == (8,) assert torch.all(state.system_idx == 0) @@ -79,7 +79,7 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -171,7 +171,7 @@ def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None: assert state.positions.shape == (16, 3) assert state.masses.shape == (16,) assert state.cell.shape == (2, 3, 3) - assert state.pbc + assert torch.all(state.pbc) assert state.atomic_numbers.shape == (16,) assert state.system_idx.shape == (16,) assert torch.all( @@ -246,7 +246,7 @@ def test_state_round_trip( assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) assert torch.all(sim_state.system_idx == round_trip_state.system_idx) - assert sim_state.pbc == round_trip_state.pbc + assert torch.equal(sim_state.pbc, round_trip_state.pbc) if isinstance(intermediate_format[0], Atoms): # TODO: masses round trip for pmg and phonopy masses is not exact diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 8a1e0f7a..72cdbf13 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -170,7 +170,7 @@ def test_primitive_neighbor_list( pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) - pbc = atoms.pbc.any() + pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE) # Get the neighbor list using the appropriate function (jitted or non-jitted) # Note: No self-interaction @@ -178,7 +178,7 @@ def test_primitive_neighbor_list( quantities="ijS", positions=pos, cell=row_vector_cell, - pbc=(pbc, pbc, pbc), + pbc=pbc, cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), device=DEVICE, dtype=DTYPE, @@ -258,7 +258,7 @@ def test_neighbor_list_implementations( # Convert to torch tensors pos = torch.tensor(atoms.positions, device=DEVICE, dtype=DTYPE) row_vector_cell = torch.tensor(atoms.cell.array, device=DEVICE, dtype=DTYPE) - pbc = atoms.pbc.any() + pbc = torch.tensor(atoms.pbc, device=DEVICE, dtype=DTYPE) # Get the neighbor list from the implementation being tested mapping, shifts = nl_implementation( @@ -371,7 +371,7 @@ def test_primitive_neighbor_list_edge_cases() -> None: quantities="ijS", positions=pos, cell=cell, - pbc=pbc, + pbc=torch.tensor(pbc, device=DEVICE, dtype=DTYPE), cutoff=cutoff, device=DEVICE, dtype=DTYPE, @@ -383,7 +383,7 @@ def test_primitive_neighbor_list_edge_cases() -> None: quantities="ijS", positions=pos, cell=cell, - pbc=(True, True, True), + pbc=torch.Tensor([True, True, True]), cutoff=cutoff, device=DEVICE, dtype=DTYPE, @@ -404,7 +404,7 @@ def test_standard_nl_edge_cases() -> None: mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, - pbc=pbc, + pbc=torch.tensor([pbc] * 3, device=DEVICE, dtype=DTYPE), cutoff=cutoff, ) assert len(mapping[0]) > 0 # Should find neighbors @@ -413,7 +413,7 @@ def test_standard_nl_edge_cases() -> None: mapping, _shifts = neighbors.standard_nl( positions=pos, cell=cell, - pbc=True, + pbc=torch.Tensor([True, True, True]), cutoff=cutoff, sort_id=True, ) @@ -430,13 +430,20 @@ def test_vesin_nl_edge_cases() -> None: # Test both implementations for nl_fn in (neighbors.vesin_nl, neighbors.vesin_nl_ts): # Test different PBC combinations - for pbc in (True, False): + for pbc in ( + torch.Tensor([True, True, True]), + torch.Tensor([False, False, False]), + ): mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=pbc, cutoff=cutoff) assert len(mapping[0]) > 0 # Should find neighbors # Test sort_id mapping, _shifts = nl_fn( - positions=pos, cell=cell, pbc=True, cutoff=cutoff, sort_id=True + positions=pos, + cell=cell, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, + sort_id=True, ) # Check if indices are sorted assert torch.all(mapping[0][1:] >= mapping[0][:-1]) @@ -446,7 +453,10 @@ def test_vesin_nl_edge_cases() -> None: pos_f32 = pos.to(dtype=torch.float32) cell_f32 = cell.to(dtype=torch.float32) mapping, _shifts = nl_fn( - positions=pos_f32, cell=cell_f32, pbc=True, cutoff=cutoff + positions=pos_f32, + cell=cell_f32, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, ) assert len(mapping[0]) > 0 # Should find neighbors @@ -528,7 +538,12 @@ def test_neighbor_lists_time_and_memory() -> None: self_interaction=False, ) else: - _mapping, _shifts = nl_fn(positions=pos, cell=cell, pbc=True, cutoff=cutoff) + _mapping, _shifts = nl_fn( + positions=pos, + cell=cell, + pbc=torch.Tensor([True, True, True]), + cutoff=cutoff, + ) end_time = time.perf_counter() execution_time = end_time - start_time @@ -551,4 +566,10 @@ def test_neighbor_lists_time_and_memory() -> None: assert cpu_memory_used < 5e8, ( f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB" ) - assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" + if nl_fn == neighbors.standard_nl: + # this function is just quite slow. So we have a higher tolerance. + # I tried removing "@jit.script" and it was still slow. + # (This nl function is just slow) + assert execution_time < 3, f"{fn_name} took too long: {execution_time}s" + else: + assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py index 215877c6..7f7252cf 100644 --- a/tests/test_trajectory.py +++ b/tests/test_trajectory.py @@ -33,7 +33,7 @@ def random_state() -> MDState: cell=torch.unsqueeze(torch.eye(3) * 10.0, 0), atomic_numbers=torch.ones(10, dtype=torch.int32), system_idx=torch.zeros(10, dtype=torch.int32), - pbc=True, + pbc=[True, True, False], ) @@ -93,7 +93,7 @@ def test_write_state_single( assert trajectory.get_array("positions").shape == (1, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (1, 3, 3) - assert trajectory.get_array("pbc").shape == (1,) + assert trajectory.get_array("pbc").shape == (3,) def test_write_state_multiple( @@ -106,7 +106,7 @@ def test_write_state_multiple( assert trajectory.get_array("positions").shape == (2, 10, 3) assert trajectory.get_array("atomic_numbers").shape == (1, 10) assert trajectory.get_array("cell").shape == (2, 3, 3) - assert trajectory.get_array("pbc").shape == (1,) + assert trajectory.get_array("pbc").shape == (3,) def test_optional_arrays(trajectory: TorchSimTrajectory, random_state: MDState) -> None: @@ -439,7 +439,7 @@ def test_get_atoms(trajectory: TorchSimTrajectory, random_state: MDState) -> Non np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() ) - assert atoms.pbc.all() == random_state.pbc + np.testing.assert_array_equal(atoms.pbc, random_state.pbc.detach().cpu().numpy()) def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> None: @@ -473,12 +473,13 @@ def test_get_state(trajectory: TorchSimTrajectory, random_state: MDState) -> Non assert state.positions.dtype == expected_dtype assert state.cell.dtype == expected_dtype assert state.atomic_numbers.dtype == torch.int # Should always be int + assert state.pbc.dtype == torch.bool # Should always be bool # Test values (convert to CPU for comparison) np.testing.assert_allclose(state.positions, random_state.positions) np.testing.assert_allclose(state.cell, random_state.cell) np.testing.assert_allclose(state.atomic_numbers, random_state.atomic_numbers) - assert state.pbc == random_state.pbc + assert torch.equal(state.pbc, random_state.pbc) def test_write_ase_trajectory( @@ -509,7 +510,7 @@ def test_write_ase_trajectory( np.testing.assert_allclose( atoms.get_atomic_numbers(), random_state.atomic_numbers.numpy() ) - assert atoms.pbc.all() == random_state.pbc + np.testing.assert_array_equal(atoms.pbc, random_state.pbc.numpy()) # Clean up ase_traj.close() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 067abad8..16565b73 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,4 +1,6 @@ # ruff: noqa: PT011 +import itertools + import numpy as np import pytest import torch @@ -195,7 +197,7 @@ def test_pbc_wrap_general_batch() -> None: @pytest.mark.parametrize( - "pbc", [[True, True, True], [True, True, False], [False, False, False], True, False] + "pbc", [*list(itertools.product([False, True], repeat=3)), True, False] ) @pytest.mark.parametrize("pretty_translation", [True, False]) def test_wrap_positions_matches_ase( @@ -892,10 +894,20 @@ def test_get_fractional_coordinates_batched() -> None: True, [[0.2, 0.0, 0.0], [0.0, 0.2, 0.0], [0.0, 0.0, 0.2]], ), + ( + [[2.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 2.2]], + torch.eye(3, dtype=DTYPE) * 2.0, + torch.tensor([True, False, True], dtype=torch.bool), + [[0.2, 0.0, 0.0], [0.0, 2.2, 0.0], [0.0, 0.0, 0.2]], + ), ], ) def test_minimum_image_displacement( - *, dr: list[list[float]], cell: torch.Tensor, pbc: bool, expected: list[list[float]] + *, + dr: list[list[float]], + cell: torch.Tensor, + pbc: bool | torch.Tensor, + expected: list[list[float]], ) -> None: """Test minimum_image_displacement with various inputs. diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 92de1ccc..b03fb7a3 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -175,10 +175,13 @@ def position_step[T: MDState](state: T, dt: float | torch.Tensor) -> T: """ new_positions = state.positions + state.velocities * dt - if state.pbc: + if state.pbc.any(): # Split positions and cells by system new_positions = transforms.pbc_wrap_batched( - new_positions, state.cell, state.system_idx + new_positions, + state.cell, + state.system_idx, + state.pbc, ) state.positions = new_positions diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4984baf4..db6e2b15 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -367,9 +367,9 @@ def _npt_langevin_position_step( state.positions = c_1 + c_2.unsqueeze(-1) * c_3 # Apply periodic boundary conditions if needed - if state.pbc: + if state.pbc.any(): state.positions = ts.transforms.pbc_wrap_batched( - state.positions, state.cell, state.system_idx + state.positions, state.cell, state.system_idx, state.pbc ) return state @@ -1030,9 +1030,9 @@ def _npt_nose_hoover_exp_iL1( # noqa: N802 new_positions = state.positions + new_positions # Apply periodic boundary conditions if needed - if state.pbc: + if state.pbc.any(): return ts.transforms.pbc_wrap_batched( - new_positions, state.current_cell, state.system_idx + new_positions, state.current_cell, state.system_idx, pbc=state.pbc ) return new_positions diff --git a/torch_sim/io.py b/torch_sim/io.py index a2081c20..253b29ed 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -117,7 +117,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: # Create structure for this system struct = Structure( - lattice=Lattice(system_cell), + lattice=Lattice(system_cell, pbc=(state.pbc.tolist())), species=species, coords=system_positions, coords_are_cartesian=True, @@ -164,8 +164,11 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: # Convert atomic numbers to chemical symbols symbols = [chemical_symbols[z] for z in system_numbers] + + # Note: pbc is not used in the init since it's always assumed to be true + # https://github.com/phonopy/phonopy/blob/develop/phonopy/structure/atoms.py#L140 phonopy_atoms = PhonopyAtoms( - symbols=symbols, positions=system_positions, cell=system_cell, pbc=state.pbc + symbols=symbols, positions=system_positions, cell=system_cell ) phonopy_atoms_list.append(phonopy_atoms) @@ -225,14 +228,14 @@ def atoms_to_state( ) # Verify consistent pbc - if not all(all(at.pbc) == all(atoms_list[0].pbc) for at in atoms_list): + if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=all(atoms_list[0].pbc), + pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, ) @@ -294,11 +297,15 @@ def structures_to_state( torch.arange(len(struct_list), device=device), atoms_per_system ) + # Verify consistent pbc + if not all(tuple(s.pbc) == tuple(struct_list[0].pbc) for s in struct_list[1:]): + raise ValueError("All systems must have the same periodic boundary conditions") + return ts.SimState( positions=positions, masses=masses, cell=cell, - pbc=True, # Structures are always periodic + pbc=struct_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, ) diff --git a/torch_sim/models/metatomic.py b/torch_sim/models/metatomic.py index 1e97cdda..2110f8c6 100644 --- a/torch_sim/models/metatomic.py +++ b/torch_sim/models/metatomic.py @@ -173,7 +173,6 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # atomic_nums = sim_state.atomic_numbers cell = sim_state.row_vector_cell positions = sim_state.positions - pbc = sim_state.pbc # Check dtype (metatomic models require a specific input dtype) if positions.dtype != self._dtype: @@ -196,9 +195,6 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # system_mask = sim_state.system_idx == sys_idx system_positions = positions[system_mask] system_cell = cell[sys_idx] - system_pbc = torch.tensor( - [pbc, pbc, pbc], device=self._device, dtype=torch.bool - ) system_atomic_numbers = atomic_nums[system_mask] # Create a System object for this system @@ -217,7 +213,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # positions=system_positions, types=system_atomic_numbers, cell=system_cell, - pbc=system_pbc, + pbc=sim_state.pbc, ) ) diff --git a/torch_sim/models/nequip_framework.py b/torch_sim/models/nequip_framework.py index 916281a5..89f1ce56 100644 --- a/torch_sim/models/nequip_framework.py +++ b/torch_sim/models/nequip_framework.py @@ -345,11 +345,7 @@ def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: # "cell": sim_state.row_vector_cell, "batch": sim_state.system_idx, "num_atoms": sim_state.system_idx.bincount(), - "pbc": torch.tensor( - [sim_state.pbc, sim_state.pbc, sim_state.pbc], - dtype=torch.bool, - device=self.device, - ), + "pbc": sim_state.pbc, "atomic_numbers": sim_state.atomic_numbers, "atom_types": atomic_types, "edge_index": edge_index, diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index b998c7a5..f667b5c8 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -147,9 +147,6 @@ def state_to_atom_graphs( # noqa: PLR0915 ) # Orb uses row vector cell convention for neighbor list atomic_numbers = state.atomic_numbers.long() - # Create PBC tensor based on state.pbc - pbc = torch.tensor([state.pbc, state.pbc, state.pbc], dtype=torch.bool) - max_num_neighbors = max_num_neighbors or system_config.max_num_neighbors # Get atom embeddings for the model @@ -168,7 +165,7 @@ def state_to_atom_graphs( # noqa: PLR0915 atomic_numbers_embedding = atom_type_embedding.to(output_dtype) # Wrap positions into the central cell if needed - if wrap and (torch.any(row_vector_cell != 0) and torch.any(pbc)): + if wrap and (torch.any(row_vector_cell != 0) and torch.any(state.pbc)): positions = feat_util.batch_map_to_pbc_cell(positions, row_vector_cell, n_node) n_systems = state.system_idx.max().item() + 1 @@ -190,13 +187,13 @@ def state_to_atom_graphs( # noqa: PLR0915 atomic_numbers_per_system = atomic_numbers[system_mask] atomic_numbers_embedding_per_system = atomic_numbers_embedding[system_mask] cell_per_system = row_vector_cell[sys_idx] - pbc_per_system = pbc + pbc = state.pbc # Compute edges directly for this system edges, vectors, unit_shifts = feat_util.compute_pbc_radius_graph( positions=positions_per_system, cell=cell_per_system, - pbc=pbc_per_system, + pbc=pbc, radius=system_config.radius, max_number_neighbors=max_num_neighbors, edge_method=edge_method, @@ -230,7 +227,7 @@ def state_to_atom_graphs( # noqa: PLR0915 graph_feats = { "cell": cell_per_system, - "pbc": pbc_per_system, + "pbc": pbc, "lattice": lattice_per_system.to(device=positions_per_system.device), } diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index a9a1b97a..c641eebb 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -509,7 +509,7 @@ def __init__( device: torch.device | None = None, dtype: torch.dtype = torch.float64, *, # Force keyword-only arguments - pbc: bool = True, + pbc: torch.Tensor | bool = True, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, @@ -538,8 +538,9 @@ def __init__( device (torch.device | None): Device for computations. If None, uses CPU. Defaults to None. dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - pbc (bool): Whether to use periodic boundary conditions. Defaults to - True. + pbc (torch.Tensor | bool): Boolean tensor of shape (3,) indicating periodic + boundary conditions in each axis. If None, all axes are assumed to be + periodic. Defaults to True. compute_forces (bool): Whether to compute forces. Defaults to True. compute_stress (bool): Whether to compute stress tensor. Defaults to False. per_atom_energies (bool): Whether to compute per-atom energy decomposition. @@ -597,7 +598,7 @@ def __init__( super().__init__() self._device = device or torch.device("cpu") self._dtype = dtype - self.pbc = pbc + self.pbc = torch.tensor([pbc] * 3) if isinstance(pbc, bool) else pbc self._compute_forces = compute_forces self._compute_stress = compute_stress self.per_atom_energies = per_atom_energies @@ -714,7 +715,7 @@ def unbatched_forward( # noqa: PLR0915 cell=cell, pbc=self.pbc, cutoff=self.cutoff, - sorti=False, + sort_id=False, ) # Get displacements between neighbor pairs dr_vec, distances = transforms.get_pair_displacements( diff --git a/torch_sim/neighbors.py b/torch_sim/neighbors.py index 491d72cb..5562a38a 100644 --- a/torch_sim/neighbors.py +++ b/torch_sim/neighbors.py @@ -11,7 +11,7 @@ @torch.jit.script def primitive_neighbor_list( # noqa: C901, PLR0915 quantities: str, - pbc: tuple[bool, bool, bool], + pbc: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, cutoff: torch.Tensor, @@ -42,8 +42,8 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 between atom i and j). With the shift vector S, the distances D between atoms can be computed from: D = positions[j]-positions[i]+S.dot(cell) - pbc: 3-tuple indicating giving periodic boundaries in the three Cartesian - directions. + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. positions: Atomic positions. Anything that can be converted to an ndarray of @@ -411,7 +411,7 @@ def primitive_neighbor_list( # noqa: C901, PLR0915 def standard_nl( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -437,7 +437,8 @@ def standard_nl( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -462,7 +463,7 @@ def standard_nl( Notes: - The function uses primitive_neighbor_list internally but provides a simpler interface - - For non-periodic systems (pbc=False), shifts will be zero vectors + - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs for complete force computation - Memory usage scales with system size and number of neighbors per atom @@ -476,7 +477,7 @@ def standard_nl( quantities="ijS", positions=positions, cell=cell, - pbc=(pbc, pbc, pbc), + pbc=pbc, cutoff=cutoff, device=device, dtype=dtype, @@ -501,7 +502,7 @@ def standard_nl( def vesin_nl_ts( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -515,7 +516,8 @@ def vesin_nl_ts( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -533,7 +535,7 @@ def vesin_nl_ts( - Uses VesinNeighborListTorch for TorchScript compatibility - Requires CPU tensors in float64 precision internally - Returns tensors on the same device as input with original precision - - For non-periodic systems (pbc=False), shifts will be zero vectors + - For non-periodic systems, shifts will be zero vectors - The neighbor list includes both (i,j) and (j,i) pairs References: @@ -547,12 +549,13 @@ def vesin_nl_ts( # Convert tensors to CPU and float64 properly positions_cpu = positions.cpu().to(dtype=torch.float64) cell_cpu = cell.cpu().to(dtype=torch.float64) + periodic_cpu = pbc.to(dtype=torch.bool).cpu() # Only works on CPU and requires float64 i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, - periodic=pbc, + periodic=periodic_cpu, quantities="ijS", ) @@ -571,7 +574,7 @@ def vesin_nl_ts( def vesin_nl( positions: torch.Tensor, cell: torch.Tensor, - pbc: bool, # noqa: FBT001 + pbc: torch.Tensor, cutoff: float | torch.Tensor, sort_id: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -585,7 +588,8 @@ def vesin_nl( positions: Atomic positions tensor of shape (num_atoms, 3) cell: Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc: Whether to use periodic boundary conditions (applied to all directions) + pbc: Boolean tensor of shape (3,) indicating periodic boundary conditions in + each axis. cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors sort_id: If True, sort neighbors by first atom index for better memory access patterns @@ -618,12 +622,13 @@ def vesin_nl( # Convert tensors to CPU and float64 without gradients positions_cpu = positions.detach().cpu().to(dtype=torch.float64) cell_cpu = cell.detach().cpu().to(dtype=torch.float64) + periodic_cpu = pbc.detach().to(dtype=torch.bool).cpu() # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, - periodic=pbc, + periodic=periodic_cpu, quantities="ijS", ) i, j = ( @@ -778,9 +783,9 @@ def torch_nl_linked_cell( positions (torch.Tensor [n_atom, 3]): A tensor containing the positions of atoms wrapped inside their respective unit cells. - cell (torch.Tensor [3*n_structure, 3]): Unit cell vectors according to + cell (torch.Tensor [3*n_systems, 3]): Unit cell vectors according to the row vector convention, i.e. `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]`. - pbc (torch.Tensor [n_structure, 3] bool): + pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. Partial PBC are not supported yet. system_idx (torch.Tensor [n_atom,] torch.long): diff --git a/torch_sim/state.py b/torch_sim/state.py index ccaffb27..c59c2efd 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -47,7 +47,9 @@ class SimState: stored as `[[a1, b1, c1], [a2, b2, c2], [a3, b3, c3]]` as opposed to the row vector convention `[[a1, a2, a3], [b1, b2, b3], [c1, c2, c3]]` used by ASE. - pbc (bool): Boolean indicating whether to use periodic boundary conditions + pbc (bool | list[bool] | torch.Tensor): indicates periodic boundary + conditions in each axis. If a boolean is provided, all axes are + assumed to have the same periodic boundary conditions. atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. @@ -80,7 +82,7 @@ class SimState: positions: torch.Tensor masses: torch.Tensor cell: torch.Tensor - pbc: bool # TODO: do all calculators support mixed pbc? + pbc: torch.Tensor | list[bool] | bool atomic_numbers: torch.Tensor system_idx: torch.Tensor | None = field(default=None) @@ -91,6 +93,11 @@ def system_idx(self) -> torch.Tensor: """A getter for system_idx that tells type checkers it's always defined.""" return self.system_idx + @property + def pbc(self) -> torch.Tensor: + """A getter for pbc that tells type checkers it's always defined.""" + return self.pbc + _atom_attributes: ClassVar[set[str]] = { "positions", "masses", @@ -102,17 +109,6 @@ def system_idx(self) -> torch.Tensor: def __post_init__(self) -> None: """Initialize the SimState and validate the arguments.""" - # Validate and process the state after initialization. - # data validation and fill system_idx - # should make pbc a tensor here - # if devices aren't all the same, raise an error, in a clean way - devices = { - attr: getattr(self, attr).device - for attr in ("positions", "masses", "cell", "atomic_numbers") - } - if len(set(devices.values())) > 1: - raise ValueError("All tensors must be on the same device") - # Check that positions, masses and atomic numbers have compatible shapes shapes = [ getattr(self, attr).shape[0] @@ -125,6 +121,11 @@ def __post_init__(self) -> None: f"masses {shapes[1]}, atomic_numbers {shapes[2]}" ) + if isinstance(self.pbc, bool): + self.pbc = [self.pbc] * 3 + if not isinstance(self.pbc, torch.Tensor): + self.pbc = torch.tensor(self.pbc, dtype=torch.bool, device=self.device) + initial_system_idx = self.system_idx if initial_system_idx is None: self.system_idx = torch.zeros( @@ -146,6 +147,21 @@ def __post_init__(self) -> None: f"Cell must have shape (n_systems, 3, 3), got {self.cell.shape}" ) + # if devices aren't all the same, raise an error, in a clean way + devices = { + attr: getattr(self, attr).device + for attr in ( + "positions", + "masses", + "cell", + "atomic_numbers", + "pbc", + "system_idx", + ) + } + if len(set(devices.values())) > 1: + raise ValueError("All tensors must be on the same device") + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index 3220b24b..ab841a85 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -512,6 +512,20 @@ def write_arrays( self.flush() + def write_global_array(self, name: str, array: np.ndarray | torch.Tensor) -> None: + """Write a global array to the trajectory file. + + This function is used to write a global array to the trajectory file. + """ + if isinstance(array, torch.Tensor): + array = array.cpu().detach().numpy() + + steps = [0] + if name not in self.array_registry: + self._initialize_array(name, array) + self._validate_array(name, array, steps) + self._serialize_array(name, array, steps) + def _initialize_array(self, name: str, array: np.ndarray) -> None: """Initialize a single array and add it to the registry. @@ -736,6 +750,7 @@ def write_state( # noqa: C901 if len(sub_states) != len(steps): raise ValueError(f"{len(sub_states)=} must match the {len(steps)=}") + # Initialize data dictionary with required arrays data = { "positions": torch.stack([s.positions for s in state]), @@ -776,7 +791,7 @@ def write_state( # noqa: C901 self.write_arrays({"atomic_numbers": state[0].atomic_numbers}, 0) if "pbc" not in self.array_registry: - self.write_arrays({"pbc": np.array(state[0].pbc)}, 0) + self.write_global_array("pbc", state[0].pbc) # Write all arrays to file self.write_arrays(data, steps) @@ -818,6 +833,8 @@ def _get_state_arrays(self, frame: int) -> dict[str, np.ndarray]: arrays["positions"] = self.get_array("positions", start=frame, stop=frame + 1)[0] def return_prop(self: Self, prop: str, frame: int) -> np.ndarray: + if prop == "pbc": + return self.get_array(prop, start=0, stop=3) if getattr(self._file.root.data, prop).shape[0] > 1: # Variable prop start, stop = frame, frame + 1 else: # Static prop @@ -887,13 +904,11 @@ def get_atoms(self, frame: int = -1) -> "Atoms": arrays = self._get_state_arrays(frame) - pbc = arrays.get("pbc", True) - return Atoms( numbers=np.ascontiguousarray(arrays["atomic_numbers"]), positions=np.ascontiguousarray(arrays["positions"]), cell=np.ascontiguousarray(arrays["cell"])[0], - pbc=pbc, + pbc=np.ascontiguousarray(arrays["pbc"]), ) def get_state( @@ -925,7 +940,7 @@ def get_state( positions=torch.tensor(arrays["positions"], device=device, dtype=dtype), masses=torch.tensor(arrays.get("masses", None), device=device, dtype=dtype), cell=torch.tensor(arrays["cell"], device=device, dtype=dtype), - pbc=bool(arrays.get("pbc", True)), + pbc=torch.tensor(arrays["pbc"], device=device, dtype=torch.bool), atomic_numbers=torch.tensor( arrays["atomic_numbers"], device=device, dtype=torch.int ), diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index 53570a66..2ab4ab2e 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -156,7 +156,10 @@ def pbc_wrap_general( def pbc_wrap_batched( - positions: torch.Tensor, cell: torch.Tensor, system_idx: torch.Tensor + positions: torch.Tensor, + cell: torch.Tensor, + system_idx: torch.Tensor, + pbc: torch.Tensor | bool = True, # noqa: FBT002 ) -> torch.Tensor: """Apply periodic boundary conditions to batched systems. @@ -171,10 +174,16 @@ def pbc_wrap_batched( lattice vectors as column vectors. system_idx (torch.Tensor): Tensor of shape (n_atoms,) containing system indices for each atom. + pbc (torch.Tensor | bool): Tensor of shape (3,) containing boolean values + indicating whether periodic boundary conditions are applied in each dimension. + Can also be a bool. Defaults to True. Returns: torch.Tensor: Wrapped positions in real space with same shape as input positions. """ + if isinstance(pbc, bool): + pbc = torch.tensor([pbc, pbc, pbc], dtype=torch.bool, device=positions.device) + # Validate inputs if not torch.is_floating_point(positions) or not torch.is_floating_point(cell): raise TypeError("Positions and lattice vectors must be floating point tensors.") @@ -202,7 +211,8 @@ def pbc_wrap_batched( frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2) # Wrap to reference cell [0,1) using modulo - wrapped_frac = frac_coords % 1.0 + wrapped_frac = frac_coords.clone() + wrapped_frac[:, pbc] = frac_coords[:, pbc] % 1.0 # Transform back to real space: r = A·f # Get the cell for each atom based on its system index @@ -216,19 +226,22 @@ def minimum_image_displacement( *, dr: torch.Tensor, cell: torch.Tensor | None = None, - pbc: bool = True, + pbc: torch.Tensor | bool = True, ) -> torch.Tensor: """Apply minimum image convention to displacement vectors. Args: dr (torch.Tensor): Displacement vectors [N, 3] or [N, N, 3]. cell (Optional[torch.Tensor]): Unit cell matrix [3, 3]. - pbc (bool): Whether to apply periodic boundary conditions. + pbc (Optional[torch.Tensor]): Boolean tensor of shape (3,) indicating + periodic boundary conditions in each dimension. Returns: torch.Tensor: Minimum image displacement vectors with same shape as input. """ - if cell is None or not pbc: + if isinstance(pbc, bool): + pbc = torch.tensor([pbc] * 3, dtype=torch.bool, device=dr.device) + if cell is None or not pbc.any(): return dr # Convert to fractional coordinates @@ -236,7 +249,7 @@ def minimum_image_displacement( dr_frac = torch.einsum("ij,...j->...i", cell_inv, dr) # Apply minimum image convention - dr_frac -= torch.round(dr_frac) + dr_frac -= torch.where(pbc, torch.round(dr_frac), torch.zeros_like(dr_frac)) # Convert back to cartesian return torch.einsum("ij,...j->...i", cell, dr_frac) @@ -246,7 +259,7 @@ def get_pair_displacements( *, positions: torch.Tensor, cell: torch.Tensor | None = None, - pbc: bool = True, + pbc: torch.Tensor | bool = True, pairs: tuple[torch.Tensor, torch.Tensor] | None = None, shifts: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -255,7 +268,8 @@ def get_pair_displacements( Args: positions (torch.Tensor): Atomic positions [N, 3]. cell (Optional[torch.Tensor]): Unit cell matrix [3, 3]. - pbc (bool): Whether to apply periodic boundary conditions. + pbc (Optional[torch.Tensor]): Boolean tensor of shape (3,) indicating + periodic boundary conditions in each dimension. pairs (Optional[Tuple[torch.Tensor, torch.Tensor]]): (i, j) indices for specific pairs to compute. shifts (Optional[torch.Tensor]): Shift vectors for periodic images [n_pairs, 3]. @@ -265,13 +279,15 @@ def get_pair_displacements( - Displacement vectors [n_pairs, 3]. - Distances [n_pairs]. """ + if isinstance(pbc, bool): + pbc = torch.tensor([pbc] * 3, dtype=torch.bool, device=positions.device) if pairs is None: # Create full distance matrix ri = positions.unsqueeze(0) # [1, N, 3] rj = positions.unsqueeze(1) # [N, 1, 3] dr = rj - ri # [N, N, 3] - if cell is not None and pbc: + if cell is not None and pbc.any(): dr = minimum_image_displacement(dr=dr, cell=cell, pbc=pbc) # Calculate distances @@ -287,7 +303,7 @@ def get_pair_displacements( i, j = pairs dr = positions[j] - positions[i] # [n_pairs, 3] - if cell is not None and pbc: + if cell is not None and pbc.any(): if shifts is not None: # Apply provided shifts dr = dr + torch.einsum("ij,kj->ki", cell, shifts)