Skip to content

Commit

Permalink
__constants__ is deprecated by torch.jit (#378)
Browse files Browse the repository at this point in the history
* __constants__ is deprecated

* commit
  • Loading branch information
zasdfgbnm committed Nov 8, 2019
1 parent d32081e commit 81e6150
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import Tensor
import math
from typing import Tuple, Optional
from torch.jit import Final


def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
Expand Down Expand Up @@ -226,9 +227,9 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int, int]) -> Tensor:
sizes: Tuple[int, int, int, int, int]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes
num_molecules = species.shape[0]
num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength
Expand Down Expand Up @@ -300,15 +301,24 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
__constants__ = ['Rcr', 'Rca', 'num_species', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
Rcr: Final[float]
Rca: Final[float]
num_species: Final[int]

radial_sublength: Final[int]
radial_length: Final[int]
angular_sublength: Final[int]
angular_length: Final[int]
aev_length: Final[int]
sizes: Final[Tuple[int, int, int, int, int]]

def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super(AEVComputer, self).__init__()
self.Rcr = Rcr
self.Rca = Rca
assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr"
self.num_species = num_species

# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1))
Expand All @@ -319,7 +329,6 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
self.register_buffer('ShfA', ShfA.view(1, 1, -1, 1))
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))

self.num_species = num_species
# The length of radial subaev of a single species
self.radial_sublength = self.EtaR.numel() * self.ShfR.numel()
# The length of full radial aev
Expand All @@ -330,7 +339,7 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength
# The length of full aev
self.aev_length = self.radial_length + self.angular_length
self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length, self.aev_length
self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length

self.register_buffer('triu_index', triu_index(num_species).to(device=self.EtaR.device))

Expand Down

0 comments on commit 81e6150

Please sign in to comment.