From ac18b88c83244a79a5a5103dd473b2cf9d403b1b Mon Sep 17 00:00:00 2001 From: hholb Date: Mon, 20 Oct 2025 09:59:43 -0600 Subject: [PATCH 1/6] NPTLangevinState inherits from MDState, update SimState references --- torch_sim/integrators/npt.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1ab4e7c3..76d054d4 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -15,12 +15,11 @@ construct_nose_hoover_chain, ) from torch_sim.models.interface import ModelInterface -from torch_sim.state import SimState from torch_sim.typing import StateDict @dataclass -class NPTLangevinState(SimState): +class NPTLangevinState(MDState): """State information for an NPT system with Langevin dynamics. This class represents the complete state of a molecular system being integrated @@ -66,11 +65,11 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = SimState._atom_attributes | { # noqa: SLF001 + _atom_attributes = MDState._atom_attributes | { # noqa: SLF001 "forces", "velocities", } - _system_attributes = SimState._system_attributes | { # noqa: SLF001 + _system_attributes = MDState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", "cell_velocities", @@ -529,7 +528,7 @@ def _compute_cell_force( def npt_langevin_init( - state: SimState | StateDict, + state: MDState | StateDict, model: ModelInterface, *, kT: torch.Tensor, @@ -550,7 +549,7 @@ def npt_langevin_init( Args: model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - state (SimState | StateDict): Either a SimState object or a dictionary + state (MDState | StateDict): Either a MDState object or a dictionary containing positions, masses, cell, pbc kT (torch.Tensor): Target temperature in energy units, either scalar or with shape [n_systems] @@ -593,8 +592,8 @@ def npt_langevin_init( if isinstance(b_tau, float): b_tau = torch.tensor(b_tau, device=device, dtype=dtype) - if not isinstance(state, SimState): - state = SimState(**state) + if not isinstance(state, MDState): + state = MDState(**state) # Get model output to initialize forces and stress model_output = model(state) @@ -1293,7 +1292,7 @@ def _npt_nose_hoover_inner_step( def npt_nose_hoover_init( - state: SimState | StateDict, + state: MDState | StateDict, model: ModelInterface, *, kT: torch.Tensor, @@ -1315,7 +1314,7 @@ def npt_nose_hoover_init( Args: model (ModelInterface): Model to compute forces and energies - state: Initial system state as SimState or dict containing positions, masses, + state: Initial system state as MDState or dict containing positions, masses, cell, and PBC information kT: Target temperature in energy units external_pressure: Target external pressure @@ -1367,8 +1366,8 @@ def npt_nose_hoover_init( dt, chain_length, chain_steps, sy_steps, t_tau ) - if not isinstance(state, SimState): - state = SimState(**state) + if not isinstance(state, MDState): + state = MDState(**state) _n_particles, dim = state.positions.shape n_systems = state.n_systems From 9032b8bc1d1f1c89d9734af7703af2cc805d7a87 Mon Sep 17 00:00:00 2001 From: hholb Date: Mon, 20 Oct 2025 11:02:02 -0600 Subject: [PATCH 2/6] revert changes to init functions --- torch_sim/integrators/npt.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 76d054d4..6e5a8e50 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -15,6 +15,7 @@ construct_nose_hoover_chain, ) from torch_sim.models.interface import ModelInterface +from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -528,7 +529,7 @@ def _compute_cell_force( def npt_langevin_init( - state: MDState | StateDict, + state: SimState | StateDict, model: ModelInterface, *, kT: torch.Tensor, @@ -592,8 +593,8 @@ def npt_langevin_init( if isinstance(b_tau, float): b_tau = torch.tensor(b_tau, device=device, dtype=dtype) - if not isinstance(state, MDState): - state = MDState(**state) + if not isinstance(state, SimState): + state = SimState(**state) # Get model output to initialize forces and stress model_output = model(state) @@ -1292,7 +1293,7 @@ def _npt_nose_hoover_inner_step( def npt_nose_hoover_init( - state: MDState | StateDict, + state: SimState | StateDict, model: ModelInterface, *, kT: torch.Tensor, @@ -1366,8 +1367,8 @@ def npt_nose_hoover_init( dt, chain_length, chain_steps, sy_steps, t_tau ) - if not isinstance(state, MDState): - state = MDState(**state) + if not isinstance(state, SimState): + state = SimState(**state) _n_particles, dim = state.positions.shape n_systems = state.n_systems From 1a9c8808adf4bd1ec200f1ed59d2935aa1f11439 Mon Sep 17 00:00:00 2001 From: hholb Date: Mon, 20 Oct 2025 11:02:53 -0600 Subject: [PATCH 3/6] Add kw_only=True to NPT state dataclasses When inheriting from MDState (which ultimately inherits from SimState with a field that has a default value), Python's dataclass mechanism requires all fields without defaults to come before fields with defaults. Using kw_only=True resolves this ordering constraint while maintaining backward compatibility since all existing instantiations already use keyword arguments. --- torch_sim/integrators/npt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 6e5a8e50..7d09c2f9 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -19,7 +19,7 @@ from torch_sim.typing import StateDict -@dataclass +@dataclass(kw_only=True) class NPTLangevinState(MDState): """State information for an NPT system with Langevin dynamics. @@ -755,7 +755,7 @@ def npt_langevin_step( return _npt_langevin_velocity_step(state, forces, dt, kT, alpha) -@dataclass +@dataclass(kw_only=True) class NPTNoseHooverState(MDState): """State information for an NPT system with Nose-Hoover chain thermostats. From cddb0cd761d0318134c50a84e498b5e0d5784d85 Mon Sep 17 00:00:00 2001 From: hholb Date: Mon, 20 Oct 2025 11:07:16 -0600 Subject: [PATCH 4/6] Update NPTLangevinState to use momenta instead of velocities NPTLangevinState now follows the same pattern as MDState (its parent class): - Store momenta as a field - Compute velocities as a property (inherited from MDState) This makes NPTLangevinState consistent with other MD integrator states and properly leverages the MDState inheritance hierarchy. Changes: - Remove velocities field from NPTLangevinState dataclass - Remove custom momenta property (now inherited from MDState) - Update all step functions to work with momenta: - Convert momenta to velocities where needed: momenta / masses.unsqueeze(-1) - Convert velocities back to momenta for updates: velocities * masses.unsqueeze(-1) - Update noise generation to match momenta shape - Update npt_langevin_init to pass momenta instead of velocities - Update test fixture to use momenta field --- tests/test_integrators.py | 2 +- torch_sim/integrators/npt.py | 30 ++++++++++++------------------ 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 31ccc75b..a9c7c49b 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -381,7 +381,7 @@ def test_compute_cell_force_atoms_per_system(): state = ts.NPTLangevinState( positions=torch.zeros((72, 3)), - velocities=torch.zeros((72, 3)), + momenta=torch.zeros((72, 3)), energy=torch.zeros(2), forces=torch.zeros((72, 3)), masses=torch.ones(72), diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 7d09c2f9..4d3d602a 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -57,7 +57,6 @@ class NPTLangevinState(MDState): # System state variables energy: torch.Tensor forces: torch.Tensor - velocities: torch.Tensor stress: torch.Tensor # Cell variables @@ -68,7 +67,6 @@ class NPTLangevinState(MDState): _atom_attributes = MDState._atom_attributes | { # noqa: SLF001 "forces", - "velocities", } _system_attributes = MDState._system_attributes | { # noqa: SLF001 "stress", @@ -79,11 +77,6 @@ class NPTLangevinState(MDState): "energy", } - @property - def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses.""" - return self.velocities * self.masses.unsqueeze(-1) - def _npt_langevin_beta( state: NPTLangevinState, @@ -109,7 +102,7 @@ def _npt_langevin_beta( torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ # Generate system-specific noise with correct shape - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) # Calculate the thermal noise amplitude by system batch_kT = kT @@ -370,7 +363,7 @@ def _npt_langevin_position_step( ) # Generate atom-specific noise - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -381,9 +374,8 @@ def _npt_langevin_position_step( noise_term = noise_prefactor.unsqueeze(-1) * noise # Velocity and force contributions with random noise - c_3 = ( - state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 - ) + velocities = state.momenta / state.masses.unsqueeze(-1) + c_3 = velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 # Update positions with all contributions state.positions = c_1 + c_2 * c_3 @@ -438,13 +430,14 @@ def _npt_langevin_velocity_step( b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2) # Velocity contribution with damping - c_1 = a * state.velocities + velocities = state.momenta / state.masses.unsqueeze(-1) + c_1 = a * velocities # Force contribution (average of initial and final forces) c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2 # Generate atom-specific noise - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -457,8 +450,9 @@ def _npt_langevin_velocity_step( # Random noise contribution c_3 = b * noise_term / state.masses.unsqueeze(-1) - # Update velocities with all contributions - state.velocities = c_1 + c_2 + c_3 + # Update momenta (velocities * masses) with all contributions + new_velocities = c_1 + c_2 + c_3 + state.momenta = new_velocities * state.masses.unsqueeze(-1) return state @@ -630,7 +624,7 @@ def npt_langevin_init( # Create the initial state return NPTLangevinState( positions=state.positions, - velocities=momenta / state.masses.unsqueeze(-1), + momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], @@ -1259,7 +1253,7 @@ def _npt_nose_hoover_inner_step( # Update particle positions and forces positions = _npt_nose_hoover_exp_iL1( - state, state.velocities, cell_momentum / cell_mass, dt + state, state.momenta / state.masses.unsqueeze(-1), cell_momentum / cell_mass, dt ) state.positions = positions state.cell = cell From a03919ea1d4881808851c30cc6b0dbe50747d54c Mon Sep 17 00:00:00 2001 From: hholb Date: Mon, 20 Oct 2025 15:16:22 -0600 Subject: [PATCH 5/6] remove redundant lines based on feedback --- torch_sim/integrators/npt.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4d3d602a..d9c81747 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -65,16 +65,12 @@ class NPTLangevinState(MDState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = MDState._atom_attributes | { # noqa: SLF001 - "forces", - } _system_attributes = MDState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", "cell_velocities", "cell_masses", "reference_cell", - "energy", } @@ -374,8 +370,9 @@ def _npt_langevin_position_step( noise_term = noise_prefactor.unsqueeze(-1) * noise # Velocity and force contributions with random noise - velocities = state.momenta / state.masses.unsqueeze(-1) - c_3 = velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 + c_3 = ( + state.velocities + dt_atoms.unsqueeze(-1) * state.forces / M_2 + noise_term / M_2 + ) # Update positions with all contributions state.positions = c_1 + c_2 * c_3 From d3094aa75433c6950d2d5ea7f620e174dd9db763 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Mon, 20 Oct 2025 17:27:38 -0400 Subject: [PATCH 6/6] Apply suggestion from @CompRhys Signed-off-by: Rhys Goodall --- torch_sim/integrators/npt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index d9c81747..d61eb44c 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -427,8 +427,7 @@ def _npt_langevin_velocity_step( b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2) # Velocity contribution with damping - velocities = state.momenta / state.masses.unsqueeze(-1) - c_1 = a * velocities + c_1 = a * state.velocities # Force contribution (average of initial and final forces) c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2