diff --git a/mlcalcdriver/calculators/ase_calculators/asespkcalculator.py b/mlcalcdriver/calculators/ase_calculators/asespkcalculator.py index 2d5ee1d..2344c23 100644 --- a/mlcalcdriver/calculators/ase_calculators/asespkcalculator.py +++ b/mlcalcdriver/calculators/ase_calculators/asespkcalculator.py @@ -7,12 +7,13 @@ class AseSpkCalculator(Calculator): - def __init__(self, model_dir, available_properties=None, device="cpu", **kwargs): + def __init__(self, model_dir, available_properties=None, device="cpu", md=False, **kwargs): Calculator.__init__(self, **kwargs) self.schnetpackcalculator = SchnetPackCalculator( model_dir=model_dir, available_properties=available_properties, device=device, + md=md, ) self.implemented_properties = self.schnetpackcalculator._get_available_properties() if "energy" in self.implemented_properties and "forces" not in self.implemented_properties: diff --git a/mlcalcdriver/calculators/schnetpack.py b/mlcalcdriver/calculators/schnetpack.py index b17848d..a5e3608 100644 --- a/mlcalcdriver/calculators/schnetpack.py +++ b/mlcalcdriver/calculators/schnetpack.py @@ -22,11 +22,7 @@ class SchnetPackCalculator(Calculator): """ def __init__( - self, - model_dir, - available_properties=None, - device="cpu", - units=eVA, + self, model_dir, available_properties=None, device="cpu", units=eVA, md=False, ): r""" Parameters @@ -47,6 +43,7 @@ def __init__( a SchnetPackCalculator. """ self.device = device + self.md = md try: self.model = load_model(model_dir, map_location=self.device) except Exception: @@ -65,6 +62,15 @@ def device(self): def device(self, device): self._device = str(device).lower() + @property + def md(self): + return self._md + + @md.setter + def md(self, md): + assert isinstance(md, bool) + self._md = md + def run( self, property, posinp=None, batch_size=128, ): @@ -148,14 +154,20 @@ def run( deriv2 = -1.0 * deriv2 pred.append({out_name: deriv2}) predictions = {} - if derivative: - predictions[property] = np.concatenate( - [batch[out_name].cpu().detach().numpy() for batch in pred] - ) + if self.md: + for p in ["energy", "forces"]: + predictions[p] = np.concatenate( + [batch[p].cpu().detach().numpy() for batch in pred] + ) else: - predictions[property] = np.concatenate( - [batch[init_property].cpu().detach().numpy() for batch in pred] - ) + if derivative: + predictions[property] = np.concatenate( + [batch[out_name].cpu().detach().numpy() for batch in pred] + ) + else: + predictions[property] = np.concatenate( + [batch[init_property].cpu().detach().numpy() for batch in pred] + ) return predictions def _get_available_properties(self): diff --git a/requirements_dev.txt b/requirements_dev.txt index cc9db37..a71d208 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,17 +1,16 @@ -pip==21.2.4 +pip==21.3 wheel==0.37.0 -black==21.9b0 -flake8==3.9.2 -numpy==1.21.2 +black==21.11b1 +flake8==4.0.1 +numpy==1.21.4 torch==1.9.1 -torchvision==0.10.1 -twine==3.4.2 +twine==3.6.0 pytest==6.2.5 -pytest-cov==2.12.1 +pytest-cov==3.0.0 pytest-sugar==0.9.4 -coveralls==3.2.0 +coveralls==3.3.1 ase==3.22.0 git+https://github.com/atomistic-machine-learning/schnetpack.git -sphinx==4.2.0 -sphinx-rtd-theme==0.5.2 +sphinx==4.3.1 +sphinx-rtd-theme==1.0.0 git+https://github.com/crossnox/m2r@dev#egg=m2r