Skip to content

Commit

Permalink
Fix CUDA support in ANIModel and EnergyShifter (#341)
Browse files Browse the repository at this point in the history
* fix device type

* Update utils.py

* Update utils.py

fix typo!
  • Loading branch information
farhadrgh committed Oct 17, 2019
1 parent 6ee3673 commit 9036b44
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torchani/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, species_aev):
aev = aev.flatten(0, 1)

output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype)
dtype=aev.dtype, device=species.device)
i = 0
for m in self.module_list:
mask = (species_ == i)
Expand Down
2 changes: 1 addition & 1 deletion torchani/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def sae(self, species):
intercept = self.self_energies[-1]

self_energies = self.self_energies[species]
self_energies[species == torch.tensor(-1)] = torch.tensor(0)
self_energies[species == torch.tensor(-1, device=species.device)] = torch.tensor(0, device=species.device)
return self_energies.sum(dim=1) + intercept

def subtract_from_dataset(self, atomic_properties, properties):
Expand Down

0 comments on commit 9036b44

Please sign in to comment.