# GUI to plot results from various codes

In [None]:
#NOTES:
# To install (in a python 3 virtual environment):
# - pip install numpy matplotlib ipywidgets
# - pip install widget_periodictable
# - jupyter nbextension enable --py widget_periodictable

In [None]:
# Use interactive plots (10x faster than creating PNGs)
%matplotlib notebook

In [None]:
# For the notebook mode, we need to reduce the default font sie
import matplotlib
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 7
}

matplotlib.rc('font', **font)

In [None]:
# The next cell prevents that a cell gets vertical scrolling.
# This is important for the final plot, as we have a lot of plots in the same notebook cell.

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
import json
import os
import numpy as np
import pylab as pl
import ipywidgets as ipw
import widget_periodictable
import matplotlib.colors as mcolors

In [None]:
## Functions (prittifiers)

In [None]:
def symmetrical_colormap(cmap_settings, new_name = None ):
    ''' 
    This function take a colormap and create a new one, as the concatenation of itself by a symmetrical fold.
    '''
    # get the colormap
    cmap = pl.cm.get_cmap(*cmap_settings)
    if not new_name:
        new_name = "sym_"+cmap_settings[0]  # ex: 'sym_Blues'
    
    # this defined the roughness of the colormap, 128 fine
    n= 128 
    
    # get the list of color from colormap
    colors_r = cmap(np.linspace(0, 1, n))    # take the standard colormap # 'right-part'
    colors_l = colors_r[::-1]                # take the first list of color and flip the order # "left-part"

    # combine them and build a new colormap
    colors = np.vstack((colors_l, colors_r))
    mymap = mcolors.LinearSegmentedColormap.from_list(new_name, colors)

    return mymap

def get_conf_nice(configuration_string):
    """
    Convert the configuration string to a nicely typeset string in LaTeX.
    """
    ret_pieces = []
    for char in configuration_string:
        if char in "0123456789":
            ret_pieces.append(f"$_{char}$")
        else:
            ret_pieces.append(char)
    return "".join(ret_pieces)

In [None]:
## Functions (maths 1 - utilities)

In [None]:
def birch_murnaghan(V,E0,V0,B0,B01):
    """
    Return the energy for given volume (V - it can be a vector) according to
    the Birch Murnaghan function with parameters E0,V0,B0,B01.
    """
    r = (V0/V)**(2./3.)
    return (E0 +
            9./16. * B0 * V0 * (
            (r-1.)**3 * B01 + 
            (r-1.)**2 * (6. - 4.* r)))


def intE12sq(v0w,b0w,b1w,v0f,b0f,b1f,V1,V2):
    """
    Integral of (E1(V) - E2(V))**2 in dV evaluated between volume V1 and volume V2
    """
    F1 = antiderE12sq(v0w,b0w,b1w,v0f,b0f,b1f,V1)
    F2 = antiderE12sq(v0w,b0w,b1w,v0f,b0f,b1f,V2)
    integral = F2 - F1

    return integral

def antiderE12sq(v0w,b0w,b1w,v0f,b0f,b1f,V):
    """
    Antiderivative of (E1(V) - E2(V))**2 where E1(V) and E2(V) are birch murnaghan
    functions with different parameters
    """
    antider = (81*(\
            6*b0w*b0f*(-16 + 3*b1w)*(-16 + 3*b1f)*V*v0w*(v0w/V)**(2/3)*v0f*(v0f/V)**(2/3) - \
            2*b0w*b0f*(-14 + 3*b1w)*(-16 + 3*b1f)*V*v0w*(v0w/V)**(4/3)*v0f*(v0f/V)**(2/3) - \
            2*b0w*b0f*(-16 + 3*b1w)*(-14 + 3*b1f)*V*v0w*(v0w/V)**(2/3)*v0f*(v0f/V)**(4/3) + \
            (6*b0w*b0f*(-14 + 3*b1w)*(-14 + 3*b1f)*V*v0w*(v0w/V)**(4/3)*v0f*(v0f/V)**(4/3))/5. + \
            V*(b0w*(-6 + b1w)*v0w - b0f*(-6 + b1f)*v0f)**2 - \
            (b0w*(-4 + b1w)*v0w**3 - b0f*(-4 + b1f)*v0f**3)**2/(3.*V**3) + \
            (3*b0f*(v0f/V)**(7/3)*(-2*b0w*(-14 + 3*b1f)*v0w*(-7*(-6 + b1w)*V**2 + \
            (-4 + b1w)*v0w**2) \
            - 7*b0f*(424 + 5*b1f*(-32 + 3*b1f))*V**2*\
            v0f + 2*b0f*(-4 + b1f)*(-14 + 3*b1f)*\
            v0f**3))/7. - \
            (3*b0f*(v0f/V)**(5/3)*\
            (-2*b0w*(-16 + 3*b1f)*v0w*\
            (5*(-6 + b1w)*V**2 + (-4 + b1w)*v0w**2) \
            + 10*b0f*(-6 + b1f)*(-16 + 3*b1f)*V**2*\
            v0f + b0f*(324 + 5*b1f*(-28 + 3*b1f))*\
            v0f**3))/5. +\
            (4*b0w**2*(124 + 5*(-10 + b1w)*b1w)*v0w**4 - \
            2*b0w*b0f*(-4 + b1w)*(-6 + b1f)*v0w**3*\
            v0f - 2*b0w*b0f*(-6 + b1w)*(-4 + b1f)*v0w*\
            v0f**3 + 4*b0f**2*(124 + 5*(-10 + b1f)*b1f)*v0f**4)/V + \
            (3*b0w*(v0w/V)**(7/3)*\
            (-7*b0w*(424 + 5*b1w*(-32 + 3*b1w))*V**2*\
            v0w + 2*b0w*(-4 + b1w)*(-14 + 3*b1w)*v0w**3 + \
            2*b0f*(-14 + 3*b1w)*v0f*\
            (7*(-6 + b1f)*V**2 - (-4 + b1f)*v0f**2)))/7. - \
            (3*b0w*(v0w/V)**(5/3)*\
            (10*b0w*(-6 + b1w)*(-16 + 3*b1w)*V**2*v0w + \
            b0w*(324 + 5*b1w*(-28 + 3*b1w))*v0w**3 - \
            2*b0f*(-16 + 3*b1w)*v0f*\
            (5*(-6 + b1f)*V**2 + (-4 + b1f)*v0f**2)))/5.))/256.

    return antider

def intEdV(V0,B0,B0pr,V1,V2):
    """
    integral of E(V) in dV evaluated between volumes V1 and V2
    """
    F1 = antiderE(V0,B0,B0pr,V1)
    F2 = antiderE(V0,B0,B0pr,V2)
    integral = F2 - F1

    return integral

def antiderE(V0,B0,B0pr,V):
    """
    antiderivative of the Birch Murnaghan E(V)
    """
    antider = (9*B0*V0*(-((-6 + B0pr)*V) - ((-4 + B0pr)*V0**2)/V + \
            3*(-14 + 3*B0pr)*V0*(V0/V)**(1/3) + \
            3*(-16 + 3*B0pr)*V*(V0/V)**(2/3)))/16

    return antider

def intE2dV(V0,B0,B0pr,V1,V2):
    """
    Integral of E**2(V) in dV evaluated between volume V1 and volume V2
    """
    F1 = antiderE2(V0,B0,B0pr,V1)
    F2 = antiderE2(V0,B0,B0pr,V2)
    integral = F2 - F1

    return integral

def antiderE2(V0,B0,B0pr,V):
    """
    Antiderivative of the Birch Murnaghan squared (E**2(V))
    """
    antider = (81*B0**2*V0**2*((-6 + B0pr)**2*V + \
            (4*(124 + 5*(-10 + B0pr)*B0pr)*V0**2)/V - \
            ((-4 + B0pr)**2*V0**4)/(3.*V**3) - \
            (3*(V0/V)**(2/3)* \
            (10*(-6 + B0pr)*(-16 + 3*B0pr)*V**2 + \
            (324 + 5*B0pr*(-28 + 3*B0pr))*V0**2))/(5.*V) \
            + (V0/V)**(1/3)* \
            (-3*(424 + 5*B0pr*(-32 + 3*B0pr))*V0 + \
            (6*(-4 + B0pr)*(-14 + 3*B0pr)*V0**3)/(7.*V**2))) \
            )/256.

    return antider

In [None]:
## Functions (maths 1 - functions used to compare the EoS results of two codes)

In [None]:
def delta(v0w, b0w, b1w, v0f, b0f, b1f, config_string):
    """
    Calculate the Delta value, function copied from the official DeltaTest repository.
    I don't understand what it does, but it works.
    THE SIGNATURE OF THIS FUNCTION HAS BEEN CHOSEN TO MATCH THE ONE OF ALL THE OTHER FUNCTIONS
    RETURNING A QUANTITY THAT IS USEFUL FOR COMPARISON, THIS SIMPLIFIES THE CODE LATER.
    Even though 'config_string' is useless here.
    """

    Vi = 0.94 * (v0w + v0f) / 2.
    Vf = 1.06 * (v0w + v0f) / 2.

    a3f = 9. * v0f**3. * b0f / 16. * (b1f - 4.)
    a2f = 9. * v0f**(7. / 3.) * b0f / 16. * (14. - 3. * b1f)
    a1f = 9. * v0f**(5. / 3.) * b0f / 16. * (3. * b1f - 16.)
    a0f = 9. * v0f * b0f / 16. * (6. - b1f)

    a3w = 9. * v0w**3. * b0w / 16. * (b1w - 4.)
    a2w = 9. * v0w**(7. / 3.) * b0w / 16. * (14. - 3. * b1w)
    a1w = 9. * v0w**(5. / 3.) * b0w / 16. * (3. * b1w - 16.)
    a0w = 9. * v0w * b0w / 16. * (6. - b1w)

    x = [0, 0, 0, 0, 0, 0, 0]

    x[0] = (a0f - a0w)**2
    x[1] = 6. * (a1f - a1w) * (a0f - a0w)
    x[2] = -3. * (2. * (a2f - a2w) * (a0f - a0w) + (a1f - a1w)**2.)
    x[3] = -2. * (a3f - a3w) * (a0f - a0w) - 2. * (a2f - a2w) * (a1f - a1w)
    x[4] = -3. / 5. * (2. * (a3f - a3w) * (a1f - a1w) + (a2f - a2w)**2.)
    x[5] = -6. / 7. * (a3f - a3w) * (a2f - a2w)
    x[6] = -1. / 3. * (a3f - a3w)**2.
    
    y = [0, 0, 0, 0, 0, 0, 0]

    y[0] = (a0f + a0w)**2 / 4.
    y[1] = 3. * (a1f + a1w) * (a0f + a0w) / 2.
    y[2] = -3. * (2. * (a2f + a2w) * (a0f + a0w) + (a1f + a1w)**2.) / 4.
    y[3] = -(a3f + a3w) * (a0f + a0w) / 2. - (a2f + a2w) * (a1f + a1w) / 2.
    y[4] = -3. / 20. * (2. * (a3f + a3w) * (a1f + a1w) + (a2f + a2w)**2.)
    y[5] = -3. / 14. * (a3f + a3w) * (a2f + a2w)
    y[6] = -1. / 12. * (a3f + a3w)**2.

    Fi = np.zeros_like(Vi)
    Ff = np.zeros_like(Vf)

    Gi = np.zeros_like(Vi)
    Gf = np.zeros_like(Vf)

    for n in range(7):
        Fi = Fi + x[n] * Vi**(-(2. * n - 3.) / 3.)
        Ff = Ff + x[n] * Vf**(-(2. * n - 3.) / 3.)

        Gi = Gi + y[n] * Vi**(-(2. * n - 3.) / 3.)
        Gf = Gf + y[n] * Vf**(-(2. * n - 3.) / 3.)

    Delta = 1000. * np.sqrt((Ff - Fi) / (Vf - Vi))
    #Deltarel = 100. * np.sqrt((Ff - Fi) / (Gf - Gi))
    #vref = 30.
    #bref = 100. * 10.**9. / 1.602176565e-19 / 10.**30. #100 GPa in ev_ang3
    #Delta1 = 1000. * np.sqrt((Ff - Fi) / (Vf - Vi)) \
    #    / (v0w + v0f) / (b0w + b0f) * 4. * vref * bref

    return Delta  #, Deltarel, Delta1


def delta_per_atom(v0w, b0w, b1w, v0f, b0f, b1f, config_string):
    """
    Divides the delta by the number of atoms in the cell.
    THE SIGNATURE OF THIS FUNCTION HAS BEEN CHOSEN TO MATCH THE ONE OF ALL THE OTHER FUNCTIONS
    RETURNING A QUANTITY THAT IS USEFULL FOR COMPARISON, THIS SIMPLIFIES THE CODE LATER.
    """
    conf_natoms_map = {'XO':2,'XO2':3,'XO3':4,'X2O':3,'X2O3':10,'X2O5':14}
    return delta(v0w, b0w, b1w, v0f, b0f, b1f, config_string)/conf_natoms_map[config_string]


def delta2_SSR(v0w, b0w, b1w, v0f, b0f, b1f, config_string):
    """
    Calculate alternative Delta2 based on 2 EOS fits
    THE SIGNATURE OF THIS FUNCTION HAS BEEN CHOSEN TO MATCH THE ONE OF ALL THE OTHER FUNCTIONS
    RETURNING A QUANTITY THAT IS USEFUL FOR COMPARISON, THIS SIMPLIFIES THE CODE LATER.
    Even though 'config_string' is useless here.
    """

    # volume range
    Vi = 0.94 * (v0w + v0f) / 2.
    Vf = 1.06 * (v0w + v0f) / 2.
    deltaV = Vf - Vi

    intdiff2 = intE12sq(v0w,b0w,b1w,v0f,b0f,b1f,Vi,Vf)
    Eavg1 = intEdV(v0w,b0w,b1w,Vi,Vf)/deltaV
    Eavg2 = intEdV(v0w,b0w,b1w,Vi,Vf)/deltaV
    int3 = intE2dV(v0w,b0w,b1w,Vi,Vf) - \
            2*Eavg1*intEdV(v0w,b0w,b1w,Vi,Vf) + \
            deltaV*Eavg1**2 # integrate (ene - mean(ene))**2
    int4 = intE2dV(v0f,b0f,b1f,Vi,Vf) - \
            2*Eavg2*intEdV(v0f,b0f,b1f,Vi,Vf) + \
            deltaV*Eavg2**2
    delta2 = intdiff2/np.sqrt(int3*int4)

    # here we use x100 multiplier to allign delta2 with what we use to as 
    # 'small' difference in the original delta definition (in meV)
    return delta2*100


def V0_diff(v0w, b0w, b1w, v0f, b0f, b1f, config_string):
    """
    Returns the absolute difference in the volumes.
    THE SIGNATURE OF THIS FUNCTION HAS BEEN CHOSEN TO MATCH THE ONE OF ALL THE OTHER FUNCTIONS
    RETURNING A QUANTITY THAT IS USEFUL FOR COMPARISON, THIS SIMPLIFIES THE CODE LATER.
    Even though several inputs are useless here.
    """
    return 100*np.log(v0w/v0f)


def B0_diff(v0w, b0w, b1w, v0f, b0f, b1f, config_string):
    """
    Returns the absolute difference in the bulk modulus.
    THE SIGNATURE OF THIS FUNCTION HAS BEEN CHOSEN TO MATCH THE ONE OF ALL THE OTHER FUNCTIONS
    RETURNING A QUANTITY THAT IS USEFUL FOR COMPARISON, THIS SIMPLIFIES THE CODE LATER.
    Even though several inputs are useless here.
    """
    return 100*np.log(b0w/b0f)

In [None]:
## Import results and variables definition

In [None]:
# Get all results from all <code>s that has a results-<code>.json file in the current folder
file_prefix = 'results-'
file_suffix = '.json'

results_folder = os.curdir

code_results = {}
for fname in os.listdir(results_folder):
    if fname.startswith(file_prefix) and fname.endswith(file_suffix):
        label = fname[len(file_prefix):-len(file_suffix)]
        with open(os.path.join(results_folder, fname)) as fhandle:
            code_results[label] = json.load(fhandle)

In [None]:
# Defines the colors of the EoS curves and associate one color to each <code>
colors = ['#1f78b4', '#33a02c', '#e31a1c', '#ff7f00', '#6a3d9a', '#b15928', '#a6cee3', '#b2df8a', '#fb9a99', '#fdbf6f', '#cab2d6', '#ffff99']
color_code_map = {}
index = 0
for plugin_name in code_results:
    color_code_map[plugin_name] = colors[index]
    index = index + 1
    index = index % len(colors) 

In [None]:
# Map a name (key of the dictionary) to a function (value of the dictionary), allows the selection
# of the quantity to use for the heatmap plot that compares the codes
quantity_for_comparison_map = {
    "delta": delta, 
    "delta_per_atom": delta_per_atom,
    "delta2_SSR": delta2_SSR,
    "B0_diff": B0_diff, 
    "V0_diff": V0_diff
}

In [None]:
## Main function that creates the plots

In [None]:
def plot_for_element(code_results, element, configuration, selected_codes, selected_quantity, axes):
    """
    For a configuration, loops over the data sets (one set for each code) and plots the data
    (both the eos points and the birch murnaghan curves). It also calculates the data for the
    comparison of codes (using the selected quantity: delta, V0_dif, ...) and return them in an heatmap plot.
    """
    # The eos data are plotted straight away in the codes loop, on the contrary we
    # delay the plotting of the fitted data, so to have the same x range for all.
    # The fitting curves info are collected in this list.
    fit_data = []
    
    # Initializations
    code_names_list = []
    color_idx = 0
    dense_volume_range = None # Will eventually be a tuple with (min_volume, max_volume)
    y_range = None
    
    # Loop over codes
    for code_name in sorted(code_results):
        reference_plugin_data = code_results[code_name]

        # Get the EOS data
        try:
            eos_data = reference_plugin_data['eos_data'][f'{element}-{configuration}']
        except KeyError:
            # This code does not have eos data, but it might have the birch murnaghan parameters
            # (for instance reference data sets). We set eos_data to None and go on
            eos_data = None

        # Get the fitted data
        try:
            ref_BM_fit_data = reference_plugin_data['BM_fit_data'][f'{element}-{configuration}']
        except KeyError:
            # Set to None if fit data is missing (might be fit failed). We will still plot the
            # points using a trick to find the reference energy.
            ref_BM_fit_data = None 
            
        # Only in no data and fit are present we skip
        if eos_data is None and ref_BM_fit_data is None:
            continue

        # Take care of range. We update the minimum and maximum volume. It is an iterative process
        # so we have a range that includes all the relevant info for any set of data
        if ref_BM_fit_data is not None:
            if dense_volume_range is None:
                dense_volume_range = (ref_BM_fit_data['min_volume'] * 0.97, ref_BM_fit_data['min_volume'] * 1.03)
            else:
                dense_volume_range = (
                    min(ref_BM_fit_data['min_volume'] * 0.97, dense_volume_range[0]), 
                    max(ref_BM_fit_data['min_volume'] * 1.03, dense_volume_range[1])
                )
        if eos_data is not None:
            volumes, energies = (np.array(eos_data).T).tolist()
            if dense_volume_range is None:
                dense_volume_range = (min(volumes), max(volumes))
            else:
                dense_volume_range = (
                    min(min(volumes), dense_volume_range[0]), 
                    max(max(volumes), dense_volume_range[1]))
        
        # Plotting style. It is different for selected and unselected codes. The unselected
        # codes will be in grey and put on the background.
        alpha = 1.
        send_to_back = False
        if code_name not in selected_codes:
            curve_color = '#000000'
            alpha = 0.1
            send_to_back = True
        else:
            curve_color = color_code_map[code_name]
            color_idx += 1

        # Set energy shift (important to compare among codes!!!)
        warning_string = ''
        if ref_BM_fit_data is not None:
            # Situation when all fit parameters but E0 are present, this hopefully happens only when
            # only fit data are present. To set to zero is the good choice
            if ref_BM_fit_data.get('E0') is None:
                ref_BM_fit_data['E0'] = 0. 
            energy_shift = ref_BM_fit_data['E0']
        else:
            # No fit data, shift selected to be the minimum of the energies. Not correct in general 
            # because we might not have the exact minimum on the grid, or even minimum might be out of range
            warning_string = " (WARNING NO FIT!)"
            volumes, energies = (np.array(eos_data).T).tolist()
            energy_shift = min(energies)
        
        # Collect the fitting data to plot later (only later will have correct range)
        position_to_insert = 0 if send_to_back else len(fit_data) + 1
        if ref_BM_fit_data is not None:
            code_names_list.insert(position_to_insert, code_name)
            fit_data.insert(position_to_insert, (ref_BM_fit_data, energy_shift, {
                # Show the label on the fit if no eos data is visible (I want one and only one label), 
                # but don't show it for hidden plots
                'label': f'{code_name}{warning_string}' if eos_data is None and send_to_back is False else None,
                'alpha': alpha,
                'curve_color': curve_color
            }))
        
        # Plot EOS points straigh away.
        if eos_data is not None:
            volumes, energies = (np.array(eos_data).T).tolist()
            # Don't show the label for hidden plots
            label = f'{code_name}{warning_string}' if send_to_back is False else None
            axes[0].plot(volumes, np.array(energies) - energy_shift, 'o', color=curve_color, label=label, alpha=alpha)
            if not send_to_back:
                if y_range is None:
                    y_range = (min(energies) - energy_shift, max(energies) - energy_shift)
                else:
                    y_range = (
                        min(min(energies) - energy_shift, y_range[0]), 
                        max(max(energies) - energy_shift, y_range[1]))
            
    # A check on the dense_volume_range is needed since we are
    # now out of the loop and it is possible that any code managed to have data for
    # a paricular element.
    if dense_volume_range is not None:
        dense_volumes = np.linspace(dense_volume_range[0], dense_volume_range[1], 100)

        # Plot all fits and calculate deltas
        iii=0
        collect=[]
        codezz=[]
        for ref_BM_fit_data, energy_shift, plot_params in fit_data:        
            reference_eos_fit_energy = birch_murnaghan(
                V=dense_volumes,
                E0=ref_BM_fit_data['E0'],
                V0=ref_BM_fit_data['min_volume'],
                B0=ref_BM_fit_data['bulk_modulus_ev_ang3'],
                B01=ref_BM_fit_data['bulk_deriv']
            )   
            axes[0].plot(
                dense_volumes,
                np.array(reference_eos_fit_energy) - energy_shift, '-',
                color=plot_params['curve_color'],
                alpha=plot_params['alpha'] * 0.5,
                label=plot_params['label']
            )
            
            #The way to distinguish selected codes here is quite fragile, based on curve color
            if plot_params['curve_color'] != '#000000':
                deltas = []
                #Need to compare this to any other selected code
                for sec in fit_data:
                    if sec[2]['curve_color'] != '#000000':
                        #Collect the values
                        V0_1 = ref_BM_fit_data['min_volume']
                        B0_1 = ref_BM_fit_data['bulk_modulus_ev_ang3']
                        B01_1 = ref_BM_fit_data['bulk_deriv']
                        V0_2 = sec[0]['min_volume']
                        B0_2 = sec[0]['bulk_modulus_ev_ang3']
                        B01_2 = sec[0]['bulk_deriv']
                        #calculate delta (or other quantity based on "selected_quantity") and collect
                        func = quantity_for_comparison_map[selected_quantity]
                        delta = float(func(V0_1, B0_1, B01_1, V0_2, B0_2, B01_2,configuration))
                        deltas.append(round(delta,2))
                        
                codezz.append(code_names_list[iii])
                collect.append(deltas)
            iii = iii+1
        
        #Plot the heatmaps with deltas
        to_plot = np.array(collect)
        maxim = max([abs(de) for de in deltas])
        axes[1].imshow(to_plot,cmap=symmetrical_colormap(("Reds",None)), vmin=-maxim, vmax=maxim)
        axes[1].set_xticks(np.arange(len(codezz)))
        axes[1].set_yticks(np.arange(len(codezz)))
        axes[1].set_xticklabels(codezz)
        axes[1].set_yticklabels(codezz)
        # Rotate the tick labels and set their alignment.
        pl.setp(axes[1].get_xticklabels(), rotation=35, ha="right", rotation_mode="anchor")
        # Loop over data dimensions and create text annotations.
        for i in range(len(codezz)):
            for j in range(len(codezz)):
                text = axes[1].text(j, i, to_plot[i, j], ha="center", va="center", color="black")

        #Some labels and visual choices 
        # Set the y range to (visible) points only, if at least one of the selected codes had EOS data points
        if y_range is not None:
            # Make sure that the minimum is zero (or negative if needed)
            y_range = (min(y_range[0], 0), y_range[1])
            axes[0].set_ylim(y_range)           
        axes[0].legend(loc='upper center')
        axes[0].set_xlabel("Cell volume ($\\AA^2$)")
        axes[0].set_ylabel("$E_{tot}$ (eV)")
        conf_nice = get_conf_nice(configuration)
        axes[0].set_title(f"{element} ({conf_nice})")
        axes[1].set_title(f"{element} ({conf_nice}) -- {selected_quantity}")

In [None]:
## Widgets definition and main call to the plot function

In [None]:
ipw_codes = ipw.SelectMultiple(
    options=sorted(code_results),
    value=sorted(code_results), # Select all
    rows=15,
    description='Code plugins',
    disabled=False
)

style = {'description_width': 'initial'}
ipw_comp_quantity = ipw.Select(
    options=sorted(quantity_for_comparison_map),
    value="delta_per_atom",
    rows=15,
    description="Quantity for code comparison",
    style=style
    #disabled=False
)

ipw_periodic = widget_periodictable.PTableWidget(states=1, selected_colors = ["#a6cee3"], selected_elements={'Si': 0})

ipw_output = ipw.Output()

fig = None
axes_list = None

def replot():
    global fig, axes_list
    with ipw_output:
        if fig is None:
            ipw_output.clear_output(wait=True)
            fig, axes_list = pl.subplots(6, 2, figsize=((10,27)), gridspec_kw={"hspace":0.5})
        else:
            for axes in axes_list:
                axes[0].clear()
                axes[1].clear()
            
        for element in sorted(ipw_periodic.selected_elements.keys()):
            #Each axes is one line, not a single sublot. So axes[0] will host EoS and fit, axes[1] the deltas
            for configuration, axes in zip(
                    ['XO', 'XO2', 'XO3', 'X2O', 'X2O3', 'X2O5'],
                    axes_list
                ):
                plot_for_element(
                    code_results=code_results,
                    element=element,
                    configuration=configuration,
                    selected_codes=ipw_codes.value,
                    selected_quantity=ipw_comp_quantity.value,
                    axes=axes
                )

        #pl.show()

def on_codes_change(event):
    if event['type'] == 'change':
        replot()
        
def on_quantity_change(event):
    if event['type'] == 'change' and event['name'] == 'value':
        replot()

last_selected = ipw_periodic.selected_elements
def on_element_select(event):
    global last_selected

    if event['name'] == 'selected_elements' and event['type'] == 'change':
        if tuple(event['new'].keys()) == ('Du', ):
            last_selected = event['old']
        elif tuple(event['old'].keys()) == ('Du', ):
            #print(last_selected, event['new'])
            if len(event['new']) != 1:
                # Reset to only one element only if there is more than one selected,
                # to avoid infinite loops
                newly_selected = set(event['new']).difference(last_selected)
                # If this is empty it's ok, unselect all
                # If there is more than one, that's weird... to avoid problems, anyway, I pick one of the two
                if newly_selected:
                    ipw_periodic.selected_elements = {list(newly_selected)[0]: 0}
                else:
                    ipw_periodic.selected_elements = {}
                # To have the correct 'last' value for next calls
                last_selected = ipw_periodic.selected_elements
            replot()

ipw_codes.observe(on_codes_change)
ipw_comp_quantity.observe(on_quantity_change)
ipw_periodic.observe(on_element_select)

display(ipw.HBox([ipw_codes, ipw.HTML("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp"), ipw_comp_quantity]))
display(ipw_periodic)

In [None]:
# Display in a different cell, so if there is scrolling, it's independent of the top widgets
display(ipw_output)

In [None]:
# Trigger first plot
replot()