Skip to content

Commit

Permalink
Use namedtuple to improve API while still maintaining backward compat…
Browse files Browse the repository at this point in the history
…ibility (#380)

* Use namedtuple to improve API

* improve
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 8, 2019
1 parent 92c307d commit 004f5a5
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 31 deletions.
1 change: 0 additions & 1 deletion examples/ase_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase.optimize import BFGS
Expand Down
3 changes: 1 addition & 2 deletions examples/energy_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
import torch
import torchani

Expand Down Expand Up @@ -43,7 +42,7 @@

###############################################################################
# Now let's compute energy and force:
_, energy = model((species, coordinates))
energy = model((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative

Expand Down
8 changes: 4 additions & 4 deletions examples/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@

###############################################################################
# And here is the result:
_, energies_ensemble = model((species, coordinates))
_, energies_single = model[0]((species, coordinates))
_, energies_ensemble_jit = loaded_compiled_model((species, coordinates))
_, energies_single_jit = loaded_compiled_model0((species, coordinates))
energies_ensemble = model((species, coordinates)).energies
energies_single = model[0]((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
4 changes: 2 additions & 2 deletions examples/load_from_neurochem.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

###############################################################################
# Now let's compute energies using the ensemble directly:
_, energy = nnp1((species, coordinates))
energy = nnp1((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative
print('Energy:', energy.item())
Expand All @@ -89,7 +89,7 @@

###############################################################################
# We can do the same thing with the single model:
_, energy = nnp2((species, coordinates))
energy = nnp2((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative
print('Energy:', energy.item())
Expand Down
4 changes: 2 additions & 2 deletions examples/nnp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def validate():
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
Expand Down Expand Up @@ -343,7 +343,7 @@ def validate():

for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)

num_atoms = torch.cat(num_atoms)
Expand Down
4 changes: 2 additions & 2 deletions examples/nnp_training_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def validate():
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
Expand Down Expand Up @@ -299,7 +299,7 @@ def validate():
# that we could compute force from it
chunk_coordinates.requires_grad_(True)

_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies

# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
Expand Down
2 changes: 1 addition & 1 deletion examples/vibration_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
_, energies = model((species, coordinates))
energies = model((species, coordinates)).energies

###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
Expand Down
14 changes: 10 additions & 4 deletions torchani/aev.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import torch
from torch import Tensor
import math
from typing import Tuple, Optional
from typing import Tuple, Optional, NamedTuple
from torch.jit import Final


class SpeciesAEV(NamedTuple):
species: Tensor
aevs: Tensor


def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
Expand Down Expand Up @@ -356,7 +361,7 @@ def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA

def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
pbc: Optional[Tensor] = None) -> SpeciesAEV:
"""Compute AEVs
Arguments:
Expand Down Expand Up @@ -384,7 +389,7 @@ def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
for that direction.
Returns:
tuple: Species and AEVs. species are the species from the input
NamedTuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
Expand All @@ -398,4 +403,5 @@ def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)

return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
aev = compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
return SpeciesAEV(species, aev)
6 changes: 3 additions & 3 deletions torchani/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def calculate(self, atoms=None, properties=['energy'],
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
else:
_, aev = self.aev_computer((species, coordinates))
aev = self.aev_computer((species, coordinates)).aevs

_, energy = self.nn((species, aev))
energy = self.nn((species, aev)).energies
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
Expand Down
17 changes: 11 additions & 6 deletions torchani/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import torch
from torch import Tensor
from typing import Tuple
from typing import Tuple, NamedTuple


class SpeciesEnergies(NamedTuple):
species: Tensor
energies: Tensor


class ANIModel(torch.nn.Module):
Expand All @@ -26,7 +31,7 @@ def __init__(self, modules):
def __getitem__(self, i):
return self.module_list[i]

def forward(self, species_aev: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
species, aev = species_aev
species_ = species.flatten()
aev = aev.flatten(0, 1)
Expand All @@ -40,7 +45,7 @@ def forward(self, species_aev: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return species, torch.sum(output, dim=1)
return SpeciesEnergies(species, torch.sum(output, dim=1))


class Ensemble(torch.nn.Module):
Expand All @@ -51,12 +56,12 @@ def __init__(self, modules):
self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list)

def forward(self, species_input: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_input: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
sum_ = 0
for x in self.modules_list:
sum_ += x(species_input)[1]
species, _ = species_input
return species, sum_ / self.size
return SpeciesEnergies(species, sum_ / self.size)

def __getitem__(self, i):
return self.modules_list[i]
Expand All @@ -69,7 +74,7 @@ def __init__(self, *modules):
super(Sequential, self).__init__()
self.modules_list = torch.nn.ModuleList(modules)

def forward(self, input_: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, input_: Tuple[Tensor, Tensor]):
for module in self.modules_list:
input_ = module(input_)
return input_
Expand Down
14 changes: 10 additions & 4 deletions torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import math
import numpy as np
from collections import defaultdict
from typing import Tuple
from typing import Tuple, NamedTuple
from .nn import SpeciesEnergies


def pad(species):
Expand Down Expand Up @@ -211,12 +212,12 @@ def subtract_from_dataset(self, atomic_properties, properties):
properties['energies'] = energies
return atomic_properties, properties

def forward(self, species_energies: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return species, energies.to(sae.dtype) + sae
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)


class ChemicalSymbolsToInts:
Expand Down Expand Up @@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
], dim=1)


class FreqsModes(NamedTuple):
freqs: Tensor
modes: Tensor


def vibrational_analysis(masses, hessian, unit='cm^-1'):
"""Computing the vibrational wavenumbers from hessian."""
if unit != 'cm^-1':
Expand All @@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers = frequencies * 17092
modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3)
return wavenumbers, modes
return FreqsModes(wavenumbers, modes)


__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
Expand Down

0 comments on commit 004f5a5

Please sign in to comment.