In [None]:
import ipywidgets as ipw
import traitlets

from wizard import WizardApp, WizardAppStep
from structures import StructureUploadComboWidget
from codes import CodeSubmitWidget

from aiida.orm import load_node, Dict, Float,Str, StructureData, WorkChainNode, BandsData
from aiida.plugins import WorkflowFactory


DEFAULT_PARAMETERS = Dict(dict={"SYSTEM": {"ecutwfc": 50., "ecutrho": 200.}})


class RelaxSubmitWidget(CodeSubmitWidget):
    
    input_structure = traitlets.Instance(StructureData, allow_none=True)
    output_structure = traitlets.Instance(StructureData, allow_none=True)
    
    def _update_state(self):
        if self.process is None:
            if self.input_structure is None:
                if self.output_structure is None:
                    self.state = WizardApp.State.INIT
                else:
                    self.state = WizardApp.State.FAIL
            else:
                if self.output_structure is None:
                    self.state = WizardApp.State.READY
                else:
                    self.state = WizardApp.State.SUCCESS
        else:
            super()._update_state()
    
    @traitlets.observe('input_structure')
    def _observe_input_structure(self, change):
        self._update_state()
        
    @traitlets.observe('output_structure')
    def _observe_output_structure(self, change):
        self._update_state()
            
    @traitlets.observe('process')
    def _observe_process(self, change):
        with self.hold_trait_notifications():
            process_node = super()._observe_process(change)
            if process_node is not None:
                self.input_structure = process_node.inputs.structure
        return process_node

    def skip(self, _):
        self.output_structure = self.input_structure

    def _refresh_outputs_keys(self):
        process_node = super()._refresh_outputs_keys()
        if process_node is not None:
            with self.hold_trait_notifications():
                if 'output_structure' in process_node.outputs:
                    self.output_structure = process_node.outputs.output_structure
                elif process_node.is_sealed:
                    self.state = WizardApp.State.FAIL

    def submit(self, _=None):
        assert self.input_structure is not None
        
        builder = WorkflowFactory('quantumespresso.pw.relax').get_builder()
        builder.base.pw.code = self.code_group.selected_code
        builder.base.pw.parameters = DEFAULT_PARAMETERS
        builder.base.pw.metadata.options = self.options
        builder.base.kpoints_distance = Float(0.8)
        builder.base.pseudo_family = Str(self.pseudo_family.value)
        builder.structure = self.input_structure
        
        self.process = submit(builder)
        
    def reset(self):
        with self.hold_trait_notifications():
            self.process = None
            self.output_structure = None
        

class ComputeBandsSubmitWidget(CodeSubmitWidget):
    
    input_structure = traitlets.Instance(StructureData, allow_none=True)
    band_structure = traitlets.Instance(BandsData, allow_none=True)

    def _update_state(self):
        if self.process is None:
            if self.input_structure is None:
                if self.band_structure is None:
                    self.state = WizardApp.State.INIT
                else:
                    self.state = WizardApp.State.FAIL
            else:
                if self.band_structure is None:
                    self.state = WizardApp.State.READY
                else:
                    self.state = WizardApp.State.SUCCESS
        else:
            super()._update_state()

    @traitlets.observe('input_structure')
    def _observe_input_structure(self, change):
        self._update_state()
        
    @traitlets.observe('band_structure')
    def _observe_band_structure(self, change):
        self._update_state()
        
    @traitlets.observe('process')
    def _observe_process(self, change):
        with self.hold_trait_notifications():
            process_node = super()._observe_process(change)
            if process_node is None:
                self.band_structure = None
            else:
                self.input_structure = process_node.inputs.structure
        return process_node

    def _refresh_outputs_keys(self):
        process_node = super()._refresh_outputs_keys()
        if process_node is not None:
            with self.hold_trait_notifications():
                if 'band_structure' in process_node.outputs:
                    self.band_structure = process_node.outputs.band_structure
                elif process_node.is_sealed:
                    self.state = WizardApp.State.FAIL
        
    def submit(self, _=None):
        assert self.input_structure is not None
        
        builder = WorkflowFactory('quantumespresso.pw.bands').get_builder()

        builder.scf.pw.code = self.code_group.selected_code
        builder.scf.pw.parameters = DEFAULT_PARAMETERS
        builder.scf.pw.metadata.options = self.options
        builder.scf.kpoints_distance = Float(0.8)
        builder.scf.pseudo_family = Str(self.pseudo_family.value)

        builder.bands.pw.code = self.code_group.selected_code
        builder.bands.pw.parameters = DEFAULT_PARAMETERS
        builder.bands.pw.metadata.options = self.options
        builder.bands.pseudo_family = Str(self.pseudo_family.value)

        builder.structure = self.input_structure
        
        self.process = submit(builder)
        
    def reset(self):
        self.process = None

# Create the application steps
structure_selection_step = StructureUploadComboWidget(
    examples=[
        ('Diamond', 'miscellaneous/structures/diamond.cif'),
        ('Gallium arsenide', 'miscellaneous/structures/GaAs.xyz'),
        ('Silicon', 'miscellaneous/structures/Si.xyz'),
        ('Silicon oxide', 'miscellaneous/structures/SiO2.xyz'),
    ],
    viewer=True)
relax_step = RelaxSubmitWidget(allow_skip=True, has_next=True)
compute_bands_step = ComputeBandsSubmitWidget()

# Link the application steps
ipw.dlink(
    (structure_selection_step, 'confirmed_structure'), (relax_step, 'input_structure'),
    transform=lambda atoms: None if atoms is None else StructureData(ase=atoms))
ipw.dlink((relax_step, 'output_structure'), (compute_bands_step, 'input_structure'))

# Propagate the configuration from the relax step to the compute band gaps step.
ipw.dlink((relax_step.code_group.dropdown, 'value'), (compute_bands_step.code_group.dropdown, 'value'))
ipw.dlink((relax_step.pseudo_family, 'value'), (compute_bands_step.pseudo_family, 'value'))
ipw.dlink((relax_step.number_of_nodes, 'value'), (compute_bands_step.number_of_nodes, 'value'))
ipw.dlink((relax_step.cpus_per_node, 'value'), (compute_bands_step.cpus_per_node, 'value'))

# Add the application steps to the application
app = WizardApp(
    steps=[
        ('Select structure', structure_selection_step),
        ('Relax', relax_step),
        ('Compute bands', compute_bands_step)])
relax_step.callbacks.append(lambda _: app._update_titles())
compute_bands_step.callbacks.append(lambda _: app._update_titles())


# TODO: REMOVE THESE LINES AFTER TESTING
# with app.hold_sync():
#     structure_selection_step.example_widget.index = 1
#     structure_selection_step.confirm()
#     relax_step.process = load_node('3d649e1f-cbe2-4009-bdf1-b44f7aa9a0f6')
#     compute_bands_step.process = load_node('7da3beb2-0b9e-489c-b406-dc51fe9c8e96')
# END OF REMOVE LINES AFTER TESTING

app