Skip to content

Commit

Permalink
Add element names to ANIModel (#398)
Browse files Browse the repository at this point in the history
* Add element names to ANIModel

* nc trainer
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 20, 2019
1 parent 66c3743 commit 1055f1f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
5 changes: 4 additions & 1 deletion tests/test_neurochem.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def testNeuroChemTrainer(self):
# test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.nn
H, C, N, O = m # noqa: E741
H = m['H']
C = m['C']
N = m['N']
O = m['O'] # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear)
self.assertListEqual(list(H[0].weight.shape), [160, 384])
self.assertIsInstance(H[1], torch.nn.CELU)
Expand Down
11 changes: 6 additions & 5 deletions torchani/neurochem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..optim import AdamW
from collections import OrderedDict


class Constants(collections.abc.Mapping):
Expand Down Expand Up @@ -240,10 +241,10 @@ def load_model(species, dir_):
chemical symbols of each supported atom type in correct order.
dir_ (str): String for directory storing network configurations.
"""
models = []
models = OrderedDict()
for i in species:
filename = os.path.join(dir_, 'ANN-{}.nnf'.format(i))
models.append(load_atomic_network(filename))
models[i] = load_atomic_network(filename)
return ANIModel(models)


Expand Down Expand Up @@ -496,8 +497,8 @@ def init_params(m):
input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length:
raise ValueError('AEV size and input size does not match')
atomic_nets = {}
for atom_type in network_setup:
atomic_nets = OrderedDict()
for atom_type in self.consts.species:
layers = network_setup[atom_type]
modules = []
i = input_size
Expand Down Expand Up @@ -537,7 +538,7 @@ def init_params(m):
'unrecognized parameter in layer setup')
i = o
atomic_nets[atom_type] = torch.nn.Sequential(*modules)
self.nn = ANIModel([atomic_nets[s] for s in self.consts.species])
self.nn = ANIModel(atomic_nets)

# initialize weights and biases
self.nn.apply(init_params)
Expand Down
17 changes: 15 additions & 2 deletions torchani/nn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from collections import OrderedDict
from torch import Tensor
from typing import Tuple, NamedTuple, Optional

Expand All @@ -13,7 +14,7 @@ class SpeciesCoordinates(NamedTuple):
coordinates: Tensor


class ANIModel(torch.nn.ModuleList):
class ANIModel(torch.nn.ModuleDict):
"""ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing
Expand All @@ -31,6 +32,18 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`.
"""

@staticmethod
def ensureOrderedDict(modules):
if isinstance(modules, OrderedDict):
return modules
od = OrderedDict()
for i, m in enumerate(modules):
od[str(i)] = m
return od

def __init__(self, modules):
super(ANIModel, self).__init__(self.ensureOrderedDict(modules))

def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
Expand All @@ -42,7 +55,7 @@ def forward(self, species_aev: Tuple[Tensor, Tensor],

output = aev.new_zeros(species_.shape)

for i, m in enumerate(self):
for i, (_, m) in enumerate(self.items()):
mask = (species_ == i)
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
Expand Down

0 comments on commit 1055f1f

Please sign in to comment.