# AiiDA by example: Computing a band structure


:::{admonition} Learning Objectives
:class: learning-objectives

In this section we will present a complete example of an AiiDA workflow, which defines the sequence of calculations needed to compute the band structure of silicon.

How to setup the input data and the details of the workflow execution will be discussed in subsequent sections.
Here we simply give an initial overview of what it means to run an AiiDA workflow.

:::

In [None]:
%load_ext aiida
%aiida

In [None]:
from aiida import engine, orm
from aiida_quantumespresso.calculations.pw import PwCalculation
from ase.build import bulk

In [None]:
scf_inputs = {
    'CONTROL': {
        'calculation': 'scf',
        # 'pseudo_dir': Path('files').absolute().as_posix(),
    },
    'SYSTEM': {
        'occupations': 'smearing',
        'smearing': 'cold',
        'degauss': 0.02
    }
}

resources = {
    'num_machines': 1,
    'num_mpiprocs_per_machine': 1
}

@engine.calcfunction
def rescale_list(structure: orm.StructureData, factor_list: orm.List):

    scaled_structure_dict = {}

    for index, scaling_factor in enumerate(factor_list.get_list()):

        ase_structure = structure.get_ase()

        new_cell = ase_structure.get_cell() * scaling_factor
        ase_structure.set_cell(new_cell, scale_atoms=True)

        scaled_structure_dict[f'structure_{index}'] = orm.StructureData(ase=ase_structure)

    return scaled_structure_dict

@engine.calcfunction
def create_eos_dictionary(**kwargs) -> orm.Dict:
    eos = {
        label: (result['volume'], result['energy'])
        for label, result in kwargs.items()
    }
    return orm.Dict(eos)




class EquationOfStateWorkChain(engine.WorkChain):
    """WorkChain to compute Equation of State using Quantum ESPRESSO."""

    @classmethod
    def define(cls, spec):
        """Specify inputs and outputs."""
        super().define(spec)
        spec.input("code", valid_type=orm.Code)
        spec.input("structure", valid_type=orm.StructureData)
        spec.input("scale_factors", valid_type=orm.List)

        spec.outline(
            cls.run_eos,
            cls.results,
        )
        spec.output("eos_dict", valid_type=orm.Dict)

    def run_eos(self):

        calcjob_dict = {}

        for label, rescaled_structure in rescale_list(self.inputs.structure, self.inputs.scale_factors).items():

            builder = PwCalculation.get_builder()
            builder.code = self.inputs.code
            builder.structure = rescaled_structure
            builder.parameters = orm.Dict(scf_inputs)
            pseudo_family = orm.load_group('SSSP/1.3/PBEsol/efficiency')
            builder.pseudos = pseudo_family.get_pseudos(structure=rescaled_structure)
            kpoints = orm.KpointsData()
            kpoints.set_kpoints_mesh([2, 2, 2])
            builder.kpoints = kpoints
            builder.metadata.options.resources = resources

            calcjob_dict[label] = self.submit(builder)

        self.ctx.labels = list(calcjob_dict.keys())

        return calcjob_dict

    def results(self):

        self.report(self.ctx)

        # label: self.ctx[label].outputs['properties'] for label in self.ctx.labels

        eos_results = {}
        for label in self.ctx.labels:
            energy = self.ctx[label].outputs.output_parameters.get_dict()['energy']
            volume = self.ctx[label].outputs.output_parameters.get_dict()['volume']
            eos_results[label] = orm.Dict({'energy': orm.Float(energy), 'volume': orm.Float(volume)})
        # }
        eos_dict = create_eos_dictionary(**eos_results)
        self.out('eos_dict', eos_dict)


In [None]:
structure = orm.StructureData(ase=bulk('Al', a=4.05, cubic=True))

results = engine.run(
    EquationOfState,
    code=orm.load_code("qe-7.3-pw@localhost"),
    structure=structure,
    scale_factors=orm.List([0.9, 0.95, 1.0, 1.05, 1.1]),
)

In [None]:
eos_dict = results['eos_dict'].get_dict()

In [None]:
eos_dict

In [None]:
from matplotlib import pyplot as plt

plt.plot(
    [e[0] for e in eos_dict.values()],
    [v[1] for v in eos_dict.values()],
)