Skip to content

Commit

Permalink
Use modern type annotations for aev.py (#372)
Browse files Browse the repository at this point in the history
* Use modern type annotations for aev.py

* commit
  • Loading branch information
zasdfgbnm committed Nov 7, 2019
1 parent 7499c8d commit 98bb123
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions torchani/aev.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import torch
from torch import Tensor
import math
from typing import Tuple, Optional


def cutoff_cosine(distances, cutoff):
# type: (torch.Tensor, float) -> torch.Tensor
def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5


def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> Tensor:
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
Expand All @@ -36,8 +35,8 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2)


def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
ShfA: Tensor, vectors1: Tensor, vectors2: Tensor) -> Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
Expand Down Expand Up @@ -72,8 +71,7 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)


def compute_shifts(cell, pbc, cutoff):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
Expand Down Expand Up @@ -115,8 +113,8 @@ def compute_shifts(cell, pbc, cutoff):
])


def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors
Arguments:
Expand Down Expand Up @@ -164,8 +162,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts


def triu_index(num_species):
# type: (int) -> torch.Tensor
def triu_index(num_species: int) -> Tensor:
species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long)
Expand All @@ -174,15 +171,13 @@ def triu_index(num_species):
return ret


def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
def cumsum_from_zero(input_: Tensor) -> Tensor:
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]])
return cumsum


def triple_by_molecule(atom_index1, atom_index2):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
Expand Down Expand Up @@ -228,8 +223,10 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2


def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int, int]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
Expand Down Expand Up @@ -349,8 +346,8 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA

def forward(self, input_, cell=None, pbc=None):
# type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Compute AEVs
Arguments:
Expand Down

0 comments on commit 98bb123

Please sign in to comment.