Skip to content

Commit

Permalink
Add example for using TorchScript (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Oct 10, 2019
1 parent 3132928 commit 4bacf40
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Welcome to TorchANI's documentation!

examples/energy_force
examples/ase_interface
examples/jit
examples/vibration_analysis
examples/load_from_neurochem
examples/nnp_training
Expand Down
54 changes: 54 additions & 0 deletions examples/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""
Using TorchScript to serialize and deploy model
===============================================
Models in TorchANI's model zoo support TorchScript. TorchScript is a way to create
serializable and optimizable models from PyTorch code. It allows users to saved their
models from a Python process and loaded in a process where there is no Python dependency.
"""

###############################################################################
# To begin with, let's first import the modules we will use:
import torch
import torchani

###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization.
model = torchani.models.ANI1ccx()

###############################################################################
# It is very easy to compile and save the model using `torch.jit`.
compiled_model = torch.jit.script(model)
torch.jit.save(compiled_model, 'compiled_model.pt')

###############################################################################
# Besides compiling the ensemble, it is also possible to compile a single network
compiled_model0 = torch.jit.script(model[0])
torch.jit.save(compiled_model0, 'compiled_model0.pt')

###############################################################################
# For testing purposes, we will now load the models we just saved and see if they
# produces the same output as the original model:
loaded_compiled_model = torch.jit.load('compiled_model.pt')
loaded_compiled_model0 = torch.jit.load('compiled_model0.pt')


###############################################################################
# We use the molecule below to test:
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]])
species = model.species_to_tensor('CHHHH').unsqueeze(0)

###############################################################################
# And here is the result:
_, energies_ensemble = model((species, coordinates))
_, energies_single = model[0]((species, coordinates))
_, energies_ensemble_jit = loaded_compiled_model((species, coordinates))
_, energies_single_jit = loaded_compiled_model0((species, coordinates))
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())

0 comments on commit 4bacf40

Please sign in to comment.