diff --git a/examples/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py b/examples/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py new file mode 100644 index 00000000..c24a3c4b --- /dev/null +++ b/examples/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py @@ -0,0 +1,154 @@ +"""Lennard-Jones simulation in NPT ensemble using Nose-Hoover chain.""" + +import os + +import torch + +from torch_sim.quantities import kinetic_energy, temperature +from torch_sim.state import BaseState +from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel +from torch_sim.unbatched.unbatched_integrators import npt_langevin +from torch_sim.units import MetalUnits as Units +from torch_sim.units import UnitConversion + + +# Set up the device and data type +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.float32 + +# Set random seed and deterministic behavior for reproducibility +torch.manual_seed(42) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +# Set up the random number generator +generator = torch.Generator(device=device) +generator.manual_seed(42) # For reproducibility + +# Number of steps to run +N_steps = 100 if os.getenv("CI") else 10_000 + +# Create face-centered cubic (FCC) Argon +# 5.26 Å is a typical lattice constant for Ar +a_len = 5.26 # Lattice constant +PERIODIC = True # Flag to use periodic boundary conditions + +# Generate base FCC unit cell positions (scaled by lattice constant) +base_positions = torch.tensor( + [ + [0.0, 0.0, 0.0], # Corner + [0.0, 0.5, 0.5], # Face centers + [0.5, 0.0, 0.5], + [0.5, 0.5, 0.0], + ], + device=device, + dtype=dtype, +) + +# Create 4x4x4 supercell of FCC Argon manually +positions = [] +for i in range(4): + for j in range(4): + for k in range(4): + for base_pos in base_positions: + # Add unit cell position + offset for supercell + pos = base_pos + torch.tensor([i, j, k], device=device, dtype=dtype) + positions.append(pos) + +# Stack the positions into a tensor +positions = torch.stack(positions) + +# Scale by lattice constant +positions = positions * a_len + +# Create the cell tensor +cell = torch.tensor( + [[4 * a_len, 0, 0], [0, 4 * a_len, 0], [0, 0, 4 * a_len]], device=device, dtype=dtype +) + +# Create the atomic numbers tensor (Argon = 18) +atomic_numbers = torch.full((positions.shape[0],), 18, device=device, dtype=torch.int) +# Create the masses tensor (Argon = 39.948 amu) +masses = torch.full((positions.shape[0],), 39.948, device=device, dtype=dtype) + +# Initialize the Lennard-Jones model +# Parameters: +# - sigma: distance at which potential is zero (3.405 Å for Ar) +# - epsilon: depth of potential well (0.0104 eV for Ar) +# - cutoff: distance beyond which interactions are ignored (typically 2.5*sigma) +model = UnbatchedLennardJonesModel( + use_neighbor_list=False, + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + device=device, + dtype=dtype, + compute_force=True, + compute_stress=True, +) +state = BaseState( + positions=positions, + masses=masses, + cell=cell, + pbc=PERIODIC, + atomic_numbers=atomic_numbers, +) +# Run initial simulation and get results +results = model(state) + +dt = 0.001 * Units.time # Time step (1 ps) +kT = 200 * Units.temperature # Temperature (200 K) +target_pressure = 10000 * Units.pressure # Target pressure (10 kbar) + +npt_init, npt_update = npt_langevin( + model=model, + dt=dt, + kT=kT, + external_pressure=target_pressure, +) + +state = npt_init(state=state, seed=1) + + +def get_pressure( + stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 +) -> torch.Tensor: + """Compute the pressure from the stress tensor. + + The stress tensor is defined as 1/volume * dU/de_ij + So the pressure is -1/volume * trace(dU/de_ij) + """ + return 1 / (dim) * ((2 * kinetic_energy / volume) - torch.trace(stress)) + + +# Run the simulation +for step in range(N_steps): + if step % 50 == 0: + temp = temperature(masses=state.masses, momenta=state.momenta) / Units.temperature + pressure = get_pressure( + model(state)["stress"], + kinetic_energy(masses=state.masses, momenta=state.momenta), + torch.linalg.det(state.cell), + ) + pressure = pressure.item() / Units.pressure + xx, yy, zz = state.cell[0, 0], state.cell[1, 1], state.cell[2, 2] + print( + f"{step=}: Temperature: {temp:.4f}, " + f"{pressure=:.4f}, " + f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" + ) + state = npt_update(state, kT=kT, external_pressure=target_pressure) + +temp = temperature(masses=state.masses, momenta=state.momenta) / Units.temperature +print(f"Final temperature: {temp:.4f}") + + +stress = model(state)["stress"] +kinetic_energy = kinetic_energy(masses=state.masses, momenta=state.momenta) +volume = torch.linalg.det(state.cell) +pressure = get_pressure(stress, kinetic_energy, volume) +pressure = pressure.item() / Units.pressure +print(f"Final {pressure=:.4f}") +print(stress * UnitConversion.eV_per_Ang3_to_GPa) diff --git a/examples/3_Dynamics/3.12_MACE_NPT_Langevin.py b/examples/3_Dynamics/3.12_MACE_NPT_Langevin.py new file mode 100644 index 00000000..12ece8ba --- /dev/null +++ b/examples/3_Dynamics/3.12_MACE_NPT_Langevin.py @@ -0,0 +1,144 @@ +"""NPT simulation with MACE and Nose-Hoover thermostat.""" + +# /// script +# dependencies = [ +# "mace-torch>=0.3.11", +# ] +# /// + +import os + +import torch +from ase.build import bulk +from mace.calculators.foundations_models import mace_mp + +from torch_sim.neighbors import vesin_nl_ts +from torch_sim.quantities import kinetic_energy, temperature +from torch_sim.state import BaseState +from torch_sim.unbatched.models.mace import UnbatchedMaceModel +from torch_sim.unbatched.unbatched_integrators import ( + npt_langevin, + nvt_nose_hoover, + nvt_nose_hoover_invariant, +) +from torch_sim.units import MetalUnits as Units + + +# Set device and data type +device = "cuda" if torch.cuda.is_available() else "cpu" +dtype = torch.float32 + +# Option 1: Load the raw model from the downloaded model +mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model" +loaded_model = mace_mp( + model=mace_checkpoint_url, + return_raw_model=True, + default_dtype=dtype, + device=device, +) + +# Option 2: Load from local file (comment out Option 1 to use this) +# MODEL_PATH = "../../../checkpoints/MACE/mace-mpa-0-medium.model" +# loaded_model = torch.load(MODEL_PATH, map_location=device) + +PERIODIC = True + +# Create diamond cubic Silicon +si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) + +# Prepare input tensors +positions = torch.tensor(si_dc.positions, device=device, dtype=dtype) +cell = torch.tensor(si_dc.cell.array, device=device, dtype=dtype) +atomic_numbers = torch.tensor(si_dc.get_atomic_numbers(), device=device, dtype=torch.int) +masses = torch.tensor(si_dc.get_masses(), device=device, dtype=dtype) + +# Print shapes for verification +print(f"Positions: {positions.shape}") +print(f"Cell: {cell.shape}") + +# Initialize the unbatched MACE model +model = UnbatchedMaceModel( + model=loaded_model, + device=device, + neighbor_list_fn=vesin_nl_ts, + periodic=PERIODIC, + compute_force=True, + compute_stress=True, + dtype=dtype, + enable_cueq=False, +) +state = BaseState( + positions=positions, + masses=masses, + cell=cell, + pbc=PERIODIC, + atomic_numbers=atomic_numbers, +) +# Run initial inference +results = model(state) + +N_steps_nvt = 20 if os.getenv("CI") else 2_000 +N_steps_npt = 20 if os.getenv("CI") else 2_000 +dt = 0.001 * Units.time # Time step (1 ps) +kT = 300 * Units.temperature # Initial temperature (300 K) +target_pressure = 10000 * Units.pressure # Target pressure (0 bar) + +nvt_init, nvt_update = nvt_nose_hoover(model=model, kT=kT, dt=dt) +state = nvt_init(state=state, seed=1) + +for step in range(N_steps_nvt): + if step % 10 == 0: + temp = temperature(masses=state.masses, momenta=state.momenta) / Units.temperature + invariant = nvt_nose_hoover_invariant(state, kT=kT).item() + print(f"{step=}: Temperature: {temp:.4f}: invariant: {invariant:.4f}, ") + state = nvt_update(state, kT=kT) + +npt_init, npt_update = npt_langevin( + model=model, kT=kT, dt=dt, external_pressure=target_pressure +) +state = npt_init(state=state, seed=1) + + +def get_pressure( + stress: torch.Tensor, kinetic_energy: torch.Tensor, volume: torch.Tensor, dim: int = 3 +) -> torch.Tensor: + """Compute the pressure from the stress tensor. + + The stress tensor is defined as 1/volume * dU/de_ij + So the pressure is -1/volume * trace(dU/de_ij) + """ + return 1 / dim * ((2 * kinetic_energy / volume) - torch.trace(stress)) + + +for step in range(N_steps_npt): + if step % 10 == 0: + temp = temperature(masses=state.masses, momenta=state.momenta) / Units.temperature + stress = model(state)["stress"] + volume = torch.det(state.cell) + pressure = ( + get_pressure( + stress, kinetic_energy(masses=state.masses, momenta=state.momenta), volume + ).item() + / Units.pressure + ) + xx, yy, zz = torch.diag(state.cell) + print( + f"{step=}: Temperature: {temp:.4f}, " + f"pressure: {pressure:.4f}, " + f"cell xx yy zz: {xx.item():.4f}, {yy.item():.4f}, {zz.item():.4f}" + ) + state = npt_update(state, kT=kT, external_pressure=target_pressure) + +final_temp = temperature(masses=state.masses, momenta=state.momenta) / Units.temperature +print(f"Final temperature: {final_temp:.4f} K") +final_stress = model(state)["stress"] +final_volume = torch.det(state.cell) +final_pressure = ( + get_pressure( + final_stress, + kinetic_energy(masses=state.masses, momenta=state.momenta), + final_volume, + ).item() + / Units.pressure +) +print(f"Final pressure: {final_pressure:.4f} bar") diff --git a/tests/test_integrators.py b/tests/test_integrators.py index ad6c1df3..15a0de09 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -2,7 +2,13 @@ import torch -from torch_sim.integrators import MDState, batched_initialize_momenta, nve, nvt_langevin +from torch_sim.integrators import ( + MDState, + batched_initialize_momenta, + npt_langevin, + nve, + nvt_langevin, +) from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import temperature from torch_sim.state import BaseState, concatenate_states @@ -167,6 +173,67 @@ def test_batched_initialize_momenta(): assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype), atol=1e-10) +def test_npt_langevin(ar_double_base_state: BaseState, lj_calculator: LennardJonesModel): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor(100.0, dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_langevin( + model=lj_calculator, + dt=dt, + kT=kT, + external_pressure=external_pressure, + alpha=40 * dt, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_base_state, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = temperature(state.momenta, state.masses, batch=state.batch) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + for mean_temp in mean_temps: + assert ( + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + ) # Allow for thermal fluctuations + + # Check energy is stable for each trajectory + for traj in energies_list: + energy_std = torch.tensor(traj).std() + assert energy_std < 1.0 # Adjust threshold as needed + + # Check positions and momenta have correct shapes + n_atoms = 8 + + # Verify the two systems remain distinct + pos_diff = torch.norm( + state.positions[:n_atoms].mean(0) - state.positions[n_atoms:].mean(0) + ) + assert pos_diff > 0.0001 # Systems should remain separated + + def test_nvt_langevin(ar_double_base_state: BaseState, lj_calculator: LennardJonesModel): dtype = torch.float64 n_steps = 100 @@ -191,7 +258,7 @@ def test_nvt_langevin(ar_double_base_state: BaseState, lj_calculator: LennardJon temp = temperature(state.momenta, state.masses, batch=state.batch) energies.append(state.energy) - temperatures.append(temp * 11606) + temperatures.append(temp / MetalUnits.temperature) # Convert temperatures list to tensor temperatures_tensor = torch.stack(temperatures) @@ -208,7 +275,7 @@ def test_nvt_langevin(ar_double_base_state: BaseState, lj_calculator: LennardJon mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory for mean_temp in mean_temps: assert ( - abs(mean_temp - kT.item() * 11606) < 150.0 + abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 ) # Allow for thermal fluctuations # Check energy is stable for each trajectory diff --git a/tests/unbatched/test_unbatched_integrators.py b/tests/unbatched/test_unbatched_integrators.py index c0c38fda..1a4c50fb 100644 --- a/tests/unbatched/test_unbatched_integrators.py +++ b/tests/unbatched/test_unbatched_integrators.py @@ -8,6 +8,7 @@ from torch_sim.unbatched.unbatched_integrators import ( MDState, calculate_momenta, + npt_langevin, npt_nose_hoover, npt_nose_hoover_invariant, nve, @@ -112,6 +113,53 @@ def test_nvt_nose_hoover_integrator( assert state.chain.positions.shape[0] == 3, "Should have 3 chain thermostats" +def test_npt_langevin_integrator( + ar_base_state: BaseState, unbatched_lj_calculator: Any +) -> None: + """Test Langevin thermostat maintains target temperature.""" + # Initialize integrator + target_temp = torch.tensor(100.0) * MetalUnits.temperature + dt = torch.tensor(0.001) * MetalUnits.time + external_pressure = torch.tensor(10000) * MetalUnits.pressure + n_steps = 4000 + dim = ar_base_state.positions.shape[1] + langevin_init, langevin_update = npt_langevin( + model=unbatched_lj_calculator, + dt=dt, + kT=target_temp, + external_pressure=external_pressure, + alpha=40 * dt, + ) + + # Remove batch dimension from cell + ar_base_state.cell = ar_base_state.cell.squeeze(0) + + state = langevin_init(state=ar_base_state, seed=42) + # Run equilibration + temperatures = torch.zeros(n_steps) + pressures = torch.zeros(n_steps) + for step in range(n_steps): + state = langevin_update(state, target_temp) + temp = temperature(state.momenta, state.masses) / MetalUnits.temperature + volume = torch.linalg.det(state.cell) + pressures[step] = ( + 1 + / (dim) + * ( + (2 * kinetic_energy(state.momenta, state.masses) / volume) + - torch.trace(state.stress) + ) + / MetalUnits.pressure + ) + temperatures[step] = temp + + average_temperature = torch.mean(temperatures[2000:]) + average_pressure = torch.mean(pressures[2000:]) + # Check temperature control + assert 120 > average_temperature > 80, "Temperature should be maintained" + assert 12000 > average_pressure > 8000, "Pressure should be maintained" + + def test_integrator_state_properties( ar_base_state: BaseState, unbatched_lj_calculator: Any ) -> None: diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 40b98572..093b11ff 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -174,42 +174,6 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def stochastic_step( - state: MDState, - dt: torch.Tensor, - kT: torch.Tensor, - gamma: torch.Tensor, -) -> MDState: - """Apply stochastic noise and friction for Langevin dynamics. - - This function implements the stochastic part of Langevin dynamics by applying - random noise and friction forces to particle momenta. The noise amplitude is - chosen to maintain the target temperature kT. - - Args: - state: Current system state containing positions, momenta, etc. - dt: Integration timestep - kT: Target temperature in energy units - gamma: Friction coefficient controlling noise strength - - Returns: - Updated state with new momenta after stochastic step - - Notes: - - Uses Ornstein-Uhlenbeck process for correct thermal sampling - - Noise amplitude scales with sqrt(mass) for equipartition - - Preserves detailed balance through fluctuation-dissipation relation - """ - c1 = torch.exp(-gamma * dt) - c2 = torch.sqrt(kT * (1 - c1**2)) - - # Generate random noise from normal distribution - noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) - new_momenta = c1 * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise - state.momenta = new_momenta - return state - - def nve( model: torch.nn.Module, *, @@ -365,6 +329,43 @@ def nvt_langevin( if isinstance(dt, float): dt = torch.tensor(dt, device=device, dtype=dtype) + def ou_step( + state: MDState, + dt: torch.Tensor, + kT: torch.Tensor, + gamma: torch.Tensor, + ) -> MDState: + """Apply stochastic noise and friction for Langevin dynamics. + + This function implements the stochastic part of Langevin dynamics by applying + random noise and friction forces to particle momenta. The noise amplitude is + chosen to maintain the target temperature kT. + + Args: + state: Current system state containing positions, momenta, etc. + dt: Integration timestep + kT: Target temperature in energy units + gamma: Friction coefficient controlling noise strength + + Returns: + Updated state with new momenta after stochastic step + + Notes: + - Uses Ornstein-Uhlenbeck process for correct thermal sampling + - Noise amplitude scales with sqrt(mass) for equipartition + - Preserves detailed balance through fluctuation-dissipation relation + """ + c1 = torch.exp(-gamma * dt) + c2 = torch.sqrt(kT * (1 - c1**2)) + + # Generate random noise from normal distribution + noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) + new_momenta = ( + c1 * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise + ) + state.momenta = new_momenta + return state + def langevin_init( state: BaseState | StateDict, kT: torch.Tensor = kT, @@ -431,7 +432,7 @@ def langevin_update( state = momentum_step(state, dt / 2) state = position_step(state, dt / 2) - state = stochastic_step(state, dt, kT, gamma) + state = ou_step(state, dt, kT, gamma) state = position_step(state, dt / 2) model_output = model(state) @@ -441,3 +442,687 @@ def langevin_update( return momentum_step(state, dt / 2) return langevin_init, langevin_update + + +@dataclass +class NPTLangevinState(BaseState): + """State information for an NPT system with Langevin dynamics. + + This class represents the complete state of a molecular system being integrated + in the NPT (constant particle number, pressure, temperature) ensemble using + Langevin dynamics. + + Attributes: + positions: Particle positions with shape [n_particles, n_dimensions] + momenta: Particle momenta with shape [n_particles, n_dimensions] + energy: Energy of the system + forces: Forces on particles with shape [n_particles, n_dimensions] + masses: Particle masses with shape [n_particles] + cell: Simulation cell matrix with shape [n_batches, n_dimensions, n_dimensions] + pbc: Whether to use periodic boundary conditions + stress: Stress tensor of the system with shape + [n_batches, n_dimensions, n_dimensions] + reference_cell: Original cell vectors used as reference for scaling + cell_positions: Cell positions (effectively the volume) + cell_velocities: Cell velocities (rate of volume change) + cell_masses: Masses associated with the cell degrees of freedom + """ + + # System state variables + energy: torch.Tensor + forces: torch.Tensor + velocities: torch.Tensor + stress: torch.Tensor + + # Cell variables + reference_cell: torch.Tensor + cell_positions: torch.Tensor + cell_velocities: torch.Tensor + cell_masses: torch.Tensor + + @property + def momenta(self) -> torch.Tensor: + """Calculate momenta from velocities and masses.""" + return self.velocities * self.masses.unsqueeze(-1) + + +def npt_langevin( # noqa: C901, PLR0915 + model: torch.nn.Module, + *, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + alpha: torch.Tensor | None = None, + cell_alpha: torch.Tensor | None = None, + b_tau: torch.Tensor | None = None, +) -> tuple[ + Callable[[BaseState | StateDict, torch.Tensor], NPTLangevinState], + Callable[[NPTLangevinState, torch.Tensor], NPTLangevinState], +]: + """Initialize and return an NPT (isothermal-isobaric) integrator using + Langevin dynamics. + + This function sets up integration in the NPT ensemble, where particle number (N), + pressure (P), and temperature (T) are conserved. It allows the simulation cell to + fluctuate to maintain the target pressure. + + Args: + model: Neural network model that computes energies, forces, and stress + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target pressure to maintain + alpha: Friction coefficient for particle Langevin thermostat + cell_alpha: Friction coefficient for cell Langevin thermostat + b_tau: Barostat time constant + + Returns: + tuple: + - callable: Function to initialize the NPTLangevinState from input data + - callable: Update function that evolves system by one timestep + """ + device = model.device + dtype = model.dtype + + # Set default values if not provided + if alpha is None: + alpha = 1.0 / (100 * dt) # Default friction based on timestep + if cell_alpha is None: + cell_alpha = alpha # Use same friction for cell by default + if b_tau is None: + b_tau = 1 / (1000 * dt) # Default barostat time constant + + # Convert all parameters to tensors with correct device and dtype + if isinstance(alpha, float): + alpha = torch.tensor(alpha, device=device, dtype=dtype) + if isinstance(cell_alpha, float): + cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + if isinstance(b_tau, float): + b_tau = torch.tensor(b_tau, device=device, dtype=dtype) + if isinstance(external_pressure, float): + external_pressure = torch.tensor(external_pressure, device=device, dtype=dtype) + + def beta( + state: NPTLangevinState, + alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, + ) -> torch.Tensor: + """Calculate random noise term for Langevin dynamics. + + This function generates the stochastic force term for the Langevin thermostat, + following the fluctuation-dissipation theorem. + + Args: + state: Current NPT state + alpha: Friction coefficient + kT: Temperature in energy units + dt: Integration timestep + + Returns: + torch.Tensor: Random noise term for force calculation + """ + # Generate batch-specific noise with correct shape + noise = torch.randn_like(state.velocities) + + # Calculate the thermal noise amplitude by batch + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_batches) + + # Map batch kT to atoms + atom_kT = batch_kT[state.batch] + + # Calculate the prefactor for each atom + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + prefactor = torch.sqrt(2 * alpha * atom_kT * dt) + + return prefactor.unsqueeze(-1) * noise + + def cell_beta( + state: NPTLangevinState, + cell_alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, + ) -> torch.Tensor: + """Generate random noise for cell fluctuations in NPT dynamics. + + This function creates properly scaled random noise for cell dynamics in NPT + simulations, following the fluctuation-dissipation theorem. + + Args: + state: Current NPT state + cell_alpha: Coupling parameter controlling noise strength + kT: System temperature in energy units + dt: Integration timestep + + Returns: + torch.Tensor: Scaled random noise for cell dynamics + """ + # Generate standard normal distribution (zero mean, unit variance) + noise = torch.randn_like(state.cell_positions, device=device, dtype=dtype) + + # Ensure cell_alpha and kT have batch dimension if they're scalars + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_batches) + if kT.ndim == 0: + kT = kT.expand(state.n_batches) + + # Reshape for broadcasting + cell_alpha = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) + kT = kT.view(-1, 1, 1) # shape: (n_batches, 1, 1) + if dt.ndim == 0: + dt = dt.expand(state.n_batches).view(-1, 1, 1) + else: + dt = dt.view(-1, 1, 1) + + # Scale to satisfy the fluctuation-dissipation theorem + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + scaling_factor = torch.sqrt(2.0 * cell_alpha * kT * dt) + + return scaling_factor * noise + + def compute_cell_force( + state: NPTLangevinState, + external_pressure: torch.Tensor, + kT: torch.Tensor, + ) -> torch.Tensor: + """Compute forces on the cell for NPT dynamics. + + This function calculates the forces acting on the simulation cell + based on the difference between internal stress and external pressure, + plus a kinetic contribution. + + Args: + state: Current NPT state + external_pressure: Target external pressure + kT: Temperature in energy units + + Returns: + torch.Tensor: Force acting on the cell + """ + # Get current volumes for each batch + volumes = torch.linalg.det(state.cell) # shape: (n_batches,) + + # Reshape for broadcasting + volumes = volumes.view(-1, 1, 1) # shape: (n_batches, 1, 1) + + # Create pressure tensor (diagonal with external pressure) + if external_pressure.ndim == 0: + # Scalar pressure - create diagonal pressure tensors for each batch + pressure_tensor = external_pressure * torch.eye(3, device=device, dtype=dtype) + pressure_tensor = pressure_tensor.unsqueeze(0).expand(state.n_batches, -1, -1) + else: + # Already a tensor with shape compatible with n_batches + pressure_tensor = external_pressure + + # Calculate virials from stress and external pressure + # Internal stress is negative of virial tensor divided by volume + virial = -volumes * state.stress + pressure_tensor * volumes + + # Add kinetic contribution (kT * Identity) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_batches) + + kinetic_term = batch_kT.view(-1, 1, 1) * torch.eye( + 3, device=device, dtype=dtype + ).unsqueeze(0) + + return virial + kinetic_term + + def cell_position_step( + state: NPTLangevinState, + dt: torch.Tensor, + pressure_force: torch.Tensor, + kT: torch.Tensor = kT, + cell_alpha: torch.Tensor = cell_alpha, + ) -> NPTLangevinState: + """Update the cell position in NPT dynamics. + + This function updates the cell position in NPT dynamics using the barostat force. + It applies a half-step update to the cell position based on the barostat force. + + Args: + state: Current NPT state + dt: Integration timestep + pressure_force: Pressure force for barostat + kT: Target temperature in energy units + cell_alpha: Cell friction coefficient + + Returns: + NPTLangevinState: Updated state with new cell positions + """ + # Calculate effective mass term + Q_2 = 2 * state.cell_masses.view(-1, 1, 1) # shape: (n_batches, 1, 1) + + # Ensure parameters have batch dimension + if dt.ndim == 0: + dt = dt.expand(state.n_batches) + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_batches) + + # Reshape for broadcasting + dt_expanded = dt.view(-1, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) + + # Calculate damping factor for cell position update + cell_b = 1 / (1 + ((cell_alpha_expanded * dt_expanded) / Q_2)) + + # Deterministic velocity contribution + c_1 = cell_b * dt_expanded * state.cell_velocities + + # Force contribution + c_2 = cell_b * dt_expanded * dt_expanded * pressure_force / Q_2 + + # Random noise contribution (thermal fluctuations) + c_3 = ( + cell_b + * dt_expanded + * cell_beta(state=state, cell_alpha=cell_alpha, kT=kT, dt=dt) + / Q_2 + ) + + # Update cell positions with all contributions + state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 + return state + + def cell_velocity_step( + state: NPTLangevinState, + F_p_n: torch.Tensor, + dt: torch.Tensor, + pressure_force: torch.Tensor, + cell_alpha: torch.Tensor, + kT: torch.Tensor, + ) -> NPTLangevinState: + """Update the cell velocities in NPT dynamics. + + This function updates the cell velocities using a Langevin-type integrator, + accounting for both deterministic forces and stochastic thermal noise. + + Args: + state: Current NPT state + F_p_n: Initial pressure force + dt: Integration timestep + pressure_force: Final pressure force + cell_alpha: Cell friction coefficient + kT: Temperature in energy units + + Returns: + NPTLangevinState: Updated state with new cell velocities + """ + # Ensure parameters have batch dimension + if dt.ndim == 0: + dt = dt.expand(state.n_batches) + if cell_alpha.ndim == 0: + cell_alpha = cell_alpha.expand(state.n_batches) + if kT.ndim == 0: + kT = kT.expand(state.n_batches) + + # Reshape for broadcasting - need to maintain 3x3 dimensions + dt_expanded = dt.view(-1, 1, 1) # shape: (n_batches, 1, 1) + cell_alpha_expanded = cell_alpha.view(-1, 1, 1) # shape: (n_batches, 1, 1) + + # Calculate cell masses per batch - reshape to match 3x3 cell matrices + cell_masses_expanded = state.cell_masses.view( + -1, 1, 1 + ) # shape: (n_batches, 1, 1) + + # These factors come from the Langevin integration scheme + a = (1 - (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) / ( + 1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded + ) + b = 1 / (1 + (cell_alpha_expanded * dt_expanded) / cell_masses_expanded) + + # Calculate the three terms for velocity update + # a will broadcast from (n_batches, 1, 1) to (n_batches, 3, 3) + c_1 = a * state.cell_velocities # Damped old velocity + + # Force contribution (average of initial and final forces) + c_2 = dt_expanded * ((a * F_p_n) + pressure_force) / (2 * cell_masses_expanded) + + # Generate batch-specific cell noise with correct shape (n_batches, 3, 3) + cell_noise = torch.randn_like(state.cell_velocities) + + # Calculate thermal noise amplitude + noise_prefactor = torch.sqrt( + 2 * cell_alpha_expanded * kT.view(-1, 1, 1) * dt_expanded + ) + noise_term = noise_prefactor * cell_noise / torch.sqrt(cell_masses_expanded) + + # Random noise contribution + c_3 = b * noise_term + + # Update velocities with all contributions + state.cell_velocities = c_1 + c_2 + c_3 + return state + + def langevin_position_step( + state: NPTLangevinState, + L_n: torch.Tensor, # This should be shape (n_batches,) + dt: torch.Tensor, + kT: torch.Tensor, + ) -> NPTLangevinState: + """Update the particle positions in NPT dynamics. + + This function updates particle positions accounting for both the changing + cell dimensions and the particle velocities/forces. + + Args: + state: Current NPT state + L_n: Previous cell length scale (n_batches,) + dt: Integration timestep + kT: Target temperature in energy units + + Returns: + NPTLangevinState: Updated state with new positions + """ + # Calculate effective mass term by batch + # Map masses to have batch dimension + M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) + + # Calculate new cell length scale (cube root of volume for isotropic scaling) + L_n_new = torch.pow( + state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 + ) # shape: (n_batches,) + + # Map batch-specific L_n and L_n_new to atom-level using batch indices + # Make sure L_n is the right shape (n_batches,) before indexing + if L_n.ndim != 1 or L_n.shape[0] != state.n_batches: + # If L_n has wrong shape, calculate it again to ensure correct shape + L_n = torch.pow( + state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 + ) + + # Map batch values to atoms using batch indices + L_n_atoms = L_n[state.batch] # shape: (n_atoms,) + L_n_new_atoms = L_n_new[state.batch] # shape: (n_atoms,) + + # Calculate damping factor + alpha_atoms = alpha + if alpha.ndim > 0: + alpha_atoms = alpha[state.batch] + dt_atoms = dt + if dt.ndim > 0: + dt_atoms = dt[state.batch] + + b = 1 / (1 + ((alpha_atoms * dt_atoms) / M_2)) + + # Scale positions due to cell volume change + c_1 = (L_n_new_atoms / L_n_atoms).unsqueeze(-1) * state.positions + + # Time step factor with average length scale + c_2 = ( + (2 * L_n_new_atoms / (L_n_new_atoms + L_n_atoms)).unsqueeze(-1) + * b + * dt_atoms.unsqueeze(-1) + ) + + # Generate atom-specific noise + noise = torch.randn_like(state.velocities) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_batches) + atom_kT = batch_kT[state.batch] + + # Calculate noise prefactor according to fluctuation-dissipation theorem + noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) + noise_term = noise_prefactor.unsqueeze(-1) * noise + + # Velocity and force contributions with random noise + c_3 = ( + state.velocities + + dt_atoms.unsqueeze(-1) * state.forces / M_2 + + noise_term / M_2 + ) + + # Update positions with all contributions + state.positions = c_1 + c_2 * c_3 + + # Apply periodic boundary conditions if needed + if state.pbc: + state.positions = pbc_wrap_batched(state.positions, state.cell, state.batch) + + return state + + def langevin_velocity_step( + state: NPTLangevinState, + forces: torch.Tensor, + dt: torch.Tensor, + kT: torch.Tensor, + ) -> NPTLangevinState: + """Update the particle velocities in NPT dynamics. + + This function updates particle velocities using a Langevin-type integrator, + accounting for both deterministic forces and stochastic thermal noise. + + Args: + state: Current NPT state + forces: Forces on particles + dt: Integration timestep + kT: Target temperature in energy units + + Returns: + NPTLangevinState: Updated state with new velocities + """ + # Calculate denominator for update equations + M_2 = 2 * state.masses.unsqueeze(-1) # shape: (n_atoms, 1) + + # Map batch parameters to atom level + alpha_atoms = alpha + if alpha.ndim > 0: + alpha_atoms = alpha[state.batch] + dt_atoms = dt + if dt.ndim > 0: + dt_atoms = dt[state.batch] + + # Calculate damping factors for Langevin integration + a = (1 - (alpha_atoms * dt_atoms) / M_2) / (1 + (alpha_atoms * dt_atoms) / M_2) + b = 1 / (1 + (alpha_atoms * dt_atoms) / M_2) + + # Velocity contribution with damping + c_1 = a * state.velocities + + # Force contribution (average of initial and final forces) + c_2 = dt_atoms.unsqueeze(-1) * ((a * forces) + state.forces) / M_2 + + # Generate atom-specific noise + noise = torch.randn_like(state.velocities) + batch_kT = kT + if kT.ndim == 0: + batch_kT = kT.expand(state.n_batches) + atom_kT = batch_kT[state.batch] + + # Calculate noise prefactor according to fluctuation-dissipation theorem + noise_prefactor = torch.sqrt(2 * alpha_atoms * atom_kT * dt_atoms) + noise_term = noise_prefactor.unsqueeze(-1) * noise + + # Random noise contribution + c_3 = b * noise_term / state.masses.unsqueeze(-1) + + # Update velocities with all contributions + state.velocities = c_1 + c_2 + c_3 + return state + + def npt_init( + state: BaseState | StateDict, + kT: torch.Tensor = kT, + seed: int | None = None, + **kwargs: Any, + ) -> NPTLangevinState: + """Initialize an NPT Langevin state from input data. + + This function creates the initial state for NPT Langevin dynamics, + setting up all necessary variables including cell parameters. + + Args: + state: Either a BaseState object or a dictionary containing positions, + masses, cell, pbc + kT: Temperature in energy units for initializing momenta + seed: Random seed for reproducibility + **kwargs: Additional state arguments + + Returns: + NPTLangevinState: Initialized state for NPT Langevin integration + """ + if not isinstance(state, BaseState): + state = BaseState(**state) + + # Override with kwargs if provided + atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) + batch = kwargs.get("batch", state.batch) + + # Get model output to initialize forces and stress + model_output = model(state) + + # Initialize momenta if not provided + momenta = getattr( + state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + ) + + # Initialize cell parameters + reference_cell = state.cell.clone() + + # Calculate initial cell_positions (volume) + cell_positions = ( + torch.linalg.det(state.cell).unsqueeze(-1).unsqueeze(-1) + ) # shape: (n_batches, 1, 1) + + # Initialize cell velocities to zero + cell_velocities = torch.zeros((state.n_batches, 3, 3), device=device, dtype=dtype) + + # Calculate cell masses based on system size and temperature + # This follows standard NPT barostat mass scaling + n_atoms_per_batch = torch.bincount(state.batch) + batch_kT = ( + kT.expand(state.n_batches) + if isinstance(kT, torch.Tensor) and kT.ndim == 0 + else kT + ) + cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + + # Create the initial state + return NPTLangevinState( + positions=state.positions, + velocities=momenta / state.masses.unsqueeze(-1), + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + batch=batch, + atomic_numbers=atomic_numbers, + reference_cell=reference_cell, + cell_positions=cell_positions, + cell_velocities=cell_velocities, + cell_masses=cell_masses, + ) + + def npt_update( + state: NPTLangevinState, + dt: torch.Tensor = dt, + kT: torch.Tensor = kT, + external_pressure: torch.Tensor = external_pressure, + alpha: torch.Tensor = alpha, + cell_alpha: torch.Tensor = cell_alpha, + ) -> NPTLangevinState: + """Perform one complete NPT Langevin dynamics integration step. + + This function implements a modified integration scheme for NPT dynamics, + handling both atomic and cell updates with Langevin thermostats. + + Args: + state: Current NPT state + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target external pressure + alpha: Position friction coefficient + cell_alpha: Cell friction coefficient + + Returns: + NPTLangevinState: Updated NPT state after one timestep + """ + # Convert any scalar parameters to tensors with batch dimension if needed + if isinstance(alpha, float): + alpha = torch.tensor(alpha, device=device, dtype=dtype) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + if isinstance(cell_alpha, float): + cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + + # Make sure parameters have batch dimension if they're scalars + batch_kT = kT.expand(state.n_batches) if kT.ndim == 0 else kT + + # Update barostat mass based on current temperature + # This ensures proper coupling between system and barostat + n_atoms_per_batch = torch.bincount(state.batch) + state.cell_masses = (n_atoms_per_batch + 1) * batch_kT * b_tau * b_tau + + # Compute model output for current state + model_output = model(state) + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Store initial values for integration + forces = state.forces + F_p_n = compute_cell_force( + state=state, external_pressure=external_pressure, kT=kT + ) + L_n = torch.pow( + state.cell_positions.reshape(state.n_batches, -1)[:, 0], 1 / 3 + ) # shape: (n_batches,) + + # Step 1: Update cell position + state = cell_position_step(state=state, dt=dt, pressure_force=F_p_n, kT=kT) + + # Update cell (currently only isotropic fluctuations) + dim = state.positions.shape[1] # Usually 3 for 3D + V_0 = torch.linalg.det(state.reference_cell) # shape: (n_batches,) + V = state.cell_positions.reshape(state.n_batches, -1)[:, 0] # shape: (n_batches,) + + # Scale cell uniformly in all dimensions + scaling = (V / V_0) ** (1.0 / dim) # shape: (n_batches,) + + # Apply scaling to reference cell to get new cell + new_cell = torch.zeros_like(state.cell) + for b in range(state.n_batches): + new_cell[b] = scaling[b] * state.reference_cell[b] + + state.cell = new_cell + + # Step 2: Update particle positions + state = langevin_position_step(state=state, L_n=L_n, dt=dt, kT=kT) + + # Recompute model output after position updates + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Compute updated pressure force + F_p_n_new = compute_cell_force( + state=state, external_pressure=external_pressure, kT=kT + ) + + # Step 3: Update cell velocities + state = cell_velocity_step( + state=state, + F_p_n=F_p_n, + dt=dt, + pressure_force=F_p_n_new, + cell_alpha=cell_alpha, + kT=kT, + ) + + # Step 4: Update particle velocities + state = langevin_velocity_step(state=state, forces=forces, dt=dt, kT=kT) + + return state # noqa: RET504 + + return npt_init, npt_update diff --git a/torch_sim/unbatched/unbatched_integrators.py b/torch_sim/unbatched/unbatched_integrators.py index 080d1497..fc9c8ab7 100644 --- a/torch_sim/unbatched/unbatched_integrators.py +++ b/torch_sim/unbatched/unbatched_integrators.py @@ -146,46 +146,6 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState: return state -def stochastic_step( - state: MDState, - dt: torch.Tensor, - kT: torch.Tensor, - gamma: torch.Tensor, - device: torch.device, - dtype: torch.dtype, -) -> MDState: - """Apply stochastic noise and friction for Langevin dynamics. - - This function implements the stochastic part of Langevin dynamics by applying - random noise and friction forces to particle momenta. The noise amplitude is - chosen to maintain the target temperature kT. - - Args: - state: Current system state containing positions, momenta, etc. - dt: Integration timestep - kT: Target temperature in energy units - gamma: Friction coefficient controlling noise strength - device: Device to initialize momenta on - dtype: Data type to initialize momenta as - - Returns: - Updated state with new momenta after stochastic step - - Notes: - - Uses Ornstein-Uhlenbeck process for correct thermal sampling - - Noise amplitude scales with sqrt(mass) for equipartition - - Preserves detailed balance through fluctuation-dissipation relation - """ - c1 = torch.exp(-gamma * dt) - c2 = torch.sqrt(kT * (1 - c1**2)) - - # Generate random noise from normal distribution - noise = torch.randn_like(state.momenta, device=device, dtype=dtype) - new_momenta = c1 * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise - state.momenta = new_momenta - return state - - def velocity_verlet(state: MDState, dt: torch.Tensor, model: torch.nn.Module) -> MDState: """Perform one complete velocity Verlet integration step. @@ -380,6 +340,47 @@ def nvt_langevin( if isinstance(dt, float): dt = torch.tensor(dt, device=device, dtype=dtype) + def ou_step( + state: MDState, + dt: torch.Tensor, + kT: torch.Tensor, + gamma: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + ) -> MDState: + """Apply stochastic noise and friction for Langevin dynamics. + + This function implements the stochastic part of Langevin dynamics by applying + random noise and friction forces to particle momenta. The noise amplitude is + chosen to maintain the target temperature kT. + + Args: + state: Current system state containing positions, momenta, etc. + dt: Integration timestep + kT: Target temperature in energy units + gamma: Friction coefficient controlling noise strength + device: Device to initialize momenta on + dtype: Data type to initialize momenta as + + Returns: + Updated state with new momenta after stochastic step + + Notes: + - Uses Ornstein-Uhlenbeck process for correct thermal sampling + - Noise amplitude scales with sqrt(mass) for equipartition + - Preserves detailed balance through fluctuation-dissipation relation + """ + c1 = torch.exp(-gamma * dt) + c2 = torch.sqrt(kT * (1 - c1**2)) + + # Generate random noise from normal distribution + noise = torch.randn_like(state.momenta, device=device, dtype=dtype) + new_momenta = ( + c1 * state.momenta + c2 * torch.sqrt(state.masses).unsqueeze(-1) * noise + ) + state.momenta = new_momenta + return state + def langevin_init( state: BaseState | StateDict, kT: torch.Tensor = kT, @@ -456,7 +457,7 @@ def langevin_update( state = momentum_step(state, dt / 2) state = position_step(state, dt / 2) - state = stochastic_step(state, dt, kT, gamma, device, dtype) + state = ou_step(state, dt, kT, gamma, device, dtype) state = position_step(state, dt / 2) model_output = model(state) @@ -468,6 +469,567 @@ def langevin_update( return langevin_init, langevin_update +@dataclass +class NPTLangevinState(BaseState): + """State information for an NPT system with Langevin dynamics. + + This class represents the complete state of a molecular system being integrated + in the NPT (constant particle number, pressure, temperature) ensemble using + Langevin dynamics. + + Attributes: + energy: Total energy of the system + forces: Forces acting on each particle + stress: Stress tensor of the system + velocities: Velocities of each particle + reference_cell: Original cell vectors used as reference for scaling + cell_positions: Cell positions (effectively the volume) + cell_velocities: Cell velocities (rate of volume change) + cell_masses: Masses associated with the cell degrees of freedom + """ + + # System state variables + energy: torch.Tensor + forces: torch.Tensor + stress: torch.Tensor + velocities: torch.Tensor + + # Cell variables + reference_cell: torch.Tensor + cell_positions: torch.Tensor + cell_velocities: torch.Tensor + cell_masses: torch.Tensor + + @property + def momenta(self) -> torch.Tensor: + """Calculate momenta from velocities and masses.""" + return self.masses.unsqueeze(-1) * self.velocities + + +def npt_langevin( # noqa: C901, PLR0915 + *, + model: torch.nn.Module, + dt: torch.Tensor, + kT: torch.Tensor, + external_pressure: torch.Tensor, + alpha: torch.Tensor | None = None, + cell_alpha: torch.Tensor | None = None, + b_tau: torch.Tensor | None = None, +) -> tuple[ + Callable[[BaseState | StateDict, torch.Tensor], MDState], + Callable[[MDState, torch.Tensor], MDState], +]: + """Initialize and return an NPT (canonical) integrator using Langevin dynamics. + + This function sets up integration in the NPT ensemble, where particle number (N), + pressure (P), and temperature (T) are conserved. It returns both an initialization + function and an update function for time evolution. + + Args: + model: Neural network model that computes energies and forces + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target pressure for the system + alpha: Friction coefficient for position updates (default: 1/(100*dt)) + cell_alpha: Friction coefficient for cell updates (default: 1/(100*dt)) + b_tau: Pressure damping parameter (default: 1/(1000*dt)) + + Returns: + tuple: + - Callable[[BaseState | StateDict, torch.Tensor], MDState]: Function to + initialize the MDState from input data and kT + - Callable[[MDState, torch.Tensor], MDState]: Update function that evolves + system by one timestep + """ + device = model.device + dtype = model.dtype + + # Set default values for coupling parameters if not provided + alpha = alpha or 1 / (100 * dt) + cell_alpha = cell_alpha or 1 / (100 * dt) + b_tau = b_tau or 1 / (1000 * dt) + + # Convert float parameters to tensors with appropriate device and dtype + if isinstance(alpha, float): + alpha = torch.tensor(alpha, device=device, dtype=dtype) + + if isinstance(cell_alpha, float): + cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) + + if isinstance(b_tau, float): + b_tau = torch.tensor(b_tau, device=device, dtype=dtype) + + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + + def cell_beta( + state: NPTLangevinState, + cell_alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, + device: torch.device = device, + dtype: torch.dtype = dtype, + ) -> torch.Tensor: + """Generate random noise for cell fluctuations in NPT dynamics. + + This function creates properly scaled random noise for cell dynamics in NPT + simulations, following the fluctuation-dissipation theorem. The noise amplitude + is scaled to maintain the target temperature. + + Args: + state: Current NPT state + cell_alpha: Coupling parameter controlling noise strength + kT: System temperature in energy units + dt: Integration timestep + device: Device to place the tensor on (CPU or GPU) + dtype: Data type for the tensor + + Returns: + torch.Tensor: Scaled random noise for cell dynamics + """ + # Generate standard normal distribution (zero mean, unit variance) + noise = torch.randn_like(state.cell_positions, device=device, dtype=dtype) + + # Scale to satisfy the fluctuation-dissipation theorem + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + scaling_factor = torch.sqrt(2.0 * cell_alpha * kT * dt) + + return scaling_factor * noise + + def beta( + state: NPTLangevinState, + alpha: torch.Tensor, + kT: torch.Tensor, + dt: torch.Tensor, + device: torch.device = device, + dtype: torch.dtype = dtype, + ) -> torch.Tensor: + """Generate random noise for particle fluctuations in NPT dynamics. + + This function creates properly scaled random noise for particle dynamics in NPT + simulations, following the fluctuation-dissipation theorem. The noise amplitude + is scaled to maintain the target temperature. + + Args: + state: Current NPT state + alpha: Coupling parameter controlling noise strength + kT: System temperature in energy units + dt: Integration timestep + device: Device to place the tensor on (CPU or GPU) + dtype: Data type for the tensor + + Returns: + torch.Tensor: Scaled random noise for particle dynamics + """ + # Generate standard normal distribution (zero mean, unit variance) + noise = torch.randn_like(state.positions, device=device, dtype=dtype) + + # Scale to satisfy the fluctuation-dissipation theorem + # The standard deviation should be sqrt(2*alpha*kB*T*dt) + scaling_factor = torch.sqrt(2.0 * alpha * kT * dt) + + return scaling_factor * noise + + def cell_position_step( + state: NPTLangevinState, + dt: torch.Tensor, + pressure_force: torch.Tensor, + kT: torch.Tensor = kT, + cell_alpha: torch.Tensor = cell_alpha, + ) -> NPTLangevinState: + """Update the cell position in NPT dynamics. + + This function updates the cell position in NPT dynamics using the barostat force. + It applies a half-step update to the cell position based on the barostat force. + + Args: + state: Current NPT state + dt: Integration timestep + pressure_force: Pressure force for barostat + kT: Target temperature in energy units + cell_alpha: Cell friction coefficient + + Returns: + NPTLangevinState: Updated state with new cell positions + """ + # Calculate effective mass term + Q_2 = 2 * state.cell_masses + + # Calculate damping factor for cell position update + cell_b = 1 / (1 + ((cell_alpha * dt) / Q_2)) + + # Deterministic velocity contribution + c_1 = cell_b * dt * state.cell_velocities + + # Force contribution + c_2 = cell_b * dt * dt * pressure_force / Q_2 + + # Random noise contribution (thermal fluctuations) + c_3 = ( + cell_b + * dt + * cell_beta(state=state, cell_alpha=cell_alpha, kT=kT, dt=dt) + / Q_2 + ) + + # Update cell positions with all contributions + state.cell_positions = state.cell_positions + c_1 + c_2 + c_3 + return state + + def cell_velocity_step( + state: NPTLangevinState, + F_p_n: torch.Tensor, + dt: torch.Tensor, + pressure_force: torch.Tensor, + cell_alpha: torch.Tensor, + kT: torch.Tensor = kT, + ) -> NPTLangevinState: + """Update the cell momentum in NPT dynamics. + + This function updates the cell velocities based on the pressure forces and + thermal fluctuations, following the Langevin dynamics equations. + + Args: + state: Current NPT state + F_p_n: Previous pressure force + dt: Integration timestep + pressure_force: Updated pressure force + cell_alpha: Cell friction coefficient + kT: Target temperature in energy units + + Returns: + NPTLangevinState: Updated state with new cell velocities + """ + # Calculate denominator for update equations + Q_2 = 2 * state.cell_masses + + # Calculate damping factors + cell_a = (1 - (cell_alpha * dt) / Q_2) / (1 + (cell_alpha * dt) / Q_2) + cell_b = 1 / (1 + (cell_alpha * dt) / Q_2) + + # Deterministic velocity contribution + c_1 = cell_alpha * state.cell_velocities + + # Force contribution (average of initial and final forces) + c_2 = dt / Q_2 * (cell_a * F_p_n + pressure_force) + + # Random noise contribution (thermal fluctuations) + c_3 = ( + cell_b + * cell_beta(state=state, cell_alpha=cell_alpha, kT=kT, dt=dt) + / state.cell_masses + ) + + # Update cell velocities with all contributions + state.cell_velocities = c_1 + c_2 + c_3 + return state + + def compute_cell_force( + state: NPTLangevinState, + external_pressure: torch.Tensor, + kT: torch.Tensor, + ) -> torch.Tensor: + """Compute the cell force in NPT dynamics. + + This function calculates the force on the cell based on the difference between + the internal stress and the external pressure. + + Args: + state: Current NPT state + external_pressure: Target external pressure + kT: System temperature in energy units + + Returns: + torch.Tensor: Computed cell force for barostat + """ + # Kinetic contribution + N = state.n_atoms + volume = state.cell_positions + KE_cell = N * kT + + # Calculate internal pressure from stress tensor + # (average of diagonal elements for isotropic pressure) + internal_pressure = torch.trace(state.stress) / state.positions.shape[1] + + # Force is proportional to pressure difference + # F = V * (P_internal - P_external) + kinetic contribution + return KE_cell - (internal_pressure * volume) - (external_pressure * volume) + + def langevin_position_step( + state: NPTLangevinState, + L_n: torch.Tensor, + dt: torch.Tensor, + kT: torch.Tensor, + ) -> NPTLangevinState: + """Update the particle positions in NPT dynamics. + + This function updates the particle positions in NPT dynamics, accounting for + both the forces on particles and the cell volume changes. + + Args: + state: Current NPT state + L_n: Previous cell length scale + dt: Integration timestep + kT: Target temperature in energy units + + Returns: + NPTLangevinState: Updated state with new positions + """ + # Calculate effective mass term + M_2 = 2 * state.masses.unsqueeze(-1) + + # Calculate new cell length scale (cube root of volume for isotropic scaling) + L_n_new = torch.pow(state.cell_positions, 1 / 3) + + # Calculate damping factor + b = 1 / (1 + ((alpha * dt) / M_2)) + + # Scale positions due to cell volume change + c_1 = (L_n_new / L_n) * state.positions + + # Time step factor with average length scale + c_2 = (2 * L_n_new / (L_n_new + L_n)) * b * dt + + # Velocity and force contributions with random noise + c_3 = ( + state.velocities + + dt * state.forces / (M_2) + + 1 / (M_2) * beta(state=state, alpha=alpha, kT=kT, dt=dt) + ) + + # Update positions with all contributions + state.positions = c_1 + c_2 * c_3 + + # Apply periodic boundary conditions if needed + if state.pbc: + new_positions = pbc_wrap_general( + positions=state.positions, lattice_vectors=state.cell + ) + state.positions = new_positions + + return state + + def langevin_velocity_step( + state: NPTLangevinState, + forces: torch.Tensor, + dt: torch.Tensor, + kT: torch.Tensor, + device: torch.device = device, + dtype: torch.dtype = dtype, + ) -> NPTLangevinState: + """Update the particle velocities in NPT dynamics. + + This function updates the particle velocities based on the forces and + thermal fluctuations, following the Langevin dynamics equations. + + Args: + state: Current NPT state + forces: Forces on particles + dt: Integration timestep + kT: Target temperature in energy units + device: Device to place the tensor on (CPU or GPU) + dtype: Data type for the tensor + + Returns: + NPTLangevinState: Updated state with new velocities + """ + # Calculate denominator for update equations + M_2 = 2 * state.masses.unsqueeze(-1) + + # Calculate damping factors + a = (1 - (alpha * dt) / M_2) / (1 + (alpha * dt) / M_2) + b = 1 / (1 + (alpha * dt) / M_2) + + # Velocity contribution with damping + c_1 = a * state.velocities + + # Force contribution (average of initial and final forces) + c_2 = dt * ((a * forces) + state.forces) / M_2 + + # Random noise contribution (thermal fluctuations) + c_3 = ( + b + * beta(state=state, alpha=alpha, kT=kT, dt=dt, device=device, dtype=dtype) + / state.masses.unsqueeze(-1) + ) + + # Update velocities with all contributions + state.velocities = c_1 + c_2 + c_3 + return state + + def npt_init( + state: BaseState | StateDict, + kT: torch.Tensor = kT, + device: torch.device = device, + dtype: torch.dtype = dtype, + seed: int | None = None, + **kwargs: Any, + ) -> MDState: + """Initialize an NPT state from input data. + + This function creates an initial NPT state from the provided base state or + state dictionary, initializing all necessary variables for NPT simulation. + + Args: + state: Either a BaseState object or a dictionary containing positions, + momenta, cell, pbc + kT: Target temperature in energy units + device: Device to place the tensor on (CPU or GPU) + dtype: Data type for the tensor + seed: Random seed for reproducibility + **kwargs: Additional state arguments + + Returns: + MDState: Initialized state for NPT integration + """ + # Convert dictionary to BaseState if needed + if not isinstance(state, BaseState): + state = BaseState(**state) + + # Get atomic numbers from kwargs or state + atomic_numbers = kwargs.get("atomic_numbers", state.atomic_numbers) + + # Compute initial energy, forces, and stress + model_output = model(state) + + # Initialize cell variables + # Cell position is the volume + cell_positions = torch.linalg.det(state.cell) + # Initial cell velocity is zero + cell_velocities = torch.zeros_like(cell_positions) + # Cell mass depends on system size, temperature and barostat time constant + cell_masses = (state.n_atoms + 1) * kT * b_tau * b_tau + + # Initialize momenta (from kwargs or calculated) + momenta = kwargs.get( + "momenta", + calculate_momenta(state.positions, state.masses, kT, device, dtype, seed), + ) + + # Create and return the NPT state + return NPTLangevinState( + positions=state.positions, + velocities=momenta / state.masses.unsqueeze(-1), + cell=state.cell, + pbc=state.pbc, + masses=state.masses, + energy=model_output["energy"], + forces=model_output["forces"], + stress=model_output["stress"], + reference_cell=state.cell.clone(), + cell_positions=cell_positions, + cell_velocities=cell_velocities, + cell_masses=cell_masses, + atomic_numbers=atomic_numbers, + ) + + def npt_update( + state: NPTLangevinState, + dt: torch.Tensor = dt, + kT: torch.Tensor = kT, + external_pressure: torch.Tensor = external_pressure, + alpha: torch.Tensor = alpha, + cell_alpha: torch.Tensor = cell_alpha, + ) -> NPTLangevinState: + """Update the NPT state for one timestep. + + This function performs a single timestep of NPT integration for the given state. + It includes both position and cell updates, following a modified BAOAB scheme + adapted for NPT dynamics. + + Args: + state: Current NPT state + dt: Integration timestep + kT: Target temperature in energy units + external_pressure: Target external pressure + alpha: Position friction coefficient + cell_alpha: Cell friction coefficient + + Returns: + NPTLangevinState: Updated NPT state after one timestep + """ + # Convert float parameters to tensors if needed + if isinstance(alpha, float): + alpha = torch.tensor(alpha, device=device, dtype=dtype) + + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + + if isinstance(cell_alpha, float): + cell_alpha = torch.tensor(cell_alpha, device=device, dtype=dtype) + + if isinstance(dt, float): + dt = torch.tensor(dt, device=device, dtype=dtype) + + # Update barostat mass based on current temperature + # This ensures proper coupling as temperature changes + state.cell_masses = (state.n_atoms + 1) * kT * b_tau * b_tau + + # Compute model output for current state + model_output = model(state) + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Store initial values for integration + forces = state.forces + F_p_n = compute_cell_force( + state=state, external_pressure=external_pressure, kT=kT + ) + L_n = torch.pow(state.cell_positions, 1 / 3) # Current length scale + + # Step 1: Update cell position + state = cell_position_step(state=state, dt=dt, pressure_force=F_p_n, kT=kT) + + # Update cell (currently only isotropic fluctuations) + dim = state.positions.shape[1] + V_0 = torch.det(state.reference_cell) + V = state.cell_positions + + # Scale cell uniformly in all dimensions + new_cell = (V / V_0) ** (1.0 / dim) * state.reference_cell + state.cell = new_cell + + # NOTE: Better to scale each dimension independently? + # state.cell = torch.tensor([[L_x, 0, 0], + # [0, L_x, 0], + # [0, 0, L_x]], device=device, dtype=dtype) + + # Step 2: Update particle positions + state = langevin_position_step(state=state, L_n=L_n, dt=dt, kT=kT) + + # state.positions = state.positions + dt * state.velocities + + # Recompute model output after position updates + model_output = model(state) + state.energy = model_output["energy"] + state.forces = model_output["forces"] + state.stress = model_output["stress"] + + # Compute updated pressure force + F_p_n_new = compute_cell_force( + state=state, external_pressure=external_pressure, kT=kT + ) + + # Step 3: Update cell velocities + state = cell_velocity_step( + state=state, + F_p_n=F_p_n, + dt=dt, + pressure_force=F_p_n_new, + cell_alpha=cell_alpha, + kT=kT, + ) + + # Step 4: Update particle velocities + state = langevin_velocity_step(state=state, forces=forces, dt=dt, kT=kT) + + # state.velocities = state.velocities + dt * forces / state.masses.unsqueeze(-1) + + return state # noqa: RET504 + + return npt_init, npt_update + + @dataclass class NoseHooverChain: """State information for a Nose-Hoover chain thermostat.