Skip to content

Commit

Permalink
Fix dtype in self_energies (#347)
Browse files Browse the repository at this point in the history
* Update utils.py

* Update utils.py

* Update utils.py

* Update utils.py

* remove float32 dtypes from comp6.py
  • Loading branch information
farhadrgh authored and zasdfgbnm committed Oct 24, 2019
1 parent e7af3cb commit 41aa0f4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 9 deletions.
12 changes: 4 additions & 8 deletions tools/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


HARTREE2KCAL = 627.509
dtype = torch.float32

# parse command line arguments
parser = argparse.ArgumentParser()
Expand All @@ -21,7 +20,7 @@
parser = parser.parse_args()

# run benchmark
ani1x = torchani.models.ANI1x().to(dtype).to(parser.device)
ani1x = torchani.models.ANI1x().to(parser.device)


def recursive_h5_files(base):
Expand Down Expand Up @@ -80,14 +79,11 @@ def do_benchmark(model):
rmse_averager_force = Averager()
for i in tqdm.tqdm(dataset, position=0, desc="dataset"):
# read
coordinates = torch.tensor(
i['coordinates'], dtype=dtype, device=parser.device)
coordinates = torch.tensor(i['coordinates'], device=parser.device)
species = model.species_to_tensor(i['species']) \
.unsqueeze(0).expand(coordinates.shape[0], -1)
energies = torch.tensor(i['energies'], dtype=dtype,
device=parser.device)
forces = torch.tensor(i['forces'], dtype=dtype,
device=parser.device)
energies = torch.tensor(i['energies'], device=parser.device)
forces = torch.tensor(i['forces'], device=parser.device)
# compute
energies2, forces2 = by_batch(species, coordinates, model)
ediff = energies - energies2
Expand Down
2 changes: 1 addition & 1 deletion torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def sae(self, species):
intercept = self.self_energies[-1]

self_energies = self.self_energies[species]
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device)
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device, dtype=torch.double)
return self_energies.sum(dim=1) + intercept

def subtract_from_dataset(self, atomic_properties, properties):
Expand Down

0 comments on commit 41aa0f4

Please sign in to comment.