Skip to content

Commit

Permalink
Implementing NacParamsWorkChain QE support
Browse files Browse the repository at this point in the history
  • Loading branch information
atztogo committed May 30, 2021
1 parent b48aaaa commit dd5688d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 38 deletions.
35 changes: 26 additions & 9 deletions aiida_phonopy/common/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
PotcarData = DataFactory('vasp.potcar')


def get_calcjob_inputs(calculator_settings, structure, label=None):
def get_calcjob_inputs(calculator_settings, structure,
calc_type=None, label=None, ctx=None):
"""Return builder inputs of a calculation."""
return _get_calcjob_inputs(calculator_settings, structure, label=label)
return _get_calcjob_inputs(calculator_settings, structure,
calc_type=calc_type, label=label, ctx=ctx)


@calcfunction
Expand All @@ -35,12 +37,17 @@ def get_nac_calcjob_inputs(calculator_settings, unitcell):


def _get_calcjob_inputs(calculator_settings, structure, calc_type=None,
label=None):
label=None, ctx=None):
"""Return builder inputs of a calculation."""
if calc_type is None:
settings = calculator_settings
if 'sequence' in calculator_settings.keys():
key = calculator_settings['sequence'][ctx.iteration - 1]
settings = calculator_settings[key]
else:
settings = calculator_settings
else:
settings = Dict(dict=calculator_settings[calc_type])

code = Code.get_from_string(settings['code_string'])
plugin_name = code.get_input_plugin_name()
if plugin_name == 'vasp.vasp':
Expand Down Expand Up @@ -68,6 +75,16 @@ def _get_calcjob_inputs(calculator_settings, structure, calc_type=None,
'code': code}
builder_inputs = {'kpoints': _get_kpoints(settings, structure),
'pw': pw}
elif plugin_name == 'quantumespresso.ph':
qpoints = KpointsData()
qpoints.set_kpoints_mesh([1, 1, 1], offset=[0, 0, 0])
ph = {'metadata': {'options': _get_options(settings),
'label': label},
'qpoints': qpoints,
'parameters': _get_parameters(settings),
'parent_folder': ctx['nac_params_1'].outputs.remote_folder,
'code': code}
builder_inputs = {'ph': ph}
else:
raise RuntimeError("Code could not be found.")

Expand All @@ -80,7 +97,7 @@ def get_calculator_process(code_string):
plugin_name = code.get_input_plugin_name()
if plugin_name == 'vasp.vasp':
return WorkflowFactory(plugin_name)
elif plugin_name == 'quantumespresso.pw':
elif plugin_name in ('quantumespresso.pw', 'quantumespresso.ph'):
return WorkflowFactory(plugin_name + ".base")
else:
raise RuntimeError("Code could not be found.")
Expand Down Expand Up @@ -162,7 +179,7 @@ def _get_parameters(settings):
@calcfunction
def get_vasp_settings(settings):
"""Update VASP settings."""
if 'parser_settings' in settings.dict:
if 'parser_settings' in settings.keys():
parser_settings_dict = settings['parser_settings']
else:
parser_settings_dict = {}
Expand All @@ -174,10 +191,10 @@ def get_vasp_settings(settings):
def _get_kpoints(settings, structure):
kpoints = KpointsData()
kpoints.set_cell_from_structure(structure)
if 'kpoints_density' in settings.dict:
if 'kpoints_density' in settings.keys():
kpoints.set_kpoints_mesh_from_density(settings['kpoints_density'])
elif 'kpoints_mesh' in settings.dict:
if 'kpoints_offset' in settings.dict:
elif 'kpoints_mesh' in settings.keys():
if 'kpoints_offset' in settings.keys():
kpoints_offset = settings['kpoints_offset']
else:
kpoints_offset = [0.0, 0.0, 0.0]
Expand Down
4 changes: 2 additions & 2 deletions aiida_phonopy/workflows/forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _get_energy(outputs, code_string):
return None
elif plugin_name == 'quantumespresso.pw':
if ('output_parameters' in outputs and
'energy' in outputs.output_parameters.dict):
'energy' in outputs.output_parameters.keys()):
energy_data = get_qe_energy(outputs.output_parameters)
return energy_data

Expand All @@ -83,7 +83,7 @@ def get_qe_energy(output_parameters):
"""Return VASP energy ArrayData."""
energy_data = ArrayData()
energy_data.set_array('energy', np.array(
[output_parameters.dict.energy, ], dtype=float))
[output_parameters['energy'], ], dtype=float))
energy_data.label = 'energy'
return energy_data

Expand Down
66 changes: 42 additions & 24 deletions aiida_phonopy/workflows/nac_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Workflow to calculate NAC params."""

from aiida.engine import WorkChain, calcfunction
from aiida.engine import WorkChain, calcfunction, if_, while_
from aiida.plugins import DataFactory
from aiida_phonopy.common.builders import (
get_calcjob_builder, get_calcjob_inputs, get_calculator_process)
Expand Down Expand Up @@ -62,8 +62,11 @@ def define(cls, spec):
default=lambda: Float(1e-5))

spec.outline(
cls.run_calculation,
cls.finalize
cls.initialize,
while_(cls.continue_calculation)(
cls.run_calculation,
),
cls.finalize,
)

spec.output('nac_params', valid_type=ArrayData, required=True)
Expand All @@ -77,35 +80,50 @@ def define(cls, spec):
message=('dielectric constant could not be retrieved '
'from calculaton.'))

def continue_calculation(self):
"""Return boolen for outline."""
if self.ctx.iteration >= self.ctx.max_iteration:
return False
self.ctx.iteration += 1
return True

def initialize(self):
"""Initialize outline control parameters."""
self.report('initialization')
self.ctx.iteration = 0
if 'sequence' in self.inputs.calculator_settings.keys():
self.ctx.max_iteration = len(
self.inputs.calculator_settings['sequence'])
else:
self.ctx.max_iteration = 1

def run_calculation(self):
"""Born charges and dielectric constant calculation."""
self.report('Calculate born charges and dielectric constant')
"""Run NAC params calculation."""
self.report('calculation iteration %d/%d'
% (self.ctx.iteration, self.ctx.max_iteration))
label = "nac_params_%d" % self.ctx.iteration
process_inputs = get_calcjob_inputs(self.inputs.calculator_settings,
self.inputs.structure,
label=self.metadata.label)
# builder = get_calcjob_builder(
# self.inputs.structure,
# self.inputs.calculator_settings['code_string'],
# builder_inputs,
# label='born_and_epsilon')
# future = self.submit(builder)
CalculatorProcess = get_calculator_process(
self.inputs.calculator_settings['code_string'])
ctx=self.ctx,
label=label)
if 'sequence' in self.inputs.calculator_settings.keys():
i = self.ctx.iteration - 1
key = self.inputs.calculator_settings['sequence'][i]
code_string = self.inputs.calculator_settings[key]['code_string']
else:
code_string = self.inputs.calculator_settings['code_string']
CalculatorProcess = get_calculator_process(code_string)
future = self.submit(CalculatorProcess, **process_inputs)
self.report('born_and_epsilon: {}'.format(future.pk))
self.to_context(**{'calc': future})
self.report('nac_params: {}'.format(future.pk))
self.to_context(**{label: future})

def finalize(self):
"""Finalize NAC params calculation."""
self.report('Create nac params data')
self.report('finalization')

calc = self.ctx.calc
if type(calc) is dict:
calc_dict = calc
structure = calc['structure']
else:
calc_dict = calc.outputs
structure = calc.inputs.structure
calc = self.ctx['nac_params_1']
calc_dict = calc.outputs
structure = calc.inputs.structure

if 'born_charges' not in calc_dict:
return self.exit_codes.ERROR_NO_BORN_EFFECTIVE_CHARGES
Expand Down
6 changes: 3 additions & 3 deletions aiida_phonopy/workflows/phonopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def run_phonopy(self):

def is_nac(self):
"""Return boolen for outline."""
if 'is_nac' in self.inputs.phonon_settings.dict:
if 'is_nac' in self.inputs.phonon_settings.keys():
return self.inputs.phonon_settings['is_nac']
else:
False
Expand Down Expand Up @@ -206,7 +206,7 @@ def initialize_structures(self):
if 'code_string' not in self.inputs:
raise RuntimeError("code_string has to be specified.")

if 'supercell_matrix' not in self.inputs.phonon_settings.dict:
if 'supercell_matrix' not in self.inputs.phonon_settings.keys():
raise RuntimeError(
"supercell_matrix was not found in phonon_settings.")

Expand Down Expand Up @@ -234,7 +234,7 @@ def initialize_postprocess_settings(self):
"""Set default settings and create supercells and primitive cell."""
self.report('initialize_postprocess_settings')

if 'mesh' in self.inputs.phonon_settings.dict:
if 'mesh' in self.inputs.phonon_settings.keys():
mesh = self.inputs.phonon_settings['mesh']
else:
mesh = 100.0
Expand Down

0 comments on commit dd5688d

Please sign in to comment.