Skip to content

Commit

Permalink
Cleanup some junks inside ANIModel introduced due to JIT (#362)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 6, 2019
1 parent 86500df commit be58b3f
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions torchani/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,11 @@ class ANIModel(torch.nn.Module):
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For
example, if the reducer is :func:`torch.sum`, then
:attr:`padding_fill` should be 0, and if the reducer is
:func:`torch.min`, then :attr:`padding_fill` should be
:obj:`math.inf`.
"""

def __init__(self, modules, padding_fill=0):
def __init__(self, modules):
super(ANIModel, self).__init__()
self.module_list = torch.nn.ModuleList(modules)
self.padding_fill = padding_fill

def __getitem__(self, i):
return self.module_list[i]
Expand All @@ -39,8 +31,7 @@ def forward(self, species_aev):
species_ = species.flatten()
aev = aev.flatten(0, 1)

output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype, device=species.device)
output = aev.new_zeros(species_.shape)

for i, m in enumerate(self.module_list):
mask = (species_ == i)
Expand Down

0 comments on commit be58b3f

Please sign in to comment.