Skip to content

Commit

Permalink
Use @ syntax (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 7, 2019
1 parent ffb075e commit 3af8c49
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
shifts_all = torch.cat([shifts_center, shifts_outide])
p1_all = torch.cat([p1_center, p1])
p2_all = torch.cat([p2_center, p2])
shift_values = torch.mm(shifts_all.to(cell.dtype), cell)
shift_values = shifts_all.to(cell.dtype) @ cell

# step 5, compute distances, and find all pairs within cutoff
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1)
Expand Down Expand Up @@ -241,7 +241,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
coordinates = coordinates.flatten(0, 1)
species1 = species[atom_index1]
species2 = species[atom_index2]
shift_values = torch.mm(shifts.to(cell.dtype), cell)
shift_values = shifts.to(cell.dtype) @ cell

vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
distances = vec.norm(2, -1)
Expand Down

0 comments on commit 3af8c49

Please sign in to comment.