In [None]:
# make cell as wide as window
from IPython.display import display, HTML
display(HTML(data="""
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))

In [None]:
from ipywidgets import IntProgress
from IPython.display import display

In [None]:
# use Qt instead of Tk backend (on mac)
import matplotlib
matplotlib.use("Qt4Agg")

import panel as pn

pn.extension()


### Overview of impurity properties

In [None]:
# general settings:
export_as_html = False


# start with progress bar
import ipywidgets as widgets
from IPython.display import display


def init_pbar(Nmax):
    pbar = widgets.IntProgress(
        value=0,
        min=0,
        max=Nmax,
        step=1,
        description='Loading:',
        bar_style='', # 'success', 'info', 'warning', 'danger' or ''
        orientation='horizontal'
    )
    display(pbar)
    return pbar

# show progress bar (updated during the course of this script via incrementing pbar.value)
pbar = init_pbar(12)

In [None]:
# load aiida
from aiida import load_profile
profile = load_profile()

In [None]:
# load other modules
import markdown 
from aiida.orm import load_node
from aiida_kkr.tools import plot_kkr

from aiida_kkr.calculations import VoronoiCalculation
from aiida.orm import StructureData

from ase_notebook import AseView, ViewConfig

from numpy import sum, sqrt, array, mean, std

import numpy as np
from bokeh.models.widgets import Select

In [None]:
#update progress bar
pbar.value += 1

In [None]:
# initialize overview page (for EF in middle of gap region)

# define ipywidget elements (buttons etc.)


# impdata = 



In [None]:
from aiida.orm import Group, load_group

#Sb2Te3_slab_imps_group = Group(label='Sb2Te3_slab_imps')
#Sb2Te3_slab_imps_group.store()
Sb2Te3_slab_imps_group = load_group('Sb2Te3_slab_imps')
Sb2Te3_slab_imps_group

In [None]:
# plot overview on periodic table

In [None]:
#imp_properties_all = load_node('f8fb868c-fe36-4593-ac32-dfd45b83ca24')
#imp_properties_all = load_node('ce884e39-b74d-4873-9c8d-bba745e54233')
imp_properties_all = load_node('3e65d6eb-d25e-48c3-8727-c969da1aff42')
#imp_properties_all.get_dict()

In [None]:
#update progress bar
pbar.value += 1

In [None]:
# import bokeh stuff
from bokeh.models import (ColumnDataSource, LinearColorMapper, LogColorMapper, ColorBar, BasicTicker)
import bokeh.plotting as bk
from bokeh.io import output_file, output_notebook, show
from bokeh.sampledata.periodic_table import elements
from bokeh.transform import dodge

# matplotlib color scales
from matplotlib.colors import Normalize, LogNorm, to_hex
from matplotlib.cm import plasma, inferno, magma, viridis, ScalarMappable

# activate inline output to notebook, displays plot with 'show' call
if export_as_html:
    output_file('/Users/ruess/Desktop/Sb2Te3_imps_overview.html')
else:
    output_notebook(hide_banner=True)

In [None]:
#update progress bar
pbar.value += 1

In [None]:
# some settings
width = 900         # plot width
cmap_choice = 3      # cmaps (0: plasma, 1: inferno, 2: magma, 3: viridis)
alpha = 0.65         # alpha value of color scale
extended = True      # show Lanthanides and actinides
log_scale = 0        # use log scale for colors
cbar_height = 500    # height of color map
cbar_fontsize = 8   # size of cbar labels
cbar_standoff = 12   # distance of labels to cbar
period_remove = []   # remove selected groups from plot
group_remove = []    # remove selected groups

In [None]:
g = elements['group'].values
p = elements['period'].values
a = elements['atomic number'].values
s = elements['symbol'].values

In [None]:
# add entry for vacancy
symbols = array(list(s)+['X'])
groups = array(list(g)+['2'])
periods = array(list(p)+['La'])
atomic_numbers = array(list(a)+['0'])

In [None]:
list_of_elements = symbols

#Assign color palette
if cmap_choice == 0:
    cmap = plasma
    bokeh_palette = 'Plasma256'
elif cmap_choice == 1:
    cmap = inferno
    bokeh_palette = 'Inferno256'
elif cmap_choice == 2:
    cmap = magma
    bokeh_palette = 'Magma256'
elif cmap_choice == 3:
    cmap = viridis
    bokeh_palette = 'Viridis256'
    
#Define number of and groups
period_label = ['1', '2', '3', '4', '5', '6', '7']
group_range = [str(x) for x in range(1, 19)]

#Remove any groups or periods
if group_remove:
    for gr in group_remove:
        group_range.remove(gr)
if period_remove:
    for pr in period_remove:
        period_label.remove(pr)
        
# add auxiliary period labels for a blank line and La and Ac series
period_label.append('blank')
period_label.append('La')
period_label.append('Ac')

# plot La and Ac series? set element groups and periods accoddingly
if extended:
    count = 0
    for i in range(56,70):
        periods[i] = 'La'
        groups[i] = str(count+4)
        count += 1

    count = 0
    for i in range(88,102):
        periods[i] = 'Ac'
        groups[i] = str(count+4)
        count += 1

In [None]:
# sort data according to EF value
imp_properties_sorted = {}
for EFset in [0, -200, -400]:
    imp_properties_sorted[EFset] = {}
    for k, v in imp_properties_all.get_dict().items():
        add_element = False
        if EFset==0 and 'EF-' not in k:
            add_element = True
        elif 'EF%i'%EFset in k:
            add_element = True
        #print(EFset, k, add_element, 'EF%i'%EFset)
        if add_element:
            imp_properties_sorted[EFset][k] = v

In [None]:
#update progress bar
pbar.value += 1

In [None]:
from pprint import pprint
#pprint([i for i in imp_properties_all.get_dict().keys()])

In [None]:
color_list_all_allEF = []
data_values_allEF = []
data_values_str_allEF = []
data_allEF = []
dstr_allEF = []
data_elements_allEF = []


# load DOS in gap and charge-transfer values
all_DOSingap = load_node('afe3e960-c335-4eca-8477-4e83a0dfbb53').get_dict()
all_DOSingap_sorted = {0:{},-200:{},-400:{}}
for k,v in all_DOSingap.items():
    if 'EF-200' in k:
        all_DOSingap_sorted[-200][k] = v
    if 'EF-400' in k:
        all_DOSingap_sorted[-400][k] = v
    else:
        all_DOSingap_sorted[0][k] = v
all_dc = load_node('04724d7f-4970-4a47-9447-64db5c6f11aa').get_dict()
all_dc_sorted = {0:{},-200:{},-400:{}}
for k,v in all_dc.items():
    if 'EF-200' in k:
        all_dc_sorted[-200][k] = v
    if 'EF-400' in k:
        all_dc_sorted[-400][k] = v
    else:
        all_dc_sorted[0][k] = v
        
# now extract data from dict to arrays
for EFset in [0,-200,-400, 'all']:
    
    # collect imp properties 
    if EFset!='all':
        imp_properties = imp_properties_sorted[EFset]
    else:
        imp_properties = imp_properties_all.get_dict()
        
    # sort out 'wrong' data
    imp_properties0 = {}
    for k,v in imp_properties.items():
        if 'EF-' not in k:
            efshift = 0.2
        if 'EF-400' in k:
            efshift = -0.2
        else:
            efshift = 0.
        mmap = v['zimp']*1000000+1000*v['zhost']+v['ilayer']+efshift
        if 1: #mmap in mapping_data_ok:
            imp_properties0[k] = v
    imp_properties = imp_properties0
    

    num_impcalcs = {}
    rms_impcalcs = {}
    omom, smom = {}, {}
    nmag = {}
    # unpack imp_properties array
    for k, v in imp_properties.items():
        impname = k.split(':')[1].split('_')[0]
        if impname not in num_impcalcs.keys():
            rms_impcalcs[impname] = []
            num_impcalcs[impname] = 0
            omom[impname], smom[impname] = [], []
            nmag[impname] = 0
        num_impcalcs[impname] += 1
        rms_impcalcs[impname].append(v.get('rms'))
        smom[impname].append(abs(v.get('spin_moment_imp')[2]))
        omom[impname].append(abs(v.get('orbital_moment_imp')[2]))
        if abs(v.get('spin_moment_imp')[2])> 10**-4:
            nmag[impname] += 1
            
    # extract values for charge doping and DOS in Gap
    if EFset!='all':
        DOSinGap_values = all_DOSingap_sorted[EFset]
        dc_values = all_dc_sorted[EFset]
    else:
        DOSinGap_values = all_DOSingap
        dc_values = all_dc
    DOSinGap = {}
    for k,v in DOSinGap_values.items():
        impname = k.split(':')[1].split('_')[0]
        if impname not in DOSinGap.keys():
            DOSinGap[impname] = []
        DOSinGap[impname].append(v)
    charge_doping = {}
    for k,v in dc_values.items():
        impname = k.split(':')[1].split('_')[0]
        if impname not in charge_doping.keys():
            charge_doping[impname] = []
        charge_doping[impname].append(v)

    # now fill big data array
    data = []; data_elements = []; dstr = []
    for k, v in imp_properties.items():
        impname = k.split(':')[1].split('_')[0]
        if impname not in data_elements:
            data_elements.append(impname)
            data.append([num_impcalcs[impname],
                         mean(rms_impcalcs[impname]), std(rms_impcalcs[impname]),
                         mean(smom[impname]), std(smom[impname]),
                         mean(omom[impname]), std(omom[impname]),
                         nmag[impname], nmag[impname]/float(num_impcalcs[impname])*100.,
                         mean(charge_doping.get(impname,-2)), std(charge_doping.get(impname,-2)), 
                         mean(DOSinGap.get(impname,-0.3)), std(DOSinGap.get(impname,-0.3))
                        ])
            dstr.append(['%i', # num clacs
                         '%.3e', '%.3e', # rms (mean/std)
                         '%.3f', '%.3f', # smom (mean/std)
                         '%.3f', '%.3f', # omom (mean/std)
                         '%i', # number mag
                         '%.2f', # % mag
                         '%.2f', '%.2f', # % charge_doping (mean/std)
                         '%.2f', '%.2f' # % DOS in Gap (mean/std)
                        ])
    from numpy import array
    data = array(data)
    dstr = array(dstr)
    data_elements = array(data_elements)
    
    # store in big arrays
    data_allEF.append(data)
    dstr_allEF.append(dstr)
    data_elements_allEF.append(data_elements)
    
    # set up color scales
    color_scale = []
    color_mapper_all = []
    #print(EFset, len(data[0]))
    for color_component in range(len(data[0,:])):
        #Define color map called 'color_scale'
        if log_scale == 0:
            ColorMapper = LinearColorMapper
            norm = Normalize(vmin = min(data[:,color_component]), vmax = max(data[:,color_component]))
        elif log_scale == 1:
            for datum in data[:,color_component]:
                if datum < 0:
                    raise ValueError('Entry for element '+datum+' is negative but'
                    ' log-scale is selected')
            ColorMapper = LogColorMapper
            norm = LogNorm(vmin = min(data[:,color_component]), vmax = max(data[:,color_component]))

        color_mapper_all.append(ColorMapper(palette = bokeh_palette, 
                                            low=min(data[:,color_component]), 
                                            high=max(data[:,color_component])))
        color_scale.append(ScalarMappable(norm=norm, cmap=cmap).to_rgba(data[:,color_component], alpha=None))
        
    #Define color for blank entries
    blank_color = '#c4c4c4'
    default_value = None
    color_list_all = []
    data_values = []
    data_values_str = []
    for i in range(len(symbols)):
        color_list_all.append([blank_color for ii in data[0,:]])
        data_values.append([None for ii in data[0,:]])
        data_values_str.append(['' for ii in data[0,:]])
        
    #Compare elements in dataset with elements in periodic table and set color etc. accordingly
    from numpy import arange 
    idx = arange(len(symbols))
    for i, data_element in enumerate(data_elements):
        element_entry = idx[symbols == data_element]
        if len(element_entry)>0:
            element_index =element_entry[0]
        else:
            print('WARNING: Invalid chemical symbol: '+data_element)
        if color_list_all[element_index][0] != blank_color:
            print('WARNING: Multiple entries for element '+data_element)

        for j in range(len(data[0,:])):
            color_list_all[element_index][j] = to_hex(color_scale[j][i])

        # add data values that are shown by hover tool
        data_values[element_index] = data[i,:]
        for j in range(len(data[i])):
            data_values_str[element_index][j] = dstr[i,j]%data[i,j]

    color_list_all = array(color_list_all)
    data_values = array(data_values)
    data_values_str = array(data_values_str)
    
    # store big arrays
    color_list_all_allEF.append(color_list_all)
    data_values_allEF.append(data_values)
    data_values_str_allEF.append(data_values_str)

In [None]:
#update progress bar
pbar.value += 1

In [None]:
#Define figure properties for visualizing data
source_allEF = []
for iEF in [3,0,1,2,3]: # first entry is active values (i.e. without EF shift)
    source_allEF.append(
        ColumnDataSource(
            data=dict(
                plot_component=[0 for x in groups],
                group=[str(x) for x in groups],
                period=[str(y) for y in periods],
                sym=symbols,
                atomic_number=atomic_numbers,
                type_color=color_list_all_allEF[iEF][:,0],
                type_color0=color_list_all_allEF[iEF][:,0],
                type_color1=color_list_all_allEF[iEF][:,1],
                type_color2=color_list_all_allEF[iEF][:,2],
                type_color3=color_list_all_allEF[iEF][:,3],
                type_color4=color_list_all_allEF[iEF][:,4],
                type_color5=color_list_all_allEF[iEF][:,5],
                type_color6=color_list_all_allEF[iEF][:,6],
                type_color7=color_list_all_allEF[iEF][:,7],
                type_color8=color_list_all_allEF[iEF][:,8],
                type_color9=color_list_all_allEF[iEF][:,9],
                type_color10=color_list_all_allEF[iEF][:,10],
                type_color11=color_list_all_allEF[iEF][:,11],
                type_color12=color_list_all_allEF[iEF][:,12],
                num_impcalcs=data_values_allEF[iEF][:,0],
                num_impcalcs_str=data_values_str_allEF[iEF][:,0],
                rms_mean=data_values_allEF[iEF][:,1],
                rms_mean_str=data_values_str_allEF[iEF][:,1],
                rms_std=data_values_allEF[iEF][:,2],
                rms_std_str=data_values_str_allEF[iEF][:,2],
                smom_mean=data_values_allEF[iEF][:,3],
                smom_mean_str=data_values_str_allEF[iEF][:,3],
                smom_std=data_values_allEF[iEF][:,4],
                smom_std_str=data_values_str_allEF[iEF][:,4],
                omom_mean=data_values_allEF[iEF][:,5],
                omom_mean_str=data_values_str_allEF[iEF][:,5],
                omom_std=data_values_allEF[iEF][:,6],
                omom_std_str=data_values_str_allEF[iEF][:,6],
                nummag=data_values_allEF[iEF][:,7],
                nummag_impcalcs_str=data_values_str_allEF[iEF][:,7],
                percent_mag=data_values_allEF[iEF][:,8],
                percent_mag_impcalcs_str=data_values_str_allEF[iEF][:,8],
                # additional fields for charge doping and DOS in Gap
                charge_doping_mean=data_values_allEF[iEF][:,9],
                charge_doping_mean_str=data_values_str_allEF[iEF][:,9],
                charge_doping_std=data_values_allEF[iEF][:,10],
                charge_doping_std_str=data_values_str_allEF[iEF][:,10],
                DOSinGap_mean=data_values_allEF[iEF][:,11],
                DOSinGap_mean_str=data_values_str_allEF[iEF][:,11],
                DOSinGap_std=data_values_allEF[iEF][:,12],
                DOSinGap_std_str=data_values_str_allEF[iEF][:,12]
            )
        )
    )
    
source = source_allEF[0] # EF=0 starting values

In [None]:
#update progress bar
pbar.value += 1

In [None]:
#Plot the periodic table with colors
plot = bk.figure(x_range=group_range, 
           y_range=list(reversed(period_label)), 
           tools='pan,wheel_zoom,reset,save',
           title = 'Impurities in Sb2Te3 (6QL)'
          )

from bokeh.models import HoverTool
from bokeh.layouts import column, row
from bokeh.models import CustomJS
from bokeh.models.widgets import Button
from bokeh.models import RadioButtonGroup
from bokeh.models.widgets import Paragraph
from bokeh.models import Panel, Tabs

hover = HoverTool()
hover.tooltips = [("Element", "@sym"), # things displayed by hover tool, needs to be in 'source' dict
         #("Zatom", "@atomic_number"),
         ("# impcalcs", "@num_impcalcs"), 
         #("rms", "@rms_mean (+/-@rms_std)"), 
         ("% magnetic", "@percent_mag"), 
         ("<spin mom>", "@smom_mean (+/-@smom_std)"), 
         ("<orb mom>", "@omom_mean (+/-@omom_std)"), 
         ("<charge doping>", "@charge_doping_mean (+/-@charge_doping_std)"), 
         ("<DOS in gap>", "@DOSinGap_mean (+/-@DOSinGap_std)"), 
        ]

plot.tools.append(hover)

plot.plot_width = width
plot.min_width = width
plot.max_width = width*2
plot.sizing_mode = 'scale_both'
plot.outline_line_color = None
plot.toolbar_location='above'
# coloured patches for the elements:
rects = plot.rect('group', 'period', 0.9, 0.9, source=source, alpha=alpha, color='type_color')
plot.axis.visible = False # show axis?
text_props = {
    'source': source,
    'angle': 0,
    'color': 'black',
    'text_align': 'left',
    'text_baseline': 'middle'
}
# add text for all pairs of (x,y)=(group,period)
x = dodge("group", -0.4, range=plot.x_range)
y = dodge("period", 0.3, range=plot.y_range)
y2 = dodge("period", -0.3, range=plot.y_range) # to displat 'c_value' entry as well
# here add the texts inside atom boxes
plot.text(x=x, y='period', text='sym', text_font_style='bold', text_font_size='16pt', **text_props)
plot.text(x=x, y=y, text='atomic_number', text_font_size='11pt', **text_props)

txt = {0: 'num_impcalcs', 1: 'rms_mean', 2: 'rms_std', 3: 'smom_mean', 4: 'smom_std', 5: 'omom_mean', 
       6: 'omom_std', 7: 'nummag', 8: 'percent_mag', 9: 'charge_doping_mean', 10: 'DOSinGap_mean'}[0]
color_value = plot.text(x=x, y=y2, text=txt, text_font_size='8pt', name='color_value', **text_props) # uses y2



# deactivate grid
plot.grid.grid_line_color = None


# title of color bar
names = ['# imps', '<rms>', 'rms std', 
          '<spin mom>', 'spin std', 
          '<orb mom>', 'orbital mom std', 
          '# magnetic', '% magnetic',
          '<Charge doping>', 'charge dop std', '<DOS in gap>', 'dos in gap std'
         ]
title_name = names[0]

# for log scale use this as ticker
from bokeh.models import LogTicker, BasicTicker

ticker = BasicTicker(desired_num_ticks=10)

from bokeh.models import PrintfTickFormatter
formatter_int = PrintfTickFormatter(format='%i')
formatter_2f = PrintfTickFormatter(format='%.2f')
formatter_1e = PrintfTickFormatter(format='%.1e')

# add color bar
color_bar = ColorBar(color_mapper=color_mapper_all[0],
    ticker=ticker, border_line_color=None,
    label_standoff=cbar_standoff, location=(0,0), orientation='vertical',
    scale_alpha=alpha, major_label_text_font_size=str(cbar_fontsize)+'pt',
    title=title_name, formatter=formatter_int, padding=20
)

if cbar_height is not None:
    color_bar.height = cbar_height

plot.add_layout(color_bar,'right')

color_bar_plot = bk.figure(title="My color bar title", title_location="right", 
                        plot_width=100, min_width=100, sizing_mode='stretch_both',
                        toolbar_location=None)

color_bar_plot.add_layout(color_bar, 'right')
color_bar_plot.title.align="center"
color_bar_plot.title.text_font_size = '12pt'


formatters = [formatter_int, formatter_1e, formatter_1e, 
              formatter_2f, formatter_2f, formatter_2f,
              formatter_2f, formatter_int, formatter_2f,
              formatter_2f, formatter_2f, formatter_2f, formatter_2f,
             ]

def get_callback(index):
    callback = CustomJS(args=dict(source=source, cbar=color_bar, cmap_update=color_mapper_all[index], 
                                  formatter=formatters[index], newtitle=names[index], 
                                  cname='type_color'+str(index)), code="""
                                    source.data['type_color'] = source.data[cname]
                                    source.change.emit();
                                    cbar.color_mapper = cmap_update
                                    cbar.title = newtitle
                                    cbar.formatter = formatter
                                    cbar.change.emit();    
                                    //color_value.change.emit();
                                    """
                       )
    return callback



def get_callback_EFset(index):
    newsource = source_allEF[{-200:1, 0:2, 200:3, 'all':4}[index]] 
    print(index, {-200:1, 0:2, 200:3, 'all':4}[index], newsource)
    callback = CustomJS(args=dict(source=source, newsource=newsource),
                        code="""
                             source.data = newsource.data
                             source.change.emit();
                             cbar.change.emit();
                             """
                       )
    return callback


# using regular buttons
#toggles = []
#for index in [0,1,3,5,7,8]:
#    toggles.append(Button(label=names[index], callback=get_callback(index), max_width=100))

# using radiobuttons
names_rb, cmap_update_rb, formatter_rb, cname_rb, newtitle_rb = [], [], [], [], []
#for index in [0,1,3,5,7,8]: #  num imps, rms, smom, omom, num mag, percent mag (see definition of txt dict)
for index in [0,3,5,8,9,11]: # num imps, smom, omom, percent mag, charge_doping, DOSinGap 
    names_rb.append(names[index])
    cmap_update_rb.append(color_mapper_all[index])
    formatter_rb.append(formatters[index])
    newtitle_rb.append(names[index])
    cname_rb.append('type_color'+str(index))

callback_rb = CustomJS(args=dict(source=source, cbar=color_bar, cmap_update=cmap_update_rb, 
                                  formatter=formatter_rb, newtitle=newtitle_rb, 
                                  cname=cname_rb),
                       code="""
                            /// get index of active button
                            idx = cb_obj.active
                            
                            /// set color according to color name
                            source.data['type_color'] = source.data[cname[idx]]
                            /// update value of plot_component index
                            source.data['plot_component'][0] = idx
                            
                            /// update color mapper
                            cbar.color_mapper = cmap_update[idx]
                            /// update title of color bar
                            cbar.title = newtitle[idx]
                            /// update color bar formatter
                            cbar.formatter = formatter[idx]
                            
                            ///submit changes
                            source.change.emit();
                            cbar.change.emit();    
                            """
                       )
                       
toggles = RadioButtonGroup(labels=names_rb, active=0, callback=callback_rb)
                       

# using ordinary buttons
#EFbuttons = []
#for EFvalue in [-200, 0, 200]:
#    EFbuttons.append(Button(label='EF= '+str(EFvalue)+'meV', max_width=210, callback=get_callback_EFset(EFvalue)))
#    print(get_callback_EFset(EFvalue))
    
# using radiobuttons
allsources = [source_allEF[{-200:1, 0:2, 200:3, 'all':4}[index]] for index in [-200,0,200,'all']] 
callback_EF = CustomJS(args=dict(source=source, allsources=allsources, colorbuttons=toggles,
                                 cbar=color_bar, cmap_update=cmap_update_rb, 
                                 formatter=formatter_rb, newtitle=newtitle_rb, 
                                 cname=cname_rb),
                       code="""
                            /// get value of active button
                            idx = cb_obj.active
                            
                            /// save value of colorbuttons active value (store in plot_component)
                            ///colorbuttons.active = source.data['plot_component'][0]
                            
                            /// now overwrite data of source with new source
                            newsource = allsources[idx]
                            source.data = newsource.data
                            
                            /// restore value of colorbuttons active button etc.
                            idx = colorbuttons.active
                            
                            
                            /// set color according to color name
                            source.data['type_color'] = source.data[cname[idx]]
                            /// update value of plot_component index
                            source.data['plot_component'][0] = idx
                            
                            /// update color mapper
                            cbar.color_mapper = cmap_update[idx]
                            /// update title of color bar
                            cbar.title = newtitle[idx]
                            /// update color bar formatter
                            cbar.formatter = formatter[idx]
                            
                            
                            
                            /// finally submit changes
                            source.change.emit();
                            colorbuttons.change.emit()
                            cbar.change.emit(); 
                            """
                       )
names_EF = ['EF= '+str(EFvalue)+'meV' for EFvalue in [-200,0,200]]+['all']
EFbuttons = RadioButtonGroup(labels=names_EF, active=3, callback=callback_EF)
  

EFtext = Paragraph(text="""Select Fermi level:""", width=110, align='start')
Colortext = Paragraph(text="""Select color scale:""", width=110, align='start')


buttonrow = row([Colortext]+[toggles])
layout_periodic_table = column(row([EFtext]+[EFbuttons]), buttonrow , plot, sizing_mode='scale_both')

#tab1 = Panel(child=layout_periodic_table, title='tab1')
#tab2 = Panel(child=plot, title='tab2')
#layout = Tabs(tabs=[tab1, tab2])

#show(layout)

In [None]:
pn.Pane(layout_periodic_table)

In [None]:
#update progress bar
pbar.value += 1

In [None]:
#-----

# Impurity properties

scatter plots:
- charge doping vs. Zimp
- magnetism (or something else?) vs. layers
- 

In [None]:
f = IntProgress(min=0, max=len(list(imp_properties_all.keys()))) # instantiate the bar
display(f)
out_all = []
for k, v in imp_properties_all.get_dict().items():
    f.value += 1
    if 'EF' not in k:
        efval = 0.0
    elif 'EF-200' in k:
        efval = -0.2
    else:
        efval = -0.4
    tmp = [v['zimp'], v['zhost'], v['ilayer'], efval, v['orbital_moment_imp'], v['rms'], v['spin_moment_imp'], v['etot_Ry']]
    out_all.append(tmp)
out_all = array(out_all)
f.close()
out_all.shape

In [None]:
from aiida import load_profile
load_profile()
from aiida.orm import load_node, Code
from aiida.engine import submit
from aiida.plugins import DataFactory
Dict = DataFactory('dict')

from aiida_kkr.calculations import VoronoiCalculation
from numpy import load
import json

In [None]:
"""
out_all = load('extracted_energies.npy',allow_pickle=True, encoding='latin1')
with open('dict_data.json') as f:
    txt = f.readlines()
    
    out_dict_all = json.loads(txt[0])
    
out_all = out_all.reshape(-1,15)
out_all.shape
"""

In [None]:
#update progress bar
pbar.value += 1

In [None]:
mapping_imp_props = {0:'Zimp' ,
                     1:'Zhost' ,
                     2:'ilayer' ,
                     3:'EFshift' ,
                     4:'orbmom' ,
                     5:'rms' ,
                     6:'spinmom' ,
                     7:'etot_Ry',
                    }

formatters = {'Zimp': formatter_int,
              'Zhost': formatter_int,
              'ilayer': formatter_int,
              'EFshift': formatter_2f,
              'orbmom': formatter_2f,
              'rms': formatter_1e,
              'spinmom': formatter_2f,
              'etot_Ry':formatter_1e,
             }

In [None]:
from matplotlib import cm

In [None]:



name0x = 'Zimp'
name0y = 'spinmom'
name0c = 'None'

dictdata = {}
for idx, name in mapping_imp_props.items():
    dictdata[name] = out_all[:,idx]
    if name in ['orbmom', 'spinmom']:
        dictdata[name] = array([i[2] for i in out_all[:,idx]])
    elif name =='ecore':
        dictdata[name] = out_all[:,12]-out_all[:,13]
        
# take uniform sign convention
dictdata['orbmom'] = dictdata['orbmom']*np.sign(dictdata['spinmom'])
dictdata['spinmom'] = dictdata['spinmom']*np.sign(dictdata['spinmom'])

# add DOS in gap and charge-transfer values
all_DOSingap = load_node('afe3e960-c335-4eca-8477-4e83a0dfbb53').get_dict()
values_DOSinGap_sorted = array([-0.03 for i in range(len(dictdata['ilayer']))])
#values_DOSinGap_sorted = array([None for i in range(len(dictdata['ilayer']))])
mapping = dictdata['Zimp']*1000000+1000*dictdata['Zhost']+dictdata['ilayer']+dictdata['EFshift']
allowed_keys = list(imp_properties_all.keys())
f = IntProgress(description='loading gap filling data', min=0, max=len(list(all_DOSingap.keys()))) # instantiate the bar
display(f)
for k in all_DOSingap.keys():
    f.value += 1
    impname, hostname, ilayer = k.split(':')[1].split('_')[0], k.split('[')[0].split('_')[-1], k.split('[')[1].split(']')[0]
    zimp, zhost, ilayer = int(atomic_numbers[symbols==impname][0]), int(atomic_numbers[symbols==hostname][0]), int(ilayer)
    efshift = 0
    if 'EF-400' in k:
        efshift = -0.4
    elif 'EF-200' in k:
        efshift = -0.2
    # set value
    if k in allowed_keys and (1000000*zimp+1000*zhost+ilayer+efshift) in mapping:
        values_DOSinGap_sorted[np.where(mapping==(1000000*zimp+1000*zhost+ilayer+efshift))[0][0]] = all_DOSingap[k]
f.close()
        
dictdata['DOS_in_gap'] = values_DOSinGap_sorted

#update progress bar
pbar.value += 1


all_dc = load_node('04724d7f-4970-4a47-9447-64db5c6f11aa').get_dict()
values_dc_sorted = array([-0.02 for i in range(len(dictdata['ilayer']))])
#values_dc_sorted = array([None for i in range(len(dictdata['ilayer']))])
f = IntProgress(description='loading charge doping data', min=0, max=len(list(all_dc.keys()))) # instantiate the bar
display(f)
for k in all_dc.keys():
    f.value += 1
    impname, hostname, ilayer = k.split(':')[1].split('_')[0], k.split('[')[0].split('_')[-1], k.split('[')[1].split(']')[0]
    zimp, zhost, ilayer = int(atomic_numbers[symbols==impname][0]), int(atomic_numbers[symbols==hostname][0]), int(ilayer)
    efshift = 0
    if 'EF-400' in k:
        efshift = -0.4
    elif 'EF-200' in k:
        efshift = -0.2
    # set value
    if k in allowed_keys and (1000000*zimp+1000*zhost+ilayer+efshift) in mapping:
        values_dc_sorted[np.where(mapping==(1000000*zimp+1000*zhost+ilayer+efshift))[0][0]] = all_dc[k]
f.close()
        
dictdata['charge_doping'] = values_dc_sorted

#update progress bar
pbar.value += 1

# add color dict entries
dictdata['color_None'] = array(['navy' for i in range(len(dictdata['spinmom']))]) # default value

In [None]:
for name in list(mapping_imp_props.values())+['DOS_in_gap', 'charge_doping']:
    if name == 'EFshift':
        dictdata[name] = dictdata[name]+0.2
    val = dictdata[name]
    if name=='orbmom': val = abs(val)
    norm = Normalize(vmin = min([ival for ival in val if ival is not None]), vmax = max([ival for ival in val if ival is not None]))
    cmap_scatter = cm.plasma
    colors = []
    for ival in val:
        if ival is not None:
            colors.append(to_hex(ScalarMappable(norm=norm, cmap=cmap_scatter).to_rgba(ival, alpha=0.6)))
        else:
            colors.append(blank_color)
    colors = array(colors)
    #colors = array([to_hex(ScalarMappable(norm=norm, cmap=cmap_scatter).to_rgba(ival, alpha=0.6)) for ival in val])
    dictdata['color_'+name] = colors

In [None]:
#"""
# remove 'wrong' data from scatterplot
cd = dictdata['charge_doping']
gapfill = dictdata['DOS_in_gap']
#Etotdiff = dictdata['Etot_diff']
#ecore = dictdata['ecore']
mcd = np.where(cd==-0.02)[0]
mgf = np.where(gapfill==-0.03)[0]
#met = np.where(Etotdiff<-1e6)[0]
#mec = np.where(ecore>0)[0]
#notm = array(list(set(list(mcd)+list(mgf)+list(met)+list(mec))))
notm = array(list(set(list(mcd)+list(mgf))))
m = [i for i in range(len(cd)) if i not in notm]
for k,v in dictdata.items():
    dictdata[k] = v[m]
print(len(notm), len(m))
#"""
"""
zimpout, zhostout, ilayerout, efout = (dictdata['Zimp'][notm],
                                         dictdata['Zhost'][notm],
                                         dictdata['ilayer'][notm],
                                         dictdata['EFshift'][notm])
remove_wrong = []
for i in range(len(zimpout)):
    if efout[i]==0:
        impnameout = ''
    elif efout[i]==-0.2:
        impnameout = 'EF-200_'
    elif efout[i]==-0.4:
        impnameout = 'EF-400_'
    imp = symbols[str(zimpout[i])==atomic_numbers][0]
    host = symbols[str(zhostout[i])==atomic_numbers][0]
    impnameout += 'imp:'+imp+'_layer_'+host+'['+str(ilayerout[i])+']'
    #print(zimpout[i], zhostout[i], ilayerout[i], efout[i], impnameout)
    remove_wrong.append(impnameout)
print(remove_wrong)
#"""

In [None]:
# define default values:
dictdata['x'] = dictdata[name0x]
dictdata['y'] = dictdata[name0y]
dictdata['color'] = dictdata['color_'+name0c]

# save as ColumnDataSource
source_scatter = ColumnDataSource(data=dictdata)

In [None]:
#mapping = dictdata['Zimp']*1000000+1000*dictdata['Zhost']+dictdata['ilayer']+dictdata['EFshift']
#np.save('mapping_data_ok.npy', mapping)

In [None]:
totimp_num_text = '#### Total number of impurities: {}'.format(len(list(imp_properties_all.keys())))
print(totimp_num_text)

In [None]:
#update progress bar
pbar.value += 1

In [None]:
height_plot = 500
height_hist = 150
width_plot = 700
width_hist = 200 #150

left_padding = 80

scatterplot = bk.figure(tools='pan,box_zoom,wheel_zoom,reset,lasso_select,box_select,undo,redo,save', #,save',
                        #title = 'Impurities in Sb2Te3 (6QL)',
                        plot_width=width_plot, plot_height=height_plot,
                        min_width=width_plot, min_height=height_plot,
                        max_width=2*width_plot, max_height=2*height_plot,
                        min_border_left=left_padding, sizing_mode='scale_both'
                       )



hover = HoverTool()
hover.tooltips = [("Imp", " @Zimp[@Zhost] (@ilayer)"), # things displayed by hover tool, needs to be in 'source' dict
        # ("Etot_diff", "@Etot_diff"), 
         #("Etot_Ry", "@etot_Ry"), 
         ("EF shift", "@EFshift"), 
         #("rms", "@rms"), 
         ("spin", "@spinmom"), 
         ("orbmom", "@orbmom"), 
#         ("Ecore", "@ecore"),
         ("charge doping", "@charge_doping"),
         ("DOS in gap", "@DOS_in_gap"),
        ]
scatterplot.add_tools(hover)

scatterplot.scatter('x', 'y', source=source_scatter, fill_alpha=0.1, # 0.1
                    color='color', size=10

)
scatterplot.xaxis.axis_label = name0x
scatterplot.yaxis.axis_label = name0y

callback_change_x = CustomJS(args=dict(source=source_scatter, plot=scatterplot, formatters=formatters
                                ),
                       code="""
                            /// get value of active button
                            val = cb_obj.value
                            /// change x column
                            source.data['x'] = source.data[val]
                            source.change.emit()
                            /// update x axis label
                            plot.below[0].axis_label = val
                            plot.below[0].formatter = formatters[val]
                            plot.change.emit()
                            """
                       )

callback_change_y = CustomJS(args=dict(source=source_scatter, plot=scatterplot, formatters=formatters
                                ),
                       code="""
                            /// get value of active button
                            val = cb_obj.value
                            /// change y column
                            source.data['y'] = source.data[val]
                            source.change.emit()
                            /// update y axis label
                            plot.left[0].axis_label = val
                            plot.left[0].formatter = formatters[val]
                            plot.change.emit()
                            """
                       )

callback_change_color = CustomJS(args=dict(source=source_scatter,),
                       code="""
                            /// get value of active button
                            val = cb_obj.value
                            /// change color column
                            source.data['color'] = source.data['color_'+val]
                            source.change.emit()
                            """
                       )


options_select = ['Zimp', 'Zhost', 'ilayer', 'EFshift', 'rms', 'spinmom', 'orbmom',# 'Etot_diff', 
#                  'etot_Ry', 'ecore', 'charge_doping', 'DOS_in_gap']
                  'etot_Ry', 'charge_doping', 'DOS_in_gap']

select_x = Select(title='Choose X:', value=name0x, options=options_select,
                  callback=callback_change_x, width=100, max_height=50
                 )
select_y = Select(title='Choose Y:', value=name0y, options=options_select,
                  callback=callback_change_y, width=100, max_height=50
                 )
select_color = Select(title='Choose color:', value=name0c, options=['None']+options_select,
                  callback=callback_change_color, width=100, max_height=50
                 )

layout = column(row(select_x, select_y, select_color), scatterplot, sizing_mode='scale_both')

#show(layout)

In [None]:
# add tap tool to scatter plot

from bokeh.models import TapTool
callback_tap = CustomJS()
tap = TapTool(callback=callback_tap)

scatterplot.add_tools(tap)

In [None]:
#update progress bar
pbar.value += 1

In [None]:
# create the vertical histogram


dict_vhist_all, dict_vhist = {}, {}
dict_hhist_all, dict_hhist = {}, {}
for name, data in source_scatter.data.items():
    if 'color' not in name:
        vhist, vedges = np.histogram(data, bins=50)
        dict_vhist_all[name] = {'left': 0, 'right':vhist, 'top':vedges[:-1], 'bottom':vedges[1:]}
        # copy values to hhist specs
        dict_hhist_all[name] = {}
        dict_hhist_all[name]['left'] = dict_vhist_all[name]['bottom']
        dict_hhist_all[name]['right'] = dict_vhist_all[name]['top']
        dict_hhist_all[name]['top'] = dict_vhist_all[name]['right']
    
dict_vhist['bottom'] = dict_vhist_all[name0y]['bottom']
dict_vhist['top'] = dict_vhist_all[name0y]['top']
dict_vhist['right'] = dict_vhist_all[name0y]['right']

dict_hhist['left'] = dict_hhist_all[name0x]['left']
dict_hhist['top'] = dict_hhist_all[name0x]['top']
dict_hhist['right'] = dict_hhist_all[name0x]['right']

src_hist_y = ColumnDataSource(data=dict_vhist)
src_hist_x = ColumnDataSource(data=dict_hhist)

# link y_range by passing scatterplot.y_range:
yhist = bk.figure(tools=['hover', 'box_zoom', 'reset'], toolbar_location='right', 
                  plot_width=width_hist, plot_height=height_plot, #sizing_mode='scale_both',
                  #min_width=width_hist, min_height=height_plot, 
                  #max_width=int(1.5*width_hist), max_height=2*height_plot 
                  x_axis_location="below", y_axis_location="right", y_range=scatterplot.y_range)
xhist = bk.figure(tools=['hover', 'box_zoom', 'reset'], toolbar_location='right',
                  plot_width=width_plot, plot_height=height_hist, #sizing_mode='scale_both',
                  #min_width=width_plot, min_height=height_hist, 
                  #max_width=2*width_plot, max_height=int(1.5*height_hist), sizing_mode='stretch_both', 
                  x_axis_location="above", y_axis_location="left",
                  min_border_left=left_padding, x_range=scatterplot.x_range)

yhist.quad(source=src_hist_y, left=0, bottom='bottom', top='top', right='right', 
        fill_alpha=0.6, color='navy')
xhist.quad(source=src_hist_x, bottom=0, left='left', right='right', top='top', 
        fill_alpha=0.6, color='navy')


xhist.xaxis.axis_label = name0x
xhist.yaxis.axis_label = 'counts'
yhist.xaxis.axis_label = 'counts'
yhist.yaxis.axis_label = name0y

xhist.y_range.start = 0
yhist.x_range.start = 0

callback_change_y_hist = CustomJS(args=dict(source=src_hist_y, dict_hist_all=dict_vhist_all, 
                                            select=select_y, plot=yhist, formatters=formatters
                                           ),
                                  code="""
                                       /// get value of active button
                                       val = select.value
                                       /// change y histogram
                                       source.data['bottom'] = dict_hist_all[val]['bottom']
                                       source.data['top'] = dict_hist_all[val]['top']
                                       source.data['right'] = dict_hist_all[val]['right']
                                       source.change.emit()
                                       /// update y axis label
                                       plot.right[0].axis_label = val
                                       plot.right[0].formatter = formatters[val]
                                       plot.change.emit()
                                       """
                                 )

callback_change_x_hist = CustomJS(args=dict(source=src_hist_x, dict_hist_all=dict_hhist_all, 
                                            select=select_x, plot=xhist, formatters=formatters
                                           ),
                                  code="""
                                       /// get value of active button
                                       val = select.value
                                       /// change x histogram
                                       source.data['left'] = dict_hist_all[val]['left']
                                       source.data['right'] = dict_hist_all[val]['right']
                                       source.data['top'] = dict_hist_all[val]['top']
                                       source.change.emit()
                                       /// update x axis label
                                       plot.above[0].axis_label = val
                                       plot.above[0].formatter = formatters[val]
                                       plot.change.emit()
                                       """
                                 )

# add callbacks for histograms to x/y selection buttons
select_x.js_on_change('value', callback_change_x_hist)
select_x.js_on_change('value', callback_change_y_hist)
select_y.js_on_change('value', callback_change_x_hist)
select_y.js_on_change('value', callback_change_y_hist)


select_buttons = column(select_x, select_y, select_color, sizing_mode='scale_width')

gspec = pn.GridSpec(sizing_mode='stretch_both', max_width=2000, max_height=1500)
gspec[0,0:3] = xhist
gspec[0, 3 ] = select_buttons
gspec[1:3,0:3] = scatterplot
gspec[1:3, 3] = yhist

layout_with_hist = gspec

#show(layout_with_hist)


#needs selenioum and phantomjs:
#from bokeh.io import export_svgs
#
#scatterplot.output_backend = "svg"
#export_svgs(scatterplot, filename="/Users/ruess/Desktop/test_plot.svg")

In [None]:
#update progress bar
pbar.value += 1

In [None]:
# collect tabs and show plots

#tab1 = Panel(child=layout_periodic_table, title='Periodic table')
#tab2 = Panel(child=layout_with_hist, title='Scatter plot')
#layout = Tabs(tabs=[tab1, tab2])

layout = pn.Column(pn.pane.Markdown("## Average values for different impurity configurations"), 
                   layout_periodic_table,
                   pn.pane.Markdown("## Scatter plot"),
                   layout_with_hist
    , sizing_mode='stretch_both')

#show(layout)
imps_overview = pn.Row(layout)
#imps_overview

In [None]:
imps_overview

In [None]:
#show(layout_periodic_table)

In [None]:
#show(layout_with_hist)

In [None]:
#close progressbar
pbar.close()

In [None]:
def plot_spin_vs_layers(zlim, tit, savepath=None):
    from matplotlib.pyplot import (title, figure, plot, show, legend, axvspan, xlim, 
                                   xticks, axvline, text, xlabel, ylabel, axhline)
    
    # colors for symbols
    import matplotlib as mpl
    c = mpl.rcParams['axes.prop_cycle']
    colors = c.by_key()['color']

    figure(figsize=(15,6)) 
    zimp = dictdata['Zimp']
    ilayer = dictdata['ilayer']
    spinmom = dictdata['spinmom']
    i=0; dx=0.15; dx0=-3*dx
    for z in set(zimp):
        m = np.where(zimp==z)
        if any(abs(spinmom[m])>10**-3) and z<zlim[1] and z>zlim[0]:
            x, y = ilayer[m]+dx0+i*dx, spinmom[m]
            efshift= dictdata['EFshift'][m]
            for ief in [0,-0.2,-0.4]:
                m2 = np.where(efshift==ief)
                if ief==0:
                    plot(x[m2], y[m2], 'd', color=colors[i])
                elif ief==-0.2:
                    plot(x[m2], y[m2], 'o', label=z, color=colors[i])
                else:
                    plot(x[m2], y[m2], 's', color=colors[i])
            i+=1
    # highlight Sb atoms
    Sb_atoms = [12, 16, 22, 26, 32, 36]
    for iSb in Sb_atoms:
        axvspan(iSb-1.15, iSb+1.15, color='violet', alpha=0.2)
    # highlight Te atoms
    Te_atoms = [10, 14, 18, 20, 24, 28, 30, 34, 38, 40]
    for iTe in Te_atoms:
        axvspan(iTe-0.85, iTe+0.85, color='orange', alpha=0.2)
    # separator lines for QLs
    #for iQL in [9,19,29,39]:
    #    axvline(iQL, color='k', lw=1)
    if '3d' in tit:
        axhline(3.15, linestyle='dotted', color='green')
        axhline(2.7, linestyle='dotted', color='green')
    
    xlim(8,40.5)
    legend(loc=1)
    # change Ticks to show 
    xticks([10,12,14,16,18,20,22,24,26,28,30,32,34,36,38], ['Te', 'Sb', 'Te\n1st QL', 'Sb', 'Te', 'Te', 'Sb', 'Te\n2nd QL', 'Sb', 'Te', 'Te', 'Sb', 'Te\n3rd QL', 'Sb', 'Te'])
    # axis labels
    ylabel('spin moment (mu_B)')
    xlabel('')
    # add text to indicate surface
    text(8.3, 1.8, 'surface', rotation=90)
    # add title
    title(tit)
    
    if savepath is not None:
        from matplotlib.pyplot import savefig
        savefig(savepath)
        
    show()

# plot for 3d impurities 
#plot_spin_vs_layers(zlim=[20,40], tit='3d impurities', savepath='/Users/ruess/Downloads/spin_mom_3d_imps.png')


In [None]:
# plot for 4d impurities
#plot_spin_vs_layers(zlim=[38,50],tit='4d impurities', savepath='/Users/ruess/Downloads/spin_mom_4d_imps.png')


In [None]:
def plot_spin_vs_layers_upright(zlim, tit, savepath=None):
    from matplotlib.pyplot import (title, figure, plot, show, legend, axhspan, ylim, 
                                   yticks, axvline, text, xlabel, ylabel, axhline, xlim)
    
    markersize=16
    from matplotlib import rcParams
    #default font size=10, increase here
    rcParams.update({'font.size': 20})
    
    #alpha = 0.08
    alpha = 0.10
    
    # colors for symbols
    import matplotlib as mpl
    c = mpl.rcParams['axes.prop_cycle']
    colors = c.by_key()['color']

    figure(figsize=(8,15)) 
    zimp = dictdata['Zimp']
    ilayer = dictdata['ilayer']
    spinmom = dictdata['spinmom']
    i=0; dx=0.15; dx0=-3*dx
    for z in set(zimp):
        STYLE = {'markersize': markersize, 'color': colors[i], 'fillstyle': 'none', 'markeredgewidth': 2}
        m = np.where(zimp==z)
        if any(abs(spinmom[m])>10**-3) and z<zlim[1] and z>zlim[0]:
            x, y = ilayer[m]+dx0+i*dx, spinmom[m]
            efshift= dictdata['EFshift'][m]
            for ief in [0,-0.2,-0.4]:
                m2 = np.where(efshift==ief)
                if ief==-0.20:
                    label = {23: 'V', 24: 'Cr', 25: 'Mn', 26: 'Fe', 27: 'Co'}[z]
                    #if z==min(zimp[zimp>zlim[0]]):
                    #    label='Zimp= %i'%label
                    plot(y[m2], x[m2], 'o', label=label, **STYLE)
                elif ief==0:
                    plot(y[m2], x[m2], '^', **STYLE)
                else:
                    plot(y[m2], x[m2], 's', **STYLE)
            i+=1
    # highlight Sb atoms
    Sb_atoms = [12, 16, 22, 26, 32, 36]
    for iSb in Sb_atoms:
        axhspan(iSb-1.15, iSb+1.15, color='violet', alpha=alpha)
    # highlight Te atoms
    Te_atoms = [10, 14, 18, 20, 24, 28, 30, 34, 38, 40]
    for iTe in Te_atoms:
        axhspan(iTe-0.85, iTe+0.85, color='orange', alpha=alpha)
    # separator lines for QLs
    #for iQL in [9,19,29,39]:
    #    axhline(iQL, color='k', lw=1)
    if '3d' in tit:
        axvline(3.15, linestyle='dashed', color=colors[0], lw=2)
        axvline(2.7, linestyle='dashed', color=colors[0], lw=2)
        #axvline(2.59, linestyle='dashed', color=colors[0], lw=2)
        #axvline(2.79, linestyle='dashed', color=colors[0], lw=2)
    
    #ylim(39.5,8)
    ylim(39.5,5.8)
    xlim(1.3,4.6)
    legend(loc=9, ncol=3)
    # change Ticks to show 
    yticks([10,12,14,16,18,20,22,24,26,28,30,32,34,36,38], ['Te', 'Sb', 'Te', 'Sb', 'Te', 
                                                            'Te', 'Sb', 'Te', 'Sb', 'Te', 
                                                            'Te', 'Sb', 'Te', 'Sb', 'Te'])
    # axis labels
    xlabel('spin moment ($\,\mu_B$)')
    ylabel('')
    # add text to indicate surface and QLs
    #N = 29
    N = 10
    text(4.80, 8.5, 'surf.', rotation=0, horizontalalignment='center', verticalalignment='center')
    text(4.72, 14, '|'+N*'-'+' 1st QL '+N*'-'+'|', rotation=270, horizontalalignment='center', verticalalignment='center')
    text(4.72, 24, '|'+N*'-'+' 2nd QL '+N*'-'+'|', rotation=270, horizontalalignment='center', verticalalignment='center')
    text(4.72, 34, '|'+N*'-'+' 3rd QL '+N*'-'+'|', rotation=270, horizontalalignment='center', verticalalignment='center')
    # add title
    #title(tit)
    
    if savepath is not None:
        from matplotlib.pyplot import savefig
        savefig(savepath)
        
    show()

# upright plot for 3d impurities, used in paper
#plot_spin_vs_layers_upright(zlim=[22,28], tit='3d impurities', 
#                            savepath='/Users/ruess/Downloads/spin_mom_3d_imps_upright.png')