Skip to content

Commit

Permalink
ASE calculator (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Oct 27, 2018
1 parent 84fc8d8 commit f9db30c
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 12 deletions.
4 changes: 4 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:


NeuroChem
Expand All @@ -51,6 +53,8 @@ ASE Interface
.. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator
:members:

Ignite Helpers
==============
Expand Down
6 changes: 5 additions & 1 deletion torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
Expand Down
47 changes: 47 additions & 0 deletions torchani/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import torch
import ase.neighborlist
from . import utils
import ase.calculators.calculator
import ase.units


class NeighborList:
Expand Down Expand Up @@ -80,3 +82,48 @@ def __call__(self, species, coordinates, cutoff):
return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3)


class Calculator(ase.calculators.calculator.Calculator):
"""TorchANI calculator for ASE
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer.
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
"""

def __init__(self, species, aev_computer, model, energy_shifter):
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.aev_computer = aev_computer
self.model = model
self.energy_shifter = energy_shifter

self.device = self.aev_computer.EtaR.device
self.dtype = self.aev_computer.EtaR.dtype

self.whole = torch.nn.Sequential(
self.aev_computer,
self.model,
self.energy_shifter
)

def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
self.aev_computer.neighbor_list = NeighborList(
cell=self.atoms.get_cell(), pbc=self.atoms.get_pbc())
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
coordinates = self.atoms.get_positions(wrap=True).unsqueeze(0)
coordinates = torch.tensor(coordinates,
device=self.device,
dtype=self.dtype,
requires_grad=('forces' in properties))
_, energy = self.whole((species, coordinates)) * ase.units.Hartree
self.results['energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.item()
16 changes: 6 additions & 10 deletions torchani/neurochem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
import timeit
from collections.abc import Mapping
from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric


class Constants(Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
Attributes:
species_to_tensor (:class:`ChemicalSymbolsToInts`): call to convert
string chemical symbols to 1d long tensor.
"""

def __init__(self, filename):
Expand All @@ -45,10 +49,7 @@ def __init__(self, filename):
except Exception:
raise ValueError('unable to parse const file')
self.num_species = len(self.species)
self.rev_species = {}
for i in range(len(self.species)):
s = self.species[i]
self.rev_species[s] = i
self.species_to_tensor = ChemicalSymbolsToInts(self.species)

def __iter__(self):
yield 'Rcr'
Expand All @@ -67,11 +68,6 @@ def __len__(self):
def __getitem__(self, item):
return getattr(self, item)

def species_to_tensor(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)


def load_sae(filename):
"""Returns an object of :class:`EnergyShifter` with self energies from
Expand Down
22 changes: 21 additions & 1 deletion torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,25 @@ def forward(self, species_energies):
return species, energies + sae


class ChemicalSymbolsToInts:
"""Helper that can be called to convert chemical symbol string to integers
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
"""

def __init__(self, all_species):
self.rev_species = {}
for i in range(len(all_species)):
s = all_species[i]
self.rev_species[s] = i

def __call__(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)


__all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding']
'strip_redundant_padding', 'ChemicalSymbolsToInts']

0 comments on commit f9db30c

Please sign in to comment.