Skip to content

Commit

Permalink
PhBaseWorkChain: fix set_qpoints step
Browse files Browse the repository at this point in the history
The `set_qpoints` step in the outline of the `PhBaseWorkChain` contained several errors
incorrectly assuming that the inputs of the `PhCalculation` are found in the
`self.ctx.inputs.ph` namespace of the context. These should actually be placed in the
`self.ctx.inputs`, which is where the `BaseRestartWorkChain` expects to find the inputs
of the process class it wraps. Here we correctly assign the inputs in the context.

`Additionally, the `set_qpoints` step would assume that the `qpoints_force_parity` input
of the PhBaseWorkChain is always present. However, this is not a required input, and
hence we take this in consideration in the `set_qpoints` logic.
  • Loading branch information
bastonero committed Feb 12, 2024
1 parent b517686 commit c353cc2
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 20 deletions.
9 changes: 4 additions & 5 deletions src/aiida_quantumespresso/workflows/ph/base.py
Expand Up @@ -177,27 +177,26 @@ def set_qpoints(self):
the case of the latter, the `KpointsData` will be constructed for the input `StructureData`
from the parent_folder using the `create_kpoints_from_distance` calculation function.
"""

try:
qpoints = self.inputs.qpoints
except AttributeError:

try:
structure = self.ctx.inputs.ph.parent_folder.creator.output.output_structure
structure = self.ctx.inputs.parent_folder.creator.output.output_structure
except AttributeError:
structure = self.ctx.inputs.ph.parent_folder.creator.inputs.structure
structure = self.ctx.inputs.parent_folder.creator.inputs.structure

inputs = {
'structure': structure,
'distance': self.inputs.qpoints_distance,
'force_parity': self.inputs.qpoints_force_parity,
'force_parity': self.inputs.get('qpoints_force_parity', orm.Bool(False)),
'metadata': {
'call_link_label': 'create_qpoints_from_distance'
}
}
qpoints = create_kpoints_from_distance(**inputs)

self.ctx.inputs.ph['qpoints'] = qpoints
self.ctx.inputs['qpoints'] = qpoints

def set_max_seconds(self, max_wallclock_seconds: None):
"""Set the `max_seconds` to a fraction of `max_wallclock_seconds` option to prevent out-of-walltime problems.
Expand Down
35 changes: 29 additions & 6 deletions tests/conftest.py
Expand Up @@ -591,18 +591,41 @@ def _generate_inputs_q2r():


@pytest.fixture
def generate_inputs_ph(fixture_sandbox, fixture_localhost, fixture_code, generate_remote_data, generate_kpoints_mesh):
def generate_inputs_ph(
generate_calc_job_node, generate_structure, fixture_localhost, fixture_code, generate_kpoints_mesh
):
"""Generate default inputs for a `PhCalculation."""

def _generate_inputs_ph():
"""Generate default inputs for a `PhCalculation."""
from aiida.orm import Dict
def _generate_inputs_ph(with_output_structure=False):
"""Generate default inputs for a `PhCalculation.
:param with_output_structure: whether the PwCalculation has a StructureData in its outputs.
This is needed to test some PhBaseWorkChain logics.
"""
from aiida.common import LinkType
from aiida.orm import Dict, RemoteData

from aiida_quantumespresso.utils.resources import get_default_options

pw_node = generate_calc_job_node(
entry_point_name='quantumespresso.pw', inputs={
'parameters': Dict(),
'structure': generate_structure()
}
)
remote_folder = RemoteData(computer=fixture_localhost, remote_path='/tmp')
remote_folder.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='remote_folder')
remote_folder.store()
parent_folder = pw_node.outputs.remote_folder

if with_output_structure:
structure = generate_structure()
structure.base.links.add_incoming(pw_node, link_type=LinkType.CREATE, link_label='output_structure')
structure.store()

inputs = {
'code': fixture_code('quantumespresso.ph'),
'parent_folder': generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw'),
'parent_folder': parent_folder,
'qpoints': generate_kpoints_mesh(2),
'parameters': Dict({'INPUTPH': {}}),
'metadata': {
Expand Down Expand Up @@ -806,7 +829,7 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False):

if inputs is None:
ph_inputs = generate_inputs_ph()
qpoints = ph_inputs.get('qpoints')
qpoints = ph_inputs.pop('qpoints')
inputs = {'ph': ph_inputs, 'qpoints': qpoints}

if return_inputs:
Expand Down
42 changes: 33 additions & 9 deletions tests/workflows/ph/test_base.py
Expand Up @@ -10,15 +10,6 @@
from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain


@pytest.mark.usefixtures('aiida_profile')
def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph):
"""Test `PhBaseWorkChain` validation methods."""
inputs = {'ph': generate_inputs_ph()}
message = r'Neither `qpoints` nor `qpoints_distance` were specified.'
with pytest.raises(ValueError, match=message):
generate_workchain_ph(inputs=inputs)


@pytest.fixture
def generate_ph_calc_job_node(generate_calc_job_node, fixture_localhost):
"""Generate a ``CalcJobNode`` that would have been created by a ``PhCalculation``."""
Expand All @@ -43,6 +34,15 @@ def _generate_ph_calc_job_node():
return _generate_ph_calc_job_node


@pytest.mark.usefixtures('aiida_profile')
def test_invalid_inputs(generate_workchain_ph, generate_inputs_ph):
"""Test `PhBaseWorkChain` validation methods."""
inputs = {'ph': generate_inputs_ph()}
message = r'Neither `qpoints` nor `qpoints_distance` were specified.'
with pytest.raises(ValueError, match=message):
generate_workchain_ph(inputs=inputs)


def test_setup(generate_workchain_ph):
"""Test `PhBaseWorkChain.setup`."""
process = generate_workchain_ph()
Expand All @@ -52,6 +52,30 @@ def test_setup(generate_workchain_ph):
assert isinstance(process.ctx.inputs, AttributeDict)


@pytest.mark.parametrize(
('with_output_structure', 'with_qpoints_distance'),
((False, False), (False, True), (True, True)),
)
def test_set_qpoints(generate_workchain_ph, generate_inputs_ph, with_output_structure, with_qpoints_distance):
"""Test `PhBaseWorkChain.set_qpoints`."""
inputs = {'ph': generate_inputs_ph(with_output_structure=with_output_structure)}
inputs['qpoints'] = inputs['ph'].pop('qpoints')

if with_qpoints_distance:
inputs.pop('qpoints')
inputs['qpoints_distance'] = orm.Float(0.5)

process = generate_workchain_ph(inputs=inputs)
process.setup()
process.set_qpoints()

assert 'qpoints' in process.ctx.inputs
assert isinstance(process.ctx.inputs['qpoints'], orm.KpointsData)

if not with_qpoints_distance:
assert process.ctx.inputs['qpoints'] == inputs['qpoints']


def test_handle_unrecoverable_failure(generate_workchain_ph):
"""Test `PhBaseWorkChain.handle_unrecoverable_failure`."""
process = generate_workchain_ph(exit_code=PhCalculation.exit_codes.ERROR_NO_RETRIEVED_FOLDER)
Expand Down

0 comments on commit c353cc2

Please sign in to comment.