Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def test_compute_cell_force_atoms_per_system():

state = ts.NPTLangevinState(
positions=torch.zeros((72, 3)),
velocities=torch.zeros((72, 3)),
momenta=torch.zeros((72, 3)),
energy=torch.zeros(2),
forces=torch.zeros((72, 3)),
masses=torch.ones(72),
Expand Down
38 changes: 14 additions & 24 deletions torch_sim/integrators/npt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch_sim.typing import StateDict


@dataclass
class NPTLangevinState(SimState):
@dataclass(kw_only=True)
class NPTLangevinState(MDState):
"""State information for an NPT system with Langevin dynamics.

This class represents the complete state of a molecular system being integrated
Expand Down Expand Up @@ -57,7 +57,6 @@ class NPTLangevinState(SimState):
# System state variables
energy: torch.Tensor
forces: torch.Tensor
velocities: torch.Tensor
stress: torch.Tensor

# Cell variables
Expand All @@ -66,24 +65,14 @@ class NPTLangevinState(SimState):
cell_velocities: torch.Tensor
cell_masses: torch.Tensor

_atom_attributes = SimState._atom_attributes | { # noqa: SLF001
"forces",
"velocities",
}
_system_attributes = SimState._system_attributes | { # noqa: SLF001
_system_attributes = MDState._system_attributes | { # noqa: SLF001
"stress",
"cell_positions",
"cell_velocities",
"cell_masses",
"reference_cell",
"energy",
}

@property
def momenta(self) -> torch.Tensor:
"""Calculate momenta from velocities and masses."""
return self.velocities * self.masses.unsqueeze(-1)


def _npt_langevin_beta(
state: NPTLangevinState,
Expand All @@ -109,7 +98,7 @@ def _npt_langevin_beta(
torch.Tensor: Random noise term for force calculation [n_particles, n_dim]
"""
# Generate system-specific noise with correct shape
noise = torch.randn_like(state.velocities)
noise = torch.randn_like(state.momenta)

# Calculate the thermal noise amplitude by system
batch_kT = kT
Expand Down Expand Up @@ -370,7 +359,7 @@ def _npt_langevin_position_step(
)

# Generate atom-specific noise
noise = torch.randn_like(state.velocities)
noise = torch.randn_like(state.momenta)
batch_kT = kT
if kT.ndim == 0:
batch_kT = kT.expand(state.n_systems)
Expand Down Expand Up @@ -444,7 +433,7 @@ def _npt_langevin_velocity_step(
c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2

# Generate atom-specific noise
noise = torch.randn_like(state.velocities)
noise = torch.randn_like(state.momenta)
batch_kT = kT
if kT.ndim == 0:
batch_kT = kT.expand(state.n_systems)
Expand All @@ -457,8 +446,9 @@ def _npt_langevin_velocity_step(
# Random noise contribution
c_3 = b * noise_term / state.masses.unsqueeze(-1)

# Update velocities with all contributions
state.velocities = c_1 + c_2 + c_3
# Update momenta (velocities * masses) with all contributions
new_velocities = c_1 + c_2 + c_3
state.momenta = new_velocities * state.masses.unsqueeze(-1)
return state


Expand Down Expand Up @@ -550,7 +540,7 @@ def npt_langevin_init(
Args:
model (ModelInterface): Neural network model that computes energies, forces,
and stress. Must return a dict with 'energy', 'forces', and 'stress' keys.
state (SimState | StateDict): Either a SimState object or a dictionary
state (MDState | StateDict): Either a MDState object or a dictionary
containing positions, masses, cell, pbc
kT (torch.Tensor): Target temperature in energy units, either scalar or
with shape [n_systems]
Expand Down Expand Up @@ -630,7 +620,7 @@ def npt_langevin_init(
# Create the initial state
return NPTLangevinState(
positions=state.positions,
velocities=momenta / state.masses.unsqueeze(-1),
momenta=momenta,
energy=model_output["energy"],
forces=model_output["forces"],
stress=model_output["stress"],
Expand Down Expand Up @@ -755,7 +745,7 @@ def npt_langevin_step(
return _npt_langevin_velocity_step(state, forces, dt, kT, alpha)


@dataclass
@dataclass(kw_only=True)
class NPTNoseHooverState(MDState):
"""State information for an NPT system with Nose-Hoover chain thermostats.

Expand Down Expand Up @@ -1259,7 +1249,7 @@ def _npt_nose_hoover_inner_step(

# Update particle positions and forces
positions = _npt_nose_hoover_exp_iL1(
state, state.velocities, cell_momentum / cell_mass, dt
state, state.momenta / state.masses.unsqueeze(-1), cell_momentum / cell_mass, dt
)
state.positions = positions
state.cell = cell
Expand Down Expand Up @@ -1315,7 +1305,7 @@ def npt_nose_hoover_init(

Args:
model (ModelInterface): Model to compute forces and energies
state: Initial system state as SimState or dict containing positions, masses,
state: Initial system state as MDState or dict containing positions, masses,
cell, and PBC information
kT: Target temperature in energy units
external_pressure: Target external pressure
Expand Down