Skip to content

Commit

Permalink
Add helper module to convert species from element id in periodic tabl…
Browse files Browse the repository at this point in the history
…e to 0, 1, 2, 3, ... format (#396)

* Init

* fix test

* flake8

* try fix

* Fix stupidity of len(self)
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 20, 2019
1 parent eb89457 commit 66c3743
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7]
test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]

Expand Down
33 changes: 33 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
import torch
import torchani


class TestSpeciesConverter(unittest.TestCase):

def setUp(self):
self.c = torchani.SpeciesConverter(['H', 'C', 'N', 'O'])

def testSpeciesConverter(self):
input_ = torch.tensor([
[1, 6, 7, 8, -1],
[1, 1, -1, 8, 1],
], dtype=torch.long)
expect = torch.tensor([
[0, 1, 2, 3, -1],
[0, 0, -1, 3, 0],
], dtype=torch.long)
dummy_coordinates = torch.empty(2, 5, 3)
output = self.c((input_, dummy_coordinates)).species
self.assertTrue(torch.allclose(output, expect))


class TestSpeciesConverterJIT(TestSpeciesConverter):

def setUp(self):
super().setUp()
self.c = torch.jit.script(self.c)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tools/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def by_batch(species, coordinates, model):
energies = []
forces = []
for s, c in zip(species, coordinates):
_, e = model((s, c))
e = model((s, c)).energies
f, = torch.autograd.grad(e.sum(), c)
energies.append(e)
forces.append(f)
Expand Down
17 changes: 7 additions & 10 deletions torchani/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,21 @@
"""

from .utils import EnergyShifter
from .nn import ANIModel, Ensemble
from .nn import ANIModel, Ensemble, SpeciesConverter
from .aev import AEVComputer
from . import utils
from . import neurochem
from . import models
from . import optim
from pkg_resources import get_distribution, DistributionNotFound
import sys

try:
__version__ = get_distribution(__name__).version
except DistributionNotFound:
# package is not installed
pass

__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter',
'utils', 'neurochem', 'models', 'optim']

try:
Expand All @@ -48,10 +47,8 @@
except ImportError:
pass


if sys.version_info[0] > 2:
try:
from . import data # noqa: F401
__all__.append('data')
except ImportError:
pass
try:
from . import data # noqa: F401
__all__.append('data')
except ImportError:
pass
4 changes: 2 additions & 2 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class ANI1x(BuiltinNet):
"""

def __init__(self):
super(ANI1x, self).__init__('ani-1x_8x.info')
super().__init__('ani-1x_8x.info')


class ANI1ccx(BuiltinNet):
Expand All @@ -210,4 +210,4 @@ class ANI1ccx(BuiltinNet):
"""

def __init__(self):
super(ANI1ccx, self).__init__('ani-1ccx_8x.info')
super().__init__('ani-1ccx_8x.info')
38 changes: 34 additions & 4 deletions torchani/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class SpeciesEnergies(NamedTuple):
energies: Tensor


class SpeciesCoordinates(NamedTuple):
species: Tensor
coordinates: Tensor


class ANIModel(torch.nn.ModuleList):
"""ANI model that compute energies from species and AEVs.
Expand All @@ -26,9 +31,6 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`.
"""

def __init__(self, modules):
super(ANIModel, self).__init__(modules)

def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
Expand All @@ -54,7 +56,7 @@ class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules."""

def __init__(self, modules):
super(Ensemble, self).__init__(modules)
super().__init__(modules)
self.size = len(modules)

def forward(self, species_input: Tuple[Tensor, Tensor],
Expand Down Expand Up @@ -89,3 +91,31 @@ class Gaussian(torch.nn.Module):
"""Gaussian activation"""
def forward(self, x: Tensor) -> Tensor:
return torch.exp(- x * x)


class SpeciesConverter(torch.nn.Module):
"""Convert from element index in the periodic table to 0, 1, 2, 3, ..."""

periodic_table = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()

def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long)
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i

def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
species, coordinates = input_
return SpeciesCoordinates(self.conv_tensor[species], coordinates)

0 comments on commit 66c3743

Please sign in to comment.