In [None]:
import os
import threading
import tempfile
from enum import Enum
from pprint import pprint
from time import sleep
from collections import deque

import ase
import nglview
import ipywidgets as ipw
import traitlets
from IPython.display import clear_output

from aiidalab_widgets_base import CodeDropdown, ProgressBarWidget, viewer

from aiida.engine import submit, ProcessState
from aiida.orm import load_node, Dict, Float,Str, StructureData, ProcessNode, WorkChainNode, BandsData
from aiida.plugins import WorkflowFactory


DEFAULT_PARAMETERS = Dict(dict={"SYSTEM": {"ecutwfc": 50., "ecutrho": 200.}})


class WizardApp(ipw.VBox):
                
    class State(Enum):
        "Every step within the WizardApp must have this traitlet."
        INIT = 0  # implicit default value
        
        CONFIGURED = 1
        READY = 2
        ACTIVE = 3
        SUCCESS = 4
        
        # All error states have negative codes
        FAIL = -1
        
    #     The state of a step can no longer be changed once
    #     it is one of the following sealed states.
    SEALED_STATES = [State.SUCCESS, State.FAIL]
   
    ICON_SEPARATOR = '\u2000'  # en-dash  (placed between title and icon)
    
    ICONS = {
        State.INIT: '\u25cb',
        State.CONFIGURED: '\u25ce',
        State.READY: '\u25ce',
        State.ACTIVE: '\u25d4',
        State.SUCCESS: '\u25cf',
        State.FAIL: '\u25cd',
    }
        
    @classmethod
    def icons(cls):
        from time import time
        ret = cls.ICONS.copy()
        period = int((time() * 4) % 4)
        ret[cls.State.ACTIVE] = ['\u25dc', '\u25dd', '\u25de', '\u25df'][period]
        return ret
        
 
    def __init__(self, steps, **kwargs):
        # The number of steps must be greater than one
        # for this app's logic to make sense.
        assert len(steps) > 1
        self.steps = steps
        
        # Unzip the steps to titles and widgets.
        self.titles, widgets = zip(*steps)
        

        # Initialize the accordion with the widgets ...
        self.accordion = ipw.Accordion(children=widgets)
        self._update_titles()
        self.accordion.observe(self._observe_selected_index, 'selected_index')
        
        # Watch for changes to each step's state
        for widget in widgets:
            assert widget.has_trait('state')
#             assert widget.state == self.State.INIT
            widget.observe(self._update_step_state, names=['state'])
        
        self.reset_button = ipw.Button(
            description='Reset',
            icon='undo',
            layout=ipw.Layout(width='auto', flex='1 1 auto'),
            disabled=True)
        self.reset_button.on_click(self._on_click_reset_step)
        
        # Create a back-button, to switch to the previous step when possible:
        self.back_button = ipw.Button(
            description='Previous step',
            icon='step-backward',
            layout=ipw.Layout(width='auto', flex='1 1 auto'),
            disabled=True)
        self.back_button.on_click(self._on_click_back_button)
        
        # Create a next-button, to switch to the next step when appropriate:
        self.next_button = ipw.Button(
            description='Next step',
            icon='step-forward',
            layout=ipw.Layout(width='auto', flex='1 1 auto'),
            disabled=True)
        self.next_button.on_click(self._on_click_next_button)
        
        self.footer = ipw.HBox(children=[self.back_button, self.reset_button, self.next_button])
        
        super().__init__(children=[self.footer, self.accordion], **kwargs)
        
    def _update_titles(self):
        for i, (title, widget) in enumerate(zip(self.titles, self.accordion.children)):
            icon = self.icons().get(widget.state, str(widget.state).upper())
            self.accordion.set_title(i, f"{icon} Step {i+1}: {title}")
            
    def _update_step_state(self, change):
        widget = change['owner']
        widget_index = self.accordion.children.index(widget)
        self._update_titles()
        self._update_buttons()
        
    @traitlets.observe('selected_index')
    def _observe_selected_index(self, change):
        "Activate/deactivate the next-button based on which step is selected."
        self._update_buttons()

    def _update_buttons(self):
        with self.hold_trait_notifications():
            index = self.accordion.selected_index
            if index is None:
                self.back_button.disabled = True
                self.next_button.disabled = True
                self.reset_button.disabled = True
            else:
                first_step_selected = index == 0
                last_step_selected = index+1 == len(self.accordion.children)
                selected_widget = self.accordion.children[index]
                next_widget = None if last_step_selected else self.accordion.children[index+1]
                
                self.back_button.disabled = \
                    first_step_selected or selected_widget.state != self.State.READY
                self.next_button.disabled = \
                    last_step_selected or selected_widget.state != self.State.SUCCESS
                                
                self.reset_button.disabled = not (  # reset possible when:
                    hasattr(selected_widget, 'reset')
                    and selected_widget.state in self.SEALED_STATES
                    and (last_step_selected or next_widget.state not in self.SEALED_STATES)
                )
                        
    def _on_click_reset_step(self, _):
        with self.hold_sync():
            self.accordion.children[self.accordion.selected_index].reset()
            
    def _on_click_back_button(self, _):
        self.accordion.selected_index -= 1
        
    def _on_click_next_button(self, _):
        self.accordion.selected_index += 1


class WizardAppStep(traitlets.HasTraits):
        
    state = traitlets.UseEnum(WizardApp.State)
        
        
class StructureFileUploadWidget(ipw.VBox):
    
    file = traitlets.Tuple(traitlets.Unicode(), traitlets.Bytes())

    def __init__(self):
        self.description = ipw.Label("Upload a structure file:")
        self.supported_formats = ipw.HTML(
            """<a href="https://wiki.fysik.dtu.dk/ase/_modules/ase/io/formats.html" target="_blank">
            Supported formats
            </a>""")
        self.file_upload = ipw.FileUpload(
            description="Select file",
            multiple=False,
            )
        self.file_upload.observe(self.on_file_uploaded, names='value')
        
        super().__init__(children=[self.description, self.file_upload, self.supported_formats])
        
    def on_file_uploaded(self, change):
        for fn, item in change['new'].items():
            self.file = (fn, item['content'])
            break
            
    def freeze(self):
        self.file_upload.disabled = True
        
    def reset(self):
        self.file_upload.disabled = False
            

class SelectionStructureUploadWidget(ipw.Dropdown):
    
    file = traitlets.Tuple(traitlets.Unicode(), traitlets.Bytes(allow_none=True))
    
    def __init__(self, options=None, hint='Select structure', **kwargs):
        if options is None:
            options = []
        else:
            options.insert(0, (hint, None))
            
        super().__init__(options=options, **kwargs)
        
    @traitlets.observe('index')
    def _observe_index(self, change):
        index = change['new']
        if index is None:
            self.file = '', None
        else:
            fn = self.options[index][1]
            if fn is None:
                self.file = '', None
            else:
                with open(fn, 'rb') as file:
                    self.file = fn, file.read()

    def reset(self):
        with self.hold_trait_notifications():
            self.index = 0
            self.disabled = False
            
    def freeze(self):
        self.disabled = True


class StructureUploadComboWidget(ipw.VBox, WizardAppStep):

    structure = traitlets.Instance(ase.atoms.Atoms, allow_none=True)   
    confirmed_structure = traitlets.Instance(ase.atoms.Atoms, allow_none=True)
    
    def __init__(self, data_importers=None, examples=None, viewer=True, **kwargs):        
        if data_importers is None:
            self.data_importers = [('Upload', StructureFileUploadWidget())]
        else:
            self.data_importers = data_importers
            
        if examples:
            self.example_widget = SelectionStructureUploadWidget(options=examples)
            self.data_importers.append(('Examples', self.example_widget))
        else:
            self.example_widget = None
            
        if len(self.data_importers) > 1:
            self.structure_sources_tab = ipw.Tab(children=[s[1] for s in self.data_importers])
            for i, source in enumerate(self.data_importers):
                self.structure_sources_tab.set_title(i, source[0])
        else:
            self.structure_sources_tab = self.data_importers[0][1]
            
        self.structure_sources_tab.layout = ipw.Layout(
            display='flex',
            flex_flow='column',
        )
            
        for data_importer in self.data_importers:
            data_importer[1].observe(self._on_structure_file_selection, 'file')
            
        if viewer:
            self.viewer = nglview.NGLWidget(width='300px', height='300px')
            self.viewer_box = ipw.Box(children=[self.viewer], layout=ipw.Layout(border='solid 1px'))
        else:
            self.viewer = None
            self.viewer_box = None
            
        self.structure_name_text = ipw.Text(
            placeholder='Structure',
            description='Selected:',
            disabled=True,
        )
        
        self.structure_sources_tab.layout = ipw.Layout(min_width='600px')
        
        grid = ipw.GridspecLayout(1, 3)
        grid[:, :2] = self.structure_sources_tab
        grid[:, 2] = self.viewer_box
        
        self.confirm_button = ipw.Button(
            description='Confirm',
            button_style='success',
            icon='check-circle',
            disabled=True,
            layout=ipw.Layout(width='auto'),
        )
        self.confirm_button.on_click(self.confirm)
       
        super().__init__(
            children=[grid, self.structure_name_text, self.confirm_button], **kwargs)
        
    @traitlets.default('state')
    def _default_state(self):
        return WizardApp.State.READY
        
    def _update_state(self):
        if self.structure is None:
            if self.confirmed_structure is None:
                self.state = WizardApp.State.READY
            else:
                self.state = WizardApp.State.FAIL
        else:
            if self.confirmed_structure is None:
                self.state = WizardApp.State.CONFIGURED
            else:
                self.state = WizardApp.State.SUCCESS

    @traitlets.observe('structure')
    def _observe_structure(self, change):
        structure = change['new']
        with self.hold_trait_notifications():
            if structure is None:
                self.structure_name_text.value = ""
            else:
                self.structure_name_text.value = str(self.structure.get_chemical_formula())
            self._update_state()
            self.refresh_view()
            
    @traitlets.observe('confirmed_structure')
    def _observe_confirmed_structure(self, change):
        with self.hold_trait_notifications():
            self._update_state()
                
    @traitlets.observe('state')
    def _observe_state(self, change):
        with self.hold_trait_notifications():
            state = change['new']
            if state is WizardApp.State.SUCCESS:
                self.freeze()
            self.confirm_button.disabled = self.state != WizardApp.State.CONFIGURED
                
    def freeze(self):
        for child in self.structure_sources_tab.children:
            child.freeze()
        
    def _on_structure_file_selection(self, change):
        assert change['name'] == 'file'
        fn, data = change['new']
        if data is None:
            self.structure = None
        else:
            basename, ext = os.path.splitext(fn)
            try:
                with tempfile.NamedTemporaryFile(suffix=ext) as file:
                    file.write(data)
                    file.flush()

                    traj = ase.io.read(file.name, index=':')
                # TODO: REACT TO len(traj) > 1
            except Exception as error:
                print(error)
            else:
                self.structure = traj[0]
            
    def refresh_view(self):
        # Note: viewer.clear() only removes the 1st component (TODO: FIX UPSTREAM!)
        for comp_id in self.viewer._ngl_component_ids:
            self.viewer.remove_component(comp_id)
        if self.structure is not None:
            self.viewer.add_component(nglview.ASEStructure(self.structure))
            self.viewer.add_unitcell()
            
    def confirm(self, button=None):
        with self.hold_trait_notifications():
            self.confirmed_structure = self.structure

    def reset(self):  # unconfirm
        with self.hold_trait_notifications():
            for child in self.structure_sources_tab.children:
                child.reset()
            self.confirmed_structure = None
            self.structure = None
    
def get_calc_job_output(process):
    from aiidalab_widgets_base.process import get_running_calcs
    previous_calc_id = None
    num_lines = 0
    
    while not process.is_sealed:
        calc = None
        for calc in get_running_calcs(process):
            if calc.id == previous_calc_id:
                break
        else:
            if calc:
                previous_calc_id = calc.id

        if calc and 'remote_folder' in calc.outputs:
            f_path = os.path.join(calc.outputs.remote_folder.get_remote_path(), calc.attributes['output_filename'])
            if os.path.exists(f_path):
                with open(f_path) as fobj:
                    new_lines = fobj.readlines()[num_lines:]
                    num_lines += len(new_lines)
                    yield from new_lines
                      

class ProgressBarWidget(ipw.VBox):
    """A bar showing the proggress of a process."""

    def __init__(self, **kwargs):
        self._process = None
        self.correspondance = {
            None: (0, 'warning'),
            "created": (0, 'info'),
            "running": (1, 'info'),
            "waiting": (1, 'info'),
            "killed": (2, 'danger'),
            "excepted": (2, 'danger'),
            "finished": (2, 'success'),
        }
        self.bar = ipw.IntProgress(  # pylint: disable=blacklisted-name
            value=0,
            min=0,
            max=2,
            step=1,
#             description='Progress:',
            bar_style='warning',  # 'success', 'info', 'warning', 'danger' or ''
            orientation='horizontal',
            layout=ipw.Layout(width="auto")
        )
        self.state = ipw.HTML(
            description="Calculation state:", value='',
            style={'description_width': '100px'},
        )
        super().__init__(children=[self.state, self.bar], **kwargs)

    def update(self):
        """Update the bar."""
        self.bar.value, self.bar.bar_style = self.correspondance[self.current_state]
        if self.current_state is None:
            self.state.value = 'N/A'
        else:
            self.state.value = self.current_state.capitalize()
        
    @property
    def process(self):
        return self._process
    
    @process.setter
    def process(self, value):
        self._process = value
        self.update()

    @property
    def current_state(self):
        if self.process is not None:
            return self.process.process_state.value


class LogOutputWidget(ipw.VBox):
    
    def __init__(self, title='Output:', num_lines_shown=3, **kwargs):
        self.description = ipw.Label(value=title)
        self.last_lines = ipw.HTML()
        
        self.lines = []
        self.lines_shown = deque([''] * num_lines_shown, maxlen=num_lines_shown)
        
        self.raw_log = ipw.Textarea(
            layout=ipw.Layout(
                width='auto',
                height='auto',
                display='flex',
                flex='1 1 auto',
            ),
            disabled=True)

        self.accordion = ipw.Accordion(children=[self.raw_log])
        self.accordion.set_title(0, 'Raw log')
        self.accordion.selected_index = None

        self._update()        
        super().__init__(children=[self.description, self.last_lines, self.accordion], **kwargs)
        
    def clear(self):
        self.lines.clear()
        self.lines_shown.clear()
        self._update()
        
    def append_line(self, line):
        self.lines.append(line.strip())
        self.lines_shown.append("{:03d}: {}".format(len(self.lines), line.strip()))
        self._update()
    
    @staticmethod
    def _format_code(text):
        return '<pre style="background-color: #1f1f2e; color: white;">{}</pre>'.format(text)
        
    def _update(self):
        with self.hold_trait_notifications():
            lines_to_show = self.lines_shown.copy()
            while len(lines_to_show) < self.lines_shown.maxlen:
                lines_to_show.append(' ')
                
            self.last_lines.value = self._format_code('\n'.join(lines_to_show))
            self.raw_log.value = '\n'.join(self.lines)
            
            
class ProcessStatusWidget(ipw.VBox):
    
    process = traitlets.Instance(ProcessNode, allow_none=True)

    def __init__(self, **kwargs):
        self._process_output = None  # Generator function for the process output
        
        # Widgets
        self.progress_bar = ProgressBarWidget()
        self.log_output = LogOutputWidget()
        self.process_id_text = ipw.Text(
            value='',
            description='Process:',
            layout=ipw.Layout(width='auto', flex="1 1 auto"),
            disabled=True,
        )
        ipw.dlink((self, 'process'), (self.process_id_text, 'value'), transform=lambda proc: str(proc))

        super().__init__(children=[self.progress_bar, self.log_output, self.process_id_text], **kwargs)
        
    @traitlets.observe('process')
    def _observe_process(self, change):
        with self.hold_trait_notifications():
            self._process_output = get_calc_job_output(change['new'])
            self.progress_bar.process = change['new']
            if change['new'] != change['old']:
                self.log_output.clear()
        
    def update(self):
        try:
            self.log_output.append_line(next(self._process_output))
        except (TypeError, StopIteration):
            pass
        self.progress_bar.update()

        
class CodeSubmitWidget(ipw.VBox, WizardAppStep):
    
    process = traitlets.Instance(ProcessNode, allow_none=True)

    def _update_total_num_cpus(self, change):
        self.total_num_cpus.value = self.number_of_nodes.value * self.cpus_per_node.value
    
    def __init__(self, **kwargs):
        self.code_group = CodeDropdown(input_plugin='quantumespresso.pw', text="Select code", path_to_root='.')
        
        extra = {
            'style': {'description_width': '150px'},
            'layout': {'max_width': '200px'}
        }
        
        self.number_of_nodes = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# nodes",
            disabled=False,
            **extra)
        self.cpus_per_node = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# cpus per node",
            **extra)
        self.total_num_cpus = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# total cpus",
            disabled=True,
            **extra)

        # Update the total # of CPUs int text:
        self.number_of_nodes.observe(self._update_total_num_cpus, 'value')
        self.cpus_per_node.observe(self._update_total_num_cpus, 'value')
                       
        self.resources = ipw.VBox(children=[
            ipw.Label("Resources:"),
            self.number_of_nodes,
            self.cpus_per_node,
            self.total_num_cpus,
        ])

        self.pseudo_family = ipw.ToggleButtons(
            options = {
                'SSSP efficiency': 'SSSP_efficiency_v1.0',
                'SSSP accuracy': 'SSSP_precision_v1.0',
            },
            description='Pseudopotential family:',
            style = {'description_width': 'initial'}
        )
        
        # Clicking on the 'submit' button will trigger the execution of the
        # submit() method.
        self.submit_button = ipw.Button(
            description='Submit',
            icon='play',
            button_style='success',
            layout=ipw.Layout(width='auto', flex="1 1 auto"),
            disabled=True)
        self.submit_button.on_click(self._on_submit_button_clicked)
        
        # The 'skip' button is only shown when the skip() method is implemented.
        self.skip_button = ipw.Button(
            description='Skip',
            icon='fast-forward',
            button_style='info',
            layout=ipw.Layout(width='auto', flex="1 1 auto"),
            disabled=True)
        if self.skip:  # skip() method is implemented
            # connect with skip_button
            self.skip_button.on_click(self.skip)  # connect with skip_button
        else:  # skip() not implemented
            # hide the button
            self.skip_button.layout.visibility = 'hidden'        
        
        # Place all buttons at the footer of the widget.
        self.buttons = ipw.HBox(children=[self.submit_button, self.skip_button])

        self.config_tabs = ipw.Tab(
            children=[self.code_group, self.pseudo_family, self.resources],
            layout=ipw.Layout(height='200px'),
        )
        self.config_tabs.set_title(0, 'Code')
        self.config_tabs.set_title(1, 'Pseudopotential')
        self.config_tabs.set_title(2, 'Compute resources')
        
        self.process_status = ProcessStatusWidget()
        ipw.dlink((self, 'process'), (self.process_status, 'process'))
        
        self.outputs_keys = ipw.Dropdown()
        self.outputs_keys.observe(self._refresh_outputs_view, names=['options', 'value'])
        self.output_control = ipw.HBox(children=[self.outputs_keys]) #, self.outputs_display_button])
                                       
        self.output_area = ipw.Output(
            layout={
                'width': 'auto',
                'height': 'auto',
                'border': '1px solid black'})
        self.results_view = ipw.VBox(children=[self.output_control, self.output_area])
        
        self.accordion = ipw.Accordion(children=[self.config_tabs, self.process_status, self.results_view])
        self.accordion.set_title(0, 'Config')
        self.accordion.set_title(1, 'Status')
        self.accordion.set_title(2, 'Results (0)')

        self._freeze_config()
        
        self.callbacks = list()

        super().__init__(children=[self.accordion, self.buttons], **kwargs)

    def _update_state(self):
        "Update state based on the process state."
        if self.process is None:
            self.state = WizardApp.State.INIT
        else:
            process_state = load_node(self.process.id).process_state
            if process_state in (ProcessState.CREATED, ProcessState.RUNNING, ProcessState.WAITING):
                self.state = WizardApp.State.ACTIVE
            elif process_state in (ProcessState.EXCEPTED, ProcessState.KILLED):
                self.state = WizardApp.State.FAIL
            elif process_state is ProcessState.FINISHED:
                self.state = WizardApp.State.SUCCESS

    @traitlets.observe('state')
    def _observe_state(self, change):
        self.skip_button.disabled = not self.state is WizardApp.State.READY
        self.submit_button.disabled = not self.state is WizardApp.State.READY

        if change['new'] == WizardApp.State.ACTIVE:
            self.accordion.selected_index = 1

        self._freeze_config(change['new'] != WizardApp.State.READY)
            
    def _freeze_config(self, value=True):
        self.code_group.dropdown.disabled = value
        self.number_of_nodes.disabled = value
        self.cpus_per_node.disabled = value
        self.pseudo_family.disabled = value
            
    @property
    def options(self):
        return {
            'max_wallclock_seconds': 3600*2,
            'resources': {
                'num_machines': self.number_of_nodes.value,
                'num_mpiprocs_per_machine': self.cpus_per_node.value}}
    
    def _refresh_outputs_keys(self):
        if self.process is None:
            self.outputs_keys.options = []
        else:
            process_node = load_node(self.process.id)
            self.outputs_keys.options = [str(o) for o in process_node.outputs]
            return process_node
        
    def _refresh_outputs_view(self, change=None):
        self.accordion.set_title(2, f"Results ({len(self.outputs_keys.options)})")
        
        if change is None or change['name'] == 'options':
            selection_key = self.outputs_keys.value
        else:
            selection_key = change['new']

        with self.output_area:
            # Clear first to ensure that we are not showing the wrong thing.
            clear_output()
            
            if selection_key is not None:
                # Load data
                process_node = load_node(self.process.id)
                data = process_node.outputs[selection_key]
                output_viewer_widget = viewer(data)

                display(output_viewer_widget)

                if data and isinstance(data, StructureData):
                    output_viewer_widget._viewer.handle_resize()
    
    def _monitor_process(self):
        assert self.process is not None
        process_node = load_node(self.process.id)

        while not process_node.is_sealed:
            self.process_status.update()
            for callback in self.callbacks:
                callback(self)
            sleep(0.1)
        
        with self.hold_trait_notifications():
            self.process_status.update()
            self._refresh_outputs_keys()

        return process_node
            
    @traitlets.observe('process')
    def _observe_process(self, change):
        process = change['new']
        with self.hold_trait_notifications():
            if process is None:
                self._refresh_outputs_keys()
            else:
                process_node = load_node(process.id)
                self.process_status.process = process_node
                if process_node.is_sealed:
                    self._refresh_outputs_keys()
                else:
                    self.state = WizardApp.State.ACTIVE
                    monitor_thread = threading.Thread(target=self._monitor_process)
                    monitor_thread.start()
                return process_node

    skip = False
    
    def _on_submit_button_clicked(self, _):
        self.submit_button.disabled = True
        self.state = WizardApp.State.ACTIVE
        self.submit()
                
    def submit(self, _):
        raise NotImplementedError()
    
    
class RelaxSubmitWidget(CodeSubmitWidget):
    
    input_structure = traitlets.Instance(StructureData, allow_none=True)
    output_structure = traitlets.Instance(StructureData, allow_none=True)
    
    def _update_state(self):
        if self.process is None:
            if self.input_structure is None:
                if self.output_structure is None:
                    self.state = WizardApp.State.INIT
                else:
                    self.state = WizardApp.State.FAIL
            else:
                if self.output_structure is None:
                    self.state = WizardApp.State.READY
                else:
                    self.state = WizardApp.State.SUCCESS
        else:
            super()._update_state()
    
    @traitlets.observe('input_structure')
    def _observe_input_structure(self, change):
        self._update_state()
        
    @traitlets.observe('output_structure')
    def _observe_output_structure(self, change):
        self._update_state()
            
    @traitlets.observe('process')
    def _observe_process(self, change):
        with self.hold_trait_notifications():
            process_node = super()._observe_process(change)
            if process_node is not None:
                self.input_structure = process_node.inputs.structure
        return process_node

    def skip(self, _):
        self.output_structure = self.input_structure

    def _refresh_outputs_keys(self):
        process_node = super()._refresh_outputs_keys()
        if process_node is not None:
            with self.hold_trait_notifications():
                if 'output_structure' in process_node.outputs:
                    self.output_structure = process_node.outputs.output_structure
                elif process_node.is_sealed:
                    self.state = WizardApp.State.FAIL

    def submit(self, _=None):
        assert self.input_structure is not None
        
        builder = WorkflowFactory('quantumespresso.pw.relax').get_builder()
        builder.base.pw.code = self.code_group.selected_code
        builder.base.pw.parameters = DEFAULT_PARAMETERS
        builder.base.pw.metadata.options = self.options
        builder.base.kpoints_distance = Float(0.8)
        builder.base.pseudo_family = Str(self.pseudo_family.value)
        builder.structure = self.input_structure
        
        self.process = submit(builder)
        
    def reset(self):
        with self.hold_trait_notifications():
            self.process = None
            self.output_structure = None
        

class ComputeBandsSubmitWidget(CodeSubmitWidget):
    
    input_structure = traitlets.Instance(StructureData, allow_none=True)
    band_structure = traitlets.Instance(BandsData, allow_none=True)

    def _update_state(self):
        if self.process is None:
            if self.input_structure is None:
                if self.band_structure is None:
                    self.state = WizardApp.State.INIT
                else:
                    self.state = WizardApp.State.FAIL
            else:
                if self.band_structure is None:
                    self.state = WizardApp.State.READY
                else:
                    self.state = WizardApp.State.SUCCESS
        else:
            super()._update_state()

    @traitlets.observe('input_structure')
    def _observe_input_structure(self, change):
        self._update_state()
        
    @traitlets.observe('band_structure')
    def _observe_band_structure(self, change):
        self._update_state()
        
    @traitlets.observe('process')
    def _observe_process(self, change):
        with self.hold_trait_notifications():
            process_node = super()._observe_process(change)
            if process_node is None:
                self.band_structure = None
            else:
                self.input_structure = process_node.inputs.structure
        return process_node

    def _refresh_outputs_keys(self):
        process_node = super()._refresh_outputs_keys()
        if process_node is not None:
            with self.hold_trait_notifications():
                if 'band_structure' in process_node.outputs:
                    self.band_structure = process_node.outputs.band_structure
                elif process_node.is_sealed:
                    self.state = WizardApp.State.FAIL
        
    def submit(self, _=None):
        assert self.input_structure is not None
        
        builder = WorkflowFactory('quantumespresso.pw.bands').get_builder()

        builder.scf.pw.code = self.code_group.selected_code
        builder.scf.pw.parameters = DEFAULT_PARAMETERS
        builder.scf.pw.metadata.options = self.options
        builder.scf.kpoints_distance = Float(0.8)
        builder.scf.pseudo_family = Str(self.pseudo_family.value)

        builder.bands.pw.code = self.code_group.selected_code
        builder.bands.pw.parameters = DEFAULT_PARAMETERS
        builder.bands.pw.metadata.options = self.options
        builder.bands.pseudo_family = Str(self.pseudo_family.value)

        builder.structure = self.input_structure
        
        self.process = submit(builder)
        
    def reset(self):
        self.process = None

# Create the application steps
structure_selection_step = StructureUploadComboWidget(
    examples=[
        ('Diamond', 'miscellaneous/structures/diamond.cif'),
        ('Gallium arsenide', 'miscellaneous/structures/GaAs.xyz'),
        ('Silicon', 'miscellaneous/structures/Si.xyz'),
        ('Silicon oxide', 'miscellaneous/structures/SiO2.xyz'),
    ],
    viewer=True)
relax_step = RelaxSubmitWidget(allow_skip=True, has_next=True)
compute_bands_step = ComputeBandsSubmitWidget()

# Link the application steps
ipw.dlink(
    (structure_selection_step, 'confirmed_structure'), (relax_step, 'input_structure'),
    transform=lambda atoms: None if atoms is None else StructureData(ase=atoms))
ipw.dlink((relax_step, 'output_structure'), (compute_bands_step, 'input_structure'))

# Propagate the configuration from the relax step to the compute band gaps step.
ipw.dlink((relax_step.code_group.dropdown, 'value'), (compute_bands_step.code_group.dropdown, 'value'))
ipw.dlink((relax_step.pseudo_family, 'value'), (compute_bands_step.pseudo_family, 'value'))
ipw.dlink((relax_step.number_of_nodes, 'value'), (compute_bands_step.number_of_nodes, 'value'))
ipw.dlink((relax_step.cpus_per_node, 'value'), (compute_bands_step.cpus_per_node, 'value'))

# Add the application steps to the application
app = WizardApp(
    steps=[
        ('Select structure', structure_selection_step),
        ('Relax', relax_step),
        ('Compute bands', compute_bands_step)])
relax_step.callbacks.append(lambda _: app._update_titles())
compute_bands_step.callbacks.append(lambda _: app._update_titles())


# TODO: REMOVE THESE LINES AFTER TESTING
# with app.hold_sync():
#     structure_selection_step.example_widget.index = 1
#     structure_selection_step.confirm()
#     relax_step.process = load_node('3d649e1f-cbe2-4009-bdf1-b44f7aa9a0f6')
#     compute_bands_step.process = load_node('7da3beb2-0b9e-489c-b406-dc51fe9c8e96')
# END OF REMOVE LINES AFTER TESTING

app