# TODOs 08.01.19

New in this notebook:
- TODO: put widgets and plots into a rudimentary dashboard layout
- TODO: for band plot, enable switching between spin and normal plot
- TODO: add file chooser
- TODO: add 3D plot: crystal structure
- TODO: add BZ plot of kpath (Brillouin zone)

# Setup

## Setup masci-tools path

In [61]:
# IMPORTANT: we need to import stuff from masci-tools folder.
# Since masci-tools is not installed as a module (yet), the notebook kernel
# needs to be started in the masci-tools folder.
# If that has not happened for some reason, then need to add the masci-tools
# manually to the sys path.
import os
import sys

cwd = os.getcwd()
path_mtools = cwd
dirname_mtools = "masci-tools"
# first try if we can get away without needing an absolute path
if dirname_mtools in path_mtools:
    while os.path.basename(path_mtools) != dirname_mtools:
        path_mtools = os.path.split(path_mtools)[0]
else:
    # okay, try with an absolute path
    path_mtools = "/home/johannes/Desktop/Studium/Kurse_RWTH/SiScLab/18W/repos/masci-tools"
    if not os.path.isdir(path_mtools):
        raise IOError(f"Could not find path to masci-tools. Please specify absolute path.")

# found masci-tools. add to syspath (for imports) and chdir.
if path_mtools not in sys.path:
    # add only once
    sys.path.append(path_mtools)

## Setup imports

In [62]:
# Jupyter, Python imports
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np

#  python 3interactive figures in a live IPython notebook session
# if run from jupyter-notebook: MAGICmatplotlib nbagg
# if run from jupyter-lab: MAGICmatplotlib widget
%matplotlib widget

# studentproject18ws imports
import os
import logging
from studenproject18ws.hdf.reader import Reader
from studenproject18ws.hdf.recipes import Recipes

## Read file, import data

- TODO: link with widget-based file-chooser

In [63]:
filename = 'banddos_4x4.hdf'
# filename = 'banddos.hdf'
# filename = 'banddos_Co.hdf'

filepath = ['..', 'data', 'input', filename]
filepath = os.path.join(*filepath)

data = None
extractor = Reader(filepath=filepath)
with extractor as h5file:
    data = extractor.read(recipe=Recipes.Bands)
    #
    # Note:
    # Inside the with statement (context manager),
    # all data attributes that are type h5py Dataset are available (in-file access)
    # When the statement is left,the HDF5 file gets closed and the datasets are closed.
    #
    # Use data outside the with-statement (in-memory access: all HDF5 datasets converted to numpy ndarrays):
    data.move_datasets_to_memory()

# Define Widgets

## Define band plot

In [64]:
# defining the plt outside of the function so we can link it with widgets.
fig, ax = plt.subplots(1, figsize=(10,6))
plt.suptitle(f"BandStructure of {filename}")

FigureCanvasNbAgg()

Text(0.5,0.98,'BandStructure of banddos_4x4.hdf')

## Band plot

### Plot function: plot layout

In [65]:
def plot_setup():
    labels = []
    for label in data.k_special_point_labels:
        label = label.decode("utf-8")
        if (label == "g"):
            labels += ["$\Gamma$"]
        else:
            labels += str(label)

    plt.xticks(data.k_special_points, labels)
    plt.ylabel("E(k) [eV]")
    plt.xlim(0, max(data.k_distances))
    plt.hlines(0, 0, max(data.k_distances), lw=0.1)    

### Plot function: 2-character-selection

In [66]:
def plot_two_characters(mask_bands, mask_characters, mask_groups, spin, unfolding_weight_exponent, alpha=1):
    
#     characters = np.array(range(4))[mask_characters]
    
#     (k_resh, evs_resh, weight_resh) = data \
#         .reshape_data(mask_bands, data._mask_characters([characters[0]]),
#                       mask_groups, spin, unfolding_weight_exponent)

#     (k_resh2, evs_resh2, weight_resh2) = data \
#         .reshape_data(mask_bands, data._mask_characters([characters[1]]),
#                       mask_groups, spin, unfolding_weight_exponent)
    
#     rel = weight_resh / (weight_resh + weight_resh2) * 20
#     cm = plt.cm.winter  # get_cmap('RdYlBu')
#     ax.scatter(k_resh2, (evs_resh - data.fermi_energy) * data.HARTREE_EV, marker='o', c=rel, s=5 * weight_resh2,
#                lw=0,
#                alpha=alpha, cmap=cm)    
    
    characters = np.array(range(4))[mask_characters]
    if (len(characters) != 2):
        print("plot_two_characters: tried to plot with other than 2 characters selected. not allowed!")

    (k_resh, evs_resh, weight_resh) = data \
        .reshape_data(mask_bands, data._mask_characters([characters[0]]),
                      mask_groups, spin, unfolding_weight_exponent)

    (k_resh2, evs_resh2, weight_resh2) = data \
        .reshape_data(mask_bands, data._mask_characters([characters[1]]),
                      mask_groups, spin, unfolding_weight_exponent)

    # print(f"non-zero elements in divisor array: {np.count_nonzero(weight_resh+weight_resh2)} of {weight_resh.size} elements.")
    rel = weight_resh / (weight_resh + weight_resh2) * 20
    tot_weight = weight_resh + weight_resh2
    # ax1.scatter(k_resh, (evs_resh-fermi_energy)*hartree_in_ev, marker='o', c="g", s = 5 * weight_resh, lw=0, alpha = alpha)
    # ax1.scatter(k_resh2, (evs_resh-fermi_energy)*hartree_in_ev, marker='o', c="r", s = 5 * weight_resh2, lw=0, alpha = alpha)
    # print(len(tot_weight))
    # print(len(k_resh2))
    # print(len(rel))
    # print(len(evs_resh))

    # dont change order inside if statement...
    speed_up = True
    if (speed_up == True):
        t = 1e-4
        k_resh2 = k_resh2[tot_weight > t]
        evs_resh = evs_resh[tot_weight > t]
        rel = rel[tot_weight > t]
        tot_weight = tot_weight[tot_weight > t]

    # print(len(tot_weight))
    # print(len(k_resh2))
    # print(len(rel))
    # print(len(evs_resh))

    # cm = plt.cm.get_cmap('RdYlBu')
    # cm = plt.cm.winter
    cm = plt.cm.plasma
    ax.scatter(k_resh2, (evs_resh - data.fermi_energy) * data.HARTREE_EV, marker='o', c=rel, s=5 * tot_weight,
               lw=0,
               alpha=alpha, cmap=cm)

### Plot function: normal

In [67]:
def plot(mask_bands, mask_characters, mask_groups, spin, unfolding_weight_exponent, isCharacterPlot=False, alpha=1):
    
    if isCharacterPlot:
        alpha=1
        plot_two_characters(mask_bands, mask_characters, mask_groups, spin, unfolding_weight_exponent, alpha)
    
    else:
        
        alpha = 1
        color = "blue"
        (k_r, E_r, W_r) = data.reshape_data(mask_bands, mask_characters, mask_groups, spin,
                                               unfolding_weight_exponent)
        # just plot points with minimal size of t
        speed_up = True
        if (speed_up == True):
            t = 1e-4
            k_r = k_r[W_r > t]
            E_r = E_r[W_r > t]
            W_r = W_r[W_r > t]
        ax.scatter(k_r, (E_r - data.fermi_energy) * data.HARTREE_EV,
                   marker='o', c=color, s=5 * W_r, lw=0, alpha=alpha)

### Define user input arguments

In [68]:
# bands = atom_group_keys = e.g. for banddos.hdf: dict_keys[(1,2,3,4,5)]
#                           Hm... should better convert to tuple back in reader?
def_groups = data.atom_group_keys
select_groups = widgets.SelectMultiple(options=def_groups, 
                                   value=tuple(def_groups),
                                  description='Atom Groups',
                                  disabled=False)


def_characters = ['s', 'p', 'd', 'f']
# Characters = namedtuple('Characters', ['s', 'p', 'd', 'f'])
# characters = Characters(0,1,2,3)
select_characters = widgets.SelectMultiple(options=def_characters,
                                          value=tuple(def_characters),
                                          description='Band Character',
                                          disabled=False)

# number of bands can be large, so use a rangeslider instead of selectionslider
def_bands = [band for band in range(data.eigenvalues.shape[2])]
select_bands = widgets.IntRangeSlider(value=[def_bands[0]+1,def_bands[-1]+1], 
                                      min=def_bands[0]+1, max=def_bands[-1]+1, step=1,
                                     description='Bands',
                                     disabled=False, continuous_update=False,
                                     orientation='horizontal', readout=True,
                                     readout_format='d')

select_exponent = widgets.FloatSlider(
    value=1.0,
    min=0,
    max=1.0,
    step=0.01,
    description='Unfolding weight exponent:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f')

## Interact function

In [69]:
fig_scale = 0.5
fig_ratio = [10,6]
fig, ax = plt.subplots(1, figsize=[fig_scale * el for el in fig_ratio])
plt.suptitle(f"BandStructure of {filename}")


@interact(bands=select_bands, characters=select_characters, groups=select_groups, unfolding_weight_exponent=select_exponent)
def update_plot(bands, characters, groups, unfolding_weight_exponent):
    ax.clear()
    
    # convert arguments to the expected format for code 181124
    bands_conved = range(bands[0]-1,bands[1])
    groups_conved = [el-1 for el in groups]
    characters_conved = [def_characters.index(el) for el in characters]
    # TODO: switch plotMethod on condition which/how many characters have been selected
    isCharacterPlot=True
    
    # convert arguments to the expected format for code 181212
    mask_characters = [el in characters for el in def_characters]
    mask_bands = [el in bands_conved for el in def_bands]
    mask_groups = [el in [el for el in groups] for el in def_groups]
    
    spin = 0
    
    plot_setup()
    plot(mask_bands, mask_characters, mask_groups, spin, unfolding_weight_exponent, isCharacterPlot)

FigureCanvasNbAgg()

interactive(children=(IntRangeSlider(value=(1, 387), continuous_update=False, description='Bands', max=387, mi…