diff --git a/CodeEntropy/levels.py b/CodeEntropy/levels.py index d5169c2..62523e6 100644 --- a/CodeEntropy/levels.py +++ b/CodeEntropy/levels.py @@ -62,7 +62,7 @@ def select_levels(self, data_container): "united_atom" ) # every molecule has at least one atom - atoms_in_fragment = fragments[molecule].select_atoms("not name H*") + atoms_in_fragment = fragments[molecule].select_atoms("prop mass > 1.1") number_residues = len(atoms_in_fragment.residues) if len(atoms_in_fragment) > 1: @@ -337,19 +337,23 @@ def get_beads(self, data_container, level): atom_group = "resindex " + str(residue) list_of_beads.append(data_container.select_atoms(atom_group)) - # NOTE this could cause problems for hydrogen or helium molecules if level == "united_atom": list_of_beads = [] - heavy_atoms = data_container.select_atoms("not name H*") - for atom in heavy_atoms: - atom_group = ( - "index " - + str(atom.index) - + " or (name H* and bonded index " - + str(atom.index) - + ")" - ) - list_of_beads.append(data_container.select_atoms(atom_group)) + heavy_atoms = data_container.select_atoms("prop mass > 1.1") + if len(heavy_atoms) == 0: + # molecule without heavy atoms would be a hydrogen molecule + list_of_beads.append(data_container.select_atoms("all")) + else: + # Select one heavy atom and all light atoms bonded to it + for atom in heavy_atoms: + atom_group = ( + "index " + + str(atom.index) + + " or ((prop mass <= 1.1) and bonded index " + + str(atom.index) + + ")" + ) + list_of_beads.append(data_container.select_atoms(atom_group)) logger.debug(f"List of beads: {list_of_beads}") @@ -421,7 +425,7 @@ def get_axes(self, data_container, level, index=0): # Rotation # for united atoms use heavy atoms bonded to the heavy atom atom_set = data_container.select_atoms( - f"not name H* and bonded index {index}" + f"(prop mass > 1.1) and bonded index {index}" ) if len(atom_set) == 0: @@ -432,7 +436,7 @@ def get_axes(self, data_container, level, index=0): atom_group = data_container.select_atoms(f"index {index}") center = atom_group.positions[0] - # get vector for average position of hydrogens + # get vector for average position of bonded atoms vector = self.get_avg_pos(atom_set, center) # use spherical coordinates function to get rotational axes @@ -1125,7 +1129,7 @@ def build_conformational_states( ) heavy_res = ( entropy_manager._run_manager.new_U_select_atom( - res_container, "not name H*" + res_container, "prop mass > 1.1" ) ) states = self.compute_dihedral_conformations( diff --git a/tests/test_CodeEntropy/test_levels.py b/tests/test_CodeEntropy/test_levels.py index 1079f9a..03e739c 100644 --- a/tests/test_CodeEntropy/test_levels.py +++ b/tests/test_CodeEntropy/test_levels.py @@ -421,6 +421,28 @@ def test_get_beads_united_atom_level(self): data_container.select_atoms.call_count, 4 ) # 1 for heavy_atoms + 3 beads + def test_get_beads_hydrogen_molecule(self): + """ + Test `get_beads` for 'united_atom' level. + Should return one bead for molecule with no heavy atoms. + """ + level_manager = LevelManager() + + data_container = MagicMock() + heavy_atoms = [] + data_container.select_atoms.side_effect = [ + heavy_atoms, + "hydrogen", + ] + + result = level_manager.get_beads(data_container, level="united_atom") + + self.assertEqual(len(result), 1) + self.assertEqual(result, ["hydrogen"]) + self.assertEqual( + data_container.select_atoms.call_count, 2 + ) # 1 for heavy_atoms + 1 beads + def test_get_axes_united_atom_no_bonds(self): """ Test `get_axes` for 'united_atom' level when no bonded atoms are found.