From 059e57d6757c19dad7458e679a84891ad3fa796f Mon Sep 17 00:00:00 2001 From: Timo Reents Date: Thu, 17 Jul 2025 16:54:24 +0200 Subject: [PATCH] Fix concatenation of states in `InFlightAutoBatcher` The `velocities` and `cell_velocities` are initialized to `None` in the `(FrechetCell)FIREState`. However, when using the `InFlightAutoBatcher` during an optimization, the current and new states are concatenated in `torch_sim.state.concatenate_states`. When trying to merge states that were already processed for a few iterations (i.e., velocities are not None anymore) and newly initialized ones, an error is raised because the code tries to merge a `Tensor` with a `None`. Here, we initialize the `(cell_)velocities` as tensors full of `nan` instead, so that one can merge already processed and newly initialized states. During the first initialization, the `fire` methods look for `nan` rows and replace them with zeros. --- tests/test_autobatching.py | 2 +- torch_sim/optimizers.py | 44 +++++++++++++++++++++++++++----------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index 7be28997..73ade42e 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -490,7 +490,7 @@ def convergence_fn(state: ts.SimState) -> bool: break # run 10 steps, arbitrary number - for _ in range(10): + for _ in range(5): state = fire_update(state) convergence_tensor = convergence_fn(state) diff --git a/torch_sim/optimizers.py b/torch_sim/optimizers.py index 98a83d64..de8e839d 100644 --- a/torch_sim/optimizers.py +++ b/torch_sim/optimizers.py @@ -589,7 +589,9 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, # Optimization attributes @@ -862,13 +864,17 @@ def fire_init( atomic_numbers=state.atomic_numbers.clone(), system_idx=state.system_idx.clone(), pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=torch.zeros(n_systems, 3, 3, device=device, dtype=dtype), - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1161,13 +1167,17 @@ def fire_init( atomic_numbers=state.atomic_numbers, system_idx=state.system_idx, pbc=state.pbc, - velocities=None, + velocities=torch.full( + state.positions.shape, torch.nan, device=device, dtype=dtype + ), forces=forces, energy=energy, stress=stress, # Cell attributes cell_positions=cell_positions, - cell_velocities=None, + cell_velocities=torch.full( + cell_forces.shape, torch.nan, device=device, dtype=dtype + ), cell_forces=cell_forces, cell_masses=cell_masses, # Optimization attributes @@ -1244,15 +1254,19 @@ def _vv_fire_step( # noqa: C901, PLR0915 dtype = state.positions.dtype deform_grad_new: torch.Tensor | None = None - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) alpha_start_system = torch.full( @@ -1461,16 +1475,20 @@ def _ase_fire_step( # noqa: C901, PLR0915 cur_deform_grad = None # Initialize cur_deform_grad to prevent UnboundLocalError - if state.velocities is None: - state.velocities = torch.zeros_like(state.positions) + nan_velocities = state.velocities.isnan().any(dim=1) + if nan_velocities.any(): + state.velocities[nan_velocities] = torch.zeros_like( + state.positions[nan_velocities] + ) forces = state.forces if is_cell_optimization: if not isinstance(state, AnyFireCellState): raise ValueError( f"Cell optimization requires one of {get_args(AnyFireCellState)}." ) - state.cell_velocities = torch.zeros( - (n_systems, 3, 3), device=device, dtype=dtype + nan_cell_velocities = state.cell_velocities.isnan().any(dim=(1, 2)) + state.cell_velocities[nan_cell_velocities] = torch.zeros_like( + state.cell_positions[nan_cell_velocities] ) cur_deform_grad = state.deform_grad() else: