## GUI to plot results from various codes

In [None]:
#NOTES:
# To install (in a python 3 virtual environment):
# - pip install numpy matplotlib ipywidgets
# - pip install widget_periodictable
# - jupyter nbextension enable --py widget_periodictable

In [1]:
import json
import os
import numpy as np
import pylab as pl
import ipywidgets as ipw
import widget_periodictable

In [2]:
def get_conf_nice(configuration_string):
    """Convert the configuration string to a nicely typeset string in LaTeX."""
    ret_pieces = []
    for char in configuration_string:
        if char in "0123456789":
            ret_pieces.append(f"$_{char}$")
        else:
            ret_pieces.append(char)
    return "".join(ret_pieces)


def birch_murnaghan(V,E0,V0,B0,B01):
    r = (V0/V)**(2./3.)
    return (E0 +
            9./16. * B0 * V0 * (
            (r-1.)**3 * B01 + 
            (r-1.)**2 * (6. - 4.* r)))

In [3]:
# Get all results from all codes
file_prefix = 'results-'
file_suffix = '.json'

results_folder = os.curdir

code_results = {}
for fname in os.listdir(results_folder):
    if fname.startswith(file_prefix) and fname.endswith(file_suffix):
        label = fname[len(file_prefix):-len(file_suffix)]
        with open(os.path.join(results_folder, fname)) as fhandle:
            code_results[label] = json.load(fhandle)

#print(f"Found results for {len(code_results)} codes")

In [4]:
colors = ['#1f78b4', '#33a02c', '#e31a1c', '#ff7f00', '#6a3d9a', '#b15928', '#a6cee3', '#b2df8a', '#fb9a99', '#fdbf6f', '#cab2d6', '#ffff99']

In [17]:
def plot_for_element(code_results, element, configuration, selected_codes, axes):
    
    # I delay the plotting so I have the x range for all
    fit_data = []
    
    color_idx = 0
    dense_volume_range = None # Will eventually be a tuple with (min_volume, max_volume)
    y_range = None
    for code_name in sorted(code_results):
        reference_plugin_data = code_results[code_name]

        # Get the EOS data
        try:
            eos_data = reference_plugin_data['eos_data'][f'{element}-{configuration}']
        except KeyError:
            # This code does not have this element and configuration
            eos_data = None

        # Get the fitted data
        try:
            ref_BM_fit_data = reference_plugin_data['BM_fit_data'][f'{element}-{configuration}']
        except KeyError:
            # Set to None if fit data is missing (if we are here, the EOS points
            # are there, so it means that the fit failed). I will still plot the
            # points
            ref_BM_fit_data = None 

        if ref_BM_fit_data is not None:
            if ref_BM_fit_data.get('E0') is None: # Either unset or set to None
                ref_BM_fit_data['E0'] = 0. # I set to a valid float value, anyway it will be shifted back to zero
            # Take +- 3% of the first set of fit data
            # Set the largest volume of computed points
            if dense_volume_range is None:
                dense_volume_range = (ref_BM_fit_data['min_volume'] * 0.97, ref_BM_fit_data['min_volume'] * 1.03)
            else:
                dense_volume_range = (
                    min(ref_BM_fit_data['min_volume'] * 0.97, dense_volume_range[0]), 
                    max(ref_BM_fit_data['min_volume'] * 1.03, dense_volume_range[1]))

        if eos_data is None and ref_BM_fit_data is None:
            continue

        alpha = 1.
        send_to_back = False
        if code_name not in selected_codes:
            curve_color = '#000000'
            alpha = 0.1
            send_to_back = True
        else:
            curve_color = colors[color_idx % len(colors)]
            color_idx += 1

        warning_string = ''
        if ref_BM_fit_data is not None:
            energy_shift = ref_BM_fit_data['E0']
        else:
            # If I have no fit data, I just shift all by the minimum of the energies
            # This is not correct in general because I might not have the
            # exact minimum on the grid, or even the minimum might be out of range
            warning_string = " (WARNING NO FIT!)"
            volumes, energies = (np.array(eos_data).T).tolist()
            energy_shift = min(energies)
        
        position_to_insert = 0 if send_to_back else len(fit_data) + 1
        # Set data for plot (delayed later)
        if ref_BM_fit_data is not None:
            fit_data.insert(position_to_insert, (ref_BM_fit_data, energy_shift, {
                # Show the label on the fit if no data is visible (I want one and only one label), 
                # but don't show it for hidden plots
                'label': f'{code_name}{warning_string}' if eos_data is None and send_to_back is False else None,
                'alpha': alpha,
                'curve_color': curve_color
            }))
        
        # Plot EOS points
        if eos_data is not None:
            volumes, energies = (np.array(eos_data).T).tolist()
            # Set the largest volume of computed points
            if dense_volume_range is None:
                dense_volume_range = (min(volumes), max(volumes))
            else:
                dense_volume_range = (
                    min(min(volumes), dense_volume_range[0]), 
                    max(max(volumes), dense_volume_range[1]))
            # Don't show the label for hidden plots
            label = f'{code_name}{warning_string}' if send_to_back is False else None
            axes.plot(volumes, np.array(energies) - energy_shift, 'o', color=curve_color, label=label, alpha=alpha)

            if not send_to_back:
                if y_range is None:
                    y_range = (min(energies) - energy_shift, max(energies) - energy_shift)
                else:
                    y_range = (
                        min(min(energies) - energy_shift, y_range[0]), 
                        max(max(energies) - energy_shift, y_range[1]))
            
    # Prepare the x range if no selected data has actual data points but only fits   
    dense_volumes = np.linspace(dense_volume_range[0], dense_volume_range[1], 100)

    # Plot all fits
    for ref_BM_fit_data, energy_shift, plot_params in fit_data:        
        reference_eos_fit_energy = birch_murnaghan(
            V=dense_volumes,
            E0=ref_BM_fit_data['E0'],
            V0=ref_BM_fit_data['min_volume'],
            B0=ref_BM_fit_data['bulk_modulus_ev_ang3'],
            B01=ref_BM_fit_data['bulk_deriv']
        )   
        axes.plot(
            dense_volumes,
            np.array(reference_eos_fit_energy) - energy_shift, '-',
            color=plot_params['curve_color'],
            alpha=plot_params['alpha'] * 0.5,
            label=plot_params['label']
        )

    # Set the y range to (visible) points only, if at least one of the selected codes had EOS data points
    if y_range is not None:
        # Make sure that the minimum is zero (or negative if needed)
        y_range = (min(y_range[0], 0), y_range[1])
        axes.set_ylim(y_range)

In [20]:
ipw_codes = ipw.SelectMultiple(
    options=sorted(code_results),
    value=sorted(code_results), # Select all
    rows=15,
    description='Code plugins',
    disabled=False
)

ipw_periodic = widget_periodictable.PTableWidget(states=1, selected_colors = ["#a6cee3"], selected_elements={'Si': 0})

ipw_output = ipw.Output()

def replot():
    with ipw_output:
        ipw_output.clear_output(wait=True)
        for element in sorted(ipw_periodic.selected_elements.keys()):
            fig, axes_list = pl.subplots(3, 2, figsize=((14,20)))

            for configuration, axes in zip(
                    ['XO', 'XO2', 'XO3', 'X2O', 'X2O3', 'X2O5'],
                    axes_list.flatten()
                ):
                plot_for_element(code_results=code_results, element=element, configuration=configuration, selected_codes=ipw_codes.value, axes=axes)

                axes.legend(loc='upper center')
                axes.set_xlabel("Cell volume ($\\AA^2$)")
                axes.set_ylabel("$E_{tot}$ (eV)")
                conf_nice = get_conf_nice(configuration)
                axes.set_title(f"{element} ({conf_nice})")

        pl.show()

def on_codes_change(event):
    if event['type'] == 'change':
        replot()

last_selected = ipw_periodic.selected_elements
def on_element_select(event):
    global last_selected

    if event['name'] == 'selected_elements' and event['type'] == 'change':
        if tuple(event['new'].keys()) == ('Du', ):
            last_selected = event['old']
        elif tuple(event['old'].keys()) == ('Du', ):
            #print(last_selected, event['new'])
            if len(event['new']) != 1:
                # Reset to only one element only if there is more than one selected,
                # to avoid infinite loops
                newly_selected = set(event['new']).difference(last_selected)
                # If this is empty it's ok, unselect all
                # If there is more than one, that's weird... to avoid problems, anyway, I pick one of the two
                if newly_selected:
                    ipw_periodic.selected_elements = {list(newly_selected)[0]: 0}
                else:
                    ipw_periodic.selected_elements = {}
                # To have the correct 'last' value for next calls
                last_selected = ipw_periodic.selected_elements
            replot()

ipw_codes.observe(on_codes_change)
ipw_periodic.observe(on_element_select)

display(ipw_codes)
display(ipw_periodic)

SelectMultiple(description='Code plugins', index=(0, 1, 2, 3, 4, 5, 6, 7, 8), options=('abinit-NC', 'abinit-PA…

PTableWidget(allElements=['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', '…

REPLOTTING


In [21]:
# Display in a different cell, so if there is scrolling, it's independent of the top widgets
display(ipw_output)

Output()

In [23]:
# Trigger first plot
replot()

REPLOTTING
