Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import torch

from torch_sim.integrators import calculate_momenta, npt_langevin, nve, nvt_langevin
from torch_sim.integrators import (
NPTLangevinState,
calculate_momenta,
npt_langevin,
nve,
nvt_langevin,
)
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.quantities import calc_kT
from torch_sim.state import SimState, concatenate_states
Expand Down Expand Up @@ -346,3 +352,39 @@ def test_compare_single_vs_batched_integrators(
assert torch.allclose(ar_single_state.forces, final_state.forces)
assert torch.allclose(ar_single_state.masses, final_state.masses)
assert torch.allclose(ar_single_state.cell, final_state.cell)


def test_compute_cell_force_atoms_per_batch():
"""Test that compute_cell_force correctly scales by number of atoms per batch.

Covers fix in https://github.com/Radical-AI/torch-sim/pull/153."""
from torch_sim.integrators import _compute_cell_force

# Setup minimal state with two batches having 8:1 atom ratio
s1, s2 = torch.zeros(8, dtype=torch.long), torch.ones(64, dtype=torch.long)

state = NPTLangevinState(
positions=torch.zeros((72, 3)),
velocities=torch.zeros((72, 3)),
energy=torch.zeros(2),
forces=torch.zeros((72, 3)),
masses=torch.ones(72),
cell=torch.eye(3).repeat(2, 1, 1),
pbc=True,
batch=torch.cat([s1, s2]),
atomic_numbers=torch.ones(72, dtype=torch.long),
stress=torch.zeros((2, 3, 3)),
reference_cell=torch.eye(3).repeat(2, 1, 1),
cell_positions=torch.ones((2, 3, 3)),
cell_velocities=torch.zeros((2, 3, 3)),
cell_masses=torch.ones(2),
)

# Get forces and compare ratio
cell_force = _compute_cell_force(state, torch.tensor(0.0), torch.tensor([1.0, 1.0]))
force_ratio = (
torch.diagonal(cell_force[1]).mean() / torch.diagonal(cell_force[0]).mean()
)

# Force ratio should match atom ratio (8:1) with the fix
assert abs(force_ratio - 8.0) / 8.0 < 0.1
2 changes: 1 addition & 1 deletion tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,4 +579,4 @@ def test_neighbor_lists_time_and_memory(
assert cpu_memory_used < 5e8, (
f"{fn_name} used too much CPU memory: {cpu_memory_used / 1e6:.2f}MB"
)
assert execution_time < 0.2, f"{fn_name} took too long: {execution_time}s"
assert execution_time < 0.8, f"{fn_name} took too long: {execution_time}s"
9 changes: 9 additions & 0 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def test_concatenate_si_and_fe_states(
)
assert torch.all(concatenated.batch == expected_batch)

# check n_atoms_per_batch
assert torch.all(
concatenated.n_atoms_per_batch
== torch.tensor(
[si_sim_state.n_atoms, fe_supercell_sim_state.n_atoms],
device=concatenated.device,
)
)

# Check that positions match for each original state
assert torch.allclose(concatenated.positions[:si_atoms], si_sim_state.positions)
assert torch.allclose(
Expand Down
93 changes: 60 additions & 33 deletions torch_sim/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ class MDState(SimState):

@property
def velocities(self) -> torch.Tensor:
"""Calculate velocities from momenta and masses.

Returns:
torch.Tensor: Velocities with shape [n_particles, n_dimensions]
"""Velocities calculated from momenta and masses with shape
[n_particles, n_dimensions].
"""
return self.momenta / self.masses.unsqueeze(-1)

Expand Down Expand Up @@ -563,6 +561,63 @@ def momenta(self) -> torch.Tensor:
return self.velocities * self.masses.unsqueeze(-1)


# Extracted out from npt_langevin body to test fix in https://github.com/Radical-AI/torch-sim/pull/153
def _compute_cell_force(
state: NPTLangevinState,
external_pressure: torch.Tensor,
kT: torch.Tensor,
) -> torch.Tensor:
"""Compute forces on the cell for NPT dynamics.

This function calculates the forces acting on the simulation cell
based on the difference between internal stress and external pressure,
plus a kinetic contribution. These forces drive the volume changes
needed to maintain constant pressure.

Args:
state (NPTLangevinState): Current NPT state
external_pressure (torch.Tensor): Target external pressure, either scalar or
tensor with shape [n_batches, n_dimensions, n_dimensions]
kT (torch.Tensor): Temperature in energy units, either scalar or
shape [n_batches]

Returns:
torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim]
"""
# Get current volumes for each batch
volumes = torch.linalg.det(state.cell) # shape: (n_batches,)

# Reshape for broadcasting
volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1)

# Create pressure tensor (diagonal with external pressure)
if external_pressure.ndim == 0:
# Scalar pressure - create diagonal pressure tensors for each batch
pressure_tensor = external_pressure * torch.eye(
3, device=state.device, dtype=state.dtype
)
pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1)
else:
# Already a tensor with shape compatible with n_batches
pressure_tensor = external_pressure

# Calculate virials from stress and external pressure
# Internal stress is negative of virial tensor divided by volume
virial = -volumes * state.stress + pressure_tensor * volumes

# Add kinetic contribution (kT * Identity)
batch_kT = kT
if kT.ndim == 0:
batch_kT = kT.expand(state.n_batches)

e_kin_per_atom = batch_kT.view(-1, 1, 1) * torch.eye(
3, device=state.device, dtype=state.dtype
).unsqueeze(0)

# Correct implementation with scaling by n_atoms_per_batch
return virial + e_kin_per_atom * state.n_atoms_per_batch.view(-1, 1, 1)


def npt_langevin( # noqa: C901, PLR0915
model: torch.nn.Module,
*,
Expand Down Expand Up @@ -748,35 +803,7 @@ def compute_cell_force(
Returns:
torch.Tensor: Force acting on the cell [n_batches, n_dim, n_dim]
"""
# Get current volumes for each batch
volumes = torch.linalg.det(state.cell) # shape: (n_batches,)

# Reshape for broadcasting
volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1)

# Create pressure tensor (diagonal with external pressure)
if external_pressure.ndim == 0:
# Scalar pressure - create diagonal pressure tensors for each batch
pressure_tensor = external_pressure * torch.eye(3, device=device, dtype=dtype)
pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1)
else:
# Already a tensor with shape compatible with n_batches
pressure_tensor = external_pressure

# Calculate virials from stress and external pressure
# Internal stress is negative of virial tensor divided by volume
virial = -volumes * state.stress + pressure_tensor * volumes

# Add kinetic contribution (kT * Identity)
batch_kT = kT
if kT.ndim == 0:
batch_kT = kT.expand(state.n_batches)

kinetic_term = batch_kT.view(-1, 1, 1) * torch.eye(
3, device=device, dtype=dtype
).unsqueeze(0)

return virial + kinetic_term
return _compute_cell_force(state, external_pressure, kT)

def cell_position_step(
state: NPTLangevinState,
Expand Down
27 changes: 4 additions & 23 deletions torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def __init__(

@property
def device(self) -> torch.device:
"""The device of the model.

Returns:
The device of the model
"""
"""The device of the model."""
return self._device

@device.setter
Expand All @@ -120,11 +116,7 @@ def device(self, device: torch.device) -> None:

@property
def dtype(self) -> torch.dtype:
"""The data type of the model.

Returns:
The data type of the model
"""
"""The data type of the model."""
return self._dtype

@dtype.setter
Expand All @@ -136,11 +128,7 @@ def dtype(self, dtype: torch.dtype) -> None:

@property
def compute_stress(self) -> bool:
"""Whether the model computes stresses.

Returns:
Whether the model computes stresses
"""
"""Whether the model computes stresses."""
return self._compute_stress

@compute_stress.setter
Expand All @@ -152,11 +140,7 @@ def compute_stress(self, compute_stress: bool) -> None:

@property
def compute_forces(self) -> bool:
"""Whether the model computes forces.

Returns:
Whether the model computes forces
"""
"""Whether the model computes forces."""
return self._compute_forces

@compute_forces.setter
Expand All @@ -173,9 +157,6 @@ def memory_scales_with(self) -> Literal["n_atoms", "n_atoms_x_density"]:
Models with radial neighbor cutoffs scale with "n_atoms_x_density",
while models with a fixed number of neighbors scale with "n_atoms".
Default is "n_atoms_x_density" because most models are radial cutoff based.

Returns:
The metric that the model scales with
"""
return getattr(self, "_memory_scales_with", "n_atoms_x_density")

Expand Down
12 changes: 2 additions & 10 deletions torch_sim/properties/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def get_array(self) -> torch.Tensor:

@property
def is_full(self) -> bool:
"""Check if the buffer is full.

Returns:
True if buffer contains size elements, False otherwise
"""
"""Check if the buffer is full."""
return self.count == self.size


Expand Down Expand Up @@ -475,9 +471,5 @@ def __call__(self, state: SimState, _: Any = None) -> torch.Tensor:

@property
def vacf(self) -> torch.Tensor | None:
"""Get current VACF result.

Returns:
Current VACF if available
"""
"""Current VACF result."""
return self._avg
59 changes: 18 additions & 41 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,68 +134,49 @@ def __post_init__(self) -> None:

@property
def wrap_positions(self) -> torch.Tensor:
"""Get positions wrapped into the primary unit cell.

Returns:
torch.Tensor: Atomic positions wrapped according to periodic boundary
conditions if pbc=True, otherwise returns unwrapped positions with
shape (n_atoms, 3).
"""Atomic positions wrapped according to periodic boundary conditions if pbc=True,
otherwise returns unwrapped positions with shape (n_atoms, 3).
"""
# TODO: implement a wrapping method
return self.positions

@property
def device(self) -> torch.device:
"""Get the device of the positions tensor.

Returns:
torch.device: The device where the tensor data is located
"""
"""The device where the tensor data is located."""
return self.positions.device

@property
def dtype(self) -> torch.dtype:
"""Get the data type of the positions tensor.

Returns:
torch.dtype: The data type of the positions tensor
"""
"""The data type of the positions tensor."""
return self.positions.dtype

@property
def n_atoms(self) -> int:
"""Get the total number of atoms in the system across all batches.

Returns:
int: Total number of atoms in the system
"""
"""Total number of atoms in the system across all batches."""
return self.positions.shape[0]

@property
def n_batches(self) -> int:
"""Get the number of batches in the system.
def n_atoms_per_batch(self) -> torch.Tensor:
"""Number of atoms per batch."""
return (
self.batch.bincount()
if self.batch is not None
else torch.tensor([self.n_atoms], device=self.device)
)

Returns:
int: Number of batches in the system
"""
@property
def n_batches(self) -> int:
"""Number of batches in the system."""
return torch.unique(self.batch).shape[0]

@property
def volume(self) -> torch.Tensor:
"""Get the volume of the system.

Returns:
torch.Tensor: Volume of the system with shape (n_batches,)
"""
"""Volume of the system."""
return torch.det(self.cell) if self.pbc else None

@property
def column_vector_cell(self) -> torch.Tensor:
"""Get the unit cell following the column vector convention.

Returns:
The unit cell in a column vector format
"""
"""Unit cell following the column vector convention."""
return self.cell

@column_vector_cell.setter
Expand All @@ -209,11 +190,7 @@ def column_vector_cell(self, value: torch.Tensor) -> None:

@property
def row_vector_cell(self) -> torch.Tensor:
"""Get the unit cell following the row vector convention.

Returns:
The unit cell in a row vector format
"""
"""Unit cell following the row vector convention."""
return self.cell.transpose(-2, -1)

@row_vector_cell.setter
Expand Down
Loading