Skip to content

Commit

Permalink
improve analytical stress calculation (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and farhadrgh committed Nov 12, 2019
1 parent e784666 commit d400d8f
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions torchani/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.flo
self.energy_shifter
).to(dtype)

@staticmethod
def strain(tensor, displacement, surface_normal_axis):
displacement_of_tensor = torch.zeros_like(tensor)
for axis in range(3):
displacement_of_tensor[..., axis] = tensor[..., surface_normal_axis] * displacement[axis]
return displacement_of_tensor

def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
Expand All @@ -70,7 +63,7 @@ def calculate(self, atoms=None, properties=['energy'],
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
coordinates = coordinates.to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties)

if pbc_enabled:
Expand All @@ -79,20 +72,13 @@ def calculate(self, atoms=None, properties=['energy'],
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())

if 'stress' in properties:
displacements = torch.zeros(3, 3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_x, displacement_y, displacement_z = displacements
strain_x = self.strain(coordinates, displacement_x, 0)
strain_y = self.strain(coordinates, displacement_y, 1)
strain_z = self.strain(coordinates, displacement_z, 2)
coordinates = coordinates + strain_x + strain_y + strain_z
scaling = torch.eye(3, requires_grad=True, dtype=self.dtype, device=self.device)
coordinates = coordinates @ scaling
coordinates = coordinates.unsqueeze(0)

if pbc_enabled:
if 'stress' in properties:
strain_x = self.strain(cell, displacement_x, 0)
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
cell = cell @ scaling
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
else:
aev = self.aev_computer((species, coordinates)).aevs
Expand All @@ -108,5 +94,5 @@ def calculate(self, atoms=None, properties=['energy'],

if 'stress' in properties:
volume = self.atoms.get_volume()
stress = torch.autograd.grad(energy.squeeze(), displacements)[0] / volume
stress = torch.autograd.grad(energy.squeeze(), scaling)[0] / volume
self.results['stress'] = stress.cpu().numpy()

0 comments on commit d400d8f

Please sign in to comment.