diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 31ccc75b..a9c7c49b 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -381,7 +381,7 @@ def test_compute_cell_force_atoms_per_system(): state = ts.NPTLangevinState( positions=torch.zeros((72, 3)), - velocities=torch.zeros((72, 3)), + momenta=torch.zeros((72, 3)), energy=torch.zeros(2), forces=torch.zeros((72, 3)), masses=torch.ones(72), diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1ab4e7c3..d61eb44c 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -19,8 +19,8 @@ from torch_sim.typing import StateDict -@dataclass -class NPTLangevinState(SimState): +@dataclass(kw_only=True) +class NPTLangevinState(MDState): """State information for an NPT system with Langevin dynamics. This class represents the complete state of a molecular system being integrated @@ -57,7 +57,6 @@ class NPTLangevinState(SimState): # System state variables energy: torch.Tensor forces: torch.Tensor - velocities: torch.Tensor stress: torch.Tensor # Cell variables @@ -66,24 +65,14 @@ class NPTLangevinState(SimState): cell_velocities: torch.Tensor cell_masses: torch.Tensor - _atom_attributes = SimState._atom_attributes | { # noqa: SLF001 - "forces", - "velocities", - } - _system_attributes = SimState._system_attributes | { # noqa: SLF001 + _system_attributes = MDState._system_attributes | { # noqa: SLF001 "stress", "cell_positions", "cell_velocities", "cell_masses", "reference_cell", - "energy", } - @property - def momenta(self) -> torch.Tensor: - """Calculate momenta from velocities and masses.""" - return self.velocities * self.masses.unsqueeze(-1) - def _npt_langevin_beta( state: NPTLangevinState, @@ -109,7 +98,7 @@ def _npt_langevin_beta( torch.Tensor: Random noise term for force calculation [n_particles, n_dim] """ # Generate system-specific noise with correct shape - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) # Calculate the thermal noise amplitude by system batch_kT = kT @@ -370,7 +359,7 @@ def _npt_langevin_position_step( ) # Generate atom-specific noise - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -444,7 +433,7 @@ def _npt_langevin_velocity_step( c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2 # Generate atom-specific noise - noise = torch.randn_like(state.velocities) + noise = torch.randn_like(state.momenta) batch_kT = kT if kT.ndim == 0: batch_kT = kT.expand(state.n_systems) @@ -457,8 +446,9 @@ def _npt_langevin_velocity_step( # Random noise contribution c_3 = b * noise_term / state.masses.unsqueeze(-1) - # Update velocities with all contributions - state.velocities = c_1 + c_2 + c_3 + # Update momenta (velocities * masses) with all contributions + new_velocities = c_1 + c_2 + c_3 + state.momenta = new_velocities * state.masses.unsqueeze(-1) return state @@ -550,7 +540,7 @@ def npt_langevin_init( Args: model (ModelInterface): Neural network model that computes energies, forces, and stress. Must return a dict with 'energy', 'forces', and 'stress' keys. - state (SimState | StateDict): Either a SimState object or a dictionary + state (MDState | StateDict): Either a MDState object or a dictionary containing positions, masses, cell, pbc kT (torch.Tensor): Target temperature in energy units, either scalar or with shape [n_systems] @@ -630,7 +620,7 @@ def npt_langevin_init( # Create the initial state return NPTLangevinState( positions=state.positions, - velocities=momenta / state.masses.unsqueeze(-1), + momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], stress=model_output["stress"], @@ -755,7 +745,7 @@ def npt_langevin_step( return _npt_langevin_velocity_step(state, forces, dt, kT, alpha) -@dataclass +@dataclass(kw_only=True) class NPTNoseHooverState(MDState): """State information for an NPT system with Nose-Hoover chain thermostats. @@ -1259,7 +1249,7 @@ def _npt_nose_hoover_inner_step( # Update particle positions and forces positions = _npt_nose_hoover_exp_iL1( - state, state.velocities, cell_momentum / cell_mass, dt + state, state.momenta / state.masses.unsqueeze(-1), cell_momentum / cell_mass, dt ) state.positions = positions state.cell = cell @@ -1315,7 +1305,7 @@ def npt_nose_hoover_init( Args: model (ModelInterface): Model to compute forces and energies - state: Initial system state as SimState or dict containing positions, masses, + state: Initial system state as MDState or dict containing positions, masses, cell, and PBC information kT: Target temperature in energy units external_pressure: Target external pressure