In [290]:
import os
import ipywidgets as ipw
import nglview
from pprint import pprint
from traitlets import Bool, Bytes, Dict, Unicode, Tuple, Instance
import ase
import tempfile

from traitlets import Int, Bool
from aiidalab_widgets_base import CodeDropdown


class StructureFileUploadWidget(ipw.VBox):
    
    file = Tuple(Unicode(), Bytes())

    def __init__(self):
        self.description = ipw.Label("Upload a structure file:")
        self.supported_formats = ipw.HTML(
            """<a href="https://wiki.fysik.dtu.dk/ase/_modules/ase/io/formats.html" target="_blank">
            Supported formats
            </a>""")
        self.file_upload = ipw.FileUpload(
            description="Select file",
            multiple=False,
            )
        self.file_upload.observe(self.on_file_uploaded, names='value')
        
        super().__init__(children=[self.description, self.file_upload, self.supported_formats])
        
    def on_file_uploaded(self, change):
        for fn, item in change['new'].items():
            self.file = (fn, item['content'])
            break
            
    def freeze(self):
        self.file_upload.disabled = True
            

class SelectionStructureUploadWidget(ipw.Dropdown):
    
    file = Tuple(Unicode(), Bytes())
    
    def __init__(self, options=None, hint='Select structure', **kwargs):
        if options is None:
            options = []
        else:
            options.insert(0, (hint, None))
            
        super().__init__(options=options, **kwargs)
        
        self.observe(self._on_select, 'index')
        
    def _on_select(self, change):
        assert change['name']  == 'index'
        fn_selected = self.options[change['new']][1]
        with open(fn_selected, 'rb') as file:
            self.file = fn_selected, file.read()
            
    def freeze(self):
        self.disabled = True
        

class StructureUploadComboWidget(ipw.VBox):
    
    confirmed = Bool()
    
    def __init__(self, data_importers=None, examples=None, viewer=True, **kwargs):
        self._structure = None
        
        if data_importers is None:
            self.data_importers = [('Upload', StructureFileUploadWidget())]
        else:
            self.data_importers = data_importers
            
        if examples:
            example_widget = SelectionStructureUploadWidget(options=examples)
            self.data_importers.append(('Examples', example_widget))
            
        if len(self.data_importers) > 1:
            self.structure_sources_tab = ipw.Tab(children=[s[1] for s in self.data_importers])
            for i, source in enumerate(self.data_importers):
                self.structure_sources_tab.set_title(i, source[0])
        else:
            self.structure_sources_tab = self.data_importers[0][1]
            
        self.structure_sources_tab.layout = ipw.Layout(
            display='flex',
            flex_flow='column',
        )
            
        for data_importer in self.data_importers:
            data_importer[1].observe(self._on_structure_file_selection, 'file')
            
        if viewer:
            self.viewer = nglview.NGLWidget(width='300px', height='300px')
            self.viewer_box = ipw.Box(children=[self.viewer], layout=ipw.Layout(border='solid 1px'))
        else:
            self.viewer = None
            self.viewer_box = None
            
        self.structure_name_text = ipw.Text(
            placeholder='Structure',
            description='Selected:',
            disabled=True,
        )
        
        self.structure_sources_tab.layout = ipw.Layout(min_width='600px')
        
        grid = ipw.GridspecLayout(1, 3)
        grid[:, :2] = self.structure_sources_tab
        grid[:, 2] = self.viewer_box
        
        self.confirm_button = ipw.Button(
            description='Confirm',
            button_style='success',
#             icon='check-circle',
            disabled=True,
            layout=ipw.Layout(width='auto'),
        )
#         self.confirm_button.observe(self.structure_confirmed)
        self.confirm_button.on_click(self.confirm)
       
        super().__init__(
            children=[grid, self.structure_name_text, self.confirm_button],
            **kwargs)
        
    @property
    def structure(self):
        return self._structure
    
    @structure.setter
    def structure(self, value):
        self._structure = value
        if self._structure is None:
            self.confirm_button.disabled = True
            self.structure_name_text.value = ""
        else:
            self.structure_name_text.value = "{}".format(self.structure.get_chemical_formula())
            self.confirm_button.disabled = False
        self.refresh_view()
        
    def _on_structure_file_selection(self, change):
        assert change['name'] == 'file'
        fn, data = change['new']
        basename, ext = os.path.splitext(fn)
        try:
            with tempfile.NamedTemporaryFile(suffix=ext) as file:
                file.write(data)
                file.flush()
                
                traj = ase.io.read(file.name, index=':')
            # TODO: REACT TO len(traj) > 1
        except Exception as error:
            print(error)
        else:
            self.structure = traj[0]
            
    def refresh_view(self):
        # Note: viewer.clear() only removes the 1st component (TODO: FIX UPSTREAM!)
        for comp_id in self.viewer._ngl_component_ids:
            viewer.remove_component(comp_id)
        if self.structure is not None:
            self.viewer.add_component(nglview.ASEStructure(self.structure))
            self.viewer.add_unitcell()
                         
    def confirm(self, button):
        for child in self.structure_sources_tab.children:
            child.freeze()
        self.confirm_button.disabled = True
        self.confirm_button.icon = 'check-circle'
        self.confirmed = True


class CodeSubmitWidget(ipw.VBox):
    
    def _update_total_num_cpus(self, change):
        self.total_num_cpus.value = self.number_of_nodes.value * self.cpus_per_node.value
    
    def __init__(self, **kwargs):
        self.code_group = CodeDropdown(input_plugin='quantumespresso.pw', text="Select code", path_to_root='.')
        
        extra = {
            'style': {'description_width': '150px'},
            'layout': {'max_width': '200px'}
        }
        
        self.number_of_nodes = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# nodes",
            disabled=False,
            **extra)
        self.cpus_per_node = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# cpus per node",
            **extra)
        self.total_num_cpus = ipw.BoundedIntText(
            value=1, step=1, min=1,
            description="# total cpus",
            disabled=True,
            **extra)
                       
        self.resources = ipw.VBox(children=[
            ipw.Label("Resources:"),
            self.number_of_nodes,
            self.cpus_per_node,
            self.total_num_cpus,
        ])

        self.pseudo_family = ipw.ToggleButtons(
            options = {
                'SSSP efficiency': 'SSSP_efficiency_v1.0',
                'SSSP accuracy': 'SSSP_precision_v1.0',
            },
            description='Pseudopotential family:',
            style = {'description_width': 'initial'}
        )
        
        self.submit_button = ipw.Button(
            description='Submit',
            icon='play-circle',
            button_style='success',
            layout=ipw.Layout(width='auto'),
            disabled=True,
        )
        
        super().__init__(children=[
                self.code_group, 
                self.pseudo_family,
                self.resources,
                self.submit_button,
            ], **kwargs)
    
        # Update the total # of cpus int text:
        self.number_of_nodes.observe(self._update_total_num_cpus, 'value')
        self.cpus_per_node.observe(self._update_total_num_cpus, 'value')
        
    def get_settings(self):
        return {
            'num_nodes': self.num_nodes.value,
            'num_cpus_per_node': self.cpus_per_node.value,
            'code_group': self.code_group.value,
            'pseudo_family': self.pseudo_family.value,
        }
        

class StagesApp(ipw.VBox):
    
    stage = Int().tag(default=-1)
    
    def __init__(self, stages=None, **kwargs):
        self.stages = stages

        self.accordion = ipw.Accordion(children=[stage[1] for stage in self.stages])
        for i, title in enumerate([stage[0] for stage in self.stages]):
            self.accordion.set_title(i, title)
        
#         self.next_button = ipw.Button(description="Next", button_style='info', layout=ipw.Layout(width='auto'))
#         self.next_button.on_click(self._next_stage)
#         self.next_button.disabled = True       
                
        super().__init__(
            children=[
                self.accordion,
#                 self.next_button,
            ],
            **kwargs)
           
structure_widget = StructureUploadComboWidget(
    examples=[("Silicon oxide", 'miscellaneous/structures/SiO2.xyz')],
    viewer=True)

base_stage = CodeSubmitWidget(kind='quantumespresso.pw.base', disabled=True)
scf_stage = CodeSubmitWidget(kind='quantumespresso.pw.base')
compute_bands_stage = CodeSubmitWidget(kind='quantumespresso.pw.bands')

app = ipw.Accordion(
    children=[structure_widget, base_stage, scf_stage, compute_bands_stage])
for i, title in enumerate([
        'Select structure',
        'SCF',
        'Relax',
        'Compute Bands',
    ]):
    app.set_title(i, "{}. {}".format(i+1, title))
    
def structure_confirmed(changes):
    app.selected_index = 1
    
structure_widget.observe(structure_confirmed, 'confirmed')
    
app

Accordion(children=(StructureUploadComboWidget(children=(GridspecLayout(children=(Tab(children=(StructureFileU…