Skip to content

Commit

Permalink
Support calculating the XPS spectra of the atoms specific by indices (#…
Browse files Browse the repository at this point in the history
…958)

Implements the XPS workchain to support calculating the XPS spectra of the atoms specific by indices. This is useful for large systems with low symmetry, e.g. supported nanoparticles.
  • Loading branch information
superstar54 committed Sep 19, 2023
1 parent 17e173f commit fc1a940
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 67 deletions.
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""CalcFunction to create structures with a marked atom for each site in a list."""
from aiida import orm
from aiida.common import ValidationError
from aiida.engine import calcfunction
from aiida.orm.nodes.data.structure import Kind, Site, StructureData


@calcfunction
def get_marked_structures(structure, atoms_list, marker='X'):
"""Read a StructureData object and return structures for XPS calculations.
:param atoms_list: the atoms_list of atoms to be marked.
:param marker: a Str node defining the name of the marked atom Kind. Default is 'X'.
:returns: StructureData objects for the generated structure.
"""
marker = marker.value
elements_present = [kind.symbol for kind in structure.kinds]
if marker in elements_present:
raise ValidationError(
f'The marker ("{marker}") should not match an existing Kind in '
f'the input structure ({elements_present}.'
)

output_params = {}
result = {}

for index in atoms_list.get_list():
marked_structure = StructureData()
kinds = {kind.name: kind for kind in structure.kinds}
marked_structure.set_cell(structure.cell)

for i, site in enumerate(structure.sites):
if i == index:
marked_kind = Kind(name=marker, symbols=site.kind_name)
marked_site = Site(kind_name=marked_kind.name, position=site.position)
marked_structure.append_kind(marked_kind)
marked_structure.append_site(marked_site)
output_params[f'site_{index}'] = {'symbol': site.kind_name, 'multiplicity': 1}
else:
if site.kind_name not in [kind.name for kind in marked_structure.kinds]:
marked_structure.append_kind(kinds[site.kind_name])
new_site = Site(kind_name=site.kind_name, position=site.position)
marked_structure.append_site(new_site)
result[f'site_{index}'] = marked_structure

result['output_parameters'] = orm.Dict(dict=output_params)

return result
171 changes: 104 additions & 67 deletions src/aiida_quantumespresso/workflows/xps.py
Expand Up @@ -28,21 +28,21 @@ def validate_inputs(inputs, _):
"""Validate the inputs before launching the WorkChain."""
structure = inputs['structure']
elements_present = [kind.name for kind in structure.kinds]
absorbing_elements_list = sorted(inputs['elements_list'])
abs_atom_marker = inputs['abs_atom_marker'].value
if abs_atom_marker in elements_present:
raise ValidationError(
f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the '
f'input structure ({elements_present}).'
)

if inputs['calc_binding_energy'].value:
ce_list = sorted(inputs['correction_energies'].get_dict().keys())
if ce_list != absorbing_elements_list:
raise ValidationError(
f'The ``correction_energies`` provided ({ce_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)
if 'elements_list' in inputs:
absorbing_elements_list = sorted(inputs['elements_list'])
if inputs['calc_binding_energy'].value:
ce_list = sorted(inputs['correction_energies'].get_dict().keys())
if ce_list != absorbing_elements_list:
raise ValidationError(
f'The ``correction_energies`` provided ({ce_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)


class XpsWorkChain(ProtocolMixin, WorkChain):
Expand Down Expand Up @@ -81,7 +81,7 @@ def define(cls, spec):
spec.expose_inputs(
PwBaseWorkChain,
namespace='ch_scf',
exclude=('kpoints', 'pw.structure'),
exclude=('pw.structure', ),
namespace_options={
'help': ('Input parameters for the basic xps workflow (core-hole SCF).'),
'validator': None
Expand Down Expand Up @@ -170,6 +170,14 @@ def define(cls, spec):
'The list of elements to be considered for analysis, each must be valid elements of the periodic table.'
)
)
spec.input(
'atoms_list',
valid_type=orm.List,
required=False,
help=(
'The indices of atoms to be considered for analysis.'
)
)
spec.input(
'calc_binding_energy',
valid_type=orm.Bool,
Expand Down Expand Up @@ -233,12 +241,14 @@ def define(cls, spec):
spec.output(
'supercell_structure',
valid_type=orm.StructureData,
required=False,
help=('The supercell of ``outputs.standardized_structure`` used to generate structures for'
' XPS sub-processes.')
)
spec.output(
'symmetry_analysis_data',
valid_type=orm.Dict,
required=False,
help='The output parameters from ``get_xspectra_structures()``.'
)
spec.output(
Expand Down Expand Up @@ -366,8 +376,8 @@ def get_treatment_filepath(cls):
@classmethod
def get_builder_from_protocol(
cls, code, structure, pseudos, core_hole_treatments=None, protocol=None,
overrides=None, elements_list=None, options=None,
structure_preparation_settings=None, **kwargs
overrides=None, elements_list=None, atoms_list=None, options=None,
structure_preparation_settings=None, correction_energies=None, **kwargs
):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.
Expand All @@ -386,9 +396,6 @@ def get_builder_from_protocol(
"""

inputs = cls.get_protocol_inputs(protocol, overrides)
calc_binding_energy = kwargs.pop('calc_binding_energy', False)
correction_energies = kwargs.pop('correction_energies', orm.Dict())

pw_args = (code, structure, protocol)
# xspectra_args = (pw_code, xs_code, structure, protocol, upf2plotcore_code)

Expand All @@ -412,8 +419,11 @@ def get_builder_from_protocol(
builder.ch_scf = ch_scf
builder.structure = structure
builder.abs_atom_marker = abs_atom_marker
builder.calc_binding_energy = calc_binding_energy
builder.correction_energies = correction_energies
if correction_energies:
builder.correction_energies = orm.Dict(correction_energies)
builder.calc_binding_energy = orm.Bool(True)
else:
builder.calc_binding_energy = orm.Bool(False)
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
core_hole_pseudos = {}
gipaw_pseudos = {}
Expand All @@ -434,6 +444,12 @@ def get_builder_from_protocol(
for element in elements_list:
core_hole_pseudos[element] = pseudos[element]['core_hole']
gipaw_pseudos[element] = pseudos[element]['gipaw']
elif atoms_list:
builder.atoms_list = orm.List(atoms_list)
for index in atoms_list:
element = structure.sites[index].kind_name
core_hole_pseudos[element] = pseudos[element]['core_hole']
gipaw_pseudos[element] = pseudos[element]['gipaw']
# if no elements list is given, we instead initalise the pseudos dict with all
# elements in the structure
else:
Expand All @@ -453,12 +469,18 @@ def get_builder_from_protocol(

def setup(self):
"""Init required context variables."""
custom_elements_list = self.inputs.get('elements_list', None)
if not custom_elements_list:
elements_list = self.inputs.get('elements_list', None)
atoms_list = self.inputs.get('atoms_list', None)
if elements_list:
self.ctx.elements_list = elements_list.get_list()
self.ctx.atoms_list = None
elif atoms_list:
self.ctx.atoms_list = atoms_list.get_list()
self.ctx.elements_list = None
else:
structure = self.inputs.structure
self.ctx.elements_list = [Kind.symbol for Kind in structure.kinds]
else:
self.ctx.elements_list = custom_elements_list.get_list()



def should_run_relax(self):
Expand Down Expand Up @@ -511,48 +533,59 @@ def prepare_structures(self):
formatted as {<variable_name> : <parameter>} for each variable in the
``get_symmetry_dataset()`` method.
"""
from aiida_quantumespresso.workflows.functions.get_marked_structures import get_marked_structures
from aiida_quantumespresso.workflows.functions.get_xspectra_structures import get_xspectra_structures

elements_list = orm.List(self.ctx.elements_list)
inputs = {
'absorbing_elements_list' : elements_list,
'absorbing_atom_marker' : self.inputs.abs_atom_marker,
'metadata' : {
'call_link_label' : 'get_xspectra_structures'
input_structure = self.inputs.structure if 'relax' not in self.inputs else self.ctx.relaxed_structure
if self.ctx.elements_list:
elements_list = orm.List(self.ctx.elements_list)
inputs = {
'absorbing_elements_list' : elements_list,
'absorbing_atom_marker' : self.inputs.abs_atom_marker,
'metadata' : {
'call_link_label' : 'get_xspectra_structures'
}
} # populate this further once the schema for WorkChain options is figured out
if 'structure_preparation_settings' in self.inputs:
optional_cell_prep = self.inputs.structure_preparation_settings
for key, node in optional_cell_prep.items():
inputs[key] = node
if 'spglib_settings' in self.inputs:
spglib_settings = self.inputs.spglib_settings
inputs['spglib_settings'] = spglib_settings
else:
spglib_settings = None

result = get_xspectra_structures(input_structure, **inputs)

supercell = result.pop('supercell')
out_params = result.pop('output_parameters')
if out_params.get_dict().get('structure_is_standardized', None):
standardized = result.pop('standardized_structure')
self.out('standardized_structure', standardized)

# structures_to_process = {Key : Value for Key, Value in result.items()}
for site in ['output_parameters', 'supercell', 'standardized_structure']:
result.pop(site, None)
self.ctx.supercell = supercell
self.ctx.equivalent_sites_data = out_params['equivalent_sites_data']
self.out('supercell_structure', supercell)
self.out('symmetry_analysis_data', out_params)
elif self.ctx.atoms_list:
atoms_list = orm.List(self.ctx.atoms_list)
inputs = {
'atoms_list' : atoms_list,
'marker' : self.inputs.abs_atom_marker,
'metadata' : {
'call_link_label' : 'get_marked_structures'
}
}
} # populate this further once the schema for WorkChain options is figured out
if 'structure_preparation_settings' in self.inputs:
optional_cell_prep = self.inputs.structure_preparation_settings
for key, node in optional_cell_prep.items():
inputs[key] = node
if 'spglib_settings' in self.inputs:
spglib_settings = self.inputs.spglib_settings
inputs['spglib_settings'] = spglib_settings
else:
spglib_settings = None

if 'relax' in self.inputs:
relaxed_structure = self.ctx.relaxed_structure
result = get_xspectra_structures(relaxed_structure, **inputs)
else:
result = get_xspectra_structures(self.inputs.structure, **inputs)

supercell = result.pop('supercell')
out_params = result.pop('output_parameters')
if out_params.get_dict().get('structure_is_standardized', None):
standardized = result.pop('standardized_structure')
self.out('standardized_structure', standardized)

# structures_to_process = {Key : Value for Key, Value in result.items()}
for site in ['output_parameters', 'supercell', 'standardized_structure']:
result.pop(site, None)
result = get_marked_structures(input_structure, **inputs)
self.ctx.supercell = input_structure
self.ctx.equivalent_sites_data = result.pop('output_parameters').get_dict()
structures_to_process = {f'{Key.split("_")[0]}_{Key.split("_")[1]}' : Value for Key, Value in result.items()}
self.ctx.supercell = supercell
self.report(f'structures_to_process: {structures_to_process}')
self.ctx.structures_to_process = structures_to_process
self.ctx.equivalent_sites_data = out_params['equivalent_sites_data']

self.out('supercell_structure', supercell)
self.out('symmetry_analysis_data', out_params)

def should_run_gs_scf(self):
"""If the 'calc_binding_energy' input namespace is True, we run a scf calculation for the supercell."""
Expand All @@ -566,9 +599,9 @@ def run_gs_scf(self):
inputs.metadata.call_link_label = 'supercell_xps'

inputs = prepare_process_inputs(PwBaseWorkChain, inputs)
equivalent_sites_data = self.ctx.equivalent_sites_data
for site in equivalent_sites_data:
abs_element = equivalent_sites_data[site]['symbol']
# pseudos for all elements to be calculated should be replaced
for site in self.ctx.equivalent_sites_data:
abs_element = self.ctx.equivalent_sites_data[site]['symbol']
inputs.pw.pseudos[abs_element] = self.inputs.gipaw_pseudos[abs_element]
running = self.submit(PwBaseWorkChain, **inputs)

Expand Down Expand Up @@ -600,7 +633,6 @@ def run_all_scf(self):
equivalent_sites_data = self.ctx.equivalent_sites_data
abs_atom_marker = self.inputs.abs_atom_marker.value


for site in structures_to_process:
inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='ch_scf'))
structure = structures_to_process[site]
Expand Down Expand Up @@ -630,9 +662,10 @@ def run_all_scf(self):

core_hole_pseudo = self.inputs.core_hole_pseudos[abs_element]
inputs.pw.pseudos[abs_atom_marker] = core_hole_pseudo
# all element in the elements_list should be replaced
for element in self.inputs.elements_list:
inputs.pw.pseudos[element] = self.inputs.gipaw_pseudos[element]
# pseudos for all elements to be calculated should be replaced
for key in self.ctx.equivalent_sites_data:
abs_element = self.ctx.equivalent_sites_data[key]['symbol']
inputs.pw.pseudos[abs_element] = self.inputs.gipaw_pseudos[abs_element]
# remove pseudo if the only element is replaced by the marker
inputs.pw.pseudos = {kind.name: inputs.pw.pseudos[kind.name] for kind in structure.kinds}

Expand Down Expand Up @@ -674,11 +707,15 @@ def results(self):
kwargs['correction_energies'] = self.inputs.correction_energies
kwargs['metadata'] = {'call_link_label' : 'compile_final_spectra'}

equivalent_sites_data = orm.Dict(dict=self.ctx.equivalent_sites_data)
elements_list = orm.List(list=self.ctx.elements_list)
if self.ctx.elements_list:
elements_list = orm.List(list=self.ctx.elements_list)
else:
symbols = {value['symbol'] for value in self.ctx.equivalent_sites_data.values()}
elements_list = orm.List(list(symbols))
voight_gamma = self.inputs.voight_gamma
voight_sigma = self.inputs.voight_sigma

equivalent_sites_data = orm.Dict(dict=self.ctx.equivalent_sites_data)
result = get_spectra_by_element(
elements_list,
equivalent_sites_data,
Expand Down
19 changes: 19 additions & 0 deletions tests/workflows/functions/test_get_marked_structures.py
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""Tests for the `get_marked_structure` class."""


def test_get_marked_structure():
"""Test the get_marked_structure function."""
from aiida.orm import List, StructureData
from ase.build import molecule

from aiida_quantumespresso.workflows.functions.get_marked_structures import get_marked_structures

mol = molecule('CH3CH2OH')
mol.center(vacuum=2.0)
structure = StructureData(ase=mol)
indices = List(list=[0, 1, 2])
output = get_marked_structures(structure, indices)
assert len(output) == 4
assert output['site_0'].get_site_kindnames() == ['X', 'C', 'O', 'H', 'H', 'H', 'H', 'H', 'H']
assert output['site_1'].get_site_kindnames() == ['C', 'X', 'O', 'H', 'H', 'H', 'H', 'H', 'H']

0 comments on commit fc1a940

Please sign in to comment.