Skip to content

Commit

Permalink
[JIT] Add TorchScript compatibility for AEVComputer (#303)
Browse files Browse the repository at this point in the history
* make aev,model compatible with jit

* add type annotation to nn

* flake8 fix

* refactor AEVComputer

* fix doc

* an example with padding

* use Optional type instead of padding

* fix

* fix

* make pbc and cell keyword arguments in test_aev

* fix

* make pbc and cell keyword arguments in ase

* fix

* fix

* fix dtype

* fix

* aev_computer dtype to double

* change test files to have aev_computer with keyword argument

* fix JIT types

* add TestAEVJIT

* fix LGTM alerts

* fix TestAEVJIT

* Update aev.py

workaround for dtype in `torch.arange`

* More arange bugs

* Even more arange

* fix LGTM alert
  • Loading branch information
farhadrgh committed Sep 6, 2019
1 parent 3957d19 commit f2170e2
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 59 deletions.
20 changes: 13 additions & 7 deletions tests/test_aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def testBenzeneMD(self):
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
self.assertAEVEqual(expected_radial, expected_angular, aev, 5e-5)

def testTripeptideMD(self):
Expand Down Expand Up @@ -245,6 +245,12 @@ def aev_forward_wrapper(coords):
)


class TestAEVJIT(TestAEV):
def setUp(self):
super().setUp()
self.aev_computer = torch.jit.script(self.aev_computer)


class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self):
self.ani1x = torchani.models.ANI1x()
Expand All @@ -262,11 +268,11 @@ def testTranslationalInvariancePBC(self):
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.bool)

_, aev = self.aev_computer((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)

for _ in range(100):
translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc))
_, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
self.assertTrue(torch.allclose(aev, aev2))

def testPBCConnersSeeEachOther(self):
Expand Down Expand Up @@ -363,7 +369,7 @@ def setUp(self):
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc))
_, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)

def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
Expand All @@ -385,7 +391,7 @@ def testCornerSurfaceAndEdge(self):
self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc))
_, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev))

Expand All @@ -402,7 +408,7 @@ def setUp(self):
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
_, self.aev = self.aev_computer((self.species, self.coordinates, self.cell, self.pbc))
_, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
self.natoms = self.aev.shape[1]

def testRepeat(self):
Expand All @@ -416,7 +422,7 @@ def testRepeat(self):
self.coordinates + 3 * c1,
], dim=1)
cell2 = torch.stack([4 * c1, c2, c3])
_, aev2 = self.aev_computer((species2, coordinates2, cell2, self.pbc))
_, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
for i in range(3):
aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))
Expand Down
8 changes: 5 additions & 3 deletions tests/test_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
ani1x = torchani.models.ANI1x()
aev_computer = ani1x.aev_computer
self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0]
shift_energy = ani1x.energy_shifter
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
self.nn = torch.nn.Sequential(nnp, shift_energy)
self.model = torch.nn.Sequential(self.aev_computer, nnp, shift_energy)

def random_skip(self):
return False
Expand Down Expand Up @@ -56,7 +57,8 @@ def testBenzeneMD(self):
coordinates = self.transform(coordinates)
species = self.transform(species)
energies = self.transform(energies)
_, energies_ = self.model((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nn((species, aev))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, tolerance)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def setUp(self):
self.tolerance = 1e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, nnp)
self.nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, self.nnp)

def random_skip(self):
return False
Expand Down Expand Up @@ -82,7 +82,8 @@ def testBenzeneMD(self):
coordinates = self.transform(coordinates)
species = self.transform(species)
forces = self.transform(forces)
_, energies_ = self.model((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nnp((species, aev))
derivative = torch.autograd.grad(energies_.sum(),
coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
Expand Down
80 changes: 39 additions & 41 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
import torch
from . import _six # noqa:F401
import math
from typing import Tuple
from typing import Tuple, Optional


# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5


# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
Expand Down Expand Up @@ -40,7 +38,6 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2)


# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
Expand Down Expand Up @@ -77,8 +74,8 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)


# @torch.jit.script
def compute_shifts(cell, pbc, cutoff):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
Expand All @@ -95,14 +92,13 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device, dtype=torch.long)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device, dtype=torch.long)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device, dtype=torch.long)
o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([
torch.cartesian_prod(r1, r2, r3),
Expand All @@ -121,8 +117,8 @@ def compute_shifts(cell, pbc, cutoff):
])


# @torch.jit.script
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""Compute pairs of atoms that are neighbors
Arguments:
Expand All @@ -135,21 +131,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

coordinates = coordinates.detach()
cell = cell.detach()
num_atoms = padding_mask.shape[1]
all_atoms = torch.arange(num_atoms, device=cell.device)
all_atoms = torch.arange(num_atoms, device=cell.device, dtype=torch.long)

# Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = shifts.new_zeros(p1_center.shape[0], 3)
shifts_center = torch.zeros((p1_center.shape[0], 3), dtype=shifts.dtype, device=shifts.device)

# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device)
all_shifts = torch.arange(num_shifts, device=cell.device, dtype=torch.long)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
shifts_outide = shifts.index_select(0, shift_index)

Expand All @@ -172,19 +166,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts


# torch.jit.script
def triu_index(num_species):
species = torch.arange(num_species)
# type: (int) -> torch.Tensor
species = torch.arange(num_species, dtype=torch.long)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
pair_index = torch.arange(species1.shape[0])
pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index
ret[species2, species1] = pair_index
return ret


# torch.jit.script
def convert_pair_index(index):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
Expand All @@ -208,15 +202,15 @@ def convert_pair_index(index):
return index - num_elems, n + 1


# torch.jit.script
def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_tensor([0]), cumsum[:-1]])
cumsum = torch.cat([torch.tensor([0], dtype=input_.dtype, device=input_.device), cumsum[:-1]])
return cumsum


# torch.jit.script
def triple_by_molecule(atom_index1, atom_index2):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
Expand All @@ -233,16 +227,18 @@ def triple_by_molecule(atom_index1, atom_index2):
sorted_ai1, rev_indices = ai1.sort()

# sort and compute unique key
uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_counts=True)
unique_results = torch.unique_consecutive(sorted_ai1, return_inverse=True, return_counts=True)
uniqued_central_atom_index = unique_results[0]
counts = unique_results[-1]

# do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2
pair_sizes = (counts * (counts - 1) / 2).long()
total_size = pair_sizes.sum()
pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = cumsum.index_select(0, pair_indices)
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device, dtype=torch.long) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = cumsum.index_select(0, pair_indices)
Expand All @@ -259,8 +255,8 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2


# torch.jit.script
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
Expand All @@ -279,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes

# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
radial_aev = torch.zeros((num_molecules * num_atoms * num_species, radial_sublength), dtype=radial_terms_.dtype, device=radial_terms_.device)
index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_)
Expand All @@ -302,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
angular_aev = torch.zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength), dtype=angular_terms_.dtype, device=angular_terms_.device)
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
Expand Down Expand Up @@ -380,23 +376,24 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA

# @torch.jit.script_method
def forward(self, input_):
def forward(self, input_, cell=None, pbc=None):
# type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""Compute AEVs
Arguments:
input_ (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of molecules
in a chunk, and ``A`` is the number of atoms.
then input can be a tuple of two tensors: species, coordinates.
species must have shape ``(C, A)``, coordinates must have shape
``(C, A, 3)`` where ``C`` is the number of molecules in a chunk,
and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc
where species and coordinates are the same as described above, cell
is a tensor of shape (3, 3) of the three vectors defining unit cell:
would be a tuple of two tensors (species, coordinates) and two keyword
arguments `cell=...` , and `pbc=...` where species and coordinates are
the same as described above, cell is a tensor of shape (3, 3) of the
three vectors defining unit cell:
.. code-block:: python
Expand All @@ -412,13 +409,14 @@ def forward(self, input_):
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
if len(input_) == 2:
species, coordinates = input_
species, coordinates = input_

if cell is None and pbc is None:
cell = self.default_cell
shifts = self.default_shifts
else:
assert len(input_) == 4
species, coordinates, cell, pbc = input_
assert (cell is not None and pbc is not None)
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)

return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
12 changes: 7 additions & 5 deletions torchani/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.flo
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
self.aev_computer = copy.deepcopy(aev_computer)
aev_computer = copy.deepcopy(aev_computer)
self.aev_computer = aev_computer.to(dtype)
self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite

self.device = self.aev_computer.EtaR.device
self.dtype = dtype

self.whole = torch.nn.Sequential(
self.aev_computer,
self.nn = torch.nn.Sequential(
self.model,
self.energy_shifter
).to(dtype)
Expand Down Expand Up @@ -93,9 +93,11 @@ def calculate(self, atoms=None, properties=['energy'],
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, energy = self.whole((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
else:
_, energy = self.whole((species, coordinates))
_, aev = self.aev_computer((species, coordinates))

_, energy = self.nn((species, aev))
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
Expand Down

0 comments on commit f2170e2

Please sign in to comment.