Skip to content

Commit

Permalink
Merge pull request #270 from OMalenfantThuot/ensemble
Browse files Browse the repository at this point in the history
Ensemble
  • Loading branch information
OMalenfantThuot committed Nov 15, 2021
2 parents 666f692 + 3247b86 commit 792bcb1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mlcalcdriver/base/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def run(self, property, batch_size=128, finite_difference=False):

for pred in predictions.keys():
# Future proofing, will probably need some work
if pred in ["energy", "gap"]:
if pred in ["energy", "energy_std", "gap"]:
if self.calculator.units["energy"] == "hartree":
predictions[pred] *= HA_TO_EV
elif pred == "forces":
elif pred in ["forces", "forces_std"]:
if self.calculator.units["energy"] == "hartree":
predictions[pred] *= HA_TO_EV
if self.calculator.units["positions"] == "atomic":
Expand Down
3 changes: 2 additions & 1 deletion mlcalcdriver/calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
`_get_available_properties()` methods must be defined, similarly
to the :class:`SchnetPackCalculator` class.
"""
from .calculator import *
from .calculator import Calculator
from .schnetpack import SchnetPackCalculator
from .ensemble import Ensemble, EnsembleCalculator, AseEnsembleCalculator
110 changes: 110 additions & 0 deletions mlcalcdriver/calculators/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import numpy as np
import mlcalcdriver.base as base
from mlcalcdriver.globals import eVA
import mlcalcdriver.calculators as mlc
from ase.calculators.calculator import Calculator, all_changes


class Ensemble:
def __init__(self, modelpaths, device="cpu", units=eVA):
self.modelpaths = modelpaths
self.models = self._load_models(device, units)

@property
def modelpaths(self):
return self._modelpaths

@modelpaths.setter
def modelpaths(self, modelpaths):
if not isinstance(modelpaths, (list, tuple, set)):
raise TypeError("The modelpaths should be given in a list, tuple, or set.")
self._modelpaths = modelpaths

@property
def models(self):
return self._models

@models.setter
def models(self, models):
self._models = models

def _load_models(self, device, units):
models = []
for path in self.modelpaths:
try:
models.append(
mlc.SchnetPackCalculator(path, device=device, units=units)
)
except Exception:
raise Exception
return models

def run(self, property, posinp=None):
results = []
for i, model in enumerate(self.models):
job = base.Job(posinp=posinp, calculator=model)
job.run(property, batch_size=1)
results.append(job.results[property][np.newaxis, ...])

result = np.mean(np.concatenate(results, axis=0), axis=0)
result_std = np.std(np.concatenate(results, axis=0), axis=0)
return {property: result, property + "_std": result_std}


class EnsembleCalculator(mlc.Calculator):
def __init__(self, modelpaths, device="cpu", available_properties=None, units=eVA):
self.ensemble = Ensemble(modelpaths, device=device, units=units)
super(EnsembleCalculator, self).__init__(
available_properties=available_properties, units=units
)

@property
def ensemble(self):
return self._ensemble

@ensemble.setter
def ensemble(self, ensemble):
self._ensemble = ensemble

def run(self, property, posinp=None, batch_size=None):
return self.ensemble.run(property, posinp=posinp)

def _get_available_properties(self):
all_props = [model.available_properties for model in self.ensemble.models]
avail_prop = []
for prop in all_props[0]:
if all(prop in el for el in all_props):
avail_prop.append(prop)
return avail_prop


class AseEnsembleCalculator(Calculator):
def __init__(self, modelpaths, available_properties=None, device="cpu", **kwargs):
Calculator.__init__(self, **kwargs)
self.ensemblecalc = EnsembleCalculator(
modelpaths=modelpaths,
device=device,
available_properties=available_properties,
)
self.implemented_properties = (
self.ensemblecalc._get_available_properties()
)
if (
"energy" in self.implemented_properties
and "forces" not in self.implemented_properties
):
self.implemented_properties.append("forces")

def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
if self.calculation_required(atoms, properties):
Calculator.calculate(self, atoms)
posinp = base.Posinp.from_ase(atoms)

job = base.Job(posinp=posinp, calculator=self.ensemblecalc)
for prop in properties:
job.run(prop)
results = {}
for prop, result in zip(job.results.keys(), job.results.values()):
results[prop] = np.squeeze(result)
self.results = results

0 comments on commit 792bcb1

Please sign in to comment.