## Package Import, Defining Functions and Global Variables

In [1]:
### Todo ###
############    

    # Writeup
    # Introduce dummy data, remove GII category
    
    # Configurable marker size, opacity
    # clickable, selectable for transition metals?
    # Add more data -- which can be NaN -- to TOOLTIP 
    #   DFT bandgap, 
    
    # Add more elements to the HTML output --import js script to HTML-- (description, pictures)
    # Draggable threshold (instead of the y<=0.2
    # Selection on histogram (maybe write hist code in JS)
    # Configurable colors for predicted d states? rows on ptable?
    
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt

from bokeh.plotting import *
from bokeh.layouts import row, column, Spacer
from bokeh.models.annotations import Label
from bokeh.models import ColumnDataSource, CDSView, CustomJS, Slider, Button, Div,\
                         Select, LegendItem, Legend, Title
from bokeh.models.filters import Filter, GroupFilter
from bokeh import events
from bokeh.io import show
from bokeh.models.tools import HoverTool
from bokeh.palettes import Category10

output_notebook()
output_file("materials_data_explorer.html", title='High-Dimensional Materials Data Explorer')

### Set GLOBAL VARIABLES ##########################################################################
SCATTER_TOOLS = "pan,wheel_zoom,box_zoom,box_select,lasso_select"
KDE_TOOLS = "pan,wheel_zoom,box_zoom"
FIG_TITLE = "High-Dimensional Materials Data Explorer"
SIZING_MODE = 'stretch_both'
STD_FONT_SIZE = 14
order = [5,3,4,6,0,2,1] #order by atomic number
ANION_DICT = {'S':'Sulfides', 'Se':'Selenides', 'Cl':'Chlorides', 'O':'Oxides', 'F':'Fluorides', 
             'N':'Nitrides', 'P':'Phosphides'}
ANION_DICT = {list(ANION_DICT.keys())[i]: list(ANION_DICT.values())[i] for i in order}
ANION_COLORS = Category10[len(ANION_DICT.keys())]
ANION_COLORS = [ANION_COLORS[i] for i in order]
ANION_MARKERS = ['x', 'plus', 'triangle', 'circle', 'dash', 'asterisk', 'inverted_triangle']
ANION_MARKERS = [ANION_MARKERS[i] for i in order]
KDE_HEIGHT = 180
COL_LENGTHEN = {
    'mm_dist':'Metal-Metal Distance (\u212B)', 'normed_dist':'Normalized M-M Distance',
    'delta':'M-M Distance - Alloy Bond Length', 'gii':'Global Instability Index',
    'sg_num':'Spacegroup #', 'tm_row':'Transition Metal Row', 'mn':'Mendeleev Number',
    'sg_sym':'Spacegroup','tm':'Transition Metal', 'ionic_r':'Ionic Radius (\u212B)', 
    'oxi':'Oxidation State', 'anions':'Anion', 'formula':'Chemical Formula',
    'd_state':'# of $d$ Electrons', 'cn':'Transition Metal Coord. #'}
COL_SHORTEN = {v:k for k,v in COL_LENGTHEN.items()}
AXIS_OPTIONS = ['gii','mm_dist','normed_dist','delta','ionic_r', 'sg_num','tm_row','mn'] 
PRETTY_AXIS_OPTIONS = [COL_LENGTHEN[key] for key in AXIS_OPTIONS]
INIT_X = 'mm_dist'
INIT_Y = 'gii'
RANGE_DICT = {
    'gii':         (-0.05,2),
    'mm_dist':     (1.4,6),
    'normed_dist': (0.4,2.5),
    'delta':       (-1.8,5),
    'ionic_r':     (0.2, 1.6),
    'sg_num':      (0, 230),
    'tm_row':      (3.9, 6.1),
    'mn':          (17,78)
}
###################################################################################################


def format_scatter_plot(plot):
    plot.line([-999,999],[0,0], color='black')
    plot.line([0,0],[-999,999], color='black')
    ## Shaded area
#     plot.patch([-999,-999,999,999],[-999,0.2,0.2,-999],alpha=0.2, line_width=0, 
#             legend_label='Y \u2264 0.2')
    ### Fonts
    plot.xaxis.axis_label = COL_LENGTHEN[INIT_X]
    plot.yaxis.axis_label = COL_LENGTHEN[INIT_Y] 
    plot.xaxis.axis_label_text_font_style = "normal"
    plot.yaxis.axis_label_text_font_style = "normal"
    plot.xaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.yaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.xaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    plot.yaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    ### Spacing
    plot.min_border = 10
    

def format_kde_plot(plot, vertical=False):
    ### Title, sizing, axes
    if not vertical:
        plot.title.text_font_size = f'{int(STD_FONT_SIZE*1.2)}pt'
        plot.height=KDE_HEIGHT
        plot.sizing_mode='stretch_width'
        plot.yaxis.minor_tick_line_color = None
        plot.xaxis.visible=False
        plot.yaxis.axis_label = 'Renormalized KDE'
    else:
        plot.width=KDE_HEIGHT
        plot.sizing_mode='stretch_height'
        plot.xaxis.minor_tick_line_color = None
        plot.yaxis.visible=False
        plot.xaxis.axis_label = 'Renormalized KDE'
    ### Fonts
    plot.xaxis.axis_label_text_font_style = "normal"
    plot.yaxis.axis_label_text_font_style = "normal"
    plot.xaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.yaxis.axis_label_text_font_size = f'{STD_FONT_SIZE}pt'
    plot.xaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt'
    plot.yaxis.major_label_text_font_size = f'{int(STD_FONT_SIZE*0.9)}pt' 
    plot.min_border = 10

## Data Import, Preparation, Filtering

In [2]:
### Import data
df = pd.read_csv('data/features_icsd_tmetal-compounds.csv')

### Filter out heteroanion, hetero-transition-metal compounds, Rh compounds (mistake in BVPs)
df = df[df.heteroanion == False]
df['heterotm'] = [tm1 != tm2 for tm1, tm2 in zip(df.tm1, df.tm2)]
df = df[df.heterotm == False]
    # The following are only relevant to hetero transition metal compounds, which we filter out
    # df['mean_cn'] = [cn1 + cn2 for cn1, cn2 in zip(df.cn1, df.cn2)]
    # df['mean_n'] = [n1 + n2 for n1, n2 in zip(df.n1, df.n2)]
    # df['mean_ionic_r'] = [ir1 + ir2 for ir1, ir2 in zip(df.ionic_r_1, df.ionic_r_2)]
df = df[df.anions != 'Sb'] #remove the SINGLE antimonide
df = df[df.tm1 != 'Rh'] # Filter Rh compounds (BVP param has mistake)
df = df[df.gii < 10] # Filter extremely high GII

### Create new columns
df['tm_row'] = [n + 1 for n in df.n1]
    
### Select only useful columns
df = df[['formula', 'sg_sym', 'sg_num',
         'anions', 'tm1', 'tm_row', 'cn1',
         'mm_dist', 'delta', 'normed_dist', 'ionic_r_1',
         'oxi1', 'd_state1', 'pred_d1', 
         'mn1',  'gii', 'n_elems']].dropna()
df = df.rename(columns={'tm1':'tm', 'cn1':'cn','ionic_r_1':'ionic_r','oxi1':'oxi',
                        'd_state1':'d_state', 'mn1':'mn'})
source = ColumnDataSource(data=df)
# df.keys()

### Generate KDEs for all columns, anions
kde_dict = {} # {col : ['vals'vals, kde]}

def calc_kde(col_data):
    kernel = gaussian_kde(col_data)
    vals = np.linspace(min(col_data),max(col_data),500)
    kde = kernel(vals).T
    kde = kde/max(kde) # max-normalization for easier display
    return (vals, kde)

for col in AXIS_OPTIONS:
    for anion in ANION_DICT.keys():
        this_df = df[df.anions == anion]
        vals, kde = calc_kde(this_df[col])
        kde_dict.update({(col,anion,'vals'):vals, (col,anion,'kde'):kde})
kde_df = pd.DataFrame.from_dict(kde_dict)
kde_source = ColumnDataSource(data=kde_df)

# kde_df = pd.DataFrame(data=np.array([col_ls, vals_ls, kde_ls, anion_ls]).T, columns=['col','vals','kde','anion'])
# kde_df
# kde_source.data

## Plotting

In [3]:
### Initialize left scatter plot
scatter_plot = figure(tools=SCATTER_TOOLS, sizing_mode=SIZING_MODE, x_range=RANGE_DICT[INIT_X], 
                      y_range=RANGE_DICT[INIT_Y])
points_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    view = CDSView(source=source, filters=[GroupFilter(column_name='anions', group=anion)])
    points = scatter_plot.scatter(x=INIT_X, y=INIT_Y, source=source, view=view,
                                  fill_alpha=0.01, line_alpha=0.5, #legend_label=anion_label,
                                  color=color, marker=marker, size=8)
    points_ls.append(points) #store for making changes in js_callback
    
### Initialize KDE plots 
hkde_plot = figure(tools=KDE_TOOLS, x_range=scatter_plot.x_range)
hkde_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    hkde_line = hkde_plot.line(x=f'{INIT_X}_{anion}_vals', y=f'{INIT_X}_{anion}_kde', line_width=3,
                               line_alpha=0.5, color=color, #legend_label=anion_label, 
                               source=kde_source)
    hkde_ls.append(hkde_line) #store for making changes in js_callback

vkde_plot = figure(tools=KDE_TOOLS, y_range=scatter_plot.y_range)
vkde_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    vkde_line = vkde_plot.line(y=f'{INIT_Y}_{anion}_vals', x=f'{INIT_Y}_{anion}_kde', line_width=3,
                               line_alpha=0.5, color=color, #legend_label=anion_label, 
                               source=kde_source)
    vkde_ls.append(vkde_line) #store for making changes in js_callback 


### Initialize right scatter plot
rscatter_plot = figure(tools=SCATTER_TOOLS, sizing_mode=SIZING_MODE, x_range=RANGE_DICT[INIT_X], 
                      y_range=RANGE_DICT[INIT_Y])
rpoints_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    view = CDSView(source=source, filters=[GroupFilter(column_name='anions', group=anion)])
    points = rscatter_plot.scatter(x=INIT_X, y=INIT_Y, source=source, view=view,
                                  fill_alpha=0.01, line_alpha=0.5, #legend_label=anion_label,
                                  color=color, marker=marker, size=8)
    rpoints_ls.append(points) #store for making changes in js_callback
    
### Initialize KDE plots 
rhkde_plot = figure(tools=KDE_TOOLS, x_range=rscatter_plot.x_range)
rhkde_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    hkde_line = rhkde_plot.line(x=f'{INIT_X}_{anion}_vals', y=f'{INIT_X}_{anion}_kde', line_width=3,
                               line_alpha=0.5, color=color, #legend_label=anion_label, 
                               source=kde_source)
    rhkde_ls.append(hkde_line) #store for making changes in js_callback

rvkde_plot = figure(tools=KDE_TOOLS, y_range=rscatter_plot.y_range)
rvkde_ls = []
for anion, anion_label, color, marker in zip(ANION_DICT.keys(), ANION_DICT.values(), ANION_COLORS, ANION_MARKERS):
    vkde_line = rvkde_plot.line(y=f'{INIT_Y}_{anion}_vals', x=f'{INIT_Y}_{anion}_kde', line_width=3,
                               line_alpha=0.5, color=color, #legend_label=anion_label, 
                               source=kde_source)
    rvkde_ls.append(vkde_line) #store for making changes in js_callback 
    

### Create universal legend
legend_items = [LegendItem(label=anion,renderers=[points, hkde_line, vkde_line, rpoints, rhkde_line, rvkde_line]) \
                for anion, points, hkde_line, vkde_line, rpoints, rhkde_line, rvkde_line in \
                zip(ANION_DICT.keys(), points_ls, hkde_ls, vkde_ls, rpoints_ls, rhkde_ls, rvkde_ls)]
# Use a dummy figure for the LEGEND
dum_fig = figure(height=60,min_width=420,toolbar_location=None, outline_line_alpha=0, sizing_mode='stretch_width')
# set the components of the figure invisible
for fig_component in [dum_fig.grid[0],dum_fig.ygrid[0],dum_fig.xaxis[0],dum_fig.yaxis[0]]:
    fig_component.visible = False
# The glyphs referred by the legend need to be present in the figure that holds the legend, so we must add them to the figure renderers
dum_fig.renderers += points_ls + hkde_ls + vkde_ls + rpoints_ls + rhkde_ls + rvkde_ls
# set the figure range outside of the range of all glyphs
dum_fig.x_range.end = 100005
dum_fig.x_range.start = 100000
# add the legend
dum_fig.add_layout(Legend(items=legend_items, click_policy='hide', location='top_center', 
                          border_line_alpha=0, glyph_height=40, glyph_width=30, 
                          label_text_font_size=f'{int(0.9*STD_FONT_SIZE)}pt', 
                          orientation='horizontal'))
                            
    
### Make callbacks
xaxis_code="""
        var col_name = cb_obj.value;
        var column = col_dict[col_name];
        // Change scatter plot axes
        points_ls.forEach(function(points) {{
            points.glyph.x.field = column;
        }});
        x_range.start = range_dict[column][0];
        x_range.end = range_dict[column][1];
        ax.axis_label = col_name;
        source.change.emit();
"""

yaxis_code="""
        var col_name = cb_obj.value;
        var column = col_dict[col_name];
        // Change scatter plot axes
        points_ls.forEach(function(points) {{
            points.glyph.y.field = column;
        }});
        y_range.start = range_dict[column][0];
        y_range.end = range_dict[column][1];
        ax.axis_label = col_name;
        source.change.emit();"""

hkde_axis_code="""
        // Change horiz KDE axis
        hkde_ls.forEach(function(line) {{
            var current_col = line.glyph.x.field;
            var anion = current_col.split('_')[current_col.split('_').length - 2];
            var new_x = `${column}_${anion}_vals`;
            var new_y = `${column}_${anion}_kde`;
            line.glyph.x.field = new_x;
            line.glyph.y.field = new_y;
        }});
        kde_source.change.emit();"""

vkde_axis_code="""
        // Change vertical KDE axis
        vkde_ls.forEach(function(line) {{
            var current_col = line.glyph.y.field;
            var anion = current_col.split('_')[current_col.split('_').length - 2];
            var new_y = `${column}_${anion}_vals`;
            var new_x = `${column}_${anion}_kde`;
            line.glyph.x.field = new_x;
            line.glyph.y.field = new_y;
        }});
        kde_source.change.emit();"""

callbackx = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=points_ls, 
                               ax=scatter_plot.xaxis[0], x_range=scatter_plot.x_range, anion=anion, hkde_ls=hkde_ls,
                               col_dict=COL_SHORTEN, range_dict=RANGE_DICT), 
                     code=xaxis_code+hkde_axis_code)

callbacky = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=points_ls, 
                               ax=scatter_plot.yaxis[0], y_range=scatter_plot.y_range, anion=anion, vkde_ls=vkde_ls,
                               col_dict=COL_SHORTEN, range_dict=RANGE_DICT), 
                     code=yaxis_code+vkde_axis_code)

rcallbackx = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=rpoints_ls, 
                               ax=rscatter_plot.xaxis[0], x_range=rscatter_plot.x_range, anion=anion, hkde_ls=rhkde_ls,
                               col_dict=COL_SHORTEN, range_dict=RANGE_DICT), 
                     code=xaxis_code+hkde_axis_code)

rcallbacky = CustomJS(args=dict(source=source, kde_source=kde_source, points_ls=rpoints_ls, 
                               ax=rscatter_plot.yaxis[0], y_range=rscatter_plot.y_range, anion=anion, vkde_ls=rvkde_ls,
                               col_dict=COL_SHORTEN, range_dict=RANGE_DICT), 
                     code=yaxis_code+vkde_axis_code)

### Axis data selector tools
xaxis_select = Select(title="X axis:", value=COL_LENGTHEN[INIT_X], options=PRETTY_AXIS_OPTIONS,
                     min_width=20, max_width=500, sizing_mode='stretch_width')
xaxis_select.js_on_change('value', callbackx)
yaxis_select = Select(title="Y axis:", value=COL_LENGTHEN[INIT_Y], options=PRETTY_AXIS_OPTIONS,
                     min_width=20, max_width=500, sizing_mode='stretch_width')
yaxis_select.js_on_change('value', callbacky)

rxaxis_select = Select(title="X axis:", value=COL_LENGTHEN[INIT_X], options=PRETTY_AXIS_OPTIONS,
                     min_width=20, max_width=500, sizing_mode='stretch_width')
rxaxis_select.js_on_change('value', rcallbackx)
ryaxis_select = Select(title="Y axis:", value=COL_LENGTHEN[INIT_Y], options=PRETTY_AXIS_OPTIONS,
                     min_width=20, max_width=500, sizing_mode='stretch_width')
ryaxis_select.js_on_change('value', rcallbacky)

### Tooltips
hover = HoverTool()
hover.tooltips ="""
    <div>
        <h3><center>@formula</center></h3>
        <div><strong>Spacegroup:    </strong>@sg_sym (@sg_num)</div>
        <div><strong>M-M Dist.:     </strong>@mm_dist \u212B</div>
        <div><strong>GII:           </strong>@gii</div>
        <div><strong>Clustered TM:  </strong>@tm</div>
        <div><strong>TM Coord. #:   </strong>@cn</div>
    </div>
"""
scatter_plot.add_tools(hover)
rscatter_plot.add_tools(hover)


### Layout
spacer = Spacer(height=KDE_HEIGHT, width=KDE_HEIGHT)

format_scatter_plot(scatter_plot)
format_kde_plot(hkde_plot)
format_kde_plot(vkde_plot, vertical=True)
format_scatter_plot(rscatter_plot)
format_kde_plot(rhkde_plot)
format_kde_plot(rvkde_plot, vertical=True)

title = Div(text=f'<h1>{FIG_TITLE}</h1>', align='center', height_policy='min', margin=(-10,0,-10,0))
# title = figure(height=40, sizing_mode='stretch_width')
# title.add_layout(Title(text=FIG_TITLE, align="center"))

top = Row(hkde_plot, spacer, sizing_mode='stretch_width')
middle = Row(scatter_plot, vkde_plot, sizing_mode='stretch_width')

rtop = Row(rhkde_plot, spacer, sizing_mode='stretch_width')
rmiddle = Row(rscatter_plot, rvkde_plot, sizing_mode='stretch_width')

# controls = Row(yaxis_select, xaxis_select, dum_fig, ryaxis_select, rxaxis_select, sizing_mode='stretch_width')
lcontrol = Row(yaxis_select, xaxis_select, sizing_mode='stretch_width')
rcontrol = Row(ryaxis_select, rxaxis_select, sizing_mode='stretch_width')
controls = Row(lcontrol, dum_fig, rcontrol, sizing_mode='stretch_width')
lplots = Column(top, middle, sizing_mode='stretch_width')
rplots = Column(rtop, rmiddle, sizing_mode='stretch_width')
layout = Column(title, Row(lplots, rplots, sizing_mode='stretch_width'), controls, sizing_mode='stretch_width')

show(layout)