Skip to content

Commit

Permalink
Add example to customize JIT models (#401)
Browse files Browse the repository at this point in the history
* add example to customize JIT models

* cleanup

* fix list

* Update jit.py
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 23, 2019
1 parent 9833dd6 commit 124f239
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
59 changes: 59 additions & 0 deletions examples/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# To begin with, let's first import the modules we will use:
import torch
import torchani
from typing import Tuple, Optional
from torch import Tensor

###############################################################################
# Scripting builtin model directly
# --------------------------------
#
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization.
model = torchani.models.ANI1ccx(periodic_table_index=True)
Expand Down Expand Up @@ -53,3 +58,57 @@
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())


###############################################################################
# Customize the model and script
# ------------------------------
#
# You could also customize the model you want to export. For example, let's do
# the following customization to the model:
#
# - uses double as dtype instead of float
# - don't care about periodic boundary condition
# - in addition to energies, allow returnsing optionally forces, and hessians
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
#
# you could do the following:
class CustomModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.model = torchani.models.ANI1x(periodic_table_index=True).double()
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()

def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if return_forces or return_hessians:
coordinates.requires_grad_(True)

energies = self.model((species, coordinates)).energies

forces: Optional[Tensor] = None # noqa: E701
hessians: Optional[Tensor] = None
if return_forces or return_hessians:
grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
assert grad is not None
forces = -grad
if return_hessians:
hessians = torchani.utils.hessian(coordinates, forces=forces)
return energies, forces, hessians


custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)

print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
4 changes: 2 additions & 2 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import Tuple, Optional
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential, SpeciesConverter
from .nn import Sequential, SpeciesConverter, SpeciesEnergies
from .aev import AEVComputer


Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, info_file, periodic_table_index=False):

def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""Calculates predicted properties for minibatch of configurations
Args:
Expand Down

0 comments on commit 124f239

Please sign in to comment.