From d41e5e9083e76baee8d6008bc082a4b1ed2c1aa2 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Tue, 8 Apr 2025 13:15:48 -0700 Subject: [PATCH 01/13] allow different batches to be run at different temperatures in nvt, test functionality --- tests/test_integrators.py | 91 ++++++++++++++++++++++++++++++++++++++- torch_sim/integrators.py | 18 ++++++-- 2 files changed, 103 insertions(+), 6 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 26d22d73..e22d551a 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -233,11 +233,56 @@ def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated + +def test_npt_langevin_multi_kt(ar_double_sim_state: SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300.0, 10000.0], dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_langevin( + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + alpha=40 * dt, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT(state.momenta, state.masses, batch=state.batch) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor(100.0, dtype=dtype) * MetalUnits.temperature + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature # Initialize integrator init_fn, update_fn = nvt_langevin( @@ -273,7 +318,7 @@ def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory for mean_temp in mean_temps: assert ( - abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 ) # Allow for thermal fluctuations # Check energy is stable for each trajectory @@ -291,6 +336,48 @@ def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated + +def test_nvt_langevin_multi_kt(ar_double_sim_state: SimState, lj_model: LennardJonesModel): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300.0, 10000.0], dtype=dtype) * MetalUnits.temperature + + # Initialize integrator + init_fn, update_fn = nvt_langevin( + model=lj_model, + dt=dt, + kT=kT, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT(state.momenta, state.masses, batch=state.batch) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + def test_nve(ar_double_sim_state: SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 746665e2..b75d8372 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -131,6 +131,7 @@ def batched_initialize_momenta( def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, + batch: torch.Tensor, kT: torch.Tensor, seed: int | None = None, ) -> torch.Tensor: @@ -156,6 +157,10 @@ def calculate_momenta( if seed is not None: generator.manual_seed(seed) + if len(kT.shape) > 0: + # kT is a tensor with shape (n_batches,) + kT = kT[batch] + # Generate random momenta from normal distribution momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator @@ -287,7 +292,7 @@ def nve_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) ) initial_state = MDState( @@ -431,7 +436,12 @@ def ou_step( where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1²)) """ c1 = torch.exp(-gamma * dt) - c2 = torch.sqrt(kT * (1 - c1**2)) + + if len(kT.shape) > 0: + # kT is a tensor with shape (n_batches,) + kT = kT[state.batch] + + c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) # Generate random noise from normal distribution noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) @@ -474,7 +484,7 @@ def langevin_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) ) initial_state = MDState( @@ -1127,7 +1137,7 @@ def npt_init( # Initialize momenta if not provided momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) ) # Initialize cell parameters From c66ff96e2e662309b6e1115c22d5d0d3805e1adc Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 07:52:47 -0700 Subject: [PATCH 02/13] add zeroed COM motion to calculate_momenta --- tests/test_integrators.py | 209 ++++++++++---------------------------- torch_sim/integrators.py | 82 ++++----------- 2 files changed, 74 insertions(+), 217 deletions(-) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index e22d551a..be9661fb 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -3,8 +3,7 @@ import torch from torch_sim.integrators import ( - MDState, - batched_initialize_momenta, + calculate_momenta, npt_langevin, nve, nvt_langevin, @@ -15,161 +14,65 @@ from torch_sim.units import MetalUnits -def batched_initialize_momenta_loop( - positions: torch.Tensor, # shape: (n_batches, n_atoms_per_batch, 3) - masses: torch.Tensor, # shape: (n_batches, n_atoms_per_batch) - kT: torch.Tensor, # shape: (n_batches,) - seeds: torch.Tensor, # shape: (n_batches,) - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - """Initialize momenta for batched molecular dynamics. - - Args: - positions: Tensor of atomic positions with shape - (n_batches, n_atoms_per_batch, 3). - masses: Tensor of atomic masses with shape - (n_batches, n_atoms_per_batch). - kT: Tensor of temperature values in energy units for - each batch with shape (n_batches,). - seeds: Tensor of random seeds for each batch with shape (n_batches,). - device: The device on which to allocate the tensors (e.g., 'cpu' or 'cuda'). - dtype: The data type of the tensors (e.g., torch.float32). - - Returns: - torch.Tensor: initialized momenta with shape (n_batches, n_atoms_per_batch, 3). - """ - n_batches = positions.shape[0] - n_atoms_per_batch = positions.shape[1] - - # Initialize momenta tensor - momenta = torch.zeros((n_batches, n_atoms_per_batch, 3), dtype=dtype) - - # Create a generator for each batch using the provided seeds - generators = [torch.Generator(device=device).manual_seed(int(seed)) for seed in seeds] - - # Generate random momenta for each batch - for batch_idx in range(n_batches): - # Generate random velocities from normal distribution - batch_momenta = torch.randn( - (n_atoms_per_batch, 3), dtype=dtype, generator=generators[batch_idx] - ) - - # Scale by sqrt(mass * kT) - mass_factors = torch.sqrt(masses[batch_idx]).unsqueeze(-1) - kT_factor = torch.sqrt(kT[batch_idx]) - batch_momenta *= mass_factors * kT_factor - - # Remove center of mass motion if more than one atom - if n_atoms_per_batch > 1: - mean_momentum = torch.mean(batch_momenta, dim=0, keepdim=True) - batch_momenta = batch_momenta - mean_momentum - - momenta[batch_idx] = batch_momenta - - return momenta - - -def test_batched_initialize_momenta_loop(): - from torch_sim.unbatched.unbatched_integrators import initialize_momenta - - # Set random seed for reproducibility +def test_calculate_momenta_basic(device: torch.device): + """Test basic functionality of calculate_momenta.""" seed = 42 - - device = torch.device("cpu") dtype = torch.float64 - - n_batches = 3 - n_atoms_per_batch = 4 - - # Create test inputs - positions = torch.randn(n_batches, n_atoms_per_batch, 3, dtype=dtype) - masses = torch.rand(n_batches, n_atoms_per_batch, dtype=dtype) + 0.5 - kT = torch.tensor([0.1, 0.2, 0.3], dtype=dtype) - seeds = torch.arange(seed, seed + n_batches, dtype=torch.int64) - - # Run non-batched version first - unbatched_momenta = [] - for batch_idx in range(n_batches): - state = MDState( - positions=positions[batch_idx], - momenta=torch.zeros_like(positions[batch_idx]), - masses=masses[batch_idx], - forces=torch.zeros_like(positions[batch_idx]), - energy=torch.zeros(1, dtype=dtype), - atomic_numbers=torch.ones(n_atoms_per_batch, dtype=torch.int64), - cell=torch.eye(3, dtype=dtype), - pbc=False, - ) - state = initialize_momenta( - state, kT[batch_idx], device, dtype, seed=int(seeds[batch_idx]) - ) - unbatched_momenta.append(state.momenta) - - unbatched_momenta = torch.stack(unbatched_momenta) - - # Run batched version - batched_momenta = batched_initialize_momenta_loop( - positions=positions, - masses=masses, - kT=kT, - seeds=seeds, # seeds before device and dtype - device=device, - dtype=dtype, - ) - - assert torch.allclose(batched_momenta, unbatched_momenta, rtol=1e-6) - - -def test_batched_initialize_momenta(): - from torch_sim.unbatched.unbatched_integrators import initialize_momenta - + + # Create test inputs for 3 batches with 2 atoms each + n_atoms = 8 + positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) + masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 + batch = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], device=device) # 3 batches with 2 atoms each + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + + # Run the function + momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + + # Basic checks + assert momenta.shape == positions.shape + assert momenta.dtype == dtype + assert momenta.device == device + + # Check that each batch has zero center of mass momentum + for b in range(4): + batch_mask = (batch == b) + batch_momenta = momenta[batch_mask] + com_momentum = torch.mean(batch_momenta, dim=0) + assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10) + + +def test_calculate_momenta_single_atoms(device: torch.device): + """Test that calculate_momenta preserves momentum for batches with single atoms.""" seed = 42 - device = torch.device("cpu") dtype = torch.float64 - - n_batches = 3 - n_atoms_per_batch = 4 - - # Create test inputs - positions = torch.randn(n_batches, n_atoms_per_batch, 3, dtype=dtype) - masses = torch.rand(n_batches, n_atoms_per_batch, dtype=dtype) + 0.5 - kT = torch.tensor([0.1, 0.2, 0.3], dtype=dtype) - seeds = torch.arange(seed, seed + n_batches, dtype=torch.int64) - - # Run non-batched version first - unbatched_momenta = [] - for batch_idx in range(n_batches): - # Use corresponding seed for each batch - - state = MDState( - positions=positions[batch_idx], - momenta=torch.zeros_like(positions[batch_idx]), - masses=masses[batch_idx], - forces=torch.zeros_like(positions[batch_idx]), - energy=torch.zeros(1, dtype=dtype), - atomic_numbers=torch.ones(n_atoms_per_batch, dtype=torch.int64), - cell=torch.eye(3, dtype=dtype), - pbc=False, - ) - state = initialize_momenta( - state, kT[batch_idx], device, dtype, seed=int(seeds[batch_idx]) - ) - unbatched_momenta.append(state.momenta) - - unbatched_momenta = torch.stack(unbatched_momenta) - - # Run batched version - batched_momenta = batched_initialize_momenta( - positions, masses, kT, seeds, device, dtype - ) - - assert torch.allclose(batched_momenta, unbatched_momenta, rtol=1e-6) - - # Verify center of mass momentum is zero for each batch - for batch_idx in range(n_batches): - com_momentum = torch.mean(batched_momenta[batch_idx], dim=0) - assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype), atol=1e-10) + + # Create test inputs with some batches having single atoms + positions = torch.randn(5, 3, dtype=dtype, device=device) + masses = torch.rand(5, dtype=dtype, device=device) + 0.5 + batch = torch.tensor([0, 1, 1, 2, 3], device=device) # Batches 0, 2, and 3 have single atoms + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + + # Generate momenta and save the raw values before COM correction + generator = torch.Generator(device=device).manual_seed(seed) + raw_momenta = torch.randn(positions.shape, device=device, dtype=dtype, generator=generator) * \ + torch.sqrt(masses * kT[batch]).unsqueeze(-1) + + # Run the function + momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + + # Check that single-atom batches have unchanged momenta + for b in [0, 2, 3]: # Single atom batches + batch_mask = (batch == b) + # The momentum should be exactly the same as the raw value for single atoms + assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask]) + + # Check that multi-atom batches have zero COM + for b in [1]: # Multi-atom batches + batch_mask = (batch == b) + batch_momenta = momenta[batch_mask] + com_momentum = torch.mean(batch_momenta, dim=0) + assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10) def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel): diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index b75d8372..8745386c 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -69,65 +69,6 @@ def velocities(self) -> torch.Tensor: return self.momenta / self.masses.unsqueeze(-1) -def batched_initialize_momenta( - positions: torch.Tensor, # shape: (n_batches, n_atoms_per_batch, 3) - masses: torch.Tensor, # shape: (n_batches, n_atoms_per_batch) - kT: torch.Tensor, # shape: (n_batches,) - seeds: torch.Tensor, # shape: (n_batches,) - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - """Initialize momenta for batched molecular dynamics simulations. - - Generates random momenta following the Maxwell-Boltzmann distribution at the - specified temperature for each batch. The center of mass motion is removed - for each batch with more than one atom to prevent system drift. - - Args: - positions (torch.Tensor): Atomic positions [n_batches, n_atoms_per_batch, 3] - masses (torch.Tensor): Atomic masses [n_batches, n_atoms_per_batch] - kT (torch.Tensor): Temperature in energy units [n_batches] - seeds (torch.Tensor): Random seeds [n_batches] - device (torch.device): Torch device for tensor operations - dtype (torch.dtype): Torch data type for tensor precision - - Returns: - torch.Tensor: Random momenta [n_batches, n_atoms_per_batch, 3] - scaled to the specified temperature - """ - n_atoms_per_batch = positions.shape[1] - - # Create a generator for each batch using the provided seeds - generators = [torch.Generator(device=device).manual_seed(int(seed)) for seed in seeds] - - # Generate random momenta for all batches at once - momenta = torch.stack( - [ - torch.randn((n_atoms_per_batch, 3), device=device, dtype=dtype, generator=gen) - for gen in generators - ] - ) - - # Scale by sqrt(mass * kT) - mass_factors = torch.sqrt(masses).unsqueeze(-1) # shape: (n_batches, n_atoms, 1) - kT_factors = torch.sqrt(kT).view(-1, 1, 1) # shape: (n_batches, 1, 1) - momenta *= mass_factors * kT_factors - - # Remove center of mass motion for batches with more than one atom - # Calculate mean momentum for each batch - mean_momentum = torch.mean(momenta, dim=1, keepdim=True) # shape: (n_batches, 1, 3) - - # Create a mask for batches with more than one atom - multi_atom_mask = torch.tensor(n_atoms_per_batch > 1, device=device, dtype=torch.bool) - - # Subtract mean momentum where needed (broadcasting handles the rest) - return torch.where( - multi_atom_mask.view(-1, 1, 1), # shape: (n_batches, 1, 1) - momenta - mean_momentum, - momenta, - ) - - def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, @@ -166,11 +107,24 @@ def calculate_momenta( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) - # Center the momentum if more than one particle - if positions.shape[0] > 1: - momenta = momenta - torch.mean(momenta, dim=0, keepdim=True) - - return momenta + batchwise_momenta = torch.zeros((batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype) + + # create 3 copies of batch + batch_3 = batch.view(-1, 1).repeat(1, 3) + bincount = torch.bincount(batch) + mean_momenta = torch.scatter_reduce( + batchwise_momenta, + dim=0, + index=batch_3, + src=momenta, + reduce="sum", + ) / bincount.view(-1, 1) + adjusted_momenta = torch.where( + torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), + momenta - mean_momenta[batch], + momenta, + ) + return adjusted_momenta def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: From 1cff9b968d35dee578eab73bb147d4a90abd32d3 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 11:19:02 -0700 Subject: [PATCH 03/13] modify integrate function to convert temps to kTs earlier --- torch_sim/runners.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index c8cfdf54..8eea68e5 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -101,7 +101,7 @@ def integrate( *, integrator: Callable, n_steps: int, - temperature: float | ArrayLike, + temperature: float | list | torch.Tensor, timestep: float, trajectory_reporter: TrajectoryReporter | dict | None = None, autobatcher: ChunkingAutoBatcher | bool = False, @@ -138,9 +138,10 @@ def integrate( # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) dtype, device = state.dtype, state.device + kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature init_fn, update_fn = integrator( model=model, - kT=torch.tensor(temps[0] * unit_system.temperature, dtype=dtype, device=device), + kT=kTs[0], dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), **integrator_kwargs, ) @@ -164,7 +165,7 @@ def integrate( # run the simulation for step in range(1, n_steps + 1): - state = update_fn(state, kT=temps[step - 1] * unit_system.temperature) + state = update_fn(state, kT=kTs[step - 1]) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) From 0fe889dd75bc004d35324165a6d7f7ccc67f3cdc Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 11:19:24 -0700 Subject: [PATCH 04/13] remove unused fixture --- tests/conftest.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 18a9a8ef..3b11d8db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -293,35 +293,3 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: compute_stress=True, cutoff=2.5 * 3.405, ) - - -@pytest.fixture -def torchsim_trajectory( - si_sim_state: SimState, - lj_model: Any, - tmp_path: Path, - device: torch.device, - dtype: torch.dtype, -): - """Test NVE integration conserves energy.""" - # Initialize integrator - kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K - dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability - - state, update_fn = nve( - **asdict(si_sim_state), - model=lj_model, - dt=dt, - kT=kT, - ) - - reporter = TrajectoryReporter(tmp_path / "test.hdf5", state_frequency=1) - - # Run several steps - for step in range(10): - state = update_fn(state, dt) - reporter.report(state, step) - - yield reporter.trajectory - - reporter.close() From 782676170bba0f54c43ad70b74375a21b1f559b3 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 11:24:48 -0700 Subject: [PATCH 05/13] allow calculate momenta to take float --- torch_sim/integrators.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 8745386c..aaa39bdc 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -73,7 +73,7 @@ def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, batch: torch.Tensor, - kT: torch.Tensor, + kT: torch.Tensor | float, seed: int | None = None, ) -> torch.Tensor: """Initialize particle momenta based on temperature. @@ -98,6 +98,9 @@ def calculate_momenta( if seed is not None: generator.manual_seed(seed) + if isinstance(kT, float): + kT = torch.tensor(kT, device=device, dtype=dtype) + if len(kT.shape) > 0: # kT is a tensor with shape (n_batches,) kT = kT[batch] From 82b15fba3fd4469108da3adbff32c7388a4caee4 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 11:27:15 -0700 Subject: [PATCH 06/13] lint --- tests/conftest.py | 4 --- tests/test_integrators.py | 62 +++++++++++++++++++++------------------ torch_sim/integrators.py | 23 ++++++++++----- torch_sim/runners.py | 1 - 4 files changed, 50 insertions(+), 40 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3b11d8db..f89933d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -from dataclasses import asdict -from pathlib import Path from typing import Any import pytest @@ -13,9 +11,7 @@ from torch_sim.io import atoms_to_state from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.state import SimState, concatenate_states -from torch_sim.trajectory import TrajectoryReporter from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel -from torch_sim.unbatched.unbatched_integrators import nve @pytest.fixture diff --git a/tests/test_integrators.py b/tests/test_integrators.py index be9661fb..d7044f62 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -2,12 +2,7 @@ import torch -from torch_sim.integrators import ( - calculate_momenta, - npt_langevin, - nve, - nvt_langevin, -) +from torch_sim.integrators import calculate_momenta, npt_langevin, nve, nvt_langevin from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT from torch_sim.state import SimState, concatenate_states @@ -18,61 +13,70 @@ def test_calculate_momenta_basic(device: torch.device): """Test basic functionality of calculate_momenta.""" seed = 42 dtype = torch.float64 - + # Create test inputs for 3 batches with 2 atoms each n_atoms = 8 positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 - batch = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3], device=device) # 3 batches with 2 atoms each + batch = torch.tensor( + [0, 0, 1, 1, 2, 2, 3, 3], device=device + ) # 3 batches with 2 atoms each kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) - + # Run the function momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) - + # Basic checks assert momenta.shape == positions.shape assert momenta.dtype == dtype assert momenta.device == device - + # Check that each batch has zero center of mass momentum for b in range(4): - batch_mask = (batch == b) + batch_mask = batch == b batch_momenta = momenta[batch_mask] com_momentum = torch.mean(batch_momenta, dim=0) - assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10) + assert torch.allclose( + com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 + ) def test_calculate_momenta_single_atoms(device: torch.device): """Test that calculate_momenta preserves momentum for batches with single atoms.""" seed = 42 dtype = torch.float64 - + # Create test inputs with some batches having single atoms positions = torch.randn(5, 3, dtype=dtype, device=device) masses = torch.rand(5, dtype=dtype, device=device) + 0.5 - batch = torch.tensor([0, 1, 1, 2, 3], device=device) # Batches 0, 2, and 3 have single atoms + batch = torch.tensor( + [0, 1, 1, 2, 3], device=device + ) # Batches 0, 2, and 3 have single atoms kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) - + # Generate momenta and save the raw values before COM correction generator = torch.Generator(device=device).manual_seed(seed) - raw_momenta = torch.randn(positions.shape, device=device, dtype=dtype, generator=generator) * \ - torch.sqrt(masses * kT[batch]).unsqueeze(-1) - + raw_momenta = torch.randn( + positions.shape, device=device, dtype=dtype, generator=generator + ) * torch.sqrt(masses * kT[batch]).unsqueeze(-1) + # Run the function momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) - + # Check that single-atom batches have unchanged momenta for b in [0, 2, 3]: # Single atom batches - batch_mask = (batch == b) + batch_mask = batch == b # The momentum should be exactly the same as the raw value for single atoms assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask]) - + # Check that multi-atom batches have zero COM for b in [1]: # Multi-atom batches - batch_mask = (batch == b) + batch_mask = batch == b batch_momenta = momenta[batch_mask] com_momentum = torch.mean(batch_momenta, dim=0) - assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10) + assert torch.allclose( + com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 + ) def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel): @@ -136,8 +140,9 @@ def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated - -def test_npt_langevin_multi_kt(ar_double_sim_state: SimState, lj_model: LennardJonesModel): +def test_npt_langevin_multi_kt( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +): dtype = torch.float64 n_steps = 200 dt = torch.tensor(0.001, dtype=dtype) @@ -239,8 +244,9 @@ def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated - -def test_nvt_langevin_multi_kt(ar_double_sim_state: SimState, lj_model: LennardJonesModel): +def test_nvt_langevin_multi_kt( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +): dtype = torch.float64 n_steps = 200 dt = torch.tensor(0.001, dtype=dtype) diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index aaa39bdc..03b3a4a7 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -85,6 +85,7 @@ def calculate_momenta( Args: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] + batch (torch.Tensor): Batch indices [n_particles] kT (torch.Tensor): Temperature in energy units [n_batches] seed (int, optional): Random seed for reproducibility. Defaults to None. @@ -110,8 +111,10 @@ def calculate_momenta( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) - batchwise_momenta = torch.zeros((batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype) - + batchwise_momenta = torch.zeros( + (batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype + ) + # create 3 copies of batch batch_3 = batch.view(-1, 1).repeat(1, 3) bincount = torch.bincount(batch) @@ -122,12 +125,12 @@ def calculate_momenta( src=momenta, reduce="sum", ) / bincount.view(-1, 1) - adjusted_momenta = torch.where( + + return torch.where( torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), momenta - mean_momenta[batch], momenta, ) - return adjusted_momenta def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: @@ -249,7 +252,9 @@ def nve_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) initial_state = MDState( @@ -441,7 +446,9 @@ def langevin_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) initial_state = MDState( @@ -1094,7 +1101,9 @@ def npt_init( # Initialize momenta if not provided momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, state.batch, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) # Initialize cell parameters diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 8eea68e5..1a2ab44a 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -11,7 +11,6 @@ from itertools import chain import torch -from numpy.typing import ArrayLike from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher from torch_sim.models.interface import ModelInterface From 3f3a2d3130139399dabeb93f152dfc7126542dd6 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Thu, 10 Apr 2025 11:29:56 -0700 Subject: [PATCH 07/13] only do kT[batch] call if it is a tensor --- torch_sim/integrators.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 03b3a4a7..451e5549 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -99,10 +99,7 @@ def calculate_momenta( if seed is not None: generator.manual_seed(seed) - if isinstance(kT, float): - kT = torch.tensor(kT, device=device, dtype=dtype) - - if len(kT.shape) > 0: + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: # kT is a tensor with shape (n_batches,) kT = kT[batch] @@ -399,7 +396,7 @@ def ou_step( """ c1 = torch.exp(-gamma * dt) - if len(kT.shape) > 0: + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: # kT is a tensor with shape (n_batches,) kT = kT[state.batch] From b9fca1173b187592e04cdc9067f8eaeaa6592874 Mon Sep 17 00:00:00 2001 From: comprhys Date: Sat, 12 Apr 2025 12:57:30 -0700 Subject: [PATCH 08/13] fix sio2 rattled system --- tests/conftest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f89933d9..686e896d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -199,8 +199,10 @@ def rattled_sio2_sim_state(sio2_sim_state: SimState) -> SimState: try: # Temporarily set a fixed seed torch.manual_seed(3) - weibull = torch.distributions.weibull.Weibull(scale=0.5, concentration=1.0) - shifts = weibull.sample((sim_state.n_atoms, 3)) + weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) + rnd = torch.randn_like(sim_state.positions) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) + shifts = weibull.sample(rnd.shape) * rnd sim_state.positions = sim_state.positions + shifts finally: # Restore the original RNG state From 1084aea1f9dd62cb2b090d4dae7ab7471414d96e Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 12 Apr 2025 12:59:13 -0700 Subject: [PATCH 09/13] throw error if sevenet is created with float64 --- torch_sim/models/sevennet.py | 3 +++ 1 file changed, 3 insertions(+) mode change 100644 => 100755 torch_sim/models/sevennet.py diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py old mode 100644 new mode 100755 index fb907689..96746e67 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -90,6 +90,9 @@ def __init__( if isinstance(self._device, str): self._device = torch.device(self._device) + if torch.dtype is not torch.float32: + raise ValueError("SevenNetModel currently only supports float32") + self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" self._compute_stress = True From f6d4a4fda11d25bc9fa2b667ff78237c96ce22cf Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 12 Apr 2025 13:14:18 -0700 Subject: [PATCH 10/13] turn sevennet failure on float64 to warning --- torch_sim/models/sevennet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 96746e67..104824f5 100755 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -91,7 +91,11 @@ def __init__( self._device = torch.device(self._device) if torch.dtype is not torch.float32: - raise ValueError("SevenNetModel currently only supports float32") + warnings.warn( + "SevenNetModel currently only supports float32, but received different dtype", + UserWarning, + stacklevel=2, + ) self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" From 870e7ae8e63fe360d504704dada884ae158edfa6 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 12 Apr 2025 13:50:22 -0700 Subject: [PATCH 11/13] try fixing OOM on state initialization --- torch_sim/runners.py | 52 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 1a2ab44a..54f689af 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -144,7 +144,7 @@ def integrate( dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), **integrator_kwargs, ) - state = init_fn(state) + # state = init_fn(state) batch_iterator = _configure_batches_iterator(model, state, autobatcher) trajectory_reporter = _configure_reporter( @@ -155,6 +155,8 @@ def integrate( final_states: list[SimState] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None for state, batch_indices in batch_iterator: + state = init_fn(state) + # set up trajectory reporters if autobatcher and trajectory_reporter: # we must remake the trajectory reporter for each batch @@ -204,7 +206,6 @@ def _configure_hot_swapping_autobatcher( if isinstance(autobatcher, HotSwappingAutoBatcher): autobatcher.return_indices = True autobatcher.max_attempts = max_attempts - autobatcher.load_states(state) else: if autobatcher: memory_scales_with = model.memory_scales_with @@ -219,10 +220,46 @@ def _configure_hot_swapping_autobatcher( memory_scales_with=memory_scales_with, max_iterations=max_attempts, ) - autobatcher.load_states(state) return autobatcher +def _chunked_apply( + fn: Callable, + states: SimState, + model: ModelInterface, + max_memory_scaler: int, + memory_scales_with: str, +) -> SimState: + """Apply a function to a state in chunks. + + This prevents us from running out of memory when applying a function to a large + number of states. + + Args: + fn (Callable): The function to apply + states (SimState): The state to apply the function to + model (ModelInterface): The model to use for the autobatcher + max_memory_scaler (int): The maximum memory scaler + memory_scales_with (str): The memory scaling metric + + Returns: + A state with the function applied + """ + autobatcher = ChunkingAutoBatcher( + model=model, + return_indices=False, + max_memory_scaler=max_memory_scaler, + memory_scales_with=memory_scales_with, + ) + autobatcher.load_states(states) + initialized_states = [] + for batch in autobatcher: + initialized_states.append(fn(batch)) + + ordered_states = autobatcher.restore_original_order(initialized_states) + return concatenate_states(ordered_states) + + def generate_force_convergence_fn(force_tol: float = 1e-1) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. @@ -314,12 +351,19 @@ def optimize( # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) init_fn, update_fn = optimizer(model=model, **optimizer_kwargs) - state = init_fn(state) max_attempts = max_steps // steps_between_swaps autobatcher = _configure_hot_swapping_autobatcher( model, state, autobatcher, max_attempts ) + state = _chunked_apply( + init_fn, + state, + model, + autobatcher.max_memory_scaler, + autobatcher.memory_scales_with, + ) + autobatcher.load_states(state) trajectory_reporter = _configure_reporter( trajectory_reporter, properties=["potential_energy"], From e00f1e10cefaec5f797ab6d2c4d631b8bf66a199 Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 12 Apr 2025 13:56:02 -0700 Subject: [PATCH 12/13] lint --- torch_sim/models/sevennet.py | 3 ++- torch_sim/runners.py | 13 +++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) mode change 100755 => 100644 torch_sim/models/sevennet.py diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py old mode 100755 new mode 100644 index 104824f5..3a0c96d0 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -92,7 +92,8 @@ def __init__( if torch.dtype is not torch.float32: warnings.warn( - "SevenNetModel currently only supports float32, but received different dtype", + "SevenNetModel currently only supports" + "float32, but received different dtype", UserWarning, stacklevel=2, ) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 54f689af..75c64609 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -227,8 +227,7 @@ def _chunked_apply( fn: Callable, states: SimState, model: ModelInterface, - max_memory_scaler: int, - memory_scales_with: str, + **batcher_kwargs: dict, ) -> SimState: """Apply a function to a state in chunks. @@ -239,8 +238,7 @@ def _chunked_apply( fn (Callable): The function to apply states (SimState): The state to apply the function to model (ModelInterface): The model to use for the autobatcher - max_memory_scaler (int): The maximum memory scaler - memory_scales_with (str): The memory scaling metric + **batcher_kwargs: Additional keyword arguments for the autobatcher Returns: A state with the function applied @@ -248,13 +246,12 @@ def _chunked_apply( autobatcher = ChunkingAutoBatcher( model=model, return_indices=False, - max_memory_scaler=max_memory_scaler, - memory_scales_with=memory_scales_with, + **batcher_kwargs, ) autobatcher.load_states(states) initialized_states = [] - for batch in autobatcher: - initialized_states.append(fn(batch)) + + initialized_states = [fn(batch) for batch in autobatcher] ordered_states = autobatcher.restore_original_order(initialized_states) return concatenate_states(ordered_states) From 31ea94a05f87bc222b293b7e5c36f5fedffc90cf Mon Sep 17 00:00:00 2001 From: orionarcher Date: Sat, 12 Apr 2025 14:03:34 -0700 Subject: [PATCH 13/13] fix args in optimize --- torch_sim/runners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 75c64609..1ff09235 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -357,8 +357,8 @@ def optimize( init_fn, state, model, - autobatcher.max_memory_scaler, - autobatcher.memory_scales_with, + max_memory_scaler=autobatcher.max_memory_scaler, + memory_scales_with=autobatcher.memory_scales_with, ) autobatcher.load_states(state) trajectory_reporter = _configure_reporter(