diff --git a/tests/test_padding.py b/tests/test_padding.py index 8fdf893db..33955b0f4 100644 --- a/tests/test_padding.py +++ b/tests/test_padding.py @@ -80,29 +80,6 @@ def testTensorSpecies(self): self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0) self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0) - def testPadSpecies(self): - species1 = torch.tensor([ - [0, 2, 3, 1], - [0, 2, 3, 1], - [0, 2, 3, 1], - [0, 2, 3, 1], - [0, 2, 3, 1], - ]) - species2 = torch.tensor([[3, 2, 0, 1, 0]]).expand(2, 5) - species = torchani.utils.pad([species1, species2]) - self.assertEqual(species.shape[0], 7) - self.assertEqual(species.shape[1], 5) - expected_species = torch.tensor([ - [0, 2, 3, 1, -1], - [0, 2, 3, 1, -1], - [0, 2, 3, 1, -1], - [0, 2, 3, 1, -1], - [0, 2, 3, 1, -1], - [3, 2, 0, 1, 0], - [3, 2, 0, 1, 0], - ]) - self.assertEqual((species - expected_species).abs().max().item(), 0) - def testPresentSpecies(self): species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1]) present_species = torchani.utils.present_species(species) diff --git a/torchani/utils.py b/torchani/utils.py index b65cba9fc..c5c9d67b3 100644 --- a/torchani/utils.py +++ b/torchani/utils.py @@ -8,33 +8,6 @@ from .nn import SpeciesEnergies -def pad(species): - """Put different species together into single tensor. - - If the species are from molecules of different number of total atoms, then - ghost atoms with atom type -1 will be added to make it fit into the same - shape. - - Arguments: - species (:class:`collections.abc.Sequence`): sequence of species. - Species must be of shape ``(N, A)``, where ``N`` is the number of - 3D structures, ``A`` is the number of atoms. - - Returns: - :class:`torch.Tensor`: species batched together. - """ - max_atoms = max([s.shape[1] for s in species]) - padded_species = [] - for s in species: - natoms = s.shape[1] - if natoms < max_atoms: - padding = torch.full((s.shape[0], max_atoms - natoms), -1, - dtype=torch.long, device=s.device) - s = torch.cat([s, padding], dim=1) - padded_species.append(s) - return torch.cat(padded_species) - - def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)): """Put a sequence of atomic properties together into single tensor. @@ -350,6 +323,6 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'): return VibAnalysis(wavenumbers, modes, fconstants, rmasses) -__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian', +__all__ = ['pad_atomic_properties', 'present_species', 'hessian', 'vibrational_analysis', 'strip_redundant_padding', 'ChemicalSymbolsToInts']