Skip to content

Commit

Permalink
guess_elements() function (#264)
Browse files Browse the repository at this point in the history
* new workflows.base.guess_elements() to get elements for RDKit because the MDAnalysis element guesser
  is not very good: here we use the accurate masses from the TPR file to guess any elements where the
  MDA guesser has problems (as indicated by mass discrepancy)
* rtol in guess_elements() to tune how close masses must match
  GROMACS masses are not exactly identical to MDAnalysis masses.
  We need a relative tolerance of at least 1e-3 to reliably
  match them.
* manually set all zero masses to DUMMY (MW is detected as DUMMY, DUMMY as D, so
  this makes all dummies DUMMY for consistency)
* hard coded ATOL=1e-6 for detecting dummies with very small discrepancy from 0
* add tests for known problem cases, no problems, and masses that are slightly off
* updated CHANGES
* updated docs

---------

Co-authored-by: Oliver Beckstein <orbeckst@gmail.com>
  • Loading branch information
cadeduckworth and orbeckst committed Aug 12, 2023
1 parent 10922e8 commit b8842a1
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Enhancements
* new workflows module (#217)
* new automated dihedral analysis workflow (detect dihedrals with SMARTS,
analyze with EnsembleAnalysis, and generate seaborn violinplots) (#217)
* new workflows.base.guess_elements() function to guess elements
from masses (PR #264)
* add new exit_on_error=False|True argument to run.runMD_or_exit() so
that failures just raise exceptions and not call sys.exit() (PR #249)

Expand Down
60 changes: 52 additions & 8 deletions mdpow/tests/test_workflows_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,83 @@

import pybol
import pytest
import numpy as np
from numpy.testing import assert_equal
import pandas as pd
import MDAnalysis as mda

from . import RESOURCES, MANIFEST, STATES
from pkg_resources import resource_filename
from mdpow.workflows import base

@pytest.fixture(scope='function')
@pytest.fixture
def molname_workflows_directory(tmp_path):
m = pybol.Manifest(str(MANIFEST))
m.assemble('workflows', tmp_path)
return tmp_path

class TestWorkflowsBase(object):

@pytest.fixture(scope='function')
@pytest.fixture
def universe(request):
masses, names = request.param
# build minimal test universe
u = mda.Universe.empty(n_atoms=len(names))
u.add_TopologyAttr("names", names)
u.add_TopologyAttr("masses", masses)
return u

@pytest.mark.parametrize("universe,elements",
[
[
(np.array([12.011, 14.007, 0, 12.011, 35.45, 12.011]),
np.array(["C", "Nx", "DUMMY", "C0S", "Cl123", "C0U"])),
np.array(['C', 'N', 'DUMMY', 'C', 'CL', 'C'])
],
[
(np.array([12.011, 14.007, 0, 35.45]),
np.array(["C", "Nx", "DUMMY", "Cl123"])),
np.array(['C', 'N', 'DUMMY', 'CL'])
],
[
(np.array([15.999, 0, 40.08, 40.08, 40.08, 24.305, 132.9]),
np.array(["OW", "MW", "C0", "CAL", "CA2+", "MG2+", "CES"])),
np.array(['O', 'DUMMY', 'CA', 'CA', 'CA', 'MG', 'CS'])
],
[
(np.array([16, 1e-6, 40.085, 133]),
np.array(["OW", "MW", "CA2+", "CES"])),
np.array(['O', 'DUMMY', 'CA', 'CS'])
],
],
indirect=["universe"])
def test_guess_elements(universe, elements):
u = universe
guessed_elements = base.guess_elements(u.atoms)

assert_equal(guessed_elements, elements)



class TestWorkflowsBase(object):
@pytest.fixture
def SM_tmp_dir(self, molname_workflows_directory):
dirname = molname_workflows_directory
return dirname

@pytest.fixture(scope='function')
@pytest.fixture
def csv_input_data(self):
csv_path = STATES['workflows'] / 'project_paths.csv'
csv_df = pd.read_csv(csv_path).reset_index(drop=True)
return csv_path, csv_df

@pytest.fixture(scope='function')
@pytest.fixture
def test_df_data(self):
test_dict = {'molecule' : ['SM25', 'SM26'],
'resname' : ['SM25', 'SM26']}
test_df = pd.DataFrame(test_dict).reset_index(drop=True)
return test_df

@pytest.fixture(scope='function')
@pytest.fixture
def project_paths_data(self, SM_tmp_dir):
project_paths = base.project_paths(parent_directory=SM_tmp_dir)
return project_paths
Expand All @@ -62,15 +106,15 @@ def test_project_paths_csv_input(self, csv_input_data):
def test_dihedral_analysis_figdir_requirement(self, project_paths_data, caplog):
caplog.clear()
caplog.set_level(logging.ERROR, logger='mdpow.workflows.base')

project_paths = project_paths_data
# change resname to match topology (every SAMPL7 resname is 'UNK')
# only necessary for this dataset, not necessary for normal use
project_paths['resname'] = 'UNK'

with pytest.raises(AssertionError,
match="figdir MUST be set, even though it is a kwarg. Will be changed with #244"):

base.automated_project_analysis(project_paths, solvents=('water',),
ensemble_analysis='DihedralAnalysis')

Expand Down
71 changes: 71 additions & 0 deletions mdpow/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
.. autofunction:: project_paths
.. autofunction:: automated_project_analysis
.. autofunction:: guess_elements
"""

import os
import re
import logging

import numpy as np
import pandas as pd
from MDAnalysis.topology import guessers, tables

logger = logging.getLogger('mdpow.workflows.base')

Expand Down Expand Up @@ -173,3 +176,71 @@ def automated_project_analysis(project_paths, ensemble_analysis, **kwargs):

logger.info('all analyses completed')
return

def guess_elements(atoms, rtol=1e-3):
"""guess elements for atoms from masses
Given masses, we perform a reverse lookup on
:data:`MDAnalysis.topology.tables.masses` to find the corresponding
element. Only atoms where the standard MDAnalysis guesser finds elements
with masses contradicting the topology masses are corrected.
.. Note:: This function *requires* correct masses to be present.
No sanity checks because MDPOW always uses TPR files that
contain correct masses.
:arguments:
*atoms*
MDAnalysis AtomGroup *with masses defined*
:keywords:
*rtol*
relative tolerance for a match (as used in :func:`numpy.isclose`);
atol=1e-6 is at a fixed value, which means that "zero" is only
recognized for values =< 1e-6
.. note:: In order to reliably match GROMACS masses, *rtol* should
be at least 1e-3.
:returns:
*elements*
array of guessed element symbols, in same order as `atoms`
.. rubric:: Example
As an example we guess masses and then set the elements for all atoms::
elements = guess_elements(atoms)
atoms.add_TopologyAttr("elements", elements)
"""
ATOL = 1e-6

names = atoms.names
masses = atoms.masses

mda_elements = np.fromiter(tables.masses.keys(), dtype="U5")
mda_masses = np.fromiter(tables.masses.values(), dtype=np.float64)

guessed_elements = guessers.guess_types(names)
guessed_masses = np.array([guessers.get_atom_mass(n) for n in guessed_elements])
problems = np.logical_not(np.isclose(masses, guessed_masses, atol=ATOL, rtol=rtol))

# match only problematic masses against the MDA reference masses
iproblem, ielem = np.nonzero(np.isclose(masses[problems, np.newaxis], mda_masses,
atol=ATOL, rtol=rtol))
# We should normally find a match for each problem but just in case, assert and
# give some useful information for debugging.
assert len(ielem) == sum(problems),\
("Not all masses could be assigned an element, "
f"missing names {set(names[problems]) - set(names[problems][iproblem])}")

guessed_elements[problems] = mda_elements[ielem]

# manually fix some dummies that are labelled "D": set ALL zero masses to DUMMY
guessed_elements[np.isclose(masses, 0, atol=ATOL)] = "DUMMY"

return np.array(guessed_elements)
5 changes: 4 additions & 1 deletion mdpow/workflows/dihedrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import svgutils.compose
import svgutils.transform

from .base import guess_elements
from ..analysis import ensemble, dihedral

logger = logging.getLogger('mdpow.workflows.dihedrals')
Expand Down Expand Up @@ -181,11 +182,13 @@ def rdkit_conversion(u, resname):
"""


try:
solute = u.select_atoms(f'resname {resname}')
mol = solute.convert_to('RDKIT')
except AttributeError:
u.add_TopologyAttr("elements", guess_types(u.atoms.names))
guessed_elements = guess_elements(u.atoms)
u.add_TopologyAttr("elements", guessed_elements)
solute = u.select_atoms(f'resname {resname}')
mol = solute.convert_to('RDKIT')

Expand Down

0 comments on commit b8842a1

Please sign in to comment.