In [13]:
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
import PySimpleGUI as sg
import numpy as np
import PySimpleGUI as sg
import matplotlib.pyplot as plt
import import_ipynb
from ase.visualize.plot import plot_atoms
from ase import Atoms
import math
from utils import *

sg.theme("DarkBlue7")

plt.style.use('dark_background')
color = sg.theme_background_color()

plotting_fit_data = {}
plotting_rdf_data = {}
plotting_view_data = {}
selected_rows_dict = {}

In [26]:
import colorsys


def adjust_lightness(color, amount=0.5):
    import matplotlib.colors as mc
    import colorsys
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    x = colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])
    ret = [int(l * 255) for l in x]
    return '#{:02X}{:02X}{:02X}'.format(ret[0], ret[1], ret[2])


color2 = adjust_lightness(color, 3)
# color2='#241663'
color2

'#110A31'

In [7]:
# Method used in all of the plotting functions. It allows for integration of matplotlib's
# plots and widgets with PYSimpleGUI. As parameters takes figure's canvas, toolbar's canvas
# and matplotlib figure to be drawn.
def draw_figure_w_toolbar(canvas, fig, canvas_toolbar):
    if canvas.children:
        for child in canvas.winfo_children():
            child.destroy()
    if canvas_toolbar.children:
        for child in canvas_toolbar.winfo_children():
            child.destroy()
    figure_canvas_agg = FigureCanvasTkAgg(fig, master=canvas)
    figure_canvas_agg.draw()
    toolbar = Toolbar(figure_canvas_agg, canvas_toolbar)
    for button in toolbar.winfo_children():  # Important for color theme
        button.config(background=color2)
    toolbar.update()
    figure_canvas_agg.get_tk_widget().pack(side='right', fill='both', expand=1)

In [3]:
class Toolbar(NavigationToolbar2Tk):
    def __init__(self, *args, **kwargs):
        super(Toolbar, self).__init__(*args, **kwargs)
        self.config(background=color2)  # Important for color theme

In [1]:
# Method used for plotting four plots in main window's tabs. As parameters takes log and
# unit_types data, canvases of a figure and toolbar, values dictionary that contains
# values of all elements in the window, tab_id which is absolute path of a file from
# which the plotted data is retrieved.
def plot_overview(log_unit_types_dict, fig_cv, toolbar_cv, values, file_id, size):
    """
    Plots an overview of the data based on the given parameters.

    Parameters:
        log_unit_types_dict (dict): A dictionary containing log and unit types information for each file.
        fig_cv (FigureCanvasTkAgg): The FigureCanvasTkAgg object used to draw the figure.
        toolbar_cv (NavigationToolbar2Tk): The NavigationToolbar2Tk object used to display the toolbar.
        values (dict): A dictionary containing the values for the combo boxes and other input elements.
        file_id (str): The ID of the file for which the overview is being plotted.
        size (tuple): The size of the figure in inches.

    Returns:
        None
    """
    
    n, xq, yq, ut, x, y = ([] for i in range(6))
    log, unit_types = log_unit_types_dict[file_id]

    for i in range(4):
        n.append(values[f'combo_n{i+1}_{file_id}'] - 1)
        xq.append(values[f'combo_x{i+1}_{file_id}'])
        yq.append(values[f'combo_y{i+1}_{file_id}'])
        ut.append(unit_types[n[i]])
        x.append(log.get(xq[i], n[i]))
        y.append(log.get(yq[i], n[i]))

    fig, axs = plt.subplots(2, 2)
    fig.set_facecolor(color) # TODO: check if this is needed
    def a(x): return 1 if x > 1 else 0
    def b(x): return 0 if x % 2 == 0 else 1
    for i in range(4):
        axs[a(i), b(i)].scatter(x[i], y[i])
        axs[a(i), b(i)].set(xlabel=f'{xq[i]} {units(ut[i], xq[i])}',
                            ylabel=f'{yq[i]} {units(ut[i], yq[i])}')

    fig.tight_layout()

    plt.figure(1)
    fig = plt.gcf()
    plt.close()
    DPI = fig.get_dpi()
    size_x, size_y = size
    fig.set_size_inches(size_x / float(DPI), size_y / float(DPI))
    draw_figure_w_toolbar(fig_cv, fig, toolbar_cv)

In [2]:
def plot_rdf(log_unit_types_dict, fig_cv, toolbar_cv, values, file_id, win, size):
    """
    Plots an RDF plot based on the given parameters.

    Parameters:
        log_unit_types_dict (dict): A dictionary containing log and unit types information for each file.
        fig_cv (FigureCanvasTkAgg): The FigureCanvasTkAgg object used to draw the figure.
        toolbar_cv (NavigationToolbar2Tk): The NavigationToolbar2Tk object used to display the toolbar.
        values (dict): A dictionary containing the values for the combo boxes and other input elements.
        file_id (str): The ID of the file for which the RDF plot is being plotted.
        win (str): The window for which the RDF plot is being plotted.
        size (tuple): The size of the figure in inches.

    Returns:
        None
    """
    c_id = plotting_rdf_data[file_id]['c_id']
    reset_rdf_data(plotting_rdf_data, file_id, c_id)
    
    n = values[f'combo_n_rdf_{file_id}'] - 1
    xq = values[f'combo_x_rdf_{file_id}']
    yq = values[f'combo_y_rdf_{file_id}']

    log, unit_types = log_unit_types_dict[file_id]
    ut = unit_types[n]
    x = log.get(xq, n)
    y = log.get(yq, n)
    ux = units(ut, xq)
    uy = units(ut, yq)

    fig, axs = plt.subplots(1, 2)
    fig.set_facecolor(color) # TODO: check if this is needed
    axs[0].scatter(x, y)
    axs[0].set(xlabel=f'{xq} {ux}',
                ylabel=f'{yq} {uy}')
    axs[1].set(xlabel='Distance ($\\AA$)', ylabel='RDF')

    fig.tight_layout()
    plt.figure(1)
    fig = plt.gcf()
    plt.close()
    DPI = fig.get_dpi()
    size_x, size_y = size
    fig.set_size_inches(size_x / float(DPI), size_y / float(DPI))
    draw_figure_w_toolbar(fig_cv, fig, toolbar_cv)

    def on_click_rdf(event):
        x_click = event.xdata
        y_click = event.ydata
        if event.inaxes is axs[0] and file_id in list(rdf_dict.keys()):
            remove_red_point(plotting_rdf_data, fig, file_id, 'point1')
            x_np, y_np, rdf_data = find_rdf_point(x, y, x_click, y_click, file_id, log.get('Step', n))
            rdf_col_number = get_rdf_col_number(values[f'combo_col_rdf_{file_id}'])

            x_rdf = rdf_data[:, 1]
            y_rdf = rdf_data[:, rdf_col_number+1]
            plotting_rdf_data[file_id]['x'] = x_rdf
            plotting_rdf_data[file_id]['y'] = y_rdf
            axs[1].cla()
            axs[1].plot(x_rdf, y_rdf)
            axs[1].set(xlabel='Distance ($\\AA$)', ylabel='RDF')
            plotting_rdf_data[file_id]['point1'] = axs[0].plot(
                x_np, y_np, color='red', marker='o')
            fig.canvas.draw()
        elif event.inaxes is axs[1] and np.size(plotting_rdf_data[file_id]['x']) and np.size(plotting_rdf_data[file_id]['y']):
            x_rdf = plotting_rdf_data[file_id]['x']
            y_rdf = plotting_rdf_data[file_id]['y']
            x_rdf_np, y_rdf_np = find_nearest_point(
                x_rdf, y_rdf, x_click, y_click)
            win[f'rdf_xy_{file_id}'].update(
                f'X={format_number(x_rdf_np)} Y={format_number(y_rdf_np)}')
            remove_red_point(plotting_rdf_data, fig, file_id, 'point2')

            plotting_rdf_data[file_id]['point2'] = axs[1].plot(
                x_rdf_np, y_rdf_np, color='red', marker='o')
            fig.canvas.draw()
    disconnect_on_click(fig, plotting_rdf_data[file_id]['c_id'])
    c_id = fig.canvas.mpl_connect("button_press_event", on_click_rdf)
    plotting_rdf_data[file_id]['c_id'] = c_id
    return axs

In [3]:
def plot_fit(log_unit_types_dict, fig_cv, toolbar_cv, values, file_id, win, size):
    log, unit_types = log_unit_types_dict[file_id]
    c_id = plotting_fit_data[file_id]['c_id']
    reset_fit_data(plotting_fit_data, selected_rows_dict, file_id, c_id, win)

    n = values[f'combo_n_fit_{file_id}'] - 1
    xq = values[f'combo_x_fit_{file_id}']
    yq = values[f'combo_y_fit_{file_id}']
    ut = unit_types[n]
    x = log.get(xq, n)
    y = log.get(yq, n)
    ux = units(ut, xq)
    uy = units(ut, yq)

    new_headings = ['Line', f'Slope {slope_unit(ux, uy)}',
                    f'Intercept {single_unit(uy)}', f'Boundry 1 {single_unit(ux)}', f'Boundry 2 {single_unit(ux)}']

    table = win[f'table_fit_{file_id}'].Widget
    update_headings(table, new_headings)

    fig, axs = plt.subplots(1)
    plotting_fit_data[file_id]['fit_axs'] = axs
    plotting_fit_data[file_id]['fit_fig'] = fig
    axs.scatter(x, y, color='mediumpurple')
    axs.set(xlabel=f'{xq} {ux}',
            ylabel=f'{yq} {uy}')

    fig.set_facecolor(color) # TODO: change color
    fig.tight_layout()

    plt.figure(1)
    fig = plt.gcf()
    plt.close()
    DPI = fig.get_dpi()
    x_size, y_size = size
    fig.set_size_inches(x_size / float(DPI), y_size / float(DPI))
    
    draw_figure_w_toolbar(fig_cv, fig, toolbar_cv)

    def on_click_fit(event):
        if event.xdata is not None:
            plotting_fit_data[file_id]['boundries'].append(event.xdata)
        fit_boundries = plotting_fit_data[file_id]['boundries']

        if len(fit_boundries) == 2:
            boundry1_index = np.where(x == min(
                x, key=lambda z: abs(fit_boundries[0] - z)))[0][0]
            boundry2_index = np.where(x == min(
                x, key=lambda z: abs(fit_boundries[1] - z)))[0][0]
            a, b = linear_fit(
                x, y, boundry1_index, boundry2_index)

            x_max, x_min = x.max(), x.min()
            x_line = np.arange(x_min, x_max, (x_max - x_min)/1000)
            label = len(plotting_fit_data[file_id]["lines"]) + 1
            line,  = axs.plot(x_line, a*x_line+b, label=f'{label}')
            axs.legend()
            fig.canvas.draw()
            plotting_fit_data[file_id]['lines'].append(line)

            plotting_fit_data[file_id]['table_rows'].append(
                [str(label), format_number(a), format_number(b), format_number(x[boundry1_index]), format_number(x[boundry2_index])])
            fit_table_rows = plotting_fit_data[file_id]['table_rows']
            win[f'table_fit_{file_id}'].update(fit_table_rows)

            number_of_rows = len(fit_table_rows)
            if number_of_rows == 1:
                new_selected_rows = [0]
                selected_rows_dict[file_id] = new_selected_rows
                win[f'table_fit_{file_id}'].update(
                    select_rows=new_selected_rows)
            else:
                new_selected_rows = [number_of_rows - 2, number_of_rows - 1]
                selected_rows_dict[file_id] = new_selected_rows
                win[f'table_fit_{file_id}'].update(
                    select_rows=new_selected_rows)

            plotting_fit_data[file_id]['boundries'].clear()
    disconnect_on_click(fig, plotting_fit_data[file_id]['c_id'])
    c_id = fig.canvas.mpl_connect("button_press_event", on_click_fit)
    plotting_fit_data[file_id]['c_id'] = c_id

    return axs

In [4]:
def plot_view(log_unit_types_dict, fig_cv, toolbar_cv, values, file_id, win):
    plotting_view_data[file_id] = {}
    plotting_view_data[file_id]['point1'] = None

    log, unit_types = log_unit_types_dict[file_id]
    n = values[f'combo_n_view_{file_id}'] - 1
    xq = values[f'combo_x_view_{file_id}']
    yq = values[f'combo_y_view_{file_id}']
    ut = unit_types[n]
    x = log.get(xq, n)
    y = log.get(yq, n)
    ux = units(ut, xq)
    uy = units(ut, yq)

    fig, axs = plt.subplots(1, 2)
    axs[0].cla()
    axs[0].scatter(x, y)
    axs[0].set(xlabel=xq + " (" + ux + ")",
               ylabel=yq + " (" + uy + ")")

    axs[1].set_axis_off()
    fig.set_facecolor(color)
    fig.tight_layout()

    plt.figure(1)
    fig = plt.gcf()
    plt.close()
    DPI = fig.get_dpi()
    # you have to play with this size to reduce the movement error
    # when the mouse hovers over the figure, it's close to canvas size
    fig.set_size_inches(500 * 2 / float(DPI), 500 / float(DPI))
    draw_figure_w_toolbar(fig_cv, fig, toolbar_cv)

    def on_click_view(event):
        win[f'file_view_{file_id}'].set_cursor('watch')
        x0 = event.xdata
        y0 = event.ydata
        view_list = list(view_dict.keys())
        if event.inaxes is axs[0] and file_id in view_list:
            point1 = plotting_view_data[file_id]['point1']
            if point1 is not None:
                point = point1.pop(0)
                point.remove()
                plotting_view_data[file_id]['point1'] = None
                fig.canvas.draw()
            x_cp, y_cp = find_nearest_point(x, y, x0, y0)
            index = np.where(x == x_cp)[0][0]
            step = log.get('Step', n)
            key = step[index]
            key_list = list(view_dict[file_id].keys())
            key_list = [int(i) for i in key_list]
            key_list = np.array(key_list)
            key = find_nearest_value(key_list, key)
            new_index = np.where(step == key)[0]

            x_cp = x[new_index]
            y_cp = y[new_index]

            atoms = view_dict[file_id][str(key)]

            if len(atoms) > 1000:
                distance_cutoff = 5.0
                center_of_mass = atoms.get_center_of_mass()
                new_atom_positions = []
                for i in range(len(atoms)):
                    if math.dist(atoms.positions[i], center_of_mass) < distance_cutoff:
                        new_atom_positions.append(
                            atoms.positions[i] - center_of_mass + [5, 5, 5])
                atoms = Atoms(
                    positions=new_atom_positions, cell=[10, 10, 10])
            axs[1].cla()
            axs[1].set_axis_off()
            plot_atoms(atoms, axs[1], radii=1.0, rotation=('20x,30y,10z'))
            plotting_view_data[file_id]['point1'] = axs[0].plot(
                x_cp, y_cp, color='red', marker='o')
            fig.canvas.draw()
            win[f'file_view_{file_id}'].set_cursor('arrow')
    # fig.canvas.mpl_disconnect()
    fig.canvas.mpl_connect("button_press_event", on_click_view)

    return axs