Skip to content

Commit

Permalink
Support periodic table indexing in builtin models (#399)
Browse files Browse the repository at this point in the history
* Support  periodic table Indexing in builtin models

* flake8

* more

* fix

* fix cuda
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 21, 2019
1 parent 1055f1f commit 493731b
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7]
test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]

Expand Down
10 changes: 7 additions & 3 deletions examples/energy_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

###############################################################################
# Let's now manually specify the device we want TorchANI to run:
device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization. Predicting the energy and force
# using the average of the 8 models outperform using a single model, so it is
# always recommended to use an ensemble, unless the speed of computation is an
# issue in your application.
model = torchani.models.ANI1ccx()
#
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species.
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device)

###############################################################################
# Now let's define the coordinate and species. If you just want to compute the
Expand All @@ -40,7 +43,8 @@
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True, device=device)
species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]], device=device)

###############################################################################
# Now let's compute energy and force:
Expand Down
5 changes: 3 additions & 2 deletions examples/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization.
model = torchani.models.ANI1ccx()
model = torchani.models.ANI1ccx(periodic_table_index=True)

###############################################################################
# It is very easy to compile and save the model using `torch.jit`.
Expand All @@ -42,7 +42,8 @@
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]])
species = model.species_to_tensor('CHHHH').unsqueeze(0)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])

###############################################################################
# And here is the result:
Expand Down
33 changes: 0 additions & 33 deletions tests/test_nn.py

This file was deleted.

64 changes: 64 additions & 0 deletions tests/test_periodic_table_indexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
import torch
import torchani


class TestSpeciesConverter(unittest.TestCase):

def setUp(self):
self.c = torchani.SpeciesConverter(['H', 'C', 'N', 'O'])

def testSpeciesConverter(self):
input_ = torch.tensor([
[1, 6, 7, 8, -1],
[1, 1, -1, 8, 1],
], dtype=torch.long)
expect = torch.tensor([
[0, 1, 2, 3, -1],
[0, 0, -1, 3, 0],
], dtype=torch.long)
dummy_coordinates = torch.empty(2, 5, 3)
output = self.c((input_, dummy_coordinates)).species
self.assertTrue(torch.allclose(output, expect))


class TestSpeciesConverterJIT(TestSpeciesConverter):

def setUp(self):
super().setUp()
self.c = torch.jit.script(self.c)


class TestBuiltinNetPeriodicTableIndex(unittest.TestCase):

def setUp(self):
self.model1 = torchani.models.ANI1x()
self.model2 = torchani.models.ANI1x(periodic_table_index=True)
self.coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
self.species1 = self.model1.species_to_tensor('CHHHH').unsqueeze(0)
self.species2 = torch.tensor([[6, 1, 1, 1, 1]])

def testCH4Ensemble(self):
energy1 = self.model1((self.species1, self.coordinates)).energies
energy2 = self.model2((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))

def testCH4Single(self):
energy1 = self.model1[0]((self.species1, self.coordinates)).energies
energy2 = self.model2[0]((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))


if __name__ == '__main__':
unittest.main()
38 changes: 27 additions & 11 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import Tuple, Optional
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential
from .nn import Sequential, SpeciesConverter
from .aev import AEVComputer


Expand Down Expand Up @@ -61,10 +61,15 @@ class BuiltinNet(torch.nn.Module):
aev_computer (:class:`torchani.AEVComputer`): AEV computer with
builtin constants
neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks
periodic_table_index (bool): Whether to use element number in periodic table
to index species. If set to `False`, then indices must be `0, 1, 2, ..., N - 1`
where `N` is the number of parametrized species.
"""

def __init__(self, info_file):
def __init__(self, info_file, periodic_table_index=False):
super(BuiltinNet, self).__init__()
self.periodic_table_index = periodic_table_index

package_name = '.'.join(__name__.split('.')[:-1])
info_file = 'resources/' + info_file
self.info_file = resource_filename(package_name, info_file)
Expand All @@ -84,6 +89,7 @@ def __init__(self, info_file):

self.consts = neurochem.Constants(self.const_file)
self.species = self.consts.species
self.species_converter = SpeciesConverter(self.species)
self.aev_computer = AEVComputer(**self.consts)
self.energy_shifter = neurochem.load_sae(self.sae_file)
self.neural_networks = neurochem.load_model_ensemble(
Expand All @@ -105,6 +111,8 @@ def forward(self, species_coordinates: Tuple[Tensor, Tensor],
.. note:: The coordinates, and cell are in Angstrom, and the energies
will be in Hartree.
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)
Expand All @@ -124,11 +132,19 @@ def __getitem__(self, index):
ret: (:class:`Sequential`): Sequential model ready for
calculations
"""
ret = Sequential(
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
)
if self.periodic_table_index:
ret = Sequential(
self.species_converter,
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
)
else:
ret = Sequential(
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
)

def ase(**kwargs):
"""Attach an ase calculator """
Expand Down Expand Up @@ -189,8 +205,8 @@ class ANI1x(BuiltinNet):
https://aip.scitation.org/doi/abs/10.1063/1.5023802
"""

def __init__(self):
super().__init__('ani-1x_8x.info')
def __init__(self, *args, **kwargs):
super().__init__('ani-1x_8x.info', *args, **kwargs)


class ANI1ccx(BuiltinNet):
Expand All @@ -209,5 +225,5 @@ class ANI1ccx(BuiltinNet):
https://doi.org/10.26434/chemrxiv.6744440.v1
"""

def __init__(self):
super().__init__('ani-1ccx_8x.info')
def __init__(self, *args, **kwargs):
super().__init__('ani-1ccx_8x.info', *args, **kwargs)
8 changes: 1 addition & 7 deletions torchani/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(self, modules):
def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
species, aev = species_aev
species_ = species.flatten()
aev = aev.flatten(0, 1)
Expand All @@ -75,8 +73,6 @@ def __init__(self, modules):
def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
sum_ = 0
for x in self:
sum_ += x(species_input)[1]
Expand All @@ -95,8 +91,6 @@ def forward(self, input_: Tuple[Tensor, Tensor],
pbc: Optional[Tensor] = None):
for module in self:
input_ = module(input_, cell=cell, pbc=pbc)
cell = None
pbc = None
return input_


Expand All @@ -123,7 +117,7 @@ def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long)
self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i

Expand Down
2 changes: 0 additions & 2 deletions torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ def forward(self, species_energies: Tuple[Tensor, Tensor],
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
assert cell is None
assert pbc is None
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
Expand Down

0 comments on commit 493731b

Please sign in to comment.