From 1cdaa8b93fa8139f0f2173e1386b3a14f454bd0f Mon Sep 17 00:00:00 2001 From: bastonero Date: Wed, 27 Mar 2024 17:39:17 +0000 Subject: [PATCH] Bands from protocol: fix `bands_kpoints` overrides The `bands_kpoints` wouldn't be override if specified in the overrides and passed to the PwBandsWorkChain.get_builder_from_protocol. If specified, it is chosen instead of the `bands_kpoints_distance`. --- src/aiida_quantumespresso/workflows/pw/bands.py | 5 ++++- tests/workflows/protocols/pw/test_bands.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/aiida_quantumespresso/workflows/pw/bands.py b/src/aiida_quantumespresso/workflows/pw/bands.py index e6c5f77cc..d4d0a32ef 100644 --- a/src/aiida_quantumespresso/workflows/pw/bands.py +++ b/src/aiida_quantumespresso/workflows/pw/bands.py @@ -163,7 +163,10 @@ def get_builder_from_protocol(cls, code, structure, protocol=None, overrides=Non builder.bands = bands builder.clean_workdir = orm.Bool(inputs['clean_workdir']) builder.nbands_factor = orm.Float(inputs['nbands_factor']) - builder.bands_kpoints_distance = orm.Float(inputs['bands_kpoints_distance']) + if 'bands_kpoints' in inputs: + builder.bands_kpoints = inputs['bands_kpoints'] + else: + builder.bands_kpoints_distance = orm.Float(inputs['bands_kpoints_distance']) return builder diff --git a/tests/workflows/protocols/pw/test_bands.py b/tests/workflows/protocols/pw/test_bands.py index 8a5fa4283..9ffbf0c62 100644 --- a/tests/workflows/protocols/pw/test_bands.py +++ b/tests/workflows/protocols/pw/test_bands.py @@ -74,6 +74,18 @@ def test_relax_type(fixture_code, generate_structure): assert 'CELL' not in builder.relax['base']['pw']['parameters'].get_dict() +def test_bands_kpoints_overrides(fixture_code, generate_structure, generate_kpoints_mesh): + """Test specifying bands kpoints ``overrides`` for the ``get_builder_from_protocol()`` method.""" + code = fixture_code('quantumespresso.pw') + structure = generate_structure('silicon') + + bands_kpoints = generate_kpoints_mesh(3) + overrides = {'bands_kpoints': bands_kpoints} + builder = PwBandsWorkChain.get_builder_from_protocol(code, structure, overrides=overrides) + assert builder.bands_kpoints == bands_kpoints # pylint: disable=no-member + assert 'bands_kpoints_distance' not in builder + + def test_options(fixture_code, generate_structure): """Test specifying ``options`` for the ``get_builder_from_protocol()`` method.""" code = fixture_code('quantumespresso.pw')