Skip to content

Commit

Permalink
Merge pull request #542 from ReactionMechanismGenerator/scissors_fix
Browse files Browse the repository at this point in the history
Scissors fix
  • Loading branch information
alongd committed Aug 15, 2022
2 parents 05e5ab6 + bb985c3 commit 58b5698
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 12 deletions.
15 changes: 15 additions & 0 deletions arc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,3 +1616,18 @@ def safe_copy_file(source: str,
break
if i >= max_cycles:
break


def sort_atoms_in_decending_label_order(mol: 'Molecule') -> None:
"""
A helper function, helpful in the context of atom mapping.
This function reassign the .atoms in Molecule with a list of atoms
with the orders based on the labels of the atoms.
for example, [int(atom.label) for atom in mol.atoms] is [1, 4, 32, 7],
then the function will re-assign the new atom with the order [1, 4, 7, 32]
Args:
mol (Molecule): An RMG Molecule object, with labeled atoms
"""
mol.atoms = sorted(mol.atoms, key= lambda x : int(x.label))

18 changes: 18 additions & 0 deletions arc/commonTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import copy
import datetime
import os
import random
import time
import unittest

Expand Down Expand Up @@ -1186,6 +1187,23 @@ def test_calc_rmsd(self):
rmsd_5 = common.calc_rmsd(a_5, b_5)
self.assertAlmostEqual(rmsd_5, 3.1622776601683795)

def test_sort_atoms_in_decending_label_order(self):
"""tests the sort_atoms_in_decending_label_order function"""
mol = Molecule(smiles="C1CCCC1")
for index, atom in enumerate(mol.atoms):
atom.label = str(index)
random.shuffle(mol.atoms)
common.sort_atoms_in_decending_label_order(mol)
for index, atom in enumerate(mol.atoms):
self.assertEqual(str(index), atom.label)
mol = Molecule(smiles="NC1=NC=NC2=C1N=CN2")
for index, atom in enumerate(mol.atoms):
atom.label = str(index)
random.shuffle(mol.atoms)
common.sort_atoms_in_decending_label_order(mol)
for index, atom in enumerate(mol.atoms):
self.assertEqual(str(index), atom.label)

def test_check_r_n_p_symbols_between_rmg_and_arc_rxns(self):
"""Test the _check_r_n_p_symbols_between_rmg_and_arc_rxns() function"""
arc_rxn = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='OH', smiles='[OH]')],
Expand Down
8 changes: 4 additions & 4 deletions arc/species/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def map_h_abstraction(rxn: 'ARCReaction',
)
spc_r1_h2.final_xyz = spc_r1_h2.get_xyz() # Scissors requires the .final_xyz attribute to be populated.
try:
spc_r1_h2_cuts = spc_r1_h2.scissors()
spc_r1_h2_cuts = spc_r1_h2.scissors(sort_atom_labels=True)
except SpeciesError:
return None
spc_r1_h2_cut = [spc for spc in spc_r1_h2_cuts if spc.label != 'H'][0] \
Expand All @@ -279,7 +279,7 @@ def map_h_abstraction(rxn: 'ARCReaction',
)
spc_r3_h2.final_xyz = spc_r3_h2.get_xyz()
try:
spc_r3_h2_cuts = spc_r3_h2.scissors()
spc_r3_h2_cuts = spc_r3_h2.scissors(sort_atom_labels=True)
except SpeciesError:
return None
spc_r3_h2_cut = [spc for spc in spc_r3_h2_cuts if spc.label != 'H'][0] \
Expand Down Expand Up @@ -447,7 +447,7 @@ def map_intra_h_migration(rxn: 'ARCReaction',
)
spc_r.final_xyz = spc_r.get_xyz() # Scissors requires the .final_xyz attribute to be populated.
try:
spc_r_dot = [spc for spc in spc_r.scissors() if spc.label != 'H'][0]
spc_r_dot = [spc for spc in spc_r.scissors(sort_atom_labels=True) if spc.label != 'H'][0]
except SpeciesError:
return None
spc_p = ARCSpecies(label='P',
Expand All @@ -457,7 +457,7 @@ def map_intra_h_migration(rxn: 'ARCReaction',
)
spc_p.final_xyz = spc_p.get_xyz()
try:
spc_p_dot = [spc for spc in spc_p.scissors() if spc.label != 'H'][0]
spc_p_dot = [spc for spc in spc_p.scissors(sort_atom_labels=True) if spc.label != 'H'][0]
except SpeciesError:
return None
map_ = map_two_species(spc_r_dot, spc_p_dot, backend=backend)
Expand Down
44 changes: 36 additions & 8 deletions arc/species/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
rmg_mol_from_dict_repr,
rmg_mol_to_dict_repr,
timedelta_from_str,
sort_atoms_in_decending_label_order,
)
from arc.exceptions import InputError, RotorError, SpeciesError, TSError
from arc.imports import settings
Expand Down Expand Up @@ -248,6 +249,7 @@ class ARCSpecies(object):
rxn_index (int): The reaction index which is the respective key to the Scheduler rxn_dict.
arkane_file (str): Path to the Arkane Species file generated in processor.
yml_path (str): Path to an Arkane YAML file representing a species (for loading the object).
keep_mol (bool): Label to prevent the generation of a new Molecule object.
checkfile (str): The local path to the latest checkfile by Gaussian for the species.
external_symmetry (int): The external symmetry of the species (not including rotor symmetries).
optical_isomers (int): Whether (=2) or not (=1) the species has chiral center/s.
Expand Down Expand Up @@ -317,6 +319,7 @@ def __init__(self,
ts_number: Optional[int] = None,
xyz: Optional[Union[list, dict, str]] = None,
yml_path: Optional[str] = None,
keep_mol: bool = False,
):
self.t1 = None
self.ts_number = ts_number
Expand Down Expand Up @@ -347,6 +350,7 @@ def __init__(self,
self.checkfile = checkfile
self.transport_data = TransportData()
self.yml_path = yml_path
self.keep_mol = keep_mol
self.fragments = fragments
self.original_label = None
self.chosen_ts = None
Expand Down Expand Up @@ -1543,7 +1547,8 @@ def mol_from_xyz(self,
f'{self.mol.copy(deep=True).to_smiles()}\n'
f'{self.mol.copy(deep=True).to_adjacency_list()}')
raise SpeciesError(f'XYZ and the 2D graph representation for {self.label} are not compliant.')
self.mol = perceived_mol
if not self.keep_mol:
self.mol = perceived_mol
else:
mol_s, mol_b = molecules_from_xyz(xyz, multiplicity=self.multiplicity, charge=self.charge)
if mol_b is not None and len(mol_b.atoms) == self.number_of_atoms:
Expand Down Expand Up @@ -1764,13 +1769,25 @@ def check_xyz_isomorphism(self,
logger.warning('Allowing nonisomorphic 2D')
return isomorphic

def scissors(self) -> list:
def label_atoms(self):
"""
Labels atoms in order.
The label is stored in the atom.label property.
"""
for index, atom in enumerate(self.mol.atoms):
atom.label = str(index)

def scissors(self,
sort_atom_labels: bool = False) -> list:
"""
Cut chemical bonds to create new species from the original one according to the .bdes attribute,
preserving the 3D geometry other than the scissioned bond.
If one of the scission-resulting species is a hydrogen atom, it will be returned last, labeled as 'H'.
Other species labels will be <original species label>_BDE_index1_index2_X, where "X" is either "A" or "B",
and the indices are 1-indexed.
Args:
sort_atom_labels (bool, optional): Boolean flag, dettermines whether or not sorting is required.
Returns: list
The scission-resulting species.
Expand All @@ -1794,21 +1811,26 @@ def scissors(self) -> list:
atom_indices_reverse = (atom_indices[1], atom_indices[0])
if atom_indices not in self.bdes and atom_indices_reverse not in self.bdes:
self.bdes.append(atom_indices)
if sort_atom_labels:
self.label_atoms()
resulting_species = list()
for index_tuple in self.bdes:
new_species_list = self._scissors(indices=index_tuple)
new_species_list = self._scissors(indices=index_tuple, sort_atom_labels=sort_atom_labels)
for new_species in new_species_list:
if new_species.label not in [existing_species.label for existing_species in resulting_species]:
# Mainly checks that the H species doesn't already exist.
resulting_species.append(new_species)
return resulting_species

def _scissors(self, indices: tuple) -> list:
def _scissors(self,
indices: tuple,
sort_atom_labels: bool = True) -> list:
"""
Cut a chemical bond to create two new species from the original one, preserving the 3D geometry.
Args:
indices (tuple): The atom indices between which to cut (1-indexed, atoms must be bonded).
sort_atom_labels (bool, optional): Boolean flag, dettermines whether or not sorting is required.
Returns: list
The scission-resulting species, a list of either one or two species, if the scissored location is linear,
Expand Down Expand Up @@ -1860,6 +1882,10 @@ def _scissors(self, indices: tuple) -> list:
logger.warning(f'Scissors were requested to remove a non-single bond in {self.label}.')
mol_copy.remove_bond(bond)
mol_splits = mol_copy.split()
if sort_atom_labels:
for split in mol_splits:
sort_atoms_in_decending_label_order(split)

if len(mol_splits) == 1: # If cutting leads to only one split, then the split is cyclic.
spc1 = ARCSpecies(label=self.label + '_BDE_' + str(indices[0] + 1) + '_' + str(indices[1] + 1) + '_cyclic',
mol=mol_splits[0],
Expand Down Expand Up @@ -1902,8 +1928,8 @@ def _scissors(self, indices: tuple) -> list:
else:
raise SpeciesError(f'Could not figure out which atom should gain a radical '
f'due to scission in {self.label}')
mol1.update(raise_atomtype_exception=False)
mol2.update(raise_atomtype_exception=False)
mol1.update(log_species=False, raise_atomtype_exception=False, sort_atoms=False)
mol2.update(log_species=False, raise_atomtype_exception=False, sort_atoms=False)

# match xyz to mol:
if len(mol1.atoms) != len(mol2.atoms):
Expand Down Expand Up @@ -1934,7 +1960,8 @@ def _scissors(self, indices: tuple) -> list:
multiplicity=mol1.multiplicity,
charge=mol1.get_net_charge(),
compute_thermo=False,
e0_only=True)
e0_only=True,
keep_mol=True)
spc1.generate_conformers()
spc1.rotors_dict = None
spc2 = ARCSpecies(label=label2,
Expand All @@ -1943,7 +1970,8 @@ def _scissors(self, indices: tuple) -> list:
multiplicity=mol2.multiplicity,
charge=mol2.get_net_charge(),
compute_thermo=False,
e0_only=True)
e0_only=True,
keep_mol=True)
spc2.generate_conformers()
spc2.rotors_dict = None

Expand Down
33 changes: 33 additions & 0 deletions arc/species/speciesTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,13 @@ def test_from_dict(self):
spc = ARCSpecies(species_dict=species_dict)
self.assertTrue(spc.is_ts)

def test_label_atoms(self):
"""Test the label_atoms method"""
spc_copy = self.spc6.copy()
spc_copy.label_atoms()
for index, atom in enumerate(spc_copy.mol.atoms):
self.assertEqual(str(index), atom.label)

def test_copy(self):
"""Test the copy() method."""
spc_copy = self.spc6.copy()
Expand Down Expand Up @@ -2028,6 +2035,32 @@ def test_scissors(self):
self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2+]C[CH2+]").mol))
self.assertEqual(len(cycle_scissors), 1)

benzyl_alcohol = ARCSpecies(label='benzyl_alcohol', smiles='c1ccccc1CO',
xyz="""O 2.64838903 0.03033680 1.02963866
C 2.08223673 -0.09327854 -0.26813441
C 0.58011672 -0.03951284 -0.19914397
C -0.09047623 1.18918897 -0.26985124
C -0.16442536 -1.21163631 -0.00891767
C -1.48186739 1.24136671 -0.17379396
C -2.21381021 0.06846364 0.00253129
C -1.55574847 -1.15724814 0.08689917
H 2.47222737 0.71379644 -0.89724902
H 2.41824638 -1.03876722 -0.70676479
H 0.46950724 2.11319745 -0.39919154
H -1.99496868 2.19776599 -0.23432175
H -3.29735459 0.10998171 0.07745660
H -2.12646340 -2.07132623 0.22966400
H 0.33744916 -2.17418164 0.06678265
H 1.91694170 0.12185320 1.66439598""")
benzyl_alcohol.bdes = [(7, 13)]
benzyl_alcohol.final_xyz = benzyl_alcohol.get_xyz()
species = benzyl_alcohol.scissors(sort_atom_labels=True)
for spc in species:
if spc.label != 'H':
for i, atom in enumerate(spc.mol.atoms):
if atom.radical_electrons:
self.assertEqual(i, 6)

def test_net_charged_species(self):
"""Test that we can define, process, and manipulate ions"""
nh4 = ARCSpecies(label='NH4', smiles='[NH4+]', charge=1)
Expand Down

0 comments on commit 58b5698

Please sign in to comment.