Skip to content

Commit

Permalink
Atom mapping 2 (#592)
Browse files Browse the repository at this point in the history
New atom mapping architecture and general atom mapping method
  • Loading branch information
alongd committed Mar 19, 2023
2 parents 6d91cd7 + efe70e3 commit 2405f56
Show file tree
Hide file tree
Showing 13 changed files with 2,984 additions and 1,533 deletions.
3 changes: 3 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ tests:
'Module: Conformers':
- arc/species/conformers.py

'Module: Mapping':
- arc/mapping/*

'Module: Converter':
- arc/species/converter.py

Expand Down
6 changes: 3 additions & 3 deletions arc/checks/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
)
from arc.imports import settings
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
from arc.species.mapping import (get_atom_indices_of_labeled_atoms_in_an_rmg_reaction,
get_rmg_reactions_from_arc_reaction,
)
from arc.mapping.engine import (get_atom_indices_of_labeled_atoms_in_an_rmg_reaction,
get_rmg_reactions_from_arc_reaction,
)
from arc.statmech.factory import statmech_factory

if TYPE_CHECKING:
Expand Down
48 changes: 48 additions & 0 deletions arc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1665,3 +1665,51 @@ def dfs(mol: Molecule,
stack.append(mol.atoms.index(atom))
visited = sorted(visited) if sort_result else visited
return visited


def sort_atoms_in_descending_label_order(mol: 'Molecule')-> None:
"""
If all atoms in the molecule object has a label, This function reassign the
.atoms in Molecule with a list of atoms with the orders based on the labels of the atoms.
for example, [int(atom.label) for atom in mol.atoms] is [1, 4, 32, 7],
then the function will return the new atom with the order [1, 4, 7, 32]
Args:
mol: An rmg Molecule object, with labeld atoms
"""
if any(atom.label is None for atom in mol.atoms):
return None
try:
mol.atoms = sorted(mol.atoms, key = lambda x: int(x.label))
except ValueError:
logger.warning(f"Some atom(s) in molecule.atoms are not integers.\nGot {[atom.label for atom in mol.atoms]}")
return None


def is_xyz_mol_match(mol: 'Molecule',
xyz: dict) -> bool:
"""
A helper function that matches rmgpy.molecule.molecule.Molecule object to an xyz,
used in _scissors to match xyz and the cut products.
This function only checks the molecular formula.
Args:
mol: rmg Molecule object
xyz: coordinates of the cut product
Returns:
bool: ``True`` if the xyz and molecule match, ``False`` otherwise
"""
element_dict_mol = mol.get_element_count()

element_dict_xyz = dict()
for atom in xyz['symbols']:
if atom in element_dict_xyz:
element_dict_xyz[atom] += 1
else:
element_dict_xyz[atom] = 1

for element, count in element_dict_mol.items():
if element not in element_dict_xyz or element_dict_xyz[element] != count:
return False
return True
56 changes: 47 additions & 9 deletions arc/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import copy
import datetime
import os
import random
import time
import unittest

import numpy as np
import pandas as pd
from random import shuffle

from rmgpy.molecule.molecule import Molecule
from rmgpy.species import Species
Expand All @@ -22,7 +22,7 @@
from arc.exceptions import InputError, SettingsError
from arc.imports import settings
from arc.rmgdb import make_rmg_database_object, load_families_only
from arc.species.mapping import get_rmg_reactions_from_arc_reaction
from arc.mapping.engine import get_rmg_reactions_from_arc_reaction
import arc.species.converter as converter
from arc.reaction import ARCReaction
from arc.species.species import ARCSpecies
Expand Down Expand Up @@ -897,7 +897,6 @@ def test_globalize_path(self):
globalized_string = common.globalize_path(string=string, project_directory='~/Code/runs/run_1/')
self.assertEqual(globalized_string, ' project_directory: ~/Code/runs/run_1/')


def test_estimate_orca_mem_cpu_requirement(self):
"""Test estimating memory and cpu requirements for an Orca job."""
num_heavy_atoms_0 = 0
Expand Down Expand Up @@ -1238,22 +1237,25 @@ def test_safe_copy_file(self):
common.safe_copy_file(source=source_path, destination=destination_path)
os.remove(destination_path)

def test_sort_atoms_in_decending_label_order(self):
"""tests the sort_atoms_in_decending_label_order function"""
def test_sort_atoms_in_descending_label_order(self):
"""tests the sort_atoms_in_descending_label_order function"""
mol = Molecule(smiles="C1CCCC1")
for index, atom in enumerate(mol.atoms):
atom.label = str(index)
random.shuffle(mol.atoms)
common.sort_atoms_in_descending_label_order(mol)
shuffle(mol.atoms)
common.sort_atoms_in_descending_label_order(mol=mol)
for index, atom in enumerate(mol.atoms):
self.assertEqual(str(index), atom.label)
mol = Molecule(smiles="NC1=NC=NC2=C1N=CN2")
for index, atom in enumerate(mol.atoms):
atom.label = str(index)
random.shuffle(mol.atoms)
common.sort_atoms_in_descending_label_order(mol)
shuffle(mol.atoms)
common.sort_atoms_in_descending_label_order(mol=mol)
for index, atom in enumerate(mol.atoms):
self.assertEqual(str(index), atom.label)
mol = Molecule(smiles = "C")
mol.atoms[0].label = "a"
self.assertIsNone(common.sort_atoms_in_descending_label_order(mol=mol))

def test_check_r_n_p_symbols_between_rmg_and_arc_rxns(self):
"""Test the _check_r_n_p_symbols_between_rmg_and_arc_rxns() function"""
Expand Down Expand Up @@ -1331,6 +1333,42 @@ def test_dfs(self):
self.assertEqual(visited, [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
visited = common.dfs(mol=mol, start=21)
self.assertEqual(visited, [5, 18, 19, 20, 21])

def test_is_xyz_mol_match(self):
"""test the is_xyz_mol_match function"""
xyz1 = {'coords': ((0.9177905887, 0.5194617797, 0.0),
(1.8140204898, 1.0381941417, 0.0),
(-0.4763167868, 0.7509348722, 0.0),
(0.999235086, -0.7048575683, 0.0),
(-1.4430010939, 0.0274543367, 0.0),
(-0.6371484821, -0.7497769134, 0.0),
(-2.0093636431, 0.0331190314, -0.8327683174),
(-2.0093636431, 0.0331190314, 0.8327683174)),
'isotopes': (14, 1, 1, 14, 14, 1, 1, 1),
'symbols': ('N', 'H', 'H', 'N', 'N', 'H', 'H', 'H')}
mols = converter.molecules_from_xyz(xyz1)
mol1 = mols[0] or mols[1]
self.assertTrue(common.is_xyz_mol_match(mol1, xyz1))
mol2 = Molecule(smiles = "CCCC")
self.assertFalse(common.is_xyz_mol_match(mol2, xyz1))
xyz2 = {'coords': ((-1.917881683438569, -0.2559899676506647, -0.18387537398950518),
(-0.46613900647877093, -0.5015648201803543, -0.5627969270693719),
(0.46613896124631365, 0.5015651413962865, 0.11351208539519701),
(1.9178817615848325, 0.2559897963130066, -0.2654088973727458),
(-2.2384553065778596, 0.7457958651105412, -0.4871989751724505),
(-2.06265761472002, -0.35084272427108437, 0.8970458843864058),
(-2.567308636432751, -0.9848304423976812, -0.6788069909092624),
(-0.3618598454997, -0.4315366161604987, -1.6518340624292893),
(-0.18705172809174478, -1.521998299852992, -0.27538426879116396),
(0.3618591035654003, 0.43153688022623043, 1.2025494067409448),
(0.18705199598927585, 1.5219988047901418, -0.17390048059784674),
(2.238454802004582, -0.7457965095921628, 0.037914346099216116),
(2.0626580972825423, 0.3508433121264869, -1.3463299029756333),
(2.5673090995664274, 0.9848295801427642, 0.22952353914113932)),
'isotopes': (12, 12, 12, 12, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
'symbols': ('C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H')}
self.assertFalse(common.is_xyz_mol_match(mol1, xyz2))
self.assertTrue(common.is_xyz_mol_match(mol2, xyz2))

@classmethod
def tearDownClass(cls):
Expand Down
2 changes: 1 addition & 1 deletion arc/job/adapters/ts/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from arc.job.factory import register_job_adapter
from arc.plotter import save_geo
from arc.species.converter import compare_zmats, relocate_zmat_dummy_atoms_to_the_end, zmat_from_xyz, zmat_to_xyz
from arc.species.mapping import map_arc_rmg_species, map_two_species
from arc.mapping.engine import map_arc_rmg_species, map_two_species
from arc.species.species import ARCSpecies, TSGuess, colliding_atoms
from arc.species.zmat import get_parameter_from_atom_indices, remove_1st_atom, up_param

Expand Down
52 changes: 45 additions & 7 deletions arc/job/adapters/ts/heuristics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,9 @@ def test_keeping_atom_order_in_ts(self):
self.assertIn(rxn_1.atom_map[1], [0, 1])
for index in [2, 3, 4, 5, 6, 7]:
self.assertIn(rxn_1.atom_map[index], [2, 3, 4, 5, 6, 16])
self.assertEqual(rxn_1.atom_map[8:], [7, 8, 9, 10, 13, 11, 12, 14, 15])
self.assertEqual(rxn_1.atom_map[8:12], [7, 8, 9, 10])
self.assertIn(tuple(rxn_1.atom_map[12:15]), itertools.permutations([13, 11, 12]))
self.assertIn(rxn_1.atom_map[15:], [[14, 15], [15, 14]])
heuristics_1 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_1],
testing=True,
Expand All @@ -933,7 +935,12 @@ def test_keeping_atom_order_in_ts(self):
ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz)])
rxn_2.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_2.family.label, 'H_Abstraction')
self.assertEqual(rxn_2.atom_map, [11, 10, 9, 16, 15, 14, 12, 13, 0, 1, 2, 3, 6, 4, 5, 7, 8])
self.assertEqual(rxn_2.atom_map[:2], [11, 10])
self.assertIn(tuple(rxn_2.atom_map[2:5]), itertools.permutations([9, 16, 15]))
self.assertIn(tuple(rxn_2.atom_map[5:8]), itertools.permutations([12, 13, 14]))
self.assertEqual(rxn_2.atom_map[8:12], [0, 1, 2, 3])
self.assertIn(tuple(rxn_2.atom_map[12:15]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_2.atom_map[15:]), itertools.permutations([7, 8]))
heuristics_2 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_2],
testing=True,
Expand All @@ -952,7 +959,13 @@ def test_keeping_atom_order_in_ts(self):
p_species=[ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz),
ARCSpecies(label='CCOOH', smiles='CCOO', xyz=self.ccooh_xyz)])
rxn_3.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_3.atom_map, [7, 8, 9, 10, 13, 11, 12, 14, 15, 1, 0, 16, 6, 5, 4, 2, 3])
self.assertEqual(rxn_3.atom_map[:4], [7, 8, 9, 10])
self.assertIn(tuple(rxn_3.atom_map[4:7]), itertools.permutations([11, 12, 13]))
self.assertIn(tuple(rxn_3.atom_map[7:9]), itertools.permutations([14, 15]))
self.assertEqual(rxn_3.atom_map[9:11], [1, 0])
self.assertIn(tuple(rxn_3.atom_map[11:14]), itertools.permutations([16, 5, 6]))
self.assertIn(tuple(rxn_3.atom_map[14:]), itertools.permutations([3, 4, 2]))

heuristics_3 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_3],
testing=True,
Expand All @@ -971,7 +984,12 @@ def test_keeping_atom_order_in_ts(self):
p_species=[ARCSpecies(label='CCOOH', smiles='CCOO', xyz=self.ccooh_xyz),
ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz)])
rxn_4.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_4.atom_map, [0, 1, 2, 3, 6, 4, 5, 7, 8, 11, 10, 9, 16, 15, 14, 12, 13])
self.assertEqual(rxn_4.atom_map[:4], [0, 1, 2, 3])
self.assertIn(tuple(rxn_4.atom_map[4:7]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_4.atom_map[7:9]), itertools.permutations([7, 8]))
self.assertEqual(rxn_4.atom_map[9:11], [11, 10])
self.assertIn(tuple(rxn_4.atom_map[11:14]), itertools.permutations([9, 15, 16]))
self.assertIn(tuple(rxn_4.atom_map[14:]), itertools.permutations([12, 13, 14 ]))
heuristics_4 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_4],
testing=True,
Expand Down Expand Up @@ -1093,9 +1111,29 @@ def test_get_new_zmat2_map(self):
reactant_2=reactant_2,
reactants_reversed=True,
)
expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0,
25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
# To determine if this test fails for atom-mapping related reasons, use the following xyz:
# xyz_7 = {'coords': ((-0.11052302098955041, -0.5106945989206113, -2.3628726319919022),
# (-0.11052302098955041, -0.5106945989206113, -1.16140301180269),
# (-0.11052302098955023, -0.5106945989206112, 0.3150305498367847),
# (1.2448888490560643, -0.9827789526552368, 0.8404002762169092),
# (-0.4375559903969747, 0.8159552435098156, 0.8744100775429131),
# (-0.7036838926552011, 1.8955361195204183, 1.3296134184916002),
# (-0.11052302098955026, -0.5106945989206114, -3.4285156134952786),
# (-1.0248180325342278, -1.3649565013173555, 0.7257981498364177),
# (1.4854985838822663, -1.9838179319127962, 0.46442407690321375),
# (1.2491645770965545, -1.0250999599192192, 1.9356267705316639),
# (-0.939726019056252, 2.853070310535801, 1.733355993511537)),
# 'isotopes': (12, 12, 12, 12, 12, 12, 1, 1, 1, 1, 1),
# 'symbols': ('C', 'C', 'C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'H')}
# To generate a reaction, and check it's atom mapping!
# Another mapping option to try is:
# expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
# 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0,
# 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4,
25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}

self.assertEqual(new_map, expected_new_map)

def test_get_new_map_based_on_zmat_1(self):
Expand Down

0 comments on commit 2405f56

Please sign in to comment.