Skip to content

Commit

Permalink
Misc improvements (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 6, 2018
1 parent e4fe2a5 commit 3913717
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def setUp(self):

def _test_molecule(self, coordinates, species):
builtins = torchani.neurochem.Builtins()
coordinates = torch.tensor(coordinates, requires_grad=True)
coordinates.requires_grad_(True)
aev = builtins.aev_computer
ensemble = builtins.models
models = [torch.nn.Sequential(aev, m) for m in ensemble]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def testIsomers(self):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f)
coordinates = torch.tensor(coordinates, requires_grad=True)
coordinates.requires_grad_(True)
_, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(),
coordinates)[0]
Expand All @@ -36,7 +36,7 @@ def testPadding(self):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f)
coordinates = torch.tensor(coordinates, requires_grad=True)
coordinates.requires_grad_(True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_coordinates(
Expand Down
9 changes: 2 additions & 7 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,7 @@ def _terms_and_indices(self, species, coordinates):
"""Shape (conformations, atoms, atoms) storing Rij distances"""

padding_mask = (species == -1).unsqueeze(1)
distances = torch.where(
padding_mask,
torch.tensor(math.inf, dtype=self.EtaR.dtype,
device=self.EtaR.device),
distances)
distances = distances.masked_fill(padding_mask, math.inf)

distances, indices = distances.sort(-1)

Expand All @@ -172,11 +168,10 @@ def _terms_and_indices(self, species, coordinates):
radial_terms = self._radial_subaev_terms(distances)

indices_a = indices.index_select(-1, inRca)
new_shape = list(indices_a.shape) + [3]

# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(*new_shape)
_indices_a = indices_a.unsqueeze(-1).expand(-1, -1, -1, 3)
vec = vec.gather(-2, _indices_a)

vec = self._combinations(vec, -2)
Expand Down

0 comments on commit 3913717

Please sign in to comment.