Skip to content

Commit

Permalink
Accelerate angular AEV computation and reduce memory cost (#290)
Browse files Browse the repository at this point in the history
* Accerate angular AEV computation and reduce memory cost

* reduce number of elementwise product
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Aug 13, 2019
1 parent 560d37a commit 920666f
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
return torch.where(
distances <= cutoff,
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
torch.zeros_like(distances)
)
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5


# @torch.jit.script
Expand Down Expand Up @@ -270,9 +267,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
num_molecules = species.shape[0]
num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)

atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
species = species.flatten()
coordinates = coordinates.flatten(0, 1)
species1 = species[atom_index1]
Expand All @@ -291,6 +287,15 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)

# Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources
# Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca
even_closer_indices = (distances <= Rca).nonzero().flatten()
atom_index1 = atom_index1.index_select(0, even_closer_indices)
atom_index2 = atom_index2.index_select(0, even_closer_indices)
species1 = species1.index_select(0, even_closer_indices)
species2 = species2.index_select(0, even_closer_indices)
vec = vec.index_select(0, even_closer_indices)

# compute angular aev
central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
Expand Down Expand Up @@ -338,6 +343,7 @@ def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super(AEVComputer, self).__init__()
self.Rcr = Rcr
self.Rca = Rca
assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr"
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1))
Expand Down

0 comments on commit 920666f

Please sign in to comment.