From 9eabe668274be18d649cbfbd6be117b5ae9f42c5 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Fri, 3 Jul 2026 22:15:01 +0100 Subject: [PATCH] Backport fix from PR #172. [ci skip] --- src/somd2/runner/_base.py | 104 ++++++++++++++++++++++++--- tests/runner/test_alchemical_ions.py | 96 ++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 11 deletions(-) diff --git a/src/somd2/runner/_base.py b/src/somd2/runner/_base.py index 589da8d..4b3ef03 100644 --- a/src/somd2/runner/_base.py +++ b/src/somd2/runner/_base.py @@ -362,14 +362,27 @@ def __init__(self, system, config): # Create alchemical ions. ion_indices = [] if charge_diff != 0: - self._system, coalchemical_restraints, ion_indices = ( + # On restart, reuse the exact molecule(s) chosen as alchemical ions + # in the original run, rather than re-running the "furthest waters" + # search. This makes ion selection independent of GCMC state (or + # anything else that might change between runs), since the search + # is otherwise only reproducible by assumption, not by construction. + mol_indices = None + if self._config.restart: + mol_indices = self._load_alchemical_ion_indices() + + self._system, coalchemical_restraints, ion_indices, ion_mol_indices = ( self._create_alchemical_ions( self._system, charge_diff, restraint_distance=self._config.coalchemical_restraint_dist, + mol_indices=mol_indices, ) ) + # Keep the stored indices in sync for any future restart. + self._save_alchemical_ion_indices(ion_mol_indices) + # Add the coalchemical restraints to the extra args. if coalchemical_restraints is not None: self._config._extra_args["coalchemical_restraints"] = ( @@ -1045,8 +1058,49 @@ def _get_charge_difference(system): return perturbed - reference + def _save_alchemical_ion_indices(self, mol_indices): + """ + Persist the absolute molecule index of each alchemical ion to a small, + dedicated file in the output directory, independent of the per-window + (regular runner) or shared (repex) checkpoint formats. This allows a + restart to reuse the exact same ion(s) chosen in the original run. + + Parameters + ---------- + + mol_indices: [int] + The absolute molecule index of each alchemical ion. + """ + import numpy as _np + + path = self._config.output_directory / "alchemical_ions.npz" + _np.savez(path, mol_indices=_np.array(mol_indices, dtype=int)) + + def _load_alchemical_ion_indices(self): + """ + Load the absolute molecule index of each alchemical ion previously + stored by `_save_alchemical_ion_indices`, if present. + + Returns + ------- + + mol_indices: [int], None + The absolute molecule index of each alchemical ion, or None if no + stored indices are available (e.g. a fresh run, or a restart from + an output directory that predates this feature). + """ + import numpy as _np + + path = self._config.output_directory / "alchemical_ions.npz" + try: + return _np.load(path)["mol_indices"].tolist() + except Exception: + return None + @staticmethod - def _create_alchemical_ions(system, charge_diff, restraint_distance=None): + def _create_alchemical_ions( + system, charge_diff, restraint_distance=None, mol_indices=None + ): """ Internal function to create alchemical ions to maintain a constant charge. @@ -1059,6 +1113,14 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): charge_diff: int The charge difference between perturbed and reference states. + mol_indices: [int] + The absolute molecule index (position in `system.molecules()`) of + each water to convert into an alchemical ion. If provided, these + molecules are converted directly, bypassing the "furthest waters" + search. Used on restart to reproduce the exact same ion(s) chosen + in the original run, independent of any GCMC state or changes to + the search heuristic. Must have the same length as `abs(charge_diff)`. + Returns ------- @@ -1073,6 +1135,11 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): The perturbable-molecule index of each alchemical ion that was added, suitable for use with `LambdaSchedule.set_molecule_schedule `. + + ion_mol_indices: [int] + The absolute molecule index (position in `system.molecules()`, + prior to any conversion) of each alchemical ion that was added. + Suitable for passing back in as `mol_indices` on a restart. """ from sire.legacy.IO import createChlorineIon as _createChlorineIon @@ -1116,12 +1183,28 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): f"{len(system['water'].molecules())} available." ) - # Reference coordinates. - coords = system.molecules("property is_perturbable").coordinates() - coord_string = f"{coords[0].value()}, {coords[1].value()}, {coords[2].value()}" + if mol_indices is not None: + if len(mol_indices) != num_waters: + raise ValueError( + f"Number of stored alchemical-ion molecule indices " + f"({len(mol_indices)}) does not match the current charge " + f"difference ({num_waters} waters required)." + ) + + # Reuse the exact molecules chosen in the original run. + all_mols = system.molecules() + waters = [all_mols[idx] for idx in mol_indices] + else: + # Reference coordinates. + coords = system.molecules("property is_perturbable").coordinates() + coord_string = ( + f"{coords[0].value()}, {coords[1].value()}, {coords[2].value()}" + ) - # Find the furthest N waters from the perturbable molecule. - waters = system[f"furthest {num_waters} waters from {coord_string}"].molecules() + # Find the furthest N waters from the perturbable molecule. + waters = system[ + f"furthest {num_waters} waters from {coord_string}" + ].molecules() # Determine the water model. if waters[0].num_atoms() == 3: @@ -1141,6 +1224,10 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): # Store the molecule numbers of the alchemical ions. ion_numbers = [] + # Store the absolute molecule index of each alchemical ion (prior to + # conversion), for persisting across restarts. + ion_mol_indices = [] + # Create the ions. for water in waters: # Flag to indicate whether we need to reverse the alchemical ion @@ -1261,6 +1348,7 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): # Get the index of the perturbed water. index = numbers.index(water.number()) + ion_mol_indices.append(index) # Log that we are adding an alchemical ion. if is_reverse: @@ -1283,7 +1371,7 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None): perturbable_mols = system.molecules()["perturbable"].molecules() ion_indices = [perturbable_mols.find(system[number]) for number in ion_numbers] - return system, restraints, ion_indices + return system, restraints, ion_indices, ion_mol_indices @staticmethod def _create_filenames(lambda_array, lambda_value, output_directory, restart=False): diff --git a/tests/runner/test_alchemical_ions.py b/tests/runner/test_alchemical_ions.py index f4a4bfe..2c5d7fe 100644 --- a/tests/runner/test_alchemical_ions.py +++ b/tests/runner/test_alchemical_ions.py @@ -1,7 +1,8 @@ import math +import pytest import tempfile -import pytest +from pathlib import Path from somd2.config import Config from somd2.runner import Runner @@ -15,20 +16,109 @@ def test_alchemical_ions(mols, request): mols = request.getfixturevalue(mols).clone() # Add 10 Cl- ions. - new_mols, _, ion_indices = Runner._create_alchemical_ions(mols, 10) + new_mols, _, ion_indices, ion_mol_indices = Runner._create_alchemical_ions(mols, 10) # Make sure the charge difference is correct. assert math.isclose(Runner._get_charge_difference(new_mols), -10.0, rel_tol=1e-6) # Make sure there is one perturbable-molecule index per ion. assert len(ion_indices) == 10 + assert len(ion_mol_indices) == 10 # Add 10 Na+ ions. - new_mols, _, ion_indices = Runner._create_alchemical_ions(mols, -10) + new_mols, _, ion_indices, ion_mol_indices = Runner._create_alchemical_ions( + mols, -10 + ) # Make sure the charge difference is correct. assert math.isclose(Runner._get_charge_difference(new_mols), 10.0, rel_tol=1e-6) assert len(ion_indices) == 10 + assert len(ion_mol_indices) == 10 + + +@pytest.mark.parametrize("mols", ["ethane_methanol", "ethane_methanol_ions"]) +def test_alchemical_ion_mol_indices_reproducible(mols, request): + """ + Ensure that passing the molecule indices returned by a previous call to + `_create_alchemical_ions` reproduces the exact same ion(s), bypassing the + "furthest waters" search entirely. This is what a restart relies on. + """ + mols = request.getfixturevalue(mols).clone() + + # Pick ions via the heuristic search, recording which molecule(s) were + # converted. + heuristic_mols, _, _, ion_mol_indices = Runner._create_alchemical_ions(mols, 3) + heuristic_ion_numbers = { + mol.number() + for mol in heuristic_mols.molecules()["perturbable"].molecules() + if mol.has_property("is_alchemical_ion") + } + + # Reuse the stored indices directly - should convert the exact same + # molecules, without running the search. + replayed_mols, _, _, replayed_mol_indices = Runner._create_alchemical_ions( + mols, 3, mol_indices=ion_mol_indices + ) + replayed_ion_numbers = { + mol.number() + for mol in replayed_mols.molecules()["perturbable"].molecules() + if mol.has_property("is_alchemical_ion") + } + + assert replayed_ion_numbers == heuristic_ion_numbers + assert replayed_mol_indices == ion_mol_indices + assert math.isclose( + Runner._get_charge_difference(replayed_mols), -3.0, rel_tol=1e-6 + ) + + +def test_alchemical_ion_mol_indices_mismatch_raises(ethane_methanol): + """A stored index count that doesn't match the charge difference should + raise a clear error, rather than silently converting the wrong number of + waters.""" + mols = ethane_methanol.clone() + + with pytest.raises(ValueError, match="does not match the current charge"): + Runner._create_alchemical_ions(mols, 3, mol_indices=[0, 1]) + + +def test_alchemical_ion_restart_reuses_same_ion(ethane_methanol_ions): + """ + Ensure that restarting a run picks the exact same alchemical ion as the + original run, via the persisted `alchemical_ions.npz` file, rather than + re-running the "furthest waters" search from scratch. + """ + mols = ethane_methanol_ions.clone() + + with tempfile.TemporaryDirectory() as tmpdir: + base_config = dict( + output_directory=tmpdir, + platform="cpu", + charge_difference=1, + ) + + # Fresh run: picks an ion via the heuristic search and persists its + # molecule index to alchemical_ions.npz. + runner1 = Runner(mols.clone(), Config(restart=False, **base_config)) + ion_number_1 = next( + mol.number() + for mol in runner1._system.molecules()["perturbable"].molecules() + if mol.has_property("is_alchemical_ion") + ) + + assert (Path(tmpdir) / "alchemical_ions.npz").exists() + + # "Restart": construct a new Runner against the same input and output + # directory. It should reuse the stored ion index rather than + # re-running the search. + runner2 = Runner(mols.clone(), Config(restart=True, **base_config)) + ion_number_2 = next( + mol.number() + for mol in runner2._system.molecules()["perturbable"].molecules() + if mol.has_property("is_alchemical_ion") + ) + + assert ion_number_1 == ion_number_2 @pytest.mark.parametrize("schedule_name", ["decouple", "annihilate"])