# GUI to plot results from various codes

In [None]:
#NOTES:
# To install (in a python 3 virtual environment):
# - pip install numpy matplotlib ipywidgets
# - pip install widget_periodictable
# - jupyter nbextension enable --py widget_periodictable

In [None]:
# Use interactive plots (10x faster than creating PNGs)
%matplotlib notebook

In [None]:
# For the notebook mode, we need to reduce the default font sie
import matplotlib
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 7
}

matplotlib.rc('font', **font)

In [None]:
# The next cell prevents that a cell gets vertical scrolling.
# This is important for the final plot, as we have a lot of plots in the same notebook cell.

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
import json
import os
import numpy as np
import pylab as pl
import ipywidgets as ipw
import widget_periodictable
import matplotlib.colors as mcolors
import quantities_for_comparison as qc

In [None]:
## Functions (prittifiers)

In [None]:
def symmetrical_colormap(cmap_settings, new_name = None ):
    ''' 
    This function take a colormap and create a new one, as the concatenation of itself by a symmetrical fold.
    '''
    # get the colormap
    cmap = pl.cm.get_cmap(*cmap_settings)
    if not new_name:
        new_name = "sym_"+cmap_settings[0]  # ex: 'sym_Blues'
    
    # this defined the roughness of the colormap, 128 fine
    n= 128 
    
    # get the list of color from colormap
    colors_r = cmap(np.linspace(0, 1, n))    # take the standard colormap # 'right-part'
    colors_l = colors_r[::-1]                # take the first list of color and flip the order # "left-part"

    # combine them and build a new colormap
    colors = np.vstack((colors_l, colors_r))
    mymap = mcolors.LinearSegmentedColormap.from_list(new_name, colors)

    return mymap

def get_conf_nice(configuration_string):
    """
    Convert the configuration string to a nicely typeset string in LaTeX.
    """
    ret_pieces = []
    for char in configuration_string:
        if char in "0123456789":
            ret_pieces.append(f"$_{char}$")
        else:
            ret_pieces.append(char)
    return "".join(ret_pieces)

In [None]:
# Get all results from all <code>s that has a results-<code>.json file in the current folder
file_prefix = 'results-'
file_suffix = '.json'

results_folder = os.curdir

code_results = {}
for fname in os.listdir(results_folder):
    if fname.startswith(file_prefix) and fname.endswith(file_suffix):
        label = fname[len(file_prefix):-len(file_suffix)]
        if "unaries" in fname:
            with open(os.path.join(results_folder, fname)) as fhandle:
                code_results[label] = json.load(fhandle)

In [None]:
# Defines the colors of the EoS curves and associate one color to each <code>
colors = ['#1f78b4', '#33a02c', '#e31a1c', '#ff7f00', '#6a3d9a', '#b15928', '#a6cee3', '#b2df8a', '#fb9a99', '#fdbf6f', '#cab2d6', '#ffff99']
color_code_map = {}
index = 0
for plugin_name in code_results:
    color_code_map[plugin_name] = colors[index]
    index = index + 1
    index = index % len(colors) 

In [None]:
# Map a name (key of the dictionary) to a function (value of the dictionary), allows the selection
# of the quantity to use for the heatmap plot that compares the codes
quantity_for_comparison_map = {
    "delta_per_formula_unit (meV)": qc.delta, 
    #"delta_per_atom": qc.delta_per_atom,
    "Prefactor*epsilon": qc.epsilon,
    "Prefactor*B0_rel_diff": qc.B0_rel_diff, 
    "Prefactor*V0_rel_diff": qc.V0_rel_diff,
    "Prefactor*B1_rel_diff": qc.B1_rel_diff,
    "Prefactor*|relerr_vec(weight_b0,weight_b1)|": qc.rel_errors_vec_length
}

In [None]:
## Main function that creates the plots

In [None]:
def plot_for_element(code_results, element, configuration, selected_codes, selected_quantity, prefactor, b0_w, b1_w, axes, max_val=None):
    """
    For a configuration, loops over the data sets (one set for each code) and plots the data
    (both the eos points and the birch murnaghan curves). It also calculates the data for the
    comparison of codes (using the selected quantity: delta, V0_dif, ...) and return them in an heatmap plot.
    """
    # The eos data are plotted straight away in the codes loop, on the contrary we
    # delay the plotting of the fitted data, so to have the same x range for all.
    # The fitting curves info are collected in this list.
    fit_data = []
    
    # Initializations
    code_names_list = []
    color_idx = 0
    dense_volume_range = None # Will eventually be a tuple with (min_volume, max_volume)
    y_range = None
    
    # Loop over codes
    for code_name in sorted(code_results):
        reference_plugin_data = code_results[code_name]
        
        scaling_ref_plugin = qc.get_volume_scaling_to_formula_unit(
            reference_plugin_data['num_atoms_in_sim_cell'][f'{element}-{configuration}'],
            element, configuration
        )

        # Get the EOS data
        try:
            eos_data = reference_plugin_data['eos_data'][f'{element}-{configuration}']
        except KeyError:
            # This code does not have eos data, but it might have the birch murnaghan parameters
            # (for instance reference data sets). We set eos_data to None and go on
            eos_data = None

        # Get the fitted data
        try:
            ref_BM_fit_data = reference_plugin_data['BM_fit_data'][f'{element}-{configuration}']
        except KeyError:
            # Set to None if fit data is missing (might be fit failed). We will still plot the
            # points using a trick to find the reference energy.
            ref_BM_fit_data = None 
            
        # Only in no data and fit are present we skip
        if eos_data is None and ref_BM_fit_data is None:
            continue

        # Take care of range. We update the minimum and maximum volume. It is an iterative process
        # so we have a range that includes all the relevant info for any set of data
        if ref_BM_fit_data is not None:
            if dense_volume_range is None:
                dense_volume_range = (ref_BM_fit_data['min_volume'] * 0.97, ref_BM_fit_data['min_volume'] * 1.03)
            else:
                dense_volume_range = (
                    min(ref_BM_fit_data['min_volume'] * 0.97, dense_volume_range[0]), 
                    max(ref_BM_fit_data['min_volume'] * 1.03, dense_volume_range[1])
                )
        if eos_data is not None:
            volumes, energies = (np.array(eos_data).T).tolist()
            if dense_volume_range is None:
                dense_volume_range = (min(volumes), max(volumes))
            else:
                dense_volume_range = (
                    min(min(volumes), dense_volume_range[0]), 
                    max(max(volumes), dense_volume_range[1]))
        
        # Plotting style. It is different for selected and unselected codes. The unselected
        # codes will be in grey and put on the background.
        alpha = 1.
        send_to_back = False
        if code_name not in selected_codes:
            curve_color = '#000000'
            alpha = 0.1
            send_to_back = True
        else:
            curve_color = color_code_map[code_name]
            color_idx += 1

        # Set energy shift (important to compare among codes!!!)
        warning_string = ''
        if ref_BM_fit_data is not None:
            # Situation when all fit parameters but E0 are present, this hopefully happens only when
            # only fit data are present. To set to zero is the good choice
            if ref_BM_fit_data.get('E0') is None:
                ref_BM_fit_data['E0'] = 0. 
            energy_shift = ref_BM_fit_data['E0']
        else:
            # No fit data, shift selected to be the minimum of the energies. Not correct in general 
            # because we might not have the exact minimum on the grid, or even minimum might be out of range
            warning_string = " (WARNING NO FIT!)"
            volumes, energies = (np.array(eos_data).T).tolist()
            energy_shift = min(energies)
        
        # Collect the fitting data to plot later (only later will have correct range)
        position_to_insert = 0 if send_to_back else len(fit_data) + 1
        if ref_BM_fit_data is not None:
            code_names_list.insert(position_to_insert, code_name)
            fit_data.insert(position_to_insert, (ref_BM_fit_data, energy_shift, {
                # Show the label on the fit if no eos data is visible (I want one and only one label), 
                # but don't show it for hidden plots
                'label': f'{code_name}{warning_string}' if eos_data is None and send_to_back is False else None,
                'alpha': alpha,
                'curve_color': curve_color
            }))
        
        # Plot EOS points straigh away.
        if eos_data is not None:
            volumes, energies = (np.array(eos_data).T).tolist()
            # Don't show the label for hidden plots
            label = f'{code_name}{warning_string}' if send_to_back is False else None
            scaled_en = np.array(energies)/scaling_ref_plugin
            scaled_en_shift = energy_shift/scaling_ref_plugin
            scaled_vol = np.array(volumes)/scaling_ref_plugin
            axes[0].plot(scaled_vol, scaled_en - scaled_en_shift, 'o', color=curve_color, label=label, alpha=alpha)
            if not send_to_back:
                if y_range is None:
                    y_range = (min(scaled_en) - scaled_en_shift, max(scaled_en) - scaled_en_shift)
                else:
                    y_range = (
                        min(min(scaled_en) - scaled_en_shift, y_range[0]/scaling_ref_plugin), 
                        max(max(scaled_en) - scaled_en_shift, y_range[1]/scaling_ref_plugin))
            
    # A check on the dense_volume_range is needed since we are
    # now out of the loop and it is possible that any code managed to have data for
    # a paricular element.
    if dense_volume_range is not None:
        dense_volumes = np.linspace(dense_volume_range[0], dense_volume_range[1], 100)

        # Plot all fits and calculate deltas
        iii=0
        collect=[]
        codezz=[]
        for ref_BM_fit_data, energy_shift, plot_params in fit_data:        
            reference_eos_fit_energy = qc.birch_murnaghan(
                V=dense_volumes,
                E0=ref_BM_fit_data['E0'],
                V0=ref_BM_fit_data['min_volume'],
                B0=ref_BM_fit_data['bulk_modulus_ev_ang3'],
                B01=ref_BM_fit_data['bulk_deriv']
            )   
            axes[0].plot(
                np.array(dense_volumes)/scaling_ref_plugin,
                np.array(reference_eos_fit_energy)/scaling_ref_plugin - energy_shift/scaling_ref_plugin, '-',
                color=plot_params['curve_color'],
                alpha=plot_params['alpha'] * 0.5,
                label=plot_params['label']
            )
            
            #The way to distinguish selected codes here is quite fragile, based on curve color
            if plot_params['curve_color'] != '#000000':
                deltas = []
                #Need to compare this to any other selected code
                for sec in fit_data:
                    if sec[2]['curve_color'] != '#000000':
                        #Collect the values
                        V0_1 = ref_BM_fit_data['min_volume']/scaling_ref_plugin
                        B0_1 = ref_BM_fit_data['bulk_modulus_ev_ang3']
                        B01_1 = ref_BM_fit_data['bulk_deriv']
                        V0_2 = sec[0]['min_volume']/scaling_ref_plugin
                        B0_2 = sec[0]['bulk_modulus_ev_ang3']
                        B01_2 = sec[0]['bulk_deriv']
                        #calculate delta (or other quantity based on "selected_quantity") and collect
                        func = quantity_for_comparison_map[selected_quantity]
                        res = func(V0_1, B0_1, B01_1, V0_2, B0_2, B01_2, prefactor, b0_w, b1_w)
                        delta = float(res)
                        deltas.append(round(delta,2))
                        
                codezz.append(code_names_list[iii])
                collect.append(deltas)
            iii = iii+1
        
        #Plot the heatmaps with deltas
        to_plot = np.array(collect)
        # set value from data if max_val is None
        maxim = max_val or max([abs(de) for de in deltas])
        axes[1].imshow(to_plot,cmap=symmetrical_colormap(("Reds",None)), vmin=-maxim, vmax=maxim)
        axes[1].set_xticks(np.arange(len(codezz)))
        axes[1].set_yticks(np.arange(len(codezz)))
        axes[1].set_xticklabels(codezz)
        axes[1].set_yticklabels(codezz)
        # Rotate the tick labels and set their alignment.
        pl.setp(axes[1].get_xticklabels(), rotation=35, ha="right", rotation_mode="anchor")
        # Loop over data dimensions and create text annotations.
        for i in range(len(codezz)):
            for j in range(len(codezz)):
                text = axes[1].text(j, i, to_plot[i, j], ha="center", va="center", color="black")

        #Some labels and visual choices 
        # Set the y range to (visible) points only, if at least one of the selected codes had EOS data points
        if y_range is not None:
            # Make sure that the minimum is zero (or negative if needed)
            y_range = (min(y_range[0], 0), y_range[1])
            axes[0].set_ylim(y_range)           
        axes[0].legend(loc='upper center')
        axes[0].set_xlabel("Cell volume per formula unit ($\\AA^3$)")
        axes[0].set_ylabel("$E-TS$ per formula unit (eV)")
        conf_nice = get_conf_nice(configuration)
        axes[0].set_title(f"{element} ({conf_nice})")
        axes[1].set_title(f"{element} ({conf_nice}) -- {selected_quantity}")

In [None]:
## Widgets definition and main call to the plot function

In [None]:
ipw_pref = ipw.FloatText(
    value=100,
    description='Prefactor',
    disabled=False
)

ipw_b0 = ipw.FloatText(
    value=0.1,
    description='weight_b0',
    disabled=False
)

ipw_b1 = ipw.FloatText(
    value=0.01,
    description='weight_b1',
    disabled=False
)

ipw_codes = ipw.SelectMultiple(
    options=sorted(code_results),
    value=sorted(code_results), # Select all
    rows=15,
    description='Code plugins',
    disabled=False
)

#style = {'description_width': 'initial', 'widget_width':'initial'}
ipw_comp_quantity = ipw.Select(
    options=sorted(quantity_for_comparison_map),
    value="Prefactor*V0_rel_diff",
    rows=15,
    #description="Quantity for code comparison",
    #style=style
    #disabled=False
)

ipw_periodic = widget_periodictable.PTableWidget(states=1, selected_colors = ["#a6cee3"], disabled_elements=['Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr'], selected_elements={'Si': 0})

ipw_output = ipw.Output()

fig = None
axes_list = None

def replot():
    global fig, axes_list
    with ipw_output:
        if fig is None:
            ipw_output.clear_output(wait=True)
            fig, axes_list = pl.subplots(4, 2, figsize=((10,20)), gridspec_kw={"hspace":0.5})
        else:
            for axes in axes_list:
                axes[0].clear()
                axes[1].clear()
            
        for element in sorted(ipw_periodic.selected_elements.keys()):
            #Each axes is one line, not a single sublot. So axes[0] will host EoS and fit, axes[1] the deltas
            for configuration, axes in zip(
                    ['X/SC', 'X/BCC', 'X/FCC', 'X/Diamond'],
                    axes_list
                ):
                plot_for_element(
                    code_results=code_results,
                    element=element,
                    configuration=configuration,
                    selected_codes=ipw_codes.value,
                    selected_quantity=ipw_comp_quantity.value,
                    prefactor=ipw_pref.value,
                    b0_w=ipw_b0.value,
                    b1_w=ipw_b1.value,
                    axes=axes
                )

        #pl.show()

def on_codes_change(event):
    if event['type'] == 'change':
        replot()
               
def on_quantity_change(event):
    if event['type'] == 'change' and event['name'] == 'value':
        replot()

def on_pref_or_weights_change(event):
    if event['type'] == 'change':
        replot()
        
last_selected = ipw_periodic.selected_elements
def on_element_select(event):
    global last_selected

    if event['name'] == 'selected_elements' and event['type'] == 'change':
        if tuple(event['new'].keys()) == ('Du', ):
            last_selected = event['old']
        elif tuple(event['old'].keys()) == ('Du', ):
            #print(last_selected, event['new'])
            if len(event['new']) != 1:
                # Reset to only one element only if there is more than one selected,
                # to avoid infinite loops
                newly_selected = set(event['new']).difference(last_selected)
                # If this is empty it's ok, unselect all
                # If there is more than one, that's weird... to avoid problems, anyway, I pick one of the two
                if newly_selected:
                    ipw_periodic.selected_elements = {list(newly_selected)[0]: 0}
                else:
                    ipw_periodic.selected_elements = {}
                # To have the correct 'last' value for next calls
                last_selected = ipw_periodic.selected_elements
            replot()

ipw_codes.observe(on_codes_change)
ipw_comp_quantity.observe(on_quantity_change)
ipw_pref.observe(on_pref_or_weights_change)
ipw_b0.observe(on_pref_or_weights_change)
ipw_b1.observe(on_pref_or_weights_change)
ipw_periodic.observe(on_element_select)

link = ipw.HTML(
    value="<a href=./descr.ipynb target='_blank'>here</a>",
)

display(ipw.HBox([ipw_codes, ipw.HTML("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp"), ipw.Label('Quantity for comparison'), ipw_comp_quantity]))
display(ipw.HBox([ipw.Label("For a description of the quantities for comparison, click"), link]))
display(ipw.Label('The following values are used only for some of the `Quantity for comparison` listed above. Look at the quantities name to understand the relevant values for each quantity.'))
display(ipw.HBox([ipw_pref,ipw_b0,ipw_b1]))
display(ipw_periodic)

In [None]:
# Display in a different cell, so if there is scrolling, it's independent of the top widgets
display(ipw_output)

In [None]:
# Trigger first plot
replot()

In [None]:
import io
import tqdm
import contextlib

def plot_all(max_val=None, only_elements=None):
    all_elements = [
        "Ac", "Ag", "Al", "Am", "Ar", "As", "At", "Au", "B", "Ba", "Be",
        "Bi", "Br", "C", "Ca", "Cd", "Ce", "Cl", "Cm", "Co", "Cr", "Cs",
        "Cu", "Dy", "Er", "Eu", "F", "Fe", "Fr", "Ga", "Gd", "Ge", "H",
        "He", "Hf", "Hg", "Ho", "I", "In", "Ir", "K", "Kr", "La", "Li",
        "Lu", "Mg", "Mn", "Mo", "N", "Na", "Nb", "Nd", "Ne", "Ni", "Np",
        "O", "Os", "P", "Pa", "Pb", "Pd", "Pm", "Po", "Pr", "Pt", "Pu", "Ra",
        "Rb", "Re", "Rh", "Rn", "Ru", "S", "Sb", "Sc", "Se", "Si", "Sm",
        "Sn", "Sr", "Ta", "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U",
        "V", "W", "Xe", "Y", "Yb", "Zn", "Zr"
    ]
    
    if only_elements is not None:
         all_elements = [elem for elem in all_elements if elem in only_elements]
    
    # Avoid interactive creation of figures
    pl.ioff()
    
    def get_axes_from_list(pos, axes_list):
        num_rows = len(axes_list)
        num_columns = len(axes_list[0])
        assert num_rows*num_columns == 12
        assert num_columns % 2 == 0
        assert pos < 6
        
        row = pos // (num_columns // 2)
        column = (pos % (num_columns // 2) ) * 2
        
        return (axes_list[row, column], axes_list[row, column+1])
    
    try:
        for element in tqdm.tqdm(all_elements):
            f = io.StringIO()
            with contextlib.redirect_stdout(f):
                num_rows = 3
                num_cols = 4
                fig, axes_list = pl.subplots(num_rows, num_cols, figsize=((5 * num_cols, 4.5 * num_rows)), gridspec_kw={"hspace":0.5})

            configurations = ['X/SC', 'X/BCC', 'X/FCC', 'X/Diamond']
    
                
            #Each axes is one line, not a single sublot. So axes[0] will host EoS and fit, axes[1] the deltas
            for idx, configuration in enumerate(configurations):
                axes = get_axes_from_list(pos=idx, axes_list=axes_list)
                plot_for_element(
                    code_results=code_results,
                    element=element,
                    configuration=configuration,
                    selected_codes=ipw_codes.value,
                    selected_quantity=ipw_comp_quantity.value,
                    prefactor=ipw_pref.value,
                    b0_w=ipw_b0.value,
                    b1_w=ipw_b1.value,
                    axes=axes,
                    max_val=max_val
                )
            fig.savefig(f"{element}.png")
            pl.close(fig)
    finally:
        # Reactivates interactive figure creation
        pl.ion()

In [None]:
## Examples - you can set the max value for the color bar, or let each plot have a different maximum defined by
## the maximum value of the data for that element and configuration.

## In addition, especially for testing, you can decide to create the files only for a few elements
## Finally, you can decide which metric (and parameters) to use directly above, with the widgets.

## Running the function will generate files, in the current folder, named Ac.pdf, Ag.pdf, Al.pdf, ...

#plot_all(max_val=2., only_elements= ["Ac", "Ag", "Al", "Am", "Ar"])
#plot_all(max_val=1.)

In [None]:
## Otherwise, uncomment this to have a button generate the plots

#def on_generate_click(button):
#    button.disabled = True    
#    try:
#        plot_all(max_val=2.)
#    finally:
#        button.disabled = False
#
#generate_button = ipw.Button(description="Generate all PNGs", status="success")
#generate_button.on_click(on_generate_click)
#display(generate_button)