# Binder Demo Backend

# Setup

## Add masci-tools to sys.path

Since `studentproject18ws` is not installed as a module, need to add it manually.
Since we're not local but on mybinder.org, we don't know the absolute path.
So assume that we're in a subfolder of studentproject18ws and go up until we find the folder that contains it. That must be the `masci-tools` folder.


In [None]:
import os
import sys

dir_masci = None
dir_here = os.getcwd()
dir_stupro = "studentproject18w"
found_dir_masci = True

if dir_stupro not in os.listdir():
    found_dir_masci = False
    while os.getcwd() != "/":
        if dir_stupro in os.listdir():
            found_dir_masci = True
            break
        os.chdir("..")

if found_dir_masci:
    dir_masci = os.getcwd()
    os.chdir(dir_here)
else:
    raise ImportError("Did not find masci-tools/studentproject18ws. If not installed as module, need location of this folder.")

# add (once) to sys path
if dir_masci not in sys.path:
    sys.path.append(dir_masci)

## set example input files

unchangeable demo input file until `fileupload` is integrated into dashboard

In [None]:
### 2018 ############################################

# # NO DOS file:
# fig_widget_title = "Si, C defect"
# fig_title = ""
# filename = 'banddos.hdf'
# filenames_dos = []

# # # NO DOS file:
# fig_widget_title = "Sodium"
# fig_title = ""
# filename = 'banddos_sodium.hdf'
# filenames_dos = []

# 2 DOS files:
fig_widget_title = "Co conductor"
fig_title = ""
filename = os.path.join('Co', 'banddos_Co.hdf')
filenames_dos = [os.path.join('Co', 'DOS.1'), os.path.join('Co', 'DOS.2')]

filepath = ['binder_demo_input', filename]
filepath = os.path.join(*filepath)
filepaths_dos = [['binder_demo_input', fd] for fd in filenames_dos]
filepaths_dos = [os.path.join(*fpd) for fpd in filepaths_dos]

## Setup imports

In [None]:
# Jupyter, Python imports
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import traitlets
from tkinter import Tk, filedialog

#  python 3interactive figures in a live IPython notebook session
# if run from jupyter-notebook: MAGICmatplotlib nbagg
# if run from jupyter-lab: MAGICmatplotlib widget
%matplotlib widget

# studentproject18ws imports
import logging
from studentproject18w.hdf.reader import Reader
from studentproject18w.hdf.recipes import Recipes
from studentproject18w.plot.matplot import BandDOSPlot, BandPlot

In [None]:
# reloading mpl backend ensures interactive plot works first time.
# At least on my system, I need to load it twice so that it works.
%matplotlib widget

## Read file, import data

- TODO: link with widget-based file-chooser

In [None]:
data = None
reader = Reader(filepath=filepath)
with reader as h5file:
    data = reader.read(recipe=Recipes.FleurBands)
    #
    # Note:
    # Inside the with statement (context manager),
    # all data attributes that are type h5py Dataset are available (in-file access)
    # When the statement is left,the HDF5 file gets closed and the datasets are closed.
    #
    # Use data outside the with-statement (in-memory access: all HDF5 datasets converted to numpy ndarrays):
    data.move_datasets_to_memory()

## Init data <--> plot interface

In [None]:
# init interface plot <--> data
plter = BandDOSPlot(plt, data, filepaths_dos)

# Define Widgets

### Define user input arguments: band plot

In [None]:
select_ylim = widgets.FloatRangeSlider(
    value=plter.icdv.ylim.initial,
    min=plter.icdv.ylim.min,
    max=plter.icdv.ylim.max+1,
    step=plter.icdv.ylim.step,
    description=plter.icdv.ylim.label,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=widgets.Layout(width="95%")
)

# number of bands can be large, so use a rangeslider instead of selectionslider
# select_bands_height = f"{len(plter.icdv.bands)*5}px"
select_bands = widgets.IntRangeSlider(
    value=plter.icdv.bands_slider.initial,
    min=plter.icdv.bands_slider.min, 
    max=plter.icdv.bands_slider.max, 
    step=plter.icdv.bands_slider.step,                                    
    description=plter.icdv.bands_slider.label,                              
    disabled=False, 
    continuous_update=False,                                    
    orientation='horizontal',
    readout=True,                                   
    readout_format='d',
    layout=select_ylim.layout
)

select_exponent = widgets.FloatSlider(
    value=plter.icdv.exponent.initial,
    min=plter.icdv.exponent.min,
    max=plter.icdv.exponent.max,
    step=plter.icdv.exponent.step,
    description=plter.icdv.exponent.label,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=select_ylim.layout
)

select_marker_size = widgets.FloatSlider(
    value=plter.icdv.marker_size.initial,
    min=plter.icdv.marker_size.min,
    max=plter.icdv.marker_size.max,
    step=plter.icdv.marker_size.step,
    description=plter.icdv.marker_size.label,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    layout=select_ylim.layout
)


select_compare_characters = widgets.Checkbox(
    value=False,
    description="",
    disabled=True,
#     style={'description_width': 'initial'} # for long descriptions
#     layout=widgets.Layout(width='30%', height='30%')
)

select_characters = widgets.SelectMultiple(
    options=plter.icdv.characters,
    value=tuple(plter.icdv.characters),
    description='',
    disabled=False,
#     layout=widgets.Layout(width='120%', height='90%')
)

# select_groups_height = f"{len(plter.icdv.groups)*19}px" # height just so no scrollbar needed
# select_groups_height = f"{len(plter.icdv.groups)*10}px"
select_groups_options = [(label,value) for label,value 
                         in zip(plter.icdv.group_labels, plter.icdv.groups)]
select_groups = widgets.SelectMultiple(
    options=select_groups_options,
    value=tuple(plter.icdv.groups),
    description='',                                  
    disabled=False,
#     layout=widgets.Layout(width='20%', height='100%')
)


select_ignore_apg = widgets.Checkbox(
    value=False,
    description="Ignore $N_g$",
    disabled=False
)

# # old select_spins: as selectmultiple
# select_spins = widgets.SelectMultiple(
#     options=plter.icdv.spins,
# #     value=(tuple(plter.icdv.groups)),
#     value=tuple([0]),
#     description='',
#     disabled=False
# )
#
# Attempt: new select_spins
# redefine select_spin as slider to put it inside slider box
# TODO: replace above declaration if it works well
select_spins = widgets.IntRangeSlider(
    value=plter.icdv.spins[0],
    min=min(plter.icdv.spins),
    max=max(plter.icdv.spins),
    step=1,
    description='Spins',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    layout=select_ylim.layout
)

### Define user input arguments: DOS plot
TODO: these only are visible if (a) DOS file(s) are loaded.

In [None]:
select_dos_groups = widgets.Checkbox(
    value=True,
    description="Groups",
    disabled=False
)

select_dos_interstitial = widgets.Checkbox(
    value=True,
    description="Interstitial",
    disabled=False
)

select_dos_characters = widgets.Checkbox(
    value=False,
    description="All characters",
    disabled=False
)



### Connect interdependent widgets

In [None]:
# disable compare_characters when not 2 characters are selected
def on_character_selection_change(change):
    # repair spin selection: (1,1) -> (1,); (0,0) -> (0,)
    spins = select_spins.value
    sp = select_spins.value
    all_equal = (len(set(sp)) <= 1)
    if all_equal:
        spins = tuple([sp[0]])
    
    if len(change.new) == 2 and (len(spins) != 2):
        select_compare_characters.disabled = False
#         select_compare_characters.description = "Compare 2 characters (enabled)"
    else:
        select_compare_characters.disabled = True
#         select_compare_characters.description = "Compare 2 characters (disabled)"
select_characters.observe(on_character_selection_change, names='value')

# disable compare_characters when 2 spins are selected.
# it would work (plot just uses the down spin then), but would be confusing.
def on_spin_selection_change(change):
    
    # repair spin selection: (1,1) -> (1,); (0,0) -> (0,)
    spins = select_spins.value
    sp = change.new
    all_equal = (len(set(sp)) <= 1)
    if all_equal:
        spins = tuple([sp[0]])
        
    if len(spins) == 2:
        select_compare_characters.disabled = True
#         select_compare_characters.description = "Compare 2 characters (enabled)"
    elif (len(select_characters.value) == 2):
        select_compare_characters.disabled = False
#         select_compare_characters.description = "Compare 2 characters (disabled)"
select_spins.observe(on_spin_selection_change, names='value')

# disable select_characters and select_spins when compare_characters is active, 
def on_compare_character(change):
    select_characters.disabled = change.new
    select_spins.disabled = change.new
select_compare_characters.observe(on_compare_character, names='value')

# Define Dashboard Layout

## v02: 4 HBoxes, box1=vertical sliders, (spin,DOS) integrated

In [None]:
layout_auto = widgets.Layout(width='auto') #override default layouts
def override_layout(widgetss, set_desc_width_style=False, remove_description=False):
    for widget in widgetss: # 'ss' cause ipywidgets is imported as 'widgets'
        no_description=(widget.description=="")
        widget.layout=layout_auto
        if no_description or remove_description:
            widget.description=''
        if set_desc_width_style:
            widget.style = {'description_width': 'initial'}
            
        
def hbox_widget_label(wig, lab, wig_left=True):
    if wig_left:
        return widgets.HBox([wig,lab])
    else:
        return widgets.HBox([lab,wig])
#     wig.layout=layout_auto
#     lab.layout=layout_auto
#     hbox_layout = widgets.Layout(
#         display='flex',
#         flex_flow='row',
#         justify_content='space-between'
#     ) # this is def of HBox I think
#     children = None
#     if wig_left:
#         children = [wig,lab]
#     else:
#         children = [lab,wig]
#     box = widgets.Box(
#         children=children,
#         layout=hbox_layout
#     )
#     return box

selects_slider = [
    select_ylim,
    select_bands,
    select_exponent,
    select_marker_size,
]
if len(plter.icdv.spins) > 1:
    selects_slider.append(select_spins)
override_layout(selects_slider)
box_sliders_width = '60%' if plter.filepaths_dos else '40%'
layout_box_sliders = widgets.Layout(
    display='flex',
    flex_flow='column',
    align_items='stretch',
    border='solid',
    width='60%'
)
box_sliders = widgets.Box(
    children=selects_slider,
    layout=layout_box_sliders
)


label_characters = widgets.Label("Characters")
label_compare_characters = widgets.Label("Compare 2")
selects_character = [
    label_characters,
    select_characters,
#     label_compare_characters,
    select_compare_characters
]
# worked for ignore_apg (see below):
# try setting style initial
select_compare_characters.description='Compare 2'
select_compare_characters.style = {'description_width': 'initial'}

override_layout(selects_character)
select_characters.description=''
# selects_character.append(hbox_widget_label(
#     select_compare_characters, 
#     label_compare_characters, wig_left=False))
layout_box_characters = widgets.Layout(
    display='flex',
    flex_flow='column',
    align_items='stretch',
    border='solid',
    width='20%'
)
box_characters = widgets.Box(
    children=selects_character,
    layout=layout_box_characters
)


label_groups = widgets.Label("Atom Groups")
selects_group = [
    label_groups,
    select_groups,
    select_ignore_apg,
]
# ignore_apg: same problem as compare_char:
# descr shortened to ..., wrong alignment.
# try solution from tutorial:
# setting style initial
select_ignore_apg.style = {'description_width': 'initial'}
# this works! try the same for compare_char

override_layout(selects_group)
select_groups.description=''
box_groups = widgets.Box(
    children=selects_group,
    layout=layout_box_characters
)


label_dos = widgets.Label("DOS")
selects_dos = [
    label_dos,
    select_dos_interstitial,
    select_dos_groups,
    select_dos_characters
]
override_layout(selects_dos, set_desc_width_style=True)
box_dos = widgets.Box(
    children=selects_dos,
    layout=layout_box_characters
)

dashboard_children = [box_sliders, box_characters, box_groups]
if plter.filepaths_dos:
    dashboard_children.append(box_dos)
dashboard_controls = widgets.HBox(dashboard_children)
# display(dashboard_controls)

In [None]:
import ipyvolume as ipv

In [None]:
select_groups_dos = widgets.SelectMultiple(
    options=select_groups_options,
    value=tuple(plter.icdv.groups),
    description='',                                  
    disabled=False,
#     layout=widgets.Layout(width='20%', height='100%')
)
select_groups_dos.display='none'
select_groups_dos.visibility='hidden'
link_groups = widgets.link((select_groups_dos, 'value'), (select_groups, 'value'))
fig_ipv = ipv.figure()
box_ipv = ipv.pylab.gcc()

import ipyvolume as ipv
x, y, z = data.atoms_position.T

def update_atoms_plot(groups):
    select_groups.description=""
    selected_atoms = []
    for group in groups:
        selected_atoms_group = np.where(data.atoms_group==group)
        for selected_atom_group in selected_atoms_group:
            selected_atoms.extend(selected_atom_group)
    
    
    scatter = ipv.scatter(x,y,z, size=5, marker="sphere", selected=selected_atoms, size_selected=8)
    ipv.show()
    
interactive_atoms_plot=interactive(update_atoms_plot, groups=select_groups_dos)


# Dashboard: Run

In [None]:
# init plot
fig_scale = 0.8
fig_ratio = [12,6]
figsize=[fig_scale * el for el in fig_ratio]

(fig, ax_bands, ax_dos) = plter.setup_figure(fig_ratio, fig_scale, fig_title=fig_widget_title)
# (fig2, ax_bands2) = plter2.setup_figure(fig_ratio, fig_scale)

def update_plot(characters, groups, bands, spins,
                unfolding_weight_exponent, marker_size,
               compare_characters, ylim, ignore_atoms_per_group,
               dos_groups, dos_interstitial, dos_characters):
    
    # repair spin selection: (1,1) -> (1,); (0,0) -> (0,)
    sp = spins
    all_equal = (len(set(sp)) <= 1)
    if all_equal:
        spins = tuple([sp[0]])
    
    (mask_bands, mask_characters, mask_groups) = plter.icdv.convert_selections(
        bands, characters, groups)
    

    plter.plot_bandDOS(mask_bands, mask_characters, mask_groups, spins,
                          unfolding_weight_exponent, compare_characters, 
                              ignore_atoms_per_group, marker_size,
                              dos_groups, dos_interstitial, dos_characters,
                               dos_fix_xlim=True, ylim=ylim)

    select_characters.description=''
    select_groups.description=''
    
    
    plt.title(fig_title)
    plt.show()

    
interactive_update_plot = interactive(
    update_plot, 
    characters=select_characters, groups=select_groups, 
    bands=select_bands, spins=select_spins,
    unfolding_weight_exponent=select_exponent, marker_size=select_marker_size,
    compare_characters=select_compare_characters, ylim=select_ylim,
    ignore_atoms_per_group=select_ignore_apg,
    dos_groups=select_dos_groups,
    dos_interstitial=select_dos_interstitial,
    dos_characters=select_dos_characters
)


# dashboard = widgets.VBox()
# dashboard_plot = widgets.VBox(interactive_update_plot.children)
# dashboard_plot.layout.display='none'
# dashboard.children = [dashboard_plot,
#                      dashboard_controls]

# select_characters.description=''
# select_groups.description=''
# select_compare_characters.description=''
display(dashboard_controls)
# # TODO: integrate into dashboard:
# display(select_spins)

# # TODO: integrate into dashboard:
# display(select_dos_groups)
# display(select_dos_interstitial)
# display(select_dos_characters)

# for child in interactive_atoms_plot.children:
#     print(type(child)) 
# first on is SelectMultiple, 2nd is Output Widget containing the ipv plot
if (len(interactive_atoms_plot.children) > 1):
    interactive_atoms_plot.children = tuple([interactive_atoms_plot.children[1]])

# display(interactive_atoms_plot)

## 3D Atoms plot

In [None]:
# import ipyvolume as ipv
x, y, z = data.atoms_position.T
@interact(groups=select_groups)
def update_atoms_plot(groups):
    selected_atoms = []
    for group in groups:
        selected_atoms_group = np.where(data.atoms_group==group)
        for selected_atom_group in selected_atoms_group:
            selected_atoms.extend(selected_atom_group)
    
    ipv.figure()
    scatter = ipv.scatter(x,y,z, size=5, marker="sphere", selected=selected_atoms, size_selected=8)
    ipv.show()