Skip to content

Commit

Permalink
Merge pull request #276 from OMalenfantThuot/md
Browse files Browse the repository at this point in the history
Md
  • Loading branch information
OMalenfantThuot committed Dec 1, 2021
2 parents 792bcb1 + f023c71 commit 5e57cbf
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
3 changes: 2 additions & 1 deletion mlcalcdriver/calculators/ase_calculators/asespkcalculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

class AseSpkCalculator(Calculator):

def __init__(self, model_dir, available_properties=None, device="cpu", **kwargs):
def __init__(self, model_dir, available_properties=None, device="cpu", md=False, **kwargs):
Calculator.__init__(self, **kwargs)
self.schnetpackcalculator = SchnetPackCalculator(
model_dir=model_dir,
available_properties=available_properties,
device=device,
md=md,
)
self.implemented_properties = self.schnetpackcalculator._get_available_properties()
if "energy" in self.implemented_properties and "forces" not in self.implemented_properties:
Expand Down
36 changes: 24 additions & 12 deletions mlcalcdriver/calculators/schnetpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ class SchnetPackCalculator(Calculator):
"""

def __init__(
self,
model_dir,
available_properties=None,
device="cpu",
units=eVA,
self, model_dir, available_properties=None, device="cpu", units=eVA, md=False,
):
r"""
Parameters
Expand All @@ -47,6 +43,7 @@ def __init__(
a SchnetPackCalculator.
"""
self.device = device
self.md = md
try:
self.model = load_model(model_dir, map_location=self.device)
except Exception:
Expand All @@ -65,6 +62,15 @@ def device(self):
def device(self, device):
self._device = str(device).lower()

@property
def md(self):
return self._md

@md.setter
def md(self, md):
assert isinstance(md, bool)
self._md = md

def run(
self, property, posinp=None, batch_size=128,
):
Expand Down Expand Up @@ -148,14 +154,20 @@ def run(
deriv2 = -1.0 * deriv2
pred.append({out_name: deriv2})
predictions = {}
if derivative:
predictions[property] = np.concatenate(
[batch[out_name].cpu().detach().numpy() for batch in pred]
)
if self.md:
for p in ["energy", "forces"]:
predictions[p] = np.concatenate(
[batch[p].cpu().detach().numpy() for batch in pred]
)
else:
predictions[property] = np.concatenate(
[batch[init_property].cpu().detach().numpy() for batch in pred]
)
if derivative:
predictions[property] = np.concatenate(
[batch[out_name].cpu().detach().numpy() for batch in pred]
)
else:
predictions[property] = np.concatenate(
[batch[init_property].cpu().detach().numpy() for batch in pred]
)
return predictions

def _get_available_properties(self):
Expand Down
19 changes: 9 additions & 10 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
pip==21.2.4
pip==21.3
wheel==0.37.0
black==21.9b0
flake8==3.9.2
numpy==1.21.2
black==21.11b1
flake8==4.0.1
numpy==1.21.4
torch==1.9.1
torchvision==0.10.1
twine==3.4.2
twine==3.6.0
pytest==6.2.5
pytest-cov==2.12.1
pytest-cov==3.0.0
pytest-sugar==0.9.4
coveralls==3.2.0
coveralls==3.3.1
ase==3.22.0
git+https://github.com/atomistic-machine-learning/schnetpack.git
sphinx==4.2.0
sphinx-rtd-theme==0.5.2
sphinx==4.3.1
sphinx-rtd-theme==1.0.0
git+https://github.com/crossnox/m2r@dev#egg=m2r

0 comments on commit 5e57cbf

Please sign in to comment.