Skip to content

Commit

Permalink
Delete pad (#393)
Browse files Browse the repository at this point in the history
* Delete pad from utils

* Delete pad_species

Co-authored-by: Farhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
  • Loading branch information
3 people committed Feb 4, 2020
1 parent 168b059 commit bf8c39e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 51 deletions.
23 changes: 0 additions & 23 deletions tests/test_padding.py
Expand Up @@ -80,29 +80,6 @@ def testTensorSpecies(self):
self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)

def testPadSpecies(self):
species1 = torch.tensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
species2 = torch.tensor([[3, 2, 0, 1, 0]]).expand(2, 5)
species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.tensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)

def testPresentSpecies(self):
species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species)
Expand Down
29 changes: 1 addition & 28 deletions torchani/utils.py
Expand Up @@ -8,33 +8,6 @@
from .nn import SpeciesEnergies


def pad(species):
"""Put different species together into single tensor.
If the species are from molecules of different number of total atoms, then
ghost atoms with atom type -1 will be added to make it fit into the same
shape.
Arguments:
species (:class:`collections.abc.Sequence`): sequence of species.
Species must be of shape ``(N, A)``, where ``N`` is the number of
3D structures, ``A`` is the number of atoms.
Returns:
:class:`torch.Tensor`: species batched together.
"""
max_atoms = max([s.shape[1] for s in species])
padded_species = []
for s in species:
natoms = s.shape[1]
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padded_species.append(s)
return torch.cat(padded_species)


def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
"""Put a sequence of atomic properties together into single tensor.
Expand Down Expand Up @@ -350,6 +323,6 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'):
return VibAnalysis(wavenumbers, modes, fconstants, rmasses)


__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
__all__ = ['pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts']

0 comments on commit bf8c39e

Please sign in to comment.