Skip to content

Commit

Permalink
[JIT] Add TorchScript Compatibility for ANIModel and Ensemble (#307)
Browse files Browse the repository at this point in the history
* add test for scripted ensemble

* use torchani.nn.Sequential

* change OrderedDict to module list

* fix

* fix nn.py

* try more fix

* try

* more

* more

* fix more

* rename

* bring ensemble back

* make ANIModel iterable
  • Loading branch information
farhadrgh committed Oct 9, 2019
1 parent b59551d commit a1adceb
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 37 deletions.
17 changes: 17 additions & 0 deletions tests/test_energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,29 @@ def testPadding(self):


class TestEnergiesEnergyShifterJIT(TestEnergies):

def setUp(self):
super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)


class TestEnergiesANIModelJIT(TestEnergies):

def setUp(self):
super().setUp()
self.nnp = torch.jit.script(self.nnp)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)


class TestEnergiesJIT(TestEnergies):

def setUp(self):
super().setUp()
self.model = torch.jit.script(self.model)


if __name__ == '__main__':
unittest.main()
21 changes: 14 additions & 7 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ class TestEnsemble(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.conformations = 20
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.model_iterator = ani1x.neural_networks
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)

def _test_molecule(self, coordinates, species):
ani1x = torchani.models.ANI1x()
model_list = [torchani.nn.Sequential(self.aev_computer, m) for m in self.model_iterator]
coordinates.requires_grad_(True)
aev = ani1x.aev_computer
model_iterator = ani1x.neural_networks
model_list = [torch.nn.Sequential(aev, m) for m in model_iterator]
ensemble = torch.nn.Sequential(aev, model_iterator)

_, energy1 = ensemble((species, coordinates))
_, energy1 = self.ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates))[1] for m in model_list]
energy2 = sum(energy2) / len(model_list)
Expand All @@ -42,5 +41,13 @@ def testGDB(self):
self._test_molecule(coordinates, species)


class TestEnsembleJIT(TestEnsemble):

def setUp(self):
super().setUp()
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)
self.ensemble = torch.jit.script(self.ensemble)


if __name__ == '__main__':
unittest.main()
89 changes: 59 additions & 30 deletions torchani/nn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
from . import utils
from typing import Tuple


class ANIModel(torch.nn.ModuleList):
class ANIModel(torch.nn.Module):
"""ANI model that compute properties from species and AEVs.
Different atom types might have different modules, when computing
Expand All @@ -17,9 +16,6 @@ class ANIModel(torch.nn.ModuleList):
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
reducer (:class:`collections.abc.Callable`): The callable that reduce
atomic outputs into molecular outputs. It must have signature
``(tensor, dim)->tensor``.
padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For
Expand All @@ -29,55 +25,88 @@ class ANIModel(torch.nn.ModuleList):
:obj:`math.inf`.
"""

def __init__(self, modules, reducer=torch.sum, padding_fill=0):
super(ANIModel, self).__init__(modules)
self.reducer = reducer
def __init__(self, modules, padding_fill=0):
super(ANIModel, self).__init__()
self.module_list = torch.nn.ModuleList(modules)
self.padding_fill = padding_fill

def __getitem__(self, i):
return self.module_list[i]

def forward(self, species_aev):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
species, aev = species_aev
species_ = species.flatten()
present_species = utils.present_species(species)
aev = aev.flatten(0, 1)

output = torch.full_like(species_, self.padding_fill,
dtype=aev.dtype)
for i in present_species:
output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype)
i = 0
for m in self.module_list:
mask = (species_ == i)
input_ = aev.index_select(0, mask.nonzero().squeeze())
output.masked_scatter_(mask, self[i](input_).squeeze())
i += 1
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return species, self.reducer(output, dim=1)
return species, torch.sum(output, dim=1)


class Ensemble(torch.nn.ModuleList):
class Ensemble(torch.nn.Module):
"""Compute the average output of an ensemble of modules."""

# FIXME: due to PyTorch bug, we have to hard code the
# ensemble size to 8.

# def __init__(self, modules):
# super(Ensemble, self).__init__()
# self.modules_list = torch.nn.ModuleList(modules)

# def forward(self, species_input):
# # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
# outputs = [x(species_input)[1] for x in self.modules_list]
# species, _ = species_input
# return species, sum(outputs) / len(outputs)

def __init__(self, modules):
super(Ensemble, self).__init__()
assert len(modules) == 8
self.model0 = modules[0]
self.model1 = modules[1]
self.model2 = modules[2]
self.model3 = modules[3]
self.model4 = modules[4]
self.model5 = modules[5]
self.model6 = modules[6]
self.model7 = modules[7]

def __getitem__(self, i):
return [self.model0, self.model1, self.model2, self.model3,
self.model4, self.model5, self.model6, self.model7][i]

def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
outputs = [x(species_input)[1] for x in self]
species, _ = species_input
return species, sum(outputs) / len(outputs)
sum_ = self.model0(species_input)[1] + self.model1(species_input)[1] \
+ self.model2(species_input)[1] + self.model3(species_input)[1] \
+ self.model4(species_input)[1] + self.model5(species_input)[1] \
+ self.model6(species_input)[1] + self.model7(species_input)[1]
return species, sum_ / 8.0


class Sequential(torch.nn.Module):
"""Modified Sequential module that accept Tuple type as input"""

def __init__(self, *args):
def __init__(self, *modules):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)

def forward(self, input):
self.modules_list = torch.nn.ModuleList(modules)

def forward(self, input_):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
for module in self._modules.values():
input = module(input)
return input
for module in self.modules_list:
input_ = module(input_)
return input_


class Gaussian(torch.nn.Module):
Expand Down

0 comments on commit a1adceb

Please sign in to comment.