Skip to content

Commit

Permalink
Remove unnecessary import (#296)
Browse files Browse the repository at this point in the history
* Remove unnecessary import

* fix
  • Loading branch information
farhadrgh authored and zasdfgbnm committed Aug 20, 2019
1 parent 1455cb3 commit f825c99
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
import torch
from . import _six # noqa:F401
import math
from torch import Tensor
from typing import Tuple


# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
# type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5


# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
Expand Down Expand Up @@ -43,7 +42,7 @@ def radial_terms(Rcr, EtaR, ShfR, distances):

# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
Expand Down Expand Up @@ -96,7 +95,7 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (Tensor, Tensor, float) -> Tensor
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
Expand Down Expand Up @@ -136,7 +135,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

coordinates = coordinates.detach()
cell = cell.detach()
Expand Down

0 comments on commit f825c99

Please sign in to comment.