Skip to content

Commit

Permalink
Phonopy WorkChain QE support
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed May 31, 2021
1 parent dd5688d commit dd066db
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 58 deletions.
39 changes: 31 additions & 8 deletions aiida_phonopy/common/builders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utilities related to process builder or inputs dist."""

from aiida.engine import calcfunction
from aiida.plugins import DataFactory, WorkflowFactory, CalculationFactory
from aiida.common import InputValidationError
Expand Down Expand Up @@ -36,6 +38,23 @@ def get_nac_calcjob_inputs(calculator_settings, unitcell):
return _get_calcjob_inputs(calculator_settings, unitcell, 'nac')


def get_plugin_names(calculator_settings):
"""Return plugin names of calculators."""
code_strings = []
if 'sequence' in calculator_settings.keys():
for key in calculator_settings['sequence']:
code_strings.append(calculator_settings[key]['code_string'])
else:
code_strings.append(calculator_settings['code_string'])

plugin_names = []
for code_string in code_strings:
code = Code.get_from_string(code_string)
plugin_names.append(code.get_input_plugin_name())

return plugin_names


def _get_calcjob_inputs(calculator_settings, structure, calc_type=None,
label=None, ctx=None):
"""Return builder inputs of a calculation."""
Expand Down Expand Up @@ -82,7 +101,7 @@ def _get_calcjob_inputs(calculator_settings, structure, calc_type=None,
'label': label},
'qpoints': qpoints,
'parameters': _get_parameters(settings),
'parent_folder': ctx['nac_params_1'].outputs.remote_folder,
'parent_folder': ctx.nac_params_calcs[0].outputs.remote_folder,
'code': code}
builder_inputs = {'ph': ph}
else:
Expand All @@ -91,14 +110,17 @@ def _get_calcjob_inputs(calculator_settings, structure, calc_type=None,
return builder_inputs


def get_calculator_process(code_string):
def get_calculator_process(code_string=None, plugin_name=None):
"""Return WorkChain or CalcJob."""
code = Code.get_from_string(code_string)
plugin_name = code.get_input_plugin_name()
if plugin_name == 'vasp.vasp':
return WorkflowFactory(plugin_name)
elif plugin_name in ('quantumespresso.pw', 'quantumespresso.ph'):
return WorkflowFactory(plugin_name + ".base")
if plugin_name is None:
code = Code.get_from_string(code_string)
_plugin_name = code.get_input_plugin_name()
else:
_plugin_name = plugin_name
if _plugin_name == 'vasp.vasp':
return WorkflowFactory(_plugin_name)
elif _plugin_name in ('quantumespresso.pw', 'quantumespresso.ph'):
return WorkflowFactory(_plugin_name + ".base")
else:
raise RuntimeError("Code could not be found.")

Expand Down Expand Up @@ -134,6 +156,7 @@ def get_calcjob_builder(structure, code_string, builder_inputs, label=None):
def get_immigrant_builder(calculation_folder,
calculator_settings,
calc_type=None):
"""Return VASP immigrant builder."""
if calc_type:
code = Code.get_from_string(
calculator_settings[calc_type]['code_string'])
Expand Down
6 changes: 4 additions & 2 deletions aiida_phonopy/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""General utilities."""

import numpy as np
from aiida.engine import calcfunction
from aiida.plugins import DataFactory
Expand Down Expand Up @@ -164,7 +166,7 @@ def get_phonon_properties(structure,
total_dos, pdos, thermal_properties = get_mesh_property_data(ph, mesh)

# Band structure
bs = get_bands_data(ph)
bs = _get_bands_data(ph)

return {'dos': total_dos,
'pdos': pdos,
Expand Down Expand Up @@ -277,7 +279,7 @@ def get_thermal_properties(thermal_properties):
return tprops


def get_bands_data(ph):
def _get_bands_data(ph):
ph.auto_band_structure()
labels = [x.replace('$', '').replace('\\', '').replace('mathrm{', '').replace('}', '').upper()
for x in ph.band_structure.labels]
Expand Down
4 changes: 2 additions & 2 deletions aiida_phonopy/workflows/forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _get_forces(outputs, code_string):
return None
elif plugin_name == 'quantumespresso.pw':
if ('output_trajectory' in outputs and
'forces' in outputs.output_trajectory.get_arraynames()):
'forces' in outputs.output_trajectory.get_arraynames()): # noqa: E129 E501
forces_data = get_qe_forces(outputs.output_trajectory)
else:
return None
Expand Down Expand Up @@ -62,7 +62,7 @@ def _get_energy(outputs, code_string):
return None
elif plugin_name == 'quantumespresso.pw':
if ('output_parameters' in outputs and
'energy' in outputs.output_parameters.keys()):
'energy' in outputs.output_parameters.keys()): # noqa: E129
energy_data = get_qe_energy(outputs.output_parameters)
return energy_data

Expand Down
121 changes: 75 additions & 46 deletions aiida_phonopy/workflows/nac_params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Workflow to calculate NAC params."""

from aiida.engine import WorkChain, calcfunction, if_, while_
from aiida.engine import WorkChain, calcfunction, while_, append_
from aiida.plugins import DataFactory
from aiida_phonopy.common.builders import (
get_calcjob_builder, get_calcjob_inputs, get_calculator_process)
get_calcjob_inputs, get_calculator_process, get_plugin_names)
from aiida_phonopy.common.utils import phonopy_atoms_from_structure
from phonopy.structure.symmetry import symmetrize_borns_and_epsilon

Expand All @@ -13,9 +13,7 @@
StructureData = DataFactory('structure')


@calcfunction
def get_nac_params(born_charges, epsilon, nac_structure, symmetry_tolerance,
primitive=None):
def _get_nac_params(ctx, symmetry_tolerance):
"""Obtain Born effective charges and dielectric constants in primitive cell.
When Born effective charges and dielectric constants are calculated within
Expand All @@ -25,27 +23,74 @@ def get_nac_params(born_charges, epsilon, nac_structure, symmetry_tolerance,
needs information of the structure where those values were calcualted and
the target primitive cell structure.
Two kargs parameters
primitive : StructureData
symmetry_tolerance : Float
"""
borns = born_charges.get_array('born_charges')
eps = epsilon.get_array('epsilon')
if ctx.plugin_names[0] == 'vasp.vasp':
calc = ctx.nac_params_calcs[0]
nac_params = get_vasp_nac_params(calc.outputs.born_charges,
calc.outputs.dielectrics,
calc.inputs.structure,
symmetry_tolerance)
elif ctx.plugin_names[0] == 'quantumespresso.pw':
pw_calc = ctx.nac_params_calcs[0]
ph_calc = ctx.nac_params_calcs[1]
nac_params = get_qe_nac_params(ph_calc.outputs.output_parameters,
pw_calc.inputs.pw.structure,
symmetry_tolerance)
else:
nac_params = None
return nac_params


@calcfunction
def get_qe_nac_params(output_parameters,
structure,
symmetry_tolerance,
primitive=None):
"""Return NAC params ArrayData created from QE results."""
nac_params = _get_nac_params_array(
output_parameters['effective_charges_eu'],
output_parameters['dielectric_constant'],
structure,
symmetry_tolerance.value,
primitive=primitive)
return nac_params


@calcfunction
def get_vasp_nac_params(born_charges,
epsilon,
structure,
symmetry_tolerance,
primitive=None):
"""Return NAC params ArrayData created from VASP results."""
nac_params = _get_nac_params_array(born_charges.get_array('born_charges'),
epsilon.get_array('epsilon'),
structure,
symmetry_tolerance.value,
primitive=primitive)
return nac_params


nac_cell = phonopy_atoms_from_structure(nac_structure)
kwargs = {}
kwargs['symprec'] = symmetry_tolerance.value
def _get_nac_params_array(born_charges,
epsilon,
structure,
symmetry_tolerance,
primitive=None):
phonopy_cell = phonopy_atoms_from_structure(structure)
if primitive is not None:
kwargs['primitive'] = phonopy_atoms_from_structure(primitive)
phonopy_primitive = phonopy_atoms_from_structure(primitive)
else:
phonopy_primitive = None
borns_, epsilon_ = symmetrize_borns_and_epsilon(
borns, eps, nac_cell, **kwargs)

born_charges,
epsilon,
phonopy_cell,
symprec=symmetry_tolerance,
primitive=phonopy_primitive)
nac_params = ArrayData()
nac_params.set_array('born_charges', borns_)
nac_params.set_array('epsilon', epsilon_)
nac_params.label = 'born_charges & epsilon'

return nac_params


Expand All @@ -72,13 +117,8 @@ def define(cls, spec):
spec.output('nac_params', valid_type=ArrayData, required=True)

spec.exit_code(
1001, 'ERROR_NO_BORN_EFFECTIVE_CHARGES',
message=('Born effecti charges could not be retrieved '
'from calculaton.'))
spec.exit_code(
1002, 'ERROR_NO_DIELECTRIC_CONSTANT',
message=('dielectric constant could not be retrieved '
'from calculaton.'))
1001, 'ERROR_NO_NAC_PARAMS',
message=('NAC params could not be retrieved from calculaton.'))

def continue_calculation(self):
"""Return boolen for outline."""
Expand All @@ -97,6 +137,9 @@ def initialize(self):
else:
self.ctx.max_iteration = 1

self.ctx.plugin_names = get_plugin_names(
self.inputs.calculator_settings)

def run_calculation(self):
"""Run NAC params calculation."""
self.report('calculation iteration %d/%d'
Expand All @@ -106,33 +149,19 @@ def run_calculation(self):
self.inputs.structure,
ctx=self.ctx,
label=label)
if 'sequence' in self.inputs.calculator_settings.keys():
i = self.ctx.iteration - 1
key = self.inputs.calculator_settings['sequence'][i]
code_string = self.inputs.calculator_settings[key]['code_string']
else:
code_string = self.inputs.calculator_settings['code_string']
CalculatorProcess = get_calculator_process(code_string)
i = self.ctx.iteration - 1
CalculatorProcess = get_calculator_process(
plugin_name=self.ctx.plugin_names[i])
future = self.submit(CalculatorProcess, **process_inputs)
self.report('nac_params: {}'.format(future.pk))
self.to_context(**{label: future})
self.to_context(nac_params_calcs=append_(future))

def finalize(self):
"""Finalize NAC params calculation."""
self.report('finalization')

calc = self.ctx['nac_params_1']
calc_dict = calc.outputs
structure = calc.inputs.structure

if 'born_charges' not in calc_dict:
return self.exit_codes.ERROR_NO_BORN_EFFECTIVE_CHARGES

if 'dielectrics' not in calc_dict:
return self.exit_codes.ERROR_NO_DIELECTRIC_CONSTANT
nac_params = _get_nac_params(self.ctx, self.inputs.symmetry_tolerance)
if nac_params is None:
return self.exit_codes.ERROR_NO_NAC_PARAMS

nac_params = get_nac_params(calc_dict['born_charges'],
calc_dict['dielectrics'],
structure,
self.inputs.symmetry_tolerance)
self.out('nac_params', nac_params)

0 comments on commit dd066db

Please sign in to comment.