diff --git a/tests/test_integrators.py b/tests/test_integrators.py index d7044f62..2bfd3669 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -2,7 +2,13 @@ import torch -from torch_sim.integrators import calculate_momenta, npt_langevin, nve, nvt_langevin +from torch_sim.integrators import ( + NPTLangevinState, + calculate_momenta, + npt_langevin, + nve, + nvt_langevin, +) from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT from torch_sim.state import SimState, concatenate_states @@ -346,3 +352,39 @@ def test_compare_single_vs_batched_integrators( assert torch.allclose(ar_single_state.forces, final_state.forces) assert torch.allclose(ar_single_state.masses, final_state.masses) assert torch.allclose(ar_single_state.cell, final_state.cell) + + +def test_compute_cell_force_atoms_per_batch(): + """Test that compute_cell_force correctly scales by number of atoms per batch. + + Covers fix in https://github.com/Radical-AI/torch-sim/pull/153.""" + from torch_sim.integrators import _compute_cell_force + + # Setup minimal state with two batches having 8:1 atom ratio + s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long) + + state = NPTLangevinState( + positions=torch.zeros((72, 3)), + velocities=torch.zeros((72, 3)), + energy=torch.zeros(2), + forces=torch.zeros((72, 3)), + masses=torch.ones(72), + cell=torch.eye(3).repeat(2, 1, 1), + pbc=True, + batch=torch.cat([s1, s2]), + atomic_numbers=torch.ones(72, dtype=torch.long), + stress=torch.zeros((2, 3, 3)), + reference_cell=torch.eye(3).repeat(2, 1, 1), + cell_positions=torch.ones((2, 3, 3)), + cell_velocities=torch.zeros((2, 3, 3)), + cell_masses=torch.ones(2), + ) + + # Get forces and compare ratio + cell_force = _compute_cell_force(state, torch.tensor(0.0), torch.tensor([1.0, 1.0])) + force_ratio = ( + torch.diagonal(cell_force[1]).mean() / torch.diagonal(cell_force[0]).mean() + ) + + # Force ratio should match atom ratio (8:1) with the fix + assert abs(force_ratio - 8.0) / 8.0 < 0.1 diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index a70dfd5d..4c9531ed 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -579,4 +579,4 @@ def test_neighbor_lists_time_and_memory( assert cpu_memory_used < 5e8, ( f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB" ) - assert execution_time < 0.2, f"{fn_name} took too long: {execution_time}s" + assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s" diff --git a/tests/test_state.py b/tests/test_state.py index d09ac075..26fec0ea 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -153,6 +153,15 @@ def test_concatenate_si_and_fe_states( ) assert torch.all(concatenated.batch == expected_batch) + # check n_atoms_per_batch + assert torch.all( + concatenated.n_atoms_per_batch + == torch.tensor( + [si_sim_state.n_atoms, fe_supercell_sim_state.n_atoms], + device=concatenated.device, + ) + ) + # Check that positions match for each original state assert torch.allclose(concatenated.positions[:si_atoms], si_sim_state.positions) assert torch.allclose( diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 451e5549..8bca00aa 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -61,10 +61,8 @@ class MDState(SimState): @property def velocities(self) -> torch.Tensor: - """Calculate velocities from momenta and masses. - - Returns: - torch.Tensor: Velocities with shape [n_particles, n_dimensions] + """Velocities calculated from momenta and masses with shape + [n_particles, n_dimensions]. """ return self.momenta / self.masses.unsqueeze(-1) @@ -563,6 +561,63 @@ def momenta(self) -> torch.Tensor: return self.velocities * self.masses.unsqueeze(-1) +# Extracted out from npt_langevin body to test fix in https://github.com/Radical-AI/torch-sim/pull/153 +def _compute_cell_force( + state: NPTLangevinState, + external_pressure: torch.Tensor, + kT: torch.Tensor, +) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the forces acting on the simulation cell + based on the difference between internal stress and external pressure, + plus a kinetic contribution. These forces drive the volume changes + needed to maintain constant pressure. + + Args: + state (NPTLangevinState): Current NPT state + external_pressure (torch.Tensor): Target external pressure, either scalar or + tensor with shape [n_batches, n_dimensions, n_dimensions] + kT (torch.Tensor): Temperature in energy units, either scalar or + shape [n_batches] + + Returns: + torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] + """ + # Get current volumes for each batch + volumes = torch.linalg.det(state.cell) # shape: (n_batches,) + + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1) + + # Create pressure tensor (diagonal with external pressure) + if external_pressure.ndim == 0: + # Scalar pressure - create diagonal pressure tensors for each batch + pressure_tensor = external_pressure * torch.eye( + 3, device=state.device, dtype=state.dtype + ) + pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1) + else: + # Already a tensor with shape compatible with n_batches + pressure_tensor = external_pressure + + # Calculate virials from stress and external pressure + # Internal stress is negative of virial tensor divided by volume + virial = -volumes * state.stress + pressure_tensor * volumes + + # Add kinetic contribution (kT * Identity) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_batches) + + e_kin_per_atom = batch_kT.view(-1, 1, 1) * torch.eye( + 3, device=state.device, dtype=state.dtype + ).unsqueeze(0) + + # Correct implementation with scaling by n_atoms_per_batch + return virial + e_kin_per_atom * state.n_atoms_per_batch.view(-1, 1, 1) + + def npt_langevin( # noqa: C901, PLR0915 model: torch.nn.Module, *, @@ -748,35 +803,7 @@ def compute_cell_force( Returns: torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim] """ - # Get current volumes for each batch - volumes = torch.linalg.det(state.cell) # shape: (n_batches,) - - # Reshape for broadcasting - volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1) - - # Create pressure tensor (diagonal with external pressure) - if external_pressure.ndim == 0: - # Scalar pressure - create diagonal pressure tensors for each batch - pressure_tensor = external_pressure * torch.eye(3, device=device, dtype=dtype) - pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1) - else: - # Already a tensor with shape compatible with n_batches - pressure_tensor = external_pressure - - # Calculate virials from stress and external pressure - # Internal stress is negative of virial tensor divided by volume - virial = -volumes * state.stress + pressure_tensor * volumes - - # Add kinetic contribution (kT * Identity) - batch_kT = kT - if kT.ndim == 0: - batch_kT = kT.expand(state.n_batches) - - kinetic_term = batch_kT.view(-1, 1, 1) * torch.eye( - 3, device=device, dtype=dtype - ).unsqueeze(0) - - return virial + kinetic_term + return _compute_cell_force(state, external_pressure, kT) def cell_position_step( state: NPTLangevinState, diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 4c4c8f29..2bf7c619 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -104,11 +104,7 @@ def __init__( @property def device(self) -> torch.device: - """The device of the model. - - Returns: - The device of the model - """ + """The device of the model.""" return self._device @device.setter @@ -120,11 +116,7 @@ def device(self, device: torch.device) -> None: @property def dtype(self) -> torch.dtype: - """The data type of the model. - - Returns: - The data type of the model - """ + """The data type of the model.""" return self._dtype @dtype.setter @@ -136,11 +128,7 @@ def dtype(self, dtype: torch.dtype) -> None: @property def compute_stress(self) -> bool: - """Whether the model computes stresses. - - Returns: - Whether the model computes stresses - """ + """Whether the model computes stresses.""" return self._compute_stress @compute_stress.setter @@ -152,11 +140,7 @@ def compute_stress(self, compute_stress: bool) -> None: @property def compute_forces(self) -> bool: - """Whether the model computes forces. - - Returns: - Whether the model computes forces - """ + """Whether the model computes forces.""" return self._compute_forces @compute_forces.setter @@ -173,9 +157,6 @@ def memory_scales_with(self) -> Literal["n_atoms", "n_atoms_x_density"]: Models with radial neighbor cutoffs scale with "n_atoms_x_density", while models with a fixed number of neighbors scale with "n_atoms". Default is "n_atoms_x_density" because most models are radial cutoff based. - - Returns: - The metric that the model scales with """ return getattr(self, "_memory_scales_with", "n_atoms_x_density") diff --git a/torch_sim/properties/correlations.py b/torch_sim/properties/correlations.py index 667125cf..6cfbe579 100644 --- a/torch_sim/properties/correlations.py +++ b/torch_sim/properties/correlations.py @@ -92,11 +92,7 @@ def get_array(self) -> torch.Tensor: @property def is_full(self) -> bool: - """Check if the buffer is full. - - Returns: - True if buffer contains size elements, False otherwise - """ + """Check if the buffer is full.""" return self.count == self.size @@ -475,9 +471,5 @@ def __call__(self, state: SimState, _: Any = None) -> torch.Tensor: @property def vacf(self) -> torch.Tensor | None: - """Get current VACF result. - - Returns: - Current VACF if available - """ + """Current VACF result.""" return self._avg diff --git a/torch_sim/state.py b/torch_sim/state.py index a7754881..12fa22e7 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -134,68 +134,49 @@ def __post_init__(self) -> None: @property def wrap_positions(self) -> torch.Tensor: - """Get positions wrapped into the primary unit cell. - - Returns: - torch.Tensor: Atomic positions wrapped according to periodic boundary - conditions if pbc=True, otherwise returns unwrapped positions with - shape (n_atoms, 3). + """Atomic positions wrapped according to periodic boundary conditions if pbc=True, + otherwise returns unwrapped positions with shape (n_atoms, 3). """ # TODO: implement a wrapping method return self.positions @property def device(self) -> torch.device: - """Get the device of the positions tensor. - - Returns: - torch.device: The device where the tensor data is located - """ + """The device where the tensor data is located.""" return self.positions.device @property def dtype(self) -> torch.dtype: - """Get the data type of the positions tensor. - - Returns: - torch.dtype: The data type of the positions tensor - """ + """The data type of the positions tensor.""" return self.positions.dtype @property def n_atoms(self) -> int: - """Get the total number of atoms in the system across all batches. - - Returns: - int: Total number of atoms in the system - """ + """Total number of atoms in the system across all batches.""" return self.positions.shape[0] @property - def n_batches(self) -> int: - """Get the number of batches in the system. + def n_atoms_per_batch(self) -> torch.Tensor: + """Number of atoms per batch.""" + return ( + self.batch.bincount() + if self.batch is not None + else torch.tensor([self.n_atoms], device=self.device) + ) - Returns: - int: Number of batches in the system - """ + @property + def n_batches(self) -> int: + """Number of batches in the system.""" return torch.unique(self.batch).shape[0] @property def volume(self) -> torch.Tensor: - """Get the volume of the system. - - Returns: - torch.Tensor: Volume of the system with shape (n_batches,) - """ + """Volume of the system.""" return torch.det(self.cell) if self.pbc else None @property def column_vector_cell(self) -> torch.Tensor: - """Get the unit cell following the column vector convention. - - Returns: - The unit cell in a column vector format - """ + """Unit cell following the column vector convention.""" return self.cell @column_vector_cell.setter @@ -209,11 +190,7 @@ def column_vector_cell(self, value: torch.Tensor) -> None: @property def row_vector_cell(self) -> torch.Tensor: - """Get the unit cell following the row vector convention. - - Returns: - The unit cell in a row vector format - """ + """Unit cell following the row vector convention.""" return self.cell.transpose(-2, -1) @row_vector_cell.setter diff --git a/torch_sim/trajectory.py b/torch_sim/trajectory.py index ad0f7c97..d709f631 100644 --- a/torch_sim/trajectory.py +++ b/torch_sim/trajectory.py @@ -171,12 +171,7 @@ def load_new_trajectories( @property def array_registry(self) -> dict[str, tuple[tuple[int, ...], np.dtype]]: - """Get the registry of array shapes and dtypes. - - Returns: - dict[str, tuple[tuple[int, ...], np.dtype]]: Dictionary mapping array names to - tuples of (shape, dtype) - """ + """Registry of array shapes and dtypes.""" # Return the registry from the first trajectory if self.trajectories: return self.trajectories[0].array_registry @@ -918,11 +913,7 @@ def get_state( @property def metadata(self) -> dict: - """Get the metadata for the trajectory. - - Returns: - dict: Metadata for the trajectory - """ + """Metadata for the trajectory.""" attrs = self._file.root.metadata._v_attrs return {name: getattr(attrs, name) for name in attrs._f_list()} diff --git a/torch_sim/unbatched/unbatched_integrators.py b/torch_sim/unbatched/unbatched_integrators.py index 78f0ed76..01c6c6ce 100644 --- a/torch_sim/unbatched/unbatched_integrators.py +++ b/torch_sim/unbatched/unbatched_integrators.py @@ -40,10 +40,8 @@ class MDState(SimState): @property def velocities(self) -> torch.Tensor: - """Calculate velocities from momenta and masses. - - Returns: - The velocities of the particles + """Velocities calculated from momenta and masses with shape + [n_particles, n_dimensions]. """ return self.momenta / self.masses.unsqueeze(-1) @@ -526,10 +524,8 @@ class NPTLangevinState(SimState): @property def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses. - - Returns: - The momenta of the particles + """Calculate momenta from velocities and masses with shape + [n_particles, n_dimensions]. """ return self.masses.unsqueeze(-1) * self.velocities @@ -1342,10 +1338,8 @@ class NVTNoseHooverState(MDState): @property def velocities(self) -> torch.Tensor: - """Calculate velocities from momenta and masses. - - Returns: - torch.Tensor: Particle velocities with shape [n_particles, n_dimensions] + """Velocities calculated from momenta and masses with shape + [n_particles, n_dimensions]. """ return self.momenta / self.masses.unsqueeze(-1)