Skip to content

Commit

Permalink
Simplify triple_by_molecule (#368)
Browse files Browse the repository at this point in the history
* Simplify triple_by_molecule

* fix

* fix

* fix
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 7, 2019
1 parent 89ff3b4 commit ffb075e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 131 deletions.
1 change: 0 additions & 1 deletion tools/training-benchmark-nsys-profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def enable_timers(model):
torchani.aev.compute_shifts = time_func('compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('compute_aev', torchani.aev.compute_aev)
Expand Down
95 changes: 0 additions & 95 deletions tools/training-benchmark-with-aevcache.py

This file was deleted.

1 change: 0 additions & 1 deletion tools/training-benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def hartree2kcal(x):
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
Expand Down
44 changes: 10 additions & 34 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,6 @@ def triu_index(num_species):
return ret


def convert_pair_index(index):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
n = (torch.sqrt(1.0 + 8.0 * index.to(torch.float)) - 1.0) / 2.0
n = torch.floor(n).to(torch.long)
num_elems = n * (n + 1) / 2
return index - num_elems, n + 1


def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0)
Expand All @@ -219,7 +194,6 @@ def triple_by_molecule(atom_index1, atom_index2):
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-others
n = atom_index1.shape[0]
ai1 = torch.cat([atom_index1, atom_index2])
sorted_ai1, rev_indices = ai1.sort()

Expand All @@ -228,17 +202,18 @@ def triple_by_molecule(atom_index1, atom_index2):
uniqued_central_atom_index = unique_results[0]
counts = unique_results[-1]

# do local combinations within unique key, assuming sorted
# compute central_atom_index
pair_sizes = (counts * (counts - 1) / 2).long()
total_size = pair_sizes.sum()
pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = cumsum.index_select(0, pair_indices)
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device, dtype=torch.long) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = cumsum.index_select(0, pair_indices)

# do local combinations within unique key, assuming sorted
m = counts.max().item() if counts.numel() > 0 else 0
n = pair_sizes.shape[0]
intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).t().unsqueeze(0).expand(n, -1, -1)
mask = (torch.arange(intra_pair_indices.shape[1], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index1, sorted_local_index2 = intra_pair_indices.flatten(0, 1)[mask, :].unbind(-1)
cumsum = cumsum_from_zero(counts).index_select(0, pair_indices)
sorted_local_index1 += cumsum
sorted_local_index2 += cumsum

Expand All @@ -247,6 +222,7 @@ def triple_by_molecule(atom_index1, atom_index2):
local_index2 = rev_indices[sorted_local_index2]

# compute mapping between representation of central-other to pair
n = atom_index1.shape[0]
sign1 = ((local_index1 < n).to(torch.long) * 2) - 1
sign2 = ((local_index2 < n).to(torch.long) * 2) - 1
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
Expand Down

0 comments on commit ffb075e

Please sign in to comment.