diff --git a/tests/conftest.py b/tests/conftest.py index 18a9a8ef..686e896d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -from dataclasses import asdict -from pathlib import Path from typing import Any import pytest @@ -13,9 +11,7 @@ from torch_sim.io import atoms_to_state from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.state import SimState, concatenate_states -from torch_sim.trajectory import TrajectoryReporter from torch_sim.unbatched.models.lennard_jones import UnbatchedLennardJonesModel -from torch_sim.unbatched.unbatched_integrators import nve @pytest.fixture @@ -203,8 +199,10 @@ def rattled_sio2_sim_state(sio2_sim_state: SimState) -> SimState: try: # Temporarily set a fixed seed torch.manual_seed(3) - weibull = torch.distributions.weibull.Weibull(scale=0.5, concentration=1.0) - shifts = weibull.sample((sim_state.n_atoms, 3)) + weibull = torch.distributions.weibull.Weibull(scale=0.1, concentration=1) + rnd = torch.randn_like(sim_state.positions) + rnd = rnd / torch.norm(rnd, dim=-1, keepdim=True) + shifts = weibull.sample(rnd.shape) * rnd sim_state.positions = sim_state.positions + shifts finally: # Restore the original RNG state @@ -293,35 +291,3 @@ def lj_model(device: torch.device, dtype: torch.dtype) -> LennardJonesModel: compute_stress=True, cutoff=2.5 * 3.405, ) - - -@pytest.fixture -def torchsim_trajectory( - si_sim_state: SimState, - lj_model: Any, - tmp_path: Path, - device: torch.device, - dtype: torch.dtype, -): - """Test NVE integration conserves energy.""" - # Initialize integrator - kT = torch.tensor(300.0, device=device, dtype=dtype) # Temperature in K - dt = torch.tensor(0.001, device=device, dtype=dtype) # Small timestep for stability - - state, update_fn = nve( - **asdict(si_sim_state), - model=lj_model, - dt=dt, - kT=kT, - ) - - reporter = TrajectoryReporter(tmp_path / "test.hdf5", state_frequency=1) - - # Run several steps - for step in range(10): - state = update_fn(state, dt) - reporter.report(state, step) - - yield reporter.trajectory - - reporter.close() diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 26d22d73..d7044f62 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -2,174 +2,81 @@ import torch -from torch_sim.integrators import ( - MDState, - batched_initialize_momenta, - npt_langevin, - nve, - nvt_langevin, -) +from torch_sim.integrators import calculate_momenta, npt_langevin, nve, nvt_langevin from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.quantities import calc_kT from torch_sim.state import SimState, concatenate_states from torch_sim.units import MetalUnits -def batched_initialize_momenta_loop( - positions: torch.Tensor, # shape: (n_batches, n_atoms_per_batch, 3) - masses: torch.Tensor, # shape: (n_batches, n_atoms_per_batch) - kT: torch.Tensor, # shape: (n_batches,) - seeds: torch.Tensor, # shape: (n_batches,) - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - """Initialize momenta for batched molecular dynamics. - - Args: - positions: Tensor of atomic positions with shape - (n_batches, n_atoms_per_batch, 3). - masses: Tensor of atomic masses with shape - (n_batches, n_atoms_per_batch). - kT: Tensor of temperature values in energy units for - each batch with shape (n_batches,). - seeds: Tensor of random seeds for each batch with shape (n_batches,). - device: The device on which to allocate the tensors (e.g., 'cpu' or 'cuda'). - dtype: The data type of the tensors (e.g., torch.float32). - - Returns: - torch.Tensor: initialized momenta with shape (n_batches, n_atoms_per_batch, 3). - """ - n_batches = positions.shape[0] - n_atoms_per_batch = positions.shape[1] - - # Initialize momenta tensor - momenta = torch.zeros((n_batches, n_atoms_per_batch, 3), dtype=dtype) - - # Create a generator for each batch using the provided seeds - generators = [torch.Generator(device=device).manual_seed(int(seed)) for seed in seeds] - - # Generate random momenta for each batch - for batch_idx in range(n_batches): - # Generate random velocities from normal distribution - batch_momenta = torch.randn( - (n_atoms_per_batch, 3), dtype=dtype, generator=generators[batch_idx] - ) - - # Scale by sqrt(mass * kT) - mass_factors = torch.sqrt(masses[batch_idx]).unsqueeze(-1) - kT_factor = torch.sqrt(kT[batch_idx]) - batch_momenta *= mass_factors * kT_factor - - # Remove center of mass motion if more than one atom - if n_atoms_per_batch > 1: - mean_momentum = torch.mean(batch_momenta, dim=0, keepdim=True) - batch_momenta = batch_momenta - mean_momentum - - momenta[batch_idx] = batch_momenta - - return momenta - - -def test_batched_initialize_momenta_loop(): - from torch_sim.unbatched.unbatched_integrators import initialize_momenta - - # Set random seed for reproducibility +def test_calculate_momenta_basic(device: torch.device): + """Test basic functionality of calculate_momenta.""" seed = 42 - - device = torch.device("cpu") dtype = torch.float64 - n_batches = 3 - n_atoms_per_batch = 4 - - # Create test inputs - positions = torch.randn(n_batches, n_atoms_per_batch, 3, dtype=dtype) - masses = torch.rand(n_batches, n_atoms_per_batch, dtype=dtype) + 0.5 - kT = torch.tensor([0.1, 0.2, 0.3], dtype=dtype) - seeds = torch.arange(seed, seed + n_batches, dtype=torch.int64) - - # Run non-batched version first - unbatched_momenta = [] - for batch_idx in range(n_batches): - state = MDState( - positions=positions[batch_idx], - momenta=torch.zeros_like(positions[batch_idx]), - masses=masses[batch_idx], - forces=torch.zeros_like(positions[batch_idx]), - energy=torch.zeros(1, dtype=dtype), - atomic_numbers=torch.ones(n_atoms_per_batch, dtype=torch.int64), - cell=torch.eye(3, dtype=dtype), - pbc=False, - ) - state = initialize_momenta( - state, kT[batch_idx], device, dtype, seed=int(seeds[batch_idx]) + # Create test inputs for 3 batches with 2 atoms each + n_atoms = 8 + positions = torch.randn(n_atoms, 3, dtype=dtype, device=device) + masses = torch.rand(n_atoms, dtype=dtype, device=device) + 0.5 + batch = torch.tensor( + [0, 0, 1, 1, 2, 2, 3, 3], device=device + ) # 3 batches with 2 atoms each + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + + # Run the function + momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + + # Basic checks + assert momenta.shape == positions.shape + assert momenta.dtype == dtype + assert momenta.device == device + + # Check that each batch has zero center of mass momentum + for b in range(4): + batch_mask = batch == b + batch_momenta = momenta[batch_mask] + com_momentum = torch.mean(batch_momenta, dim=0) + assert torch.allclose( + com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) - unbatched_momenta.append(state.momenta) - - unbatched_momenta = torch.stack(unbatched_momenta) - - # Run batched version - batched_momenta = batched_initialize_momenta_loop( - positions=positions, - masses=masses, - kT=kT, - seeds=seeds, # seeds before device and dtype - device=device, - dtype=dtype, - ) - - assert torch.allclose(batched_momenta, unbatched_momenta, rtol=1e-6) - -def test_batched_initialize_momenta(): - from torch_sim.unbatched.unbatched_integrators import initialize_momenta +def test_calculate_momenta_single_atoms(device: torch.device): + """Test that calculate_momenta preserves momentum for batches with single atoms.""" seed = 42 - device = torch.device("cpu") dtype = torch.float64 - n_batches = 3 - n_atoms_per_batch = 4 - - # Create test inputs - positions = torch.randn(n_batches, n_atoms_per_batch, 3, dtype=dtype) - masses = torch.rand(n_batches, n_atoms_per_batch, dtype=dtype) + 0.5 - kT = torch.tensor([0.1, 0.2, 0.3], dtype=dtype) - seeds = torch.arange(seed, seed + n_batches, dtype=torch.int64) - - # Run non-batched version first - unbatched_momenta = [] - for batch_idx in range(n_batches): - # Use corresponding seed for each batch - - state = MDState( - positions=positions[batch_idx], - momenta=torch.zeros_like(positions[batch_idx]), - masses=masses[batch_idx], - forces=torch.zeros_like(positions[batch_idx]), - energy=torch.zeros(1, dtype=dtype), - atomic_numbers=torch.ones(n_atoms_per_batch, dtype=torch.int64), - cell=torch.eye(3, dtype=dtype), - pbc=False, + # Create test inputs with some batches having single atoms + positions = torch.randn(5, 3, dtype=dtype, device=device) + masses = torch.rand(5, dtype=dtype, device=device) + 0.5 + batch = torch.tensor( + [0, 1, 1, 2, 3], device=device + ) # Batches 0, 2, and 3 have single atoms + kT = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=dtype, device=device) + + # Generate momenta and save the raw values before COM correction + generator = torch.Generator(device=device).manual_seed(seed) + raw_momenta = torch.randn( + positions.shape, device=device, dtype=dtype, generator=generator + ) * torch.sqrt(masses * kT[batch]).unsqueeze(-1) + + # Run the function + momenta = calculate_momenta(positions, masses, batch, kT, seed=seed) + + # Check that single-atom batches have unchanged momenta + for b in [0, 2, 3]: # Single atom batches + batch_mask = batch == b + # The momentum should be exactly the same as the raw value for single atoms + assert torch.allclose(momenta[batch_mask], raw_momenta[batch_mask]) + + # Check that multi-atom batches have zero COM + for b in [1]: # Multi-atom batches + batch_mask = batch == b + batch_momenta = momenta[batch_mask] + com_momentum = torch.mean(batch_momenta, dim=0) + assert torch.allclose( + com_momentum, torch.zeros(3, dtype=dtype, device=device), atol=1e-10 ) - state = initialize_momenta( - state, kT[batch_idx], device, dtype, seed=int(seeds[batch_idx]) - ) - unbatched_momenta.append(state.momenta) - - unbatched_momenta = torch.stack(unbatched_momenta) - - # Run batched version - batched_momenta = batched_initialize_momenta( - positions, masses, kT, seeds, device, dtype - ) - - assert torch.allclose(batched_momenta, unbatched_momenta, rtol=1e-6) - - # Verify center of mass momentum is zero for each batch - for batch_idx in range(n_batches): - com_momentum = torch.mean(batched_momenta[batch_idx], dim=0) - assert torch.allclose(com_momentum, torch.zeros(3, dtype=dtype), atol=1e-10) def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel): @@ -233,11 +140,57 @@ def test_npt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated +def test_npt_langevin_multi_kt( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300.0, 10000.0], dtype=dtype) * MetalUnits.temperature + external_pressure = torch.tensor(0.0, dtype=dtype) * MetalUnits.pressure + + # Initialize integrator + init_fn, update_fn = npt_langevin( + model=lj_model, + dt=dt, + kT=kT, + external_pressure=external_pressure, + alpha=40 * dt, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT(state.momenta, state.masses, batch=state.batch) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 dt = torch.tensor(0.001, dtype=dtype) - kT = torch.tensor(100.0, dtype=dtype) * MetalUnits.temperature + kT = torch.tensor(300, dtype=dtype) * MetalUnits.temperature # Initialize integrator init_fn, update_fn = nvt_langevin( @@ -273,7 +226,7 @@ def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory for mean_temp in mean_temps: assert ( - abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0 + abs(mean_temp - kT.item() / MetalUnits.temperature) < 100.0 ) # Allow for thermal fluctuations # Check energy is stable for each trajectory @@ -291,6 +244,49 @@ def test_nvt_langevin(ar_double_sim_state: SimState, lj_model: LennardJonesModel assert pos_diff > 0.0001 # Systems should remain separated +def test_nvt_langevin_multi_kt( + ar_double_sim_state: SimState, lj_model: LennardJonesModel +): + dtype = torch.float64 + n_steps = 200 + dt = torch.tensor(0.001, dtype=dtype) + kT = torch.tensor([300.0, 10000.0], dtype=dtype) * MetalUnits.temperature + + # Initialize integrator + init_fn, update_fn = nvt_langevin( + model=lj_model, + dt=dt, + kT=kT, + ) + + # Run dynamics for several steps + state = init_fn(state=ar_double_sim_state, seed=42) + energies = [] + temperatures = [] + for _step in range(n_steps): + state = update_fn(state=state) + + # Calculate instantaneous temperature from kinetic energy + temp = calc_kT(state.momenta, state.masses, batch=state.batch) + energies.append(state.energy) + temperatures.append(temp / MetalUnits.temperature) + + # Convert temperatures list to tensor + temperatures_tensor = torch.stack(temperatures) + temperatures_list = [t.tolist() for t in temperatures_tensor.T] + + energies_tensor = torch.stack(energies) + energies_list = [t.tolist() for t in energies_tensor.T] + + # Basic sanity checks + assert len(energies_list[0]) == n_steps + assert len(temperatures_list[0]) == n_steps + + # Check temperature is roughly maintained for each trajectory + mean_temps = torch.mean(temperatures_tensor, dim=0) # Mean temp for each trajectory + assert torch.allclose(mean_temps, kT / MetalUnits.temperature, rtol=0.5) + + def test_nve(ar_double_sim_state: SimState, lj_model: LennardJonesModel): dtype = torch.float64 n_steps = 100 diff --git a/torch_sim/integrators.py b/torch_sim/integrators.py index 746665e2..451e5549 100644 --- a/torch_sim/integrators.py +++ b/torch_sim/integrators.py @@ -69,69 +69,11 @@ def velocities(self) -> torch.Tensor: return self.momenta / self.masses.unsqueeze(-1) -def batched_initialize_momenta( - positions: torch.Tensor, # shape: (n_batches, n_atoms_per_batch, 3) - masses: torch.Tensor, # shape: (n_batches, n_atoms_per_batch) - kT: torch.Tensor, # shape: (n_batches,) - seeds: torch.Tensor, # shape: (n_batches,) - device: torch.device, - dtype: torch.dtype, -) -> torch.Tensor: - """Initialize momenta for batched molecular dynamics simulations. - - Generates random momenta following the Maxwell-Boltzmann distribution at the - specified temperature for each batch. The center of mass motion is removed - for each batch with more than one atom to prevent system drift. - - Args: - positions (torch.Tensor): Atomic positions [n_batches, n_atoms_per_batch, 3] - masses (torch.Tensor): Atomic masses [n_batches, n_atoms_per_batch] - kT (torch.Tensor): Temperature in energy units [n_batches] - seeds (torch.Tensor): Random seeds [n_batches] - device (torch.device): Torch device for tensor operations - dtype (torch.dtype): Torch data type for tensor precision - - Returns: - torch.Tensor: Random momenta [n_batches, n_atoms_per_batch, 3] - scaled to the specified temperature - """ - n_atoms_per_batch = positions.shape[1] - - # Create a generator for each batch using the provided seeds - generators = [torch.Generator(device=device).manual_seed(int(seed)) for seed in seeds] - - # Generate random momenta for all batches at once - momenta = torch.stack( - [ - torch.randn((n_atoms_per_batch, 3), device=device, dtype=dtype, generator=gen) - for gen in generators - ] - ) - - # Scale by sqrt(mass * kT) - mass_factors = torch.sqrt(masses).unsqueeze(-1) # shape: (n_batches, n_atoms, 1) - kT_factors = torch.sqrt(kT).view(-1, 1, 1) # shape: (n_batches, 1, 1) - momenta *= mass_factors * kT_factors - - # Remove center of mass motion for batches with more than one atom - # Calculate mean momentum for each batch - mean_momentum = torch.mean(momenta, dim=1, keepdim=True) # shape: (n_batches, 1, 3) - - # Create a mask for batches with more than one atom - multi_atom_mask = torch.tensor(n_atoms_per_batch > 1, device=device, dtype=torch.bool) - - # Subtract mean momentum where needed (broadcasting handles the rest) - return torch.where( - multi_atom_mask.view(-1, 1, 1), # shape: (n_batches, 1, 1) - momenta - mean_momentum, - momenta, - ) - - def calculate_momenta( positions: torch.Tensor, masses: torch.Tensor, - kT: torch.Tensor, + batch: torch.Tensor, + kT: torch.Tensor | float, seed: int | None = None, ) -> torch.Tensor: """Initialize particle momenta based on temperature. @@ -143,6 +85,7 @@ def calculate_momenta( Args: positions (torch.Tensor): Particle positions [n_particles, n_dim] masses (torch.Tensor): Particle masses [n_particles] + batch (torch.Tensor): Batch indices [n_particles] kT (torch.Tensor): Temperature in energy units [n_batches] seed (int, optional): Random seed for reproducibility. Defaults to None. @@ -156,16 +99,35 @@ def calculate_momenta( if seed is not None: generator.manual_seed(seed) + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: + # kT is a tensor with shape (n_batches,) + kT = kT[batch] + # Generate random momenta from normal distribution momenta = torch.randn( positions.shape, device=device, dtype=dtype, generator=generator ) * torch.sqrt(masses * kT).unsqueeze(-1) - # Center the momentum if more than one particle - if positions.shape[0] > 1: - momenta = momenta - torch.mean(momenta, dim=0, keepdim=True) + batchwise_momenta = torch.zeros( + (batch[-1] + 1, momenta.shape[1]), device=device, dtype=dtype + ) - return momenta + # create 3 copies of batch + batch_3 = batch.view(-1, 1).repeat(1, 3) + bincount = torch.bincount(batch) + mean_momenta = torch.scatter_reduce( + batchwise_momenta, + dim=0, + index=batch_3, + src=momenta, + reduce="sum", + ) / bincount.view(-1, 1) + + return torch.where( + torch.repeat_interleave(bincount > 1, bincount).view(-1, 1), + momenta - mean_momenta[batch], + momenta, + ) def momentum_step(state: MDState, dt: torch.Tensor) -> MDState: @@ -287,7 +249,9 @@ def nve_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) initial_state = MDState( @@ -431,7 +395,12 @@ def ou_step( where c1 = exp(-gamma*dt) and c2 = sqrt(kT*(1-c1²)) """ c1 = torch.exp(-gamma * dt) - c2 = torch.sqrt(kT * (1 - c1**2)) + + if isinstance(kT, torch.Tensor) and len(kT.shape) > 0: + # kT is a tensor with shape (n_batches,) + kT = kT[state.batch] + + c2 = torch.sqrt(kT * (1 - c1**2)).unsqueeze(-1) # Generate random noise from normal distribution noise = torch.randn_like(state.momenta, device=state.device, dtype=state.dtype) @@ -474,7 +443,9 @@ def langevin_init( model_output = model(state) momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) initial_state = MDState( @@ -1127,7 +1098,9 @@ def npt_init( # Initialize momenta if not provided momenta = getattr( - state, "momenta", calculate_momenta(state.positions, state.masses, kT, seed) + state, + "momenta", + calculate_momenta(state.positions, state.masses, state.batch, kT, seed), ) # Initialize cell parameters diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index fb907689..3a0c96d0 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -90,6 +90,14 @@ def __init__( if isinstance(self._device, str): self._device = torch.device(self._device) + if torch.dtype is not torch.float32: + warnings.warn( + "SevenNetModel currently only supports" + "float32, but received different dtype", + UserWarning, + stacklevel=2, + ) + self._dtype = dtype self._memory_scales_with = "n_atoms_x_density" self._compute_stress = True diff --git a/torch_sim/runners.py b/torch_sim/runners.py index c8cfdf54..1ff09235 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -11,7 +11,6 @@ from itertools import chain import torch -from numpy.typing import ArrayLike from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher from torch_sim.models.interface import ModelInterface @@ -101,7 +100,7 @@ def integrate( *, integrator: Callable, n_steps: int, - temperature: float | ArrayLike, + temperature: float | list | torch.Tensor, timestep: float, trajectory_reporter: TrajectoryReporter | dict | None = None, autobatcher: ChunkingAutoBatcher | bool = False, @@ -138,13 +137,14 @@ def integrate( # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) dtype, device = state.dtype, state.device + kTs = torch.tensor(temps, dtype=dtype, device=device) * unit_system.temperature init_fn, update_fn = integrator( model=model, - kT=torch.tensor(temps[0] * unit_system.temperature, dtype=dtype, device=device), + kT=kTs[0], dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), **integrator_kwargs, ) - state = init_fn(state) + # state = init_fn(state) batch_iterator = _configure_batches_iterator(model, state, autobatcher) trajectory_reporter = _configure_reporter( @@ -155,6 +155,8 @@ def integrate( final_states: list[SimState] = [] og_filenames = trajectory_reporter.filenames if trajectory_reporter else None for state, batch_indices in batch_iterator: + state = init_fn(state) + # set up trajectory reporters if autobatcher and trajectory_reporter: # we must remake the trajectory reporter for each batch @@ -164,7 +166,7 @@ def integrate( # run the simulation for step in range(1, n_steps + 1): - state = update_fn(state, kT=temps[step - 1] * unit_system.temperature) + state = update_fn(state, kT=kTs[step - 1]) if trajectory_reporter: trajectory_reporter.report(state, step, model=model) @@ -204,7 +206,6 @@ def _configure_hot_swapping_autobatcher( if isinstance(autobatcher, HotSwappingAutoBatcher): autobatcher.return_indices = True autobatcher.max_attempts = max_attempts - autobatcher.load_states(state) else: if autobatcher: memory_scales_with = model.memory_scales_with @@ -219,10 +220,43 @@ def _configure_hot_swapping_autobatcher( memory_scales_with=memory_scales_with, max_iterations=max_attempts, ) - autobatcher.load_states(state) return autobatcher +def _chunked_apply( + fn: Callable, + states: SimState, + model: ModelInterface, + **batcher_kwargs: dict, +) -> SimState: + """Apply a function to a state in chunks. + + This prevents us from running out of memory when applying a function to a large + number of states. + + Args: + fn (Callable): The function to apply + states (SimState): The state to apply the function to + model (ModelInterface): The model to use for the autobatcher + **batcher_kwargs: Additional keyword arguments for the autobatcher + + Returns: + A state with the function applied + """ + autobatcher = ChunkingAutoBatcher( + model=model, + return_indices=False, + **batcher_kwargs, + ) + autobatcher.load_states(states) + initialized_states = [] + + initialized_states = [fn(batch) for batch in autobatcher] + + ordered_states = autobatcher.restore_original_order(initialized_states) + return concatenate_states(ordered_states) + + def generate_force_convergence_fn(force_tol: float = 1e-1) -> Callable: """Generate a force-based convergence function for the convergence_fn argument of the optimize function. @@ -314,12 +348,19 @@ def optimize( # initialize the state state: SimState = initialize_state(system, model.device, model.dtype) init_fn, update_fn = optimizer(model=model, **optimizer_kwargs) - state = init_fn(state) max_attempts = max_steps // steps_between_swaps autobatcher = _configure_hot_swapping_autobatcher( model, state, autobatcher, max_attempts ) + state = _chunked_apply( + init_fn, + state, + model, + max_memory_scaler=autobatcher.max_memory_scaler, + memory_scales_with=autobatcher.memory_scales_with, + ) + autobatcher.load_states(state) trajectory_reporter = _configure_reporter( trajectory_reporter, properties=["potential_energy"],