Skip to content

Commit

Permalink
Separate out neighborlist computer (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 26, 2018
1 parent 3913717 commit 7c25379
Showing 1 changed file with 59 additions and 44 deletions.
103 changes: 59 additions & 44 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff):
)


def default_neighborlist(species, coordinates, cutoff):
"""Default neighborlist computer"""

vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""

distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""

padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)

distances, indices = distances.sort(-1)

min_distances, _ = distances.flatten(end_dim=1).min(0)
in_cutoff = (min_distances <= cutoff).nonzero().flatten()[1:]
indices = indices.index_select(-1, in_cutoff)

# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
neighbor_species = species_.gather(-1, indices)

neighbor_distances = distances.index_select(-1, in_cutoff)

# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_ = indices.unsqueeze(-1).expand(-1, -1, -1, 3)
neighbor_coordinates = vec.gather(-2, indices_)
return neighbor_species, neighbor_distances, neighbor_coordinates


class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs.
Arguments:
Rcr (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (3) in the `ANI paper`_.
Rca (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (4) in the `ANI paper`_.
Rcr (float): :math:`R_C` in equation (2) when used at equation (3)
in the `ANI paper`_.
Rca (float): :math:`R_C` in equation (2) when used at equation (4)
in the `ANI paper`_.
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (3) in the `ANI paper`_.
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
Expand All @@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""

def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species):
num_species, neighborlist_computer=default_neighborlist):
super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr)
self.register_buffer('Rca', Rca)
self.Rcr = Rcr
self.Rca = Rca
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1))
Expand All @@ -54,6 +96,7 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))

self.num_species = num_species
self.neighborlist = neighborlist_computer

def radial_sublength(self):
"""Returns the length of radial subaev of a single species"""
Expand Down Expand Up @@ -147,33 +190,11 @@ def _terms_and_indices(self, species, coordinates):
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""

vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""

distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""

padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)

distances, indices = distances.sort(-1)

min_distances, _ = distances.flatten(end_dim=1).min(0)
inRcr = (min_distances <= self.Rcr).nonzero().flatten()[1:]
inRca = (min_distances <= self.Rca).nonzero().flatten()[1:]

distances = distances.index_select(-1, inRcr)
indices_r = indices.index_select(-1, inRcr)
max_cutoff = max([self.Rcr, self.Rca])
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
radial_terms = self._radial_subaev_terms(distances)

indices_a = indices.index_select(-1, inRca)

# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(-1, -1, -1, 3)
vec = vec.gather(-2, _indices_a)

vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec)

Expand All @@ -182,7 +203,7 @@ def _terms_and_indices(self, species, coordinates):
# (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, neighbors)
# (conformations, atoms, pairs)
return radial_terms, angular_terms, indices_r, indices_a
return radial_terms, angular_terms, species_

def _combinations(self, tensor, dim=0):
# TODO: remove this when combinations is merged into PyTorch
Expand All @@ -199,16 +220,14 @@ def _combinations(self, tensor, dim=0):
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)

def _compute_mask_r(self, species, indices_r):
def _compute_mask_r(self, species_r):
"""Get mask of radial terms for each supported species from indices"""
species_r = species.gather(-1, indices_r)
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(self.num_species, device=self.EtaR.device))
return mask_r

def _compute_mask_a(self, species, indices_a, present_species):
def _compute_mask_a(self, species_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self._combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
Expand Down Expand Up @@ -283,14 +302,10 @@ def forward(self, species_coordinates):

present_species = utils.present_species(species)

# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)

radial_terms, angular_terms, indices_r, indices_a = \
radial_terms, angular_terms, species_ = \
self._terms_and_indices(species, coordinates)
mask_r = self._compute_mask_r(species_, indices_r)
mask_a = self._compute_mask_a(species_, indices_a, present_species)
mask_r = self._compute_mask_r(species_)
mask_a = self._compute_mask_a(species_, present_species)

radial, angular = self._assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
Expand Down

0 comments on commit 7c25379

Please sign in to comment.