In [None]:
def show_results(outputs):
    for key in ['retrieved', 'output_parameters', 'output_structure', 'band_structure']:
        if key in outputs:
            display(outputs[key])

In [None]:
import os
import threading
import tempfile
from enum import Enum
from pprint import pprint
from time import sleep
from collections import deque

import ase
import nglview
import ipywidgets as ipw
import traitlets

from aiidalab_widgets_base import CodeDropdown, ProgressBarWidget

from aiida.engine import submit
from aiida.orm import load_node, Dict, Float,Str, StructureData
from aiida.plugins import WorkflowFactory


DEFAULT_PARAMETERS = Dict(dict={"SYSTEM": {"ecutwfc": 50., "ecutrho": 200.}})


class StagesApp(ipw.Accordion):
    
    _ICON_SEPARATOR = '\u2000'  # en-dash
    _ICON_SUCCESS = '\u2714'    # ✔
    _ICON_FAIL = '\u2715'       # ✕
            
    class State(Enum):
        "Every stage within the StagedApp must have this traitlet."
        INIT = 0  # initialized state
        
        CONFIGURED = 1
        READY = 2
        ACTIVE = 3
        SUCCESS = 4
        
        # All error states have negative codes
        FAIL = -1

    def __init__(self, stages, **kwargs):
        # The number of stages must be greater than one
        # for this app's logic to make sense.
        assert len(stages) > 1
        self.stages = stages
        
        # Unzip the stages to titles and widgets.
        titles, widgets = zip(*stages)
        
        # Observe all widgets' state
        for widget in widgets:
            assert widget.state == self.State.INIT
            widget.observe(self._update_stage_state, names=['state'])
            if hasattr(widget, 'next_button'):
                widget.next_button.on_click(self._next_button_clicked)
        
        super().__init__(children=widgets, **kwargs)
        
        # Set the titles (enumerated).
        for i, title in enumerate(titles):
            self.set_title(i, f"{i+1}. {title}{self._ICON_SEPARATOR}")
            
    def set_title_icon(self, index, icon):
        tmp = self.get_title(index).rpartition(self._ICON_SEPARATOR)
        self.set_title(index, tmp[0] + tmp[1] + icon)        
        
    def _update_stage_state(self, change):
        widget = change['owner']
        widget_index = self.children.index(widget)
        
        if change['new'] == self.State.SUCCESS:
            self.set_title_icon(widget_index, self._ICON_SUCCESS)
            self.selected_index = min(len(self.children)-1, widget_index+1)
            
            # Enable next stage
            self.children[self.selected_index].state = self.State.READY

        elif change['new'] == self.State.FAIL:
            self.set_title_icon(widget_index, self._ICON_FAIL)
        else:
            pass  # do nothing
            
    def _next_button_clicked(self, button):
        self.selected_index += 1


class StructureFileUploadWidget(ipw.VBox):
    
    file = traitlets.Tuple(traitlets.Unicode(), traitlets.Bytes())

    def __init__(self):
        self.description = ipw.Label("Upload a structure file:")
        self.supported_formats = ipw.HTML(
            """<a href="https://wiki.fysik.dtu.dk/ase/_modules/ase/io/formats.html" target="_blank">
            Supported formats
            </a>""")
        self.file_upload = ipw.FileUpload(
            description="Select file",
            multiple=False,
            )
        self.file_upload.observe(self.on_file_uploaded, names='value')
        
        super().__init__(children=[self.description, self.file_upload, self.supported_formats])
        
    def on_file_uploaded(self, change):
        for fn, item in change['new'].items():
            self.file = (fn, item['content'])
            break
            
    def freeze(self):
        self.file_upload.disabled = True
            

class SelectionStructureUploadWidget(ipw.Dropdown):
    
    file = traitlets.Tuple(traitlets.Unicode(), traitlets.Bytes())
    
    def __init__(self, options=None, hint='Select structure', **kwargs):
        if options is None:
            options = []
        else:
            options.insert(0, (hint, None))
            
        super().__init__(options=options, **kwargs)
        
        self.observe(self._on_select, 'index')
        
    def _on_select(self, change):
        assert change['name']  == 'index'
        fn_selected = self.options[change['new']][1]
        with open(fn_selected, 'rb') as file:
            self.file = fn_selected, file.read()
            
    def reset(self):
        self.file = (None, None)
        self.index = 0
            
    def freeze(self):
        self.disabled = True


class StructureUploadComboWidget(ipw.VBox):
   
    state = traitlets.UseEnum(StagesApp.State)
    
    def __init__(self, data_importers=None, examples=None, viewer=True, **kwargs):
        self._structure = None
        self._structure_node = None
        
        if data_importers is None:
            self.data_importers = [('Upload', StructureFileUploadWidget())]
        else:
            self.data_importers = data_importers
            
        if examples:
            self.example_widget = SelectionStructureUploadWidget(options=examples)
            self.data_importers.append(('Examples', self.example_widget))
        else:
            self.example_widget = None
            
        if len(self.data_importers) > 1:
            self.structure_sources_tab = ipw.Tab(children=[s[1] for s in self.data_importers])
            for i, source in enumerate(self.data_importers):
                self.structure_sources_tab.set_title(i, source[0])
        else:
            self.structure_sources_tab = self.data_importers[0][1]
            
        self.structure_sources_tab.layout = ipw.Layout(
            display='flex',
            flex_flow='column',
        )
            
        for data_importer in self.data_importers:
            data_importer[1].observe(self._on_structure_file_selection, 'file')
            
        if viewer:
            self.viewer = nglview.NGLWidget(width='300px', height='300px')
            self.viewer_box = ipw.Box(children=[self.viewer], layout=ipw.Layout(border='solid 1px'))
        else:
            self.viewer = None
            self.viewer_box = None
            
        self.structure_name_text = ipw.Text(
            placeholder='Structure',
            description='Selected:',
            disabled=True,
        )
        
        self.structure_sources_tab.layout = ipw.Layout(min_width='600px')
        
        grid = ipw.GridspecLayout(1, 3)
        grid[:, :2] = self.structure_sources_tab
        grid[:, 2] = self.viewer_box
        
        self.confirm_button = ipw.Button(
            description='Confirm',
            button_style='success',
            icon='check-circle',
            disabled=True,
            layout=ipw.Layout(width='auto'),
        )
        self.confirm_button.on_click(self.confirm)
       
        super().__init__(
            children=[grid, self.structure_name_text, self.confirm_button],
            **kwargs)
        
    @traitlets.observe('state')
    def _observe_state(self, change):
        self.confirm_button.disabled = not self.state is StagesApp.State.READY
     
    @property
    def structure(self):
        return self._structure
    
    @structure.setter
    def structure(self, value):
        self._structure = value
        if self._structure is None:
            self.confirm_button.disabled = True
            self.structure_name_text.value = ""
        else:
            self.structure_name_text.value = "{}".format(self.structure.get_chemical_formula())
            self.confirm_button.disabled = False
        self.refresh_view()
        
    def _on_structure_file_selection(self, change):
        assert change['name'] == 'file'
        fn, data = change['new']
        basename, ext = os.path.splitext(fn)
        try:
            with tempfile.NamedTemporaryFile(suffix=ext) as file:
                file.write(data)
                file.flush()
                
                traj = ase.io.read(file.name, index=':')
            # TODO: REACT TO len(traj) > 1
        except Exception as error:
            print(error)
        else:
            self.structure = traj[0]
            
    def refresh_view(self):
        # Note: viewer.clear() only removes the 1st component (TODO: FIX UPSTREAM!)
        for comp_id in self.viewer._ngl_component_ids:
            self.viewer.remove_component(comp_id)
        if self.structure is not None:
            self.viewer.add_component(nglview.ASEStructure(self.structure))
            self.viewer.add_unitcell()
                         
    def confirm(self, button):
        for child in self.structure_sources_tab.children:
            child.freeze()
        self.confirm_button.disabled = True
        self.confirm_button.icon = 'check-circle'
        self.state = StagesApp.State.SUCCESS
        
    @property
    def structure_node(self):
        return self._structure_node

    
def get_calc_job_output(process):
    from aiidalab_widgets_base.process import get_running_calcs
    previous_calc_id = None
    num_lines = 0
    
    while not process.is_sealed:
        calc = None
        for calc in get_running_calcs(process):
            if calc.id == previous_calc_id:
                break
        else:
            if calc:
                previous_calc_id = calc.id

        if calc and 'remote_folder' in calc.outputs:
            f_path = os.path.join(calc.outputs.remote_folder.get_remote_path(), calc.attributes['output_filename'])
            if os.path.exists(f_path):
                with open(f_path) as fobj:
                    new_lines = fobj.readlines()[num_lines:]
                    num_lines += len(new_lines)
                    yield from new_lines
                      

class ProgressBarWidget(ipw.VBox):
    """A bar showing the proggress of a process."""

    def __init__(self, **kwargs):
        self._process = None
        self.correspondance = {
            None: (0, 'warning'),
            "created": (0, 'info'),
            "running": (1, 'info'),
            "waiting": (1, 'info'),
            "killed": (2, 'danger'),
            "excepted": (2, 'danger'),
            "finished": (2, 'success'),
        }
        self.bar = ipw.IntProgress(  # pylint: disable=blacklisted-name
            value=0,
            min=0,
            max=2,
            step=1,
#             description='Progress:',
            bar_style='warning',  # 'success', 'info', 'warning', 'danger' or ''
            orientation='horizontal',
            layout=ipw.Layout(width="auto")
        )
        self.state = ipw.HTML(
            description="Calculation state:", value='',
            style={'description_width': '100px'},
        )
        super().__init__(children=[self.state, self.bar], **kwargs)

    def update(self):
        """Update the bar."""
        self.bar.value, self.bar.bar_style = self.correspondance[self.current_state]
        if self.current_state is None:
            self.state.value = 'N/A'
        else:
            self.state.value = self.current_state.capitalize()
        
    @property
    def process(self):
        return self._process
    
    @process.setter
    def process(self, value):
        self._process = value
        self.update()

    @property
    def current_state(self):
        if self.process is not None:
            return self.process.process_state.value


class LogOutputWidget(ipw.VBox):
    
    def __init__(self, title='Output:', num_lines_shown=3, **kwargs):
        self.description = ipw.Label(value=title)
        self.last_lines = ipw.HTML()
        
        self.lines = []
        self.lines_shown = deque([''] * num_lines_shown, maxlen=num_lines_shown)
        
        self.raw_log = ipw.Textarea(
            layout=ipw.Layout(
                width='auto',
                height='auto',
                display='flex',
                flex='1 1 auto',
            ),
            disabled=True)

        self.accordion = ipw.Accordion(children=[self.raw_log])
        self.accordion.set_title(0, 'Raw log')
        self.accordion.selected_index = None

        self.update()        
        super().__init__(children=[self.description, self.last_lines, self.accordion], **kwargs)
        
    def clear(self):
        self.lines = []
        self.lines_shown.extend('' * len(self.lines_shown))
        self.update()
        
    def format_code(self, text):
        return '<pre style="background-color: #1f1f2e; color: white;">{}</pre>'.format(text)
        
    def append_line(self, line):
        self.lines.append(line.strip())
        self.lines_shown.append("{:03d}: {}".format(len(self.lines), line.strip()))
        self.update()
        
    def update(self):
        self.last_lines.value = self.format_code('\n'.join(self.lines_shown))
        self.raw_log.value = '\n'.join(self.lines)
            
            
class ProcessStatusWidget(ipw.VBox):

    def __init__(self, **kwargs):
        self._process = None
        self._process_output = None
        self.progress_bar = ProgressBarWidget()
        self.log_output = LogOutputWidget()

        super().__init__(children=[self.progress_bar, self.log_output], **kwargs)
        
    @property
    def process(self):
        return self._process
    
    @process.setter
    def process(self, value):
        self._process = value
        self._process_output = get_calc_job_output(value)
        
        self.progress_bar.process = value
        
    def update(self):
        try:
            self.log_output.append_line(next(self._process_output))
        except (TypeError, StopIteration):
            pass
        self.progress_bar.update()    

        
class CodeSubmitWidget(ipw.VBox):
    
    state = traitlets.UseEnum(StagesApp.State)
    
    
    def _update_total_num_cpus(self, change):
        self.total_num_cpus.value = self.number_of_nodes.value * self.cpus_per_node.value
    
    def __init__(self, allow_skip=False, has_next=False, **kwargs):
        self.code_group = CodeDropdown(input_plugin='quantumespresso.pw', text="Select code", path_to_root='.')
        
        extra = {
            'style': {'description_width': '150px'},
            'layout': {'max_width': '200px'}
        }
        
        self.number_of_nodes = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# nodes",
            disabled=False,
            **extra)
        self.cpus_per_node = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# cpus per node",
            **extra)
        self.total_num_cpus = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# total cpus",
            disabled=True,
            **extra)
                       
        self.resources = ipw.VBox(children=[
            ipw.Label("Resources:"),
            self.number_of_nodes,
            self.cpus_per_node,
            self.total_num_cpus,
        ])

        self.pseudo_family = ipw.ToggleButtons(
            options = {
                'SSSP efficiency': 'SSSP_efficiency_v1.0',
                'SSSP accuracy': 'SSSP_precision_v1.0',
            },
            description='Pseudopotential family:',
            style = {'description_width': 'initial'}
        )
        
        self.skip_button = ipw.Button(
            description='Skip',
            icon='fast-forward',
            button_style='info',
            layout=ipw.Layout(width='auto', flex="1 1 auto"),
            disabled=True)
                    
        self.submit_button = ipw.Button(
            description='Submit',
            icon='play',
            button_style='success',
            layout=ipw.Layout(width='auto', flex="1 1 auto"),
            disabled=True)
        
        self.next_button = ipw.Button(
            description='Next',
            icon='step-forward',
            layout=ipw.Layout(width='auto', flex='1 1 auto'),
            disabled=True)

        self.skip_button.layout.visibility = 'visible' if allow_skip else 'hidden'
        self.next_button.layout.visibility = 'visible' if has_next else 'hidden'
        self.buttons = ipw.HBox(children=[self.submit_button, self.next_button, self.skip_button])

        self.tabs = ipw.Tab(
            children=[self.code_group, self.pseudo_family, self.resources],
            layout=ipw.Layout(height='200px'),
        )
        self.tabs.set_title(0, 'Code')
        self.tabs.set_title(1, 'Pseudopotential')
        self.tabs.set_title(2, 'Compute resources')
        
        self.status = ProcessStatusWidget()
        
        self.accordion = ipw.Accordion(children=[self.tabs, self.status])
        self.accordion.set_title(0, 'Config')
        self.accordion.set_title(1, 'Status')
        
        super().__init__(children=[self.accordion, self.buttons], **kwargs)
    
        # Update the total # of CPUs int text:
        self.number_of_nodes.observe(self._update_total_num_cpus, 'value')
        self.cpus_per_node.observe(self._update_total_num_cpus, 'value')
        
    def get_settings(self):
        return {
            'num_nodes': self.num_nodes.value,
            'num_cpus_per_node': self.cpus_per_node.value,
            'code_group': self.code_group.value,
            'pseudo_family': self.pseudo_family.value,
        }
    
    @traitlets.observe('state')
    def _observe_state(self, change):
        self.skip_button.disabled = not self.state is StagesApp.State.READY
        self.submit_button.disabled = not self.state is StagesApp.State.READY
        self.next_button.disabled = not self.state is StagesApp.State.SUCCESS

        if change['new'] == StagesApp.State.ACTIVE:
            self.accordion.selected_index = 1
            
    @property
    def options(self):
        return {
            'max_wallclock_seconds': 3600*2,
            'resources': {
                'num_machines': self.number_of_nodes.value,
                'num_mpiprocs_per_machine': self.cpus_per_node.value}}
        

structure_widget = StructureUploadComboWidget(
    examples=[("Silicon oxide", 'common/structures/SiO2.xyz')],
    viewer=True)

# base_stage = CodeSubmitWidget(disabled=True, allow_skip=True, has_next=True)
relax_stage = CodeSubmitWidget(disabled=True, allow_skip=True, has_next=True)
compute_bands_stage = CodeSubmitWidget(disabled=True)

    
app = StagesApp(
    stages=[
        ('Select structure', structure_widget),
#         ('SCF', base_stage),
        ('Relax', relax_stage),
        ('Compute bands', compute_bands_stage)])

# ************* IMPLEMENT THE APP LOGIC ************* :

def monitor_base_stage_progress(base_stage, process_node):    
    base_stage.status.process = process_node
    while not process_node.is_sealed:
        base_stage.status.update()
        sleep(0.1)
    base_stage.status.update()


def monitor_relax_stage_progress(relax_stage, process_node):    
    relax_stage.status.process = process_node
    while not process_node.is_sealed:
        relax_stage.status.update()
        sleep(0.1)
    relax_stage.status.update()
    
    if 'output_structure' in process_node.outputs:
        relax_stage.structure = process_node.outputs.output_structure
        relax_stage.state = StagesApp.State.SUCCESS
    else:
        relax_stage.state = StagesApp.State.FAIL

    
def monitor_compute_bands_stage_progress(compute_bands_stage, process_node):
    compute_bands_stage.status.process = process_node
    while not process_node.is_sealed:
        compute_bands_stage.status.update()
        sleep(0.1)
    compute_bands_stage.status.update()
    
    # TODO: Trying to figure out what's going on here.
    pprint(process_node.outputs)
    
    compute_bands_stage.state = StagesApp.State.FAIL; return  # until validated
    compute_bands_stage.state = StagesApp.State.SUCCESS
    

def submit_base(button):
    builder = WorkflowFactory('quantumespresso.pw.base').get_builder()
     
    builder.pw.code = base_stage.code_group.selected_code
    builder.pw.parameters = DEFAULT_PARAMETERS
    builder.pw.metadata.options = base_stage.options
    builder.kpoints.distance = Float(0.8)
    builder.pseudo_family = Str(base_stage.pseudo_family.value)
    builder.pw.structure = StructureData(ase=structure_widget.structure)

    try:
        process_node = load_node(submit(builder).id)
    except Exception:
        base_stage.state.StagesApp.State.FAIL
    else:
        base_stage.state = StagesApp.State.ACTIVE
        update_thread = threading.Thread(target=monitor_base_stage_progress, args=(base_stage, process_node))
        update_thread.start()

    
def skip_base(button):
    base_stage.structure = StructureData(ase=structure_widget.structure)
    base_stage.state = StagesApp.State.SUCCESS


def submit_relax(button):
    builder = WorkflowFactory('quantumespresso.pw.relax').get_builder()   
    builder.base.pw.code = relax_stage.code_group.selected_code
    builder.base.pw.parameters = DEFAULT_PARAMETERS
    builder.base.pw.metadata.options = relax_stage.options
    builder.base.kpoints_distance = Float(0.8)
    builder.base.pseudo_family = Str(relax_stage.pseudo_family.value)
#     builder.structure = base_stage.structure
    builder.structure = StructureData(ase=structure_widget.structure)
    
    process = submit(builder)
    process_node = load_node(process.id)
    
    relax_stage.state = StagesApp.State.ACTIVE
        
    update_thread = threading.Thread(target=monitor_relax_stage_progress, args=(relax_stage, process_node))
    update_thread.start()

    
def skip_relax(button):
    relax_stage.structure = StructureData(ase=structure_widget.structure)
    relax_stage.state = StagesApp.State.SUCCESS

    
def submit_compute_bands(button):   
    builder = WorkflowFactory('quantumespresso.pw.bands').get_builder()
        
    builder.scf.pw.code = compute_bands_stage.code_group.selected_code
    builder.scf.pw.parameters = DEFAULT_PARAMETERS
    builder.scf.pw.metadata.options = compute_bands_stage.options
    builder.scf.kpoints_distance = Float(0.8)
    builder.scf.pseudo_family = Str(compute_bands_stage.pseudo_family.value)
    
    builder.bands.pw.code = compute_bands_stage.code_group.selected_code
    builder.bands.pw.parameters = DEFAULT_PARAMETERS
    builder.bands.pw.metadata.options = compute_bands_stage.options
    builder.bands.pseudo_family = Str(compute_bands_stage.pseudo_family.value)
    
    builder.structure = relax_stage.structure
    
    process_node = load_node(submit(builder).id)
    compute_bands_stage.state = StagesApp.State.ACTIVE

    update_thread = threading.Thread(
        target=monitor_compute_bands_stage_progress, args=(compute_bands_stage, process_node))
    update_thread.start()


# base_stage.skip_button.on_click(skip_base)
# base_stage.submit_button.on_click(submit_base)
relax_stage.skip_button.on_click(skip_relax)
relax_stage.submit_button.on_click(submit_relax)
compute_bands_stage.submit_button.on_click(submit_compute_bands)

# FOR TESTING PURPOSES ONLY!!
# structure_widget.example_widget.index = 1

app