Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed Aug 7, 2020
1 parent 37077b0 commit 8898b76
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 43 deletions.
49 changes: 23 additions & 26 deletions aiida_phonopy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,6 @@ def generate_phono3py_cells(phonon_settings,
return return_vals


@calcfunction
def check_imported_supercell_structure(supercell_ref,
supercell_calc,
symmetry_tolerance):
symprec = symmetry_tolerance.value
cell_diff = np.subtract(supercell_ref.cell, supercell_calc.cell)
if (np.abs(cell_diff) > symprec).any():
succeeded = Bool(False)
succeeded.label = "False"
return succeeded

positions_ref = [site.position for site in supercell_ref.sites]
positions_calc = [site.position for site in supercell_calc.sites]
diff = np.subtract(positions_ref, positions_calc)
diff -= np.rint(diff)
dist = np.sqrt(np.sum(np.dot(diff, supercell_ref.cell) ** 2, axis=1))
if (dist > symprec).any():
succeeded = Bool(False)
succeeded.label = "False"
return succeeded

succeeded = Bool(True)
succeeded.label = "True"
return succeeded


@calcfunction
def get_vasp_force_sets_dict(**forces_dict):
forces = []
Expand Down Expand Up @@ -239,6 +213,29 @@ def get_data_from_node_id(node_id):
raise RuntimeError("Forces or NAC params were not found.")


def compare_structures(cell_ref, cell_calc, symmetry_tolerance):
symprec = symmetry_tolerance.value
cell_diff = np.subtract(cell_ref.cell, cell_calc.cell)
if (np.abs(cell_diff) > symprec).any():
succeeded = Bool(False)
succeeded.label = "False"
return succeeded

positions_ref = [site.position for site in cell_ref.sites]
positions_calc = [site.position for site in cell_calc.sites]
diff = np.subtract(positions_ref, positions_calc)
diff -= np.rint(diff)
dist = np.sqrt(np.sum(np.dot(diff, cell_ref.cell) ** 2, axis=1))
if (dist > symprec).any():
succeeded = Bool(False)
succeeded.label = "False"
return succeeded

succeeded = Bool(True)
succeeded.label = "True"
return succeeded


def get_mesh_property_data(ph, mesh):
ph.set_mesh(mesh)
ph.run_total_dos()
Expand Down
19 changes: 9 additions & 10 deletions aiida_phonopy/workflows/phono3py.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from aiida.engine import WorkChain
from aiida.plugins import WorkflowFactory, DataFactory
from aiida.orm import Float, Bool, Str, Code
from aiida.orm import Float, Bool
from aiida.engine import if_
from aiida_phonopy.common.builders import (
get_calcjob_builder, get_immigrant_builder)
from aiida_phonopy.common.utils import (
generate_phono3py_cells, get_nac_params,
get_vasp_force_sets_dict, collect_vasp_forces_and_energies,
check_imported_supercell_structure)
compare_structures)


PhonopyWorkChain = WorkflowFactory('phonopy.phonopy')
Expand Down Expand Up @@ -37,7 +37,7 @@ def define(cls, spec):
cls.initialize,
if_(cls.import_calculations_from_files)(
cls.read_force_and_nac_calculations_from_files,
cls.check_imported_supercell_structures,
cls.check_imported_structures,
).else_(
cls.run_force_and_nac_calculations,
),
Expand Down Expand Up @@ -244,7 +244,7 @@ def read_force_and_nac_calculations_from_files(self):
self.report('{} pk = {}'.format(label, future.pk))
self.to_context(**{label: future})

def check_imported_supercell_structures(self):
def check_imported_structures(self):
self.report('check imported supercell structures')

msg = ("Immigrant failed because of inconsistency of supercell"
Expand All @@ -259,7 +259,7 @@ def check_imported_supercell_structures(self):
calc_dict = calc.inputs
supercell_ref = self.ctx.supercells["supercell_%s" % num]
supercell_calc = calc_dict['structure']
if not check_imported_supercell_structure(
if not compare_structures(
supercell_ref,
supercell_calc,
self.inputs.symmetry_tolerance):
Expand All @@ -275,10 +275,9 @@ def check_imported_supercell_structures(self):
supercell_ref = self.ctx.phonon_supercells[
"phonon_supercell_%s" % num]
supercell_calc = calc_dict['structure']
if not check_imported_supercell_structure(
supercell_ref,
supercell_calc,
self.inputs.symmetry_tolerance):
if not compare_structures(supercell_ref,
supercell_calc,
self.inputs.symmetry_tolerance):
raise RuntimeError(msg)

def create_force_sets(self):
Expand Down Expand Up @@ -319,7 +318,7 @@ def create_nac_params(self):
"in the calculation. Please check the calculation setting.")

kwargs = {}
if self.import_calculations():
if self.import_calculations_from_files():
kwargs['primitive'] = self.ctx.primitive
self.ctx.nac_params = get_nac_params(
calc_dict['born_charges'],
Expand Down
13 changes: 6 additions & 7 deletions aiida_phonopy/workflows/phonopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
get_calcjob_builder, get_immigrant_builder)
from aiida_phonopy.common.utils import (
get_force_constants, get_nac_params, get_phonon,
generate_phonopy_cells, check_imported_supercell_structure,
generate_phonopy_cells, compare_structures,
from_node_id_to_aiida_node_id, get_data_from_node_id,
get_vasp_force_sets_dict, collect_vasp_forces_and_energies)

Expand Down Expand Up @@ -108,7 +108,7 @@ def define(cls, spec):
if_(cls.import_calculations_from_nodes)(
cls.read_calculation_data_from_nodes,
),
cls.check_imported_supercell_structures,
cls.check_imported_structures,
).else_(
cls.run_force_and_nac_calculations,
),
Expand Down Expand Up @@ -274,7 +274,7 @@ def read_calculation_data_from_nodes(self):
# self.ctx[label]['dielectrics'] -> ArrayData()('epsilon')
self.ctx[label] = get_data_from_node_id(aiida_node_id)

def check_imported_supercell_structures(self):
def check_imported_structures(self):
self.report('check imported supercell structures')

msg = ("Immigrant failed because of inconsistency of supercell"
Expand All @@ -289,10 +289,9 @@ def check_imported_supercell_structures(self):
calc_dict = calc.inputs
supercell_ref = self.ctx.supercells["supercell_%s" % num]
supercell_calc = calc_dict['structure']
if not check_imported_supercell_structure(
supercell_ref,
supercell_calc,
self.inputs.symmetry_tolerance):
if not compare_structures(supercell_ref,
supercell_calc,
self.inputs.symmetry_tolerance):
raise RuntimeError(msg)

def postprocess_of_dry_run(self):
Expand Down

0 comments on commit 8898b76

Please sign in to comment.