In [1]:
import numpy as np
from ase.calculators.calculator import Calculator
import warnings

class Ensemble_Calculator(Calculator):
    implemented_properties = ['energy', 'forces']

    def __init__(self, calculators: list, *args, **kwargs):
        super().__init__(*args, **kwargs)

        num_models = len(calculators)

        # Testing input:
        if num_models == 0:
            raise ValueError('Provided list of calculators is empty (length is 0)')
        
        # Testing type of list
        nun_ASE_calcs = 0
        for calc in calculators:
            if not isinstance(calc, Calculator):
                nun_ASE_calcs += 1
        if nun_ASE_calcs > 0:
            warnings.warn(
                f"{nun_ASE_calcs} out of {len(calculators)} elements of the provided calculators-list "
                "are not ASE calculators. This may result in the failure of this calculator.",
                UserWarning
            )

        self.calculators = calculators
        self.potential_energy_variance = None
        self.forces_variances = None
        self.num_models = num_models

    def calculate(self, atoms, properties, system_changes):
        super().calculate(atoms, properties, system_changes)

        # Check which properties need to be calculated
        energy = 0.0
        forces = None

        if 'energy' in properties:
            # Perform energy calculation here (replace this with your actual calculation)
            energy = self.__calculate_potential_energy(self.atoms)

        if 'forces' in properties:
            # Perform forces calculation here (replace this with your actual calculation)
            forces = self.__calculate_forces(self.atoms)

        # Store the calculated values
        self.results = {'energy': energy, 'forces': forces}

    def get_potential_energy_variance(self):
        return self.potential_energy_variance
    
    def get_potential_energy_standard_deviation(self):
        return np.sqrt(self.potential_energy_variance)
    
    def get_forces_variances(self):
        pass

    def get_forces_standard_deviations(self):
        pass

    def __calculate_potential_energy(self, atoms):

        calc_energies = []
        for calc in self.calculators:
            atoms_copy = atoms.copy()
            atoms.calc = calc
            calc_energies.append(atoms.get_potential_energy())
        

        
        return 0
    
    def __calculate_forces(self, atoms):

        calc_forces = []
        for calc in self.calculators:
            atoms_copy = atoms.copy()
            atoms.calc = calc
            calc_forces.append(atoms.get_forces())
        
        

        return 0



# from yourmodule import SimpleCalculator  # Make sure to replace 'yourmodule' with the actual module name



# Now you can access the energy and forces


In [2]:
from ase.calculators.emt import EMT
from ase import Atoms

atoms = Atoms('H2', positions=[(0, 0, 0), (0, 0, 1)])

calcs = [EMT(), EMT(), EMT()]

ensamble = Ensemble_Calculator(calcs)

atoms.calc = ensamble

atoms.get_forces()

0

In [3]:

import sys
from ase.calculators.calculator import Calculator
from ase.calculators.emt import EMT

from ase_ensemble_calculator.ensemble_calculator import Ensemble_Calculator as ES



calcs = [EMT(), EMT(), EMT()] # 3 EMT calculators which are ASE calculator objects
ensamble = ES(calcs) # Setup Ensemble_Calculator, which sould work for ASE calculators

ensamble.num_calculators

3

In [4]:
atoms = Atoms('H2', positions=[(0, 0, 0), (0, 0, 1)])
atoms.calc = EMT()

forces = []
for i in range(3):
    forces.append(atoms.get_forces() + i)

print(forces)


x = np.vstack((atoms.get_forces(), atoms.get_forces() + 1))

M = len(forces)
print(M)

mean_force = np.mean(forces, axis = 0)

print(mean_force,
       '\n----')

x = (forces - mean_force)**2
print(x,'\n----')

x = np.sum(x, axis = 0)
print(x,'\n----')

x = np.mean(x, axis = 1)
print(x,'\n----')

x = np.sqrt(x)
print(x,'\n----')


[array([[ 0.        ,  0.        ,  8.03580661],
       [ 0.        ,  0.        , -8.03580661]]), array([[ 1.        ,  1.        ,  9.03580661],
       [ 1.        ,  1.        , -7.03580661]]), array([[ 2.        ,  2.        , 10.03580661],
       [ 2.        ,  2.        , -6.03580661]])]
3
[[ 1.          1.          9.03580661]
 [ 1.          1.         -7.03580661]] 
----
[[[1. 1. 1.]
  [1. 1. 1.]]

 [[0. 0. 0.]
  [0. 0. 0.]]

 [[1. 1. 1.]
  [1. 1. 1.]]] 
----
[[2. 2. 2.]
 [2. 2. 2.]] 
----
[2. 2.] 
----
[1.41421356 1.41421356] 
----


In [5]:

forces = [
    np.array([[-1,-1,-1],[-1,-1,-1]]),
    np.array([[0,0,0],[-1,-1,-1]]),
    np.array([[1,1,1],[-1,-1,-1]])
]

forces = [
    np.array([[-3,-3,-3],[-1,-1,-1]]),
    np.array([[3,3,3],[-1,-1,-1]])
]


print(len(forces))

mean_force = np.mean(forces, axis = 0)

x = (forces - mean_force)**2
print(x,'\n----')

x = np.sum(x, axis = 0)
print(x,'\n----')

x = np.sum(x, axis = 1) / (len(forces)*3)
print(x,'\n----')

x = np.sqrt(x)
print(x,'\n----')

2
[[[9. 9. 9.]
  [0. 0. 0.]]

 [[9. 9. 9.]
  [0. 0. 0.]]] 
----
[[18. 18. 18.]
 [ 0.  0.  0.]] 
----
[9. 0.] 
----
[3. 0.] 
----


In [6]:
def test_same_EMT_calculations():

    pos = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0]], dtype=np.float64)
    pos += np.random.rand(5,3) * 0.2
    atoms = Atoms('H2C3', positions = pos)

    calcs = [EMT(), EMT(), EMT()] # 3 EMT calculators which are ASE calculator objects
    ensamble = ES(calcs) # Setup Ensemble_Calculator, which sould work for ASE calculators

    atoms.calc = ensamble

    ens_forces = atoms.get_forces()
    ens_energy = atoms.get_potential_energy()

    
    atoms.calc = EMT()

    emt_energy = atoms.get_potential_energy()
    emt_forces = atoms.get_forces()
    
    print(ensamble.get_forces_variances())

    assert np.allclose(ens_energy, emt_energy)
    assert np.allclose(ens_forces, emt_forces)
    

test_same_EMT_calculations()

8.71233187379983
8.71233187379983
[[-14.18557039 -15.80080304  -7.37813862]
 [ -1.99864314 -12.26794482   6.38568944]
 [  0.11336578  15.87888272  -0.89151758]
 [  1.00083392  11.73111375   2.84335409]
 [ 15.07001382   0.45875139  -0.95938733]]
[[-14.18557039 -15.80080304  -7.37813862]
 [ -1.99864314 -12.26794482   6.38568944]
 [  0.11336578  15.87888272  -0.89151758]
 [  1.00083392  11.73111375   2.84335409]
 [ 15.07001382   0.45875139  -0.95938733]]
