Skip to content

Commit

Permalink
Bring new_zeros back (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 5, 2019
1 parent 123e476 commit f12a27b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchani/aev.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):

# Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = torch.zeros((p1_center.shape[0], 3), dtype=shifts.dtype, device=shifts.device)
shifts_center = shifts.new_zeros((p1_center.shape[0], 3))

# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
Expand Down Expand Up @@ -205,7 +205,7 @@ def convert_pair_index(index):
def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([torch.tensor([0], dtype=input_.dtype, device=input_.device), cumsum[:-1]])
cumsum = torch.cat([input_.new_zeros(1), cumsum[:-1]])
return cumsum


Expand Down Expand Up @@ -275,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes

# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = torch.zeros((num_molecules * num_atoms * num_species, radial_sublength), dtype=radial_terms_.dtype, device=radial_terms_.device)
radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_)
Expand All @@ -298,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = torch.zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength), dtype=angular_terms_.dtype, device=angular_terms_.device)
angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength))
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
Expand Down

0 comments on commit f12a27b

Please sign in to comment.