Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,27 @@ def __init__(self, system, config):
# Create alchemical ions.
ion_indices = []
if charge_diff != 0:
self._system, coalchemical_restraints, ion_indices = (
# On restart, reuse the exact molecule(s) chosen as alchemical ions
# in the original run, rather than re-running the "furthest waters"
# search. This makes ion selection independent of GCMC state (or
# anything else that might change between runs), since the search
# is otherwise only reproducible by assumption, not by construction.
mol_indices = None
if self._config.restart:
mol_indices = self._load_alchemical_ion_indices()

self._system, coalchemical_restraints, ion_indices, ion_mol_indices = (
self._create_alchemical_ions(
self._system,
charge_diff,
restraint_distance=self._config.coalchemical_restraint_dist,
mol_indices=mol_indices,
)
)

# Keep the stored indices in sync for any future restart.
self._save_alchemical_ion_indices(ion_mol_indices)

# Add the coalchemical restraints to the extra args.
if coalchemical_restraints is not None:
self._config._extra_args["coalchemical_restraints"] = (
Expand Down Expand Up @@ -1045,8 +1058,49 @@ def _get_charge_difference(system):

return perturbed - reference

def _save_alchemical_ion_indices(self, mol_indices):
"""
Persist the absolute molecule index of each alchemical ion to a small,
dedicated file in the output directory, independent of the per-window
(regular runner) or shared (repex) checkpoint formats. This allows a
restart to reuse the exact same ion(s) chosen in the original run.

Parameters
----------

mol_indices: [int]
The absolute molecule index of each alchemical ion.
"""
import numpy as _np

path = self._config.output_directory / "alchemical_ions.npz"
_np.savez(path, mol_indices=_np.array(mol_indices, dtype=int))

def _load_alchemical_ion_indices(self):
"""
Load the absolute molecule index of each alchemical ion previously
stored by `_save_alchemical_ion_indices`, if present.

Returns
-------

mol_indices: [int], None
The absolute molecule index of each alchemical ion, or None if no
stored indices are available (e.g. a fresh run, or a restart from
an output directory that predates this feature).
"""
import numpy as _np

path = self._config.output_directory / "alchemical_ions.npz"
try:
return _np.load(path)["mol_indices"].tolist()
except Exception:
return None

@staticmethod
def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
def _create_alchemical_ions(
system, charge_diff, restraint_distance=None, mol_indices=None
):
"""
Internal function to create alchemical ions to maintain a constant charge.

Expand All @@ -1059,6 +1113,14 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
charge_diff: int
The charge difference between perturbed and reference states.

mol_indices: [int]
The absolute molecule index (position in `system.molecules()`) of
each water to convert into an alchemical ion. If provided, these
molecules are converted directly, bypassing the "furthest waters"
search. Used on restart to reproduce the exact same ion(s) chosen
in the original run, independent of any GCMC state or changes to
the search heuristic. Must have the same length as `abs(charge_diff)`.

Returns
-------

Expand All @@ -1073,6 +1135,11 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
The perturbable-molecule index of each alchemical ion that was
added, suitable for use with
`LambdaSchedule.set_molecule_schedule <sire.cas.LambdaSchedule>`.

ion_mol_indices: [int]
The absolute molecule index (position in `system.molecules()`,
prior to any conversion) of each alchemical ion that was added.
Suitable for passing back in as `mol_indices` on a restart.
"""

from sire.legacy.IO import createChlorineIon as _createChlorineIon
Expand Down Expand Up @@ -1116,12 +1183,28 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
f"{len(system['water'].molecules())} available."
)

# Reference coordinates.
coords = system.molecules("property is_perturbable").coordinates()
coord_string = f"{coords[0].value()}, {coords[1].value()}, {coords[2].value()}"
if mol_indices is not None:
if len(mol_indices) != num_waters:
raise ValueError(
f"Number of stored alchemical-ion molecule indices "
f"({len(mol_indices)}) does not match the current charge "
f"difference ({num_waters} waters required)."
)

# Reuse the exact molecules chosen in the original run.
all_mols = system.molecules()
waters = [all_mols[idx] for idx in mol_indices]
else:
# Reference coordinates.
coords = system.molecules("property is_perturbable").coordinates()
coord_string = (
f"{coords[0].value()}, {coords[1].value()}, {coords[2].value()}"
)

# Find the furthest N waters from the perturbable molecule.
waters = system[f"furthest {num_waters} waters from {coord_string}"].molecules()
# Find the furthest N waters from the perturbable molecule.
waters = system[
f"furthest {num_waters} waters from {coord_string}"
].molecules()

# Determine the water model.
if waters[0].num_atoms() == 3:
Expand All @@ -1141,6 +1224,10 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
# Store the molecule numbers of the alchemical ions.
ion_numbers = []

# Store the absolute molecule index of each alchemical ion (prior to
# conversion), for persisting across restarts.
ion_mol_indices = []

# Create the ions.
for water in waters:
# Flag to indicate whether we need to reverse the alchemical ion
Expand Down Expand Up @@ -1261,6 +1348,7 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):

# Get the index of the perturbed water.
index = numbers.index(water.number())
ion_mol_indices.append(index)

# Log that we are adding an alchemical ion.
if is_reverse:
Expand All @@ -1283,7 +1371,7 @@ def _create_alchemical_ions(system, charge_diff, restraint_distance=None):
perturbable_mols = system.molecules()["perturbable"].molecules()
ion_indices = [perturbable_mols.find(system[number]) for number in ion_numbers]

return system, restraints, ion_indices
return system, restraints, ion_indices, ion_mol_indices

@staticmethod
def _create_filenames(lambda_array, lambda_value, output_directory, restart=False):
Expand Down
96 changes: 93 additions & 3 deletions tests/runner/test_alchemical_ions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import pytest
import tempfile

import pytest
from pathlib import Path

from somd2.config import Config
from somd2.runner import Runner
Expand All @@ -15,20 +16,109 @@ def test_alchemical_ions(mols, request):
mols = request.getfixturevalue(mols).clone()

# Add 10 Cl- ions.
new_mols, _, ion_indices = Runner._create_alchemical_ions(mols, 10)
new_mols, _, ion_indices, ion_mol_indices = Runner._create_alchemical_ions(mols, 10)

# Make sure the charge difference is correct.
assert math.isclose(Runner._get_charge_difference(new_mols), -10.0, rel_tol=1e-6)

# Make sure there is one perturbable-molecule index per ion.
assert len(ion_indices) == 10
assert len(ion_mol_indices) == 10

# Add 10 Na+ ions.
new_mols, _, ion_indices = Runner._create_alchemical_ions(mols, -10)
new_mols, _, ion_indices, ion_mol_indices = Runner._create_alchemical_ions(
mols, -10
)

# Make sure the charge difference is correct.
assert math.isclose(Runner._get_charge_difference(new_mols), 10.0, rel_tol=1e-6)
assert len(ion_indices) == 10
assert len(ion_mol_indices) == 10


@pytest.mark.parametrize("mols", ["ethane_methanol", "ethane_methanol_ions"])
def test_alchemical_ion_mol_indices_reproducible(mols, request):
"""
Ensure that passing the molecule indices returned by a previous call to
`_create_alchemical_ions` reproduces the exact same ion(s), bypassing the
"furthest waters" search entirely. This is what a restart relies on.
"""
mols = request.getfixturevalue(mols).clone()

# Pick ions via the heuristic search, recording which molecule(s) were
# converted.
heuristic_mols, _, _, ion_mol_indices = Runner._create_alchemical_ions(mols, 3)
heuristic_ion_numbers = {
mol.number()
for mol in heuristic_mols.molecules()["perturbable"].molecules()
if mol.has_property("is_alchemical_ion")
}

# Reuse the stored indices directly - should convert the exact same
# molecules, without running the search.
replayed_mols, _, _, replayed_mol_indices = Runner._create_alchemical_ions(
mols, 3, mol_indices=ion_mol_indices
)
replayed_ion_numbers = {
mol.number()
for mol in replayed_mols.molecules()["perturbable"].molecules()
if mol.has_property("is_alchemical_ion")
}

assert replayed_ion_numbers == heuristic_ion_numbers
assert replayed_mol_indices == ion_mol_indices
assert math.isclose(
Runner._get_charge_difference(replayed_mols), -3.0, rel_tol=1e-6
)


def test_alchemical_ion_mol_indices_mismatch_raises(ethane_methanol):
"""A stored index count that doesn't match the charge difference should
raise a clear error, rather than silently converting the wrong number of
waters."""
mols = ethane_methanol.clone()

with pytest.raises(ValueError, match="does not match the current charge"):
Runner._create_alchemical_ions(mols, 3, mol_indices=[0, 1])


def test_alchemical_ion_restart_reuses_same_ion(ethane_methanol_ions):
"""
Ensure that restarting a run picks the exact same alchemical ion as the
original run, via the persisted `alchemical_ions.npz` file, rather than
re-running the "furthest waters" search from scratch.
"""
mols = ethane_methanol_ions.clone()

with tempfile.TemporaryDirectory() as tmpdir:
base_config = dict(
output_directory=tmpdir,
platform="cpu",
charge_difference=1,
)

# Fresh run: picks an ion via the heuristic search and persists its
# molecule index to alchemical_ions.npz.
runner1 = Runner(mols.clone(), Config(restart=False, **base_config))
ion_number_1 = next(
mol.number()
for mol in runner1._system.molecules()["perturbable"].molecules()
if mol.has_property("is_alchemical_ion")
)

assert (Path(tmpdir) / "alchemical_ions.npz").exists()

# "Restart": construct a new Runner against the same input and output
# directory. It should reuse the stored ion index rather than
# re-running the search.
runner2 = Runner(mols.clone(), Config(restart=True, **base_config))
ion_number_2 = next(
mol.number()
for mol in runner2._system.molecules()["perturbable"].molecules()
if mol.has_property("is_alchemical_ion")
)

assert ion_number_1 == ion_number_2


@pytest.mark.parametrize("schedule_name", ["decouple", "annihilate"])
Expand Down