In [1]:
import lammps_logfile
import scipy
import PySimpleGUI as sg
import random
import numpy as np
import PySimpleGUI as sg
from ase.io.lammpsrun import read_lammps_dump_text
import math
from ase import Atoms

rdf_dict = {}
view_dict = {}

In [2]:
def read_log_file(file):
    """
    Reads a log file and returns the log and units.

    Parameters:
        file (str or file-like object): The path to the log file or a file-like object.

    Returns:
        log (lammps_logfile.File): The parsed log file.
        units (list): A list of unit types (e.g., 'metal') of simulations in the file.
    """
    log = lammps_logfile.File(file)
    units = []
    file.seek(0)
    for line in file:
        if 'units' in line:
            unit = line.split()[1].rstrip()
            file.seek(0)
            for line in file:
                if 'Per MPI rank' in line:
                    units.append(unit)
            break
        elif 'Unit style' in line:
            unit = line.partition(': ')[2].rstrip()
            units.append(unit)
    return log, units

In [3]:
# Meethod used to load rdf LAMMPS files.
def read_RDF_file(file, file_id):
    if file_id not in list(rdf_dict.keys()):
        rdf_dict[file_id] = {}
    data_dict = rdf_dict[file_id]
    try:
        x = []
        col_number = []
        missing_col_keys = []
        redundant_col_keys = []
        for line in file:
            if not line.startswith("#"):
                t = line.strip('\n').split()
                t = [float(x) for x in t]
                if len(t) == 2:
                    key, N = str(int(t[0])), t[1]
                else:
                    length = len(t)
                    col_number.append(length)
                    if length == col_number[0]:
                        x.append(t)
                    elif length < col_number[0]:
                        r = col_number[0]-length
                        missing_col_keys.append(key)
                        [t.append(0.0) for i in range(r)]
                        x.append(t)
                    else:
                        r = length-col_number[0]
                        redundant_col_keys.append(key)
                        [t.pop(-i) for i in range(1, r+1)]
                        x.append(t)

                if len(x) == N:
                    data_dict.update({key: x})
                    x = []

        if len(missing_col_keys) > 0:
            sg.popup_ok(
                f'A mismatch in number of columns occurred in timestep(s)\n{", ".join(missing_col_keys)}\nwhile attempting to load\n{file.name}.\nMissing rows were filled with 0.0.', title='Warning')
        if len(redundant_col_keys) > 0:
            sg.popup_ok(
                f'A mismatch in number of columns occurred in timestep(s)\n{", ".join(redundant_col_keys)}\nwhile attempting to load\n{file.name}.\nRedundant rows were removed.', title='Warning')

        rdf_dict[file_id] = data_dict
        return rdf_dict, col_number[0]
    except ValueError:
        sg.popup_ok(
            f'An error occurred while attempting to load\n{file.name}.', title='Error')
    except TypeError:
        sg.popup_ok(
            f'An error occurred while attempting to load\n{file.name}.', title='Error')

In [4]:
def read_dump_file(file, file_id):
    try:
        if file_id not in list(view_dict.keys()):
            view_dict[file_id] = {}
        timesteps = []
        for line in file:
            if 'ITEM: TIMESTEP' in line:
                timesteps.append(file.readline().strip())

        for n in range(len(timesteps)):
            file.seek(0)
            key = timesteps[n]
            view_dict[file_id][key] = read_lammps_dump_text(file, index=n)
        number_of_atoms = len(view_dict[file_id][key])
        if number_of_atoms >5000:
            sg.popup_ok(
                f'Dump file\n{file.name}\ncontains more than 5000 atoms. A smaller section of the simulation box will be shown fot the sake of clarity and performance.', title='Warning')
        return view_dict
    except ValueError:
        sg.popup_ok(
            f'An error occurred while attempting to load\n{file.name}.', title='Error')
    except TypeError:
        sg.popup_ok(
            f'An error occurred while attempting to load\n{file.name}.', title='Error')

In [5]:
def linear_fit(x, y, bound1_id, bound2_id):
    """
    Performs a linear regression on a given set of data points within specified bounds.

    Parameters:
        x (ndarray): The x-coordinates of the data points.
        y (ndarray): The y-coordinates of the data points.
        bound1_id (int): The index of the first bound data point.
        bound2_id (int): The index of the second bound data point.

    Returns:
        tuple: The slope (a) and y-intercept (b) of the linear regression line.

    Raises:
        ValueError: If the bounds are invalid or if the bounds are too close to each other.
    """

    if bound2_id > bound1_id:
        lower_bound_id = bound1_id
        higher_bound_id = bound2_id
    else:
        lower_bound_id = bound2_id
        higher_bound_id = bound1_id
    
    if higher_bound_id > lower_bound_id and higher_bound_id - lower_bound_id > 2:
        x_new = x[lower_bound_id:higher_bound_id+1]
        y_new = y[lower_bound_id:higher_bound_id+1]
        a, b, _, _, _ = scipy.stats.linregress(x_new, y_new)
        return a, b
    else:
        return 1, 1

In [6]:
units_dict = {'metal': {'Step': '', 'Time': 'ps', 'TotEng': 'eV', 'Temp': 'K', 'Volume': '$\\AA^3$', 'Density': '$g/cm^3$', 'Press': 'bar', 'Ndanger': '', 'PotEng': 'eV', 'KinEng': 'eV', 'Enthalpy': 'eV'},
              'real': {'Step': '', 'Time': 'fs', 'TotEng': 'kcal/mol', 'Temp': 'K', 'Volume': '$\\AA^3$', 'Density': '$g/cm^3$', 'Press': 'atm', 'Ndanger': '', 'PotEng': 'kcal/mol', 'KinEng': 'kcal/mol', 'Enthalpy': 'kcal/mol'},
              'SI': {'Step': '', 'Time': 's', 'TotEng': 'J', 'Temp': 'K', 'Volume': '$m^3$', 'Density': '$kg/m^3$', 'Press': 'Pa', 'Ndanger': '', 'PotEng': 'J', 'KinEng': 'J', 'Enthalpy': 'J'},
              'LJ': {'Step': '', 'Time': '', 'TotEng': '', 'Temp': '', 'Volume': '', 'Density': '', 'Press': '', 'Ndanger': '', 'PotEng': '', 'KinEng': '', 'Enthalpy': ''}}


def units(ut, q):
    """
    Returns the units for the given quantity `q` based on the unit type `ut`.

    Parameters:
        ut (str): The unit type of the quantity.
        q (str): The name of the quantity.

    Returns:
        str: The units for the given quantity.
    """

    if q in list(units_dict[ut].keys()):
        unit = units_dict[ut][q]
        if unit == '':
            return unit
        else:
            return f'({unit})'
    else:
        return ''

In [7]:
def slope_unit(ux, uy):
    """
    Returns the slope unit based on the given x and y units.

    Parameters:
        ux (str): The unit for the x-axis.
        uy (str): The unit for the y-axis.

    Returns:
        str: The slope unit based on the given x and y units.
    """

    if ux == '' and uy == '':
        slope_unit = ''
    elif ux == '':
        uy2 = uy.replace('(', '')
        uy2 = uy2.replace(')', '')
        slope_unit = f'(1/{uy2})'
    elif uy == '':
        slope_unit = ux
    else:
        ux2 = ux.replace('(', '')
        ux2 = ux2.replace(')', '')
        uy2 = uy.replace('(', '')
        uy2 = uy2.replace(')', '')
        slope_unit = f'({uy2}/{ux2})'
        
    if '$' in slope_unit:
        slope_unit = slope_unit.replace('$', '')

    return slope_unit

def single_unit(unit):
    """
    Returns a formatted string representation of the given unit.

    Parameters:
        unit (str): The unit to format.

    Returns:
        str: The formatted unit.
    """
    ret = unit
    if f'$' in unit:
        ret = unit.replace('$', '')
    return ret

In [8]:
def format_number(number, prec=4):
    """
    Formats the given number with the specified precision.

    Parameters:
        number (float): The number to format.
        precision (int, optional): The number of significant decimal places to round to. Defaults to 4.

    Returns:
        float: The formatted number as a float.
    """

    str_number = str(number)
    str_number = str_number.split('.')
    zeros = 0
    if len(str_number) == 1:
        return number
    else:
        for char in list(str_number[1]):
            if str_number[0] == '0' and char == '0':
                zeros+=1
        n = prec + zeros
        formatted_number = format(number, f'.{n}f')
        return float(formatted_number)

In [9]:
def intersection(a1, b1, a2, b2):
    """
    Calculates the intersection point of two lines in a 2D coordinate system.

    Parameters:
        a1 (float): The slope of the first line.
        b1 (float): The y-intercept of the first line.
        a2 (float): The slope of the second line.
        b2 (float): The y-intercept of the second line.

    Returns:
        tuple: The coordinates of the intersection point (x, y).
    """

    x_intersection = (b2 - b1) / (a1 - a2)
    y_intersection = a1 * x_intersection + b1

    if math.isnan(x_intersection):
                sg.popup_ok(
                    f'Error occurred while calculating intersection point. Lines are almost parallel.', title='Warning')
                
    return x_intersection, y_intersection


In [10]:
def find_nearest_point(x, y, x0, y0):
    """
    Finds the nearest point from the (x,y) data to a target point.

    Parameters:
        x (ndarray): The array of numeric values.
        y (ndarray): The array of numeric values.
        x0 (float): The 'x' value of the target point.
        x0 (float): The 'y' value of the target point.

    Returns:
        float: The nearest point in the (x,y) data to the target point.
    """

    min_distance = float('inf')
    nearest_point_index = -1
    x_norm = abs(np.max(x)-np.min(x))
    y_norm = abs(np.max(y)-np.min(y))
    x00 = x0 / x_norm
    y00 = y0 / y_norm

    for i in range(len(x)):
        xp = x[i] / x_norm
        yp = y[i] / y_norm
        distance = np.sqrt((xp - x00)**2 + (yp - y00)**2)

        if distance < min_distance:
            min_distance = distance
            nearest_point_index = i

    x_res = x[nearest_point_index]
    y_res = y[nearest_point_index]

    return x_res, y_res

In [12]:
def find_nearest_value(array, value):
    """
    Finds the nearest value in an array to the target value.

    Parameters:
        array (ndarray): The array of numeric values.
        value (float): The target value to find the nearest value to.

    Returns:
        float: The nearest value in the array to the target value.
    """

    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [13]:
def update_headings(table, headings):
    """
    Updates the headings of a table widget with the new headings.

    Parameters:
        table (TableWidget): The table widget to update.
        headings (list): The list of new headings.

    Returns:
        None
    """

    COL_HEADINGS = ['Line', 'Slope', 'Intercept', 'Boundry 1', 'Boundry 2']
    for cid, text in zip(COL_HEADINGS, headings):
        table.heading(cid, text=text)

In [14]:
def setup_tab(keywords, n, window, file_id, plotting_rdf_data, plotting_fit_data, plotting_view_data):
    """
    Setup the tab with default values for combo boxes based on the provided keywords. Then trigger events that plot all the plots inside the tab.

    Parameters:
        keywords (list): A list of keywords.
        n (int): The number of runs in the log file.
        window: The window object.
        file_id (str): The ID of the file.
        plotting_rdf_data (dict): A dictionary of plotting data for RDF plots.
        plotting_fit_data (dict): A dictionary of plotting data for fit plots.
        plotting_view_data (dict): A dictionary of plotting data for view plots.

    Returns:
        None
    """
    combo_n = np.arange(1, n+1).tolist()
    
    if 'Step' or 'Time' in keywords:
        time_or_step = 'Time' if 'Time' in keywords else 'Step'
    else:
        time_or_step = ''
        
    default_xys = [('Temp', 'TotEng'), (time_or_step, 'Temp'), (time_or_step, 'Press'), (time_or_step, 'Ndanger')] 
    
    for i in range(len(default_xys)):
        if (default_xys[i][0] and default_xys[i][1]) not in keywords:
            default_xys[i] = tuple(random.sample(keywords, 2))

    for i in range(1, 5):
        window[f'combo_n{i}_{file_id}'].update(
            value=combo_n[0], values=combo_n)
        window[f'combo_x{i}_{file_id}'].update(
            value=default_xys[i-1][0], values=keywords)
        window[f'combo_y{i}_{file_id}'].update(
            value=default_xys[i-1][1], values=keywords)
        
    for i in ('fit', 'rdf', 'view'):
        window[f'combo_x_{i}_{file_id}'].update(
            value=default_xys[0][0], values=keywords)

        window[f'combo_y_{i}_{file_id}'].update(
            value=default_xys[0][1], values=keywords)

        window[f'combo_n_{i}_{file_id}'].update(
            value=combo_n[0], values=combo_n)
        
    plotting_rdf_data[file_id] = {}
    plotting_rdf_data[file_id]['c_id'] = None
    plotting_fit_data[file_id] = {}
    plotting_fit_data[file_id]['c_id'] = None
    plotting_view_data[file_id] = {}
    plotting_view_data[file_id]['c_id'] = None

    for i in ('overview', 'fit', 'rdf', 'view'):    
        window.write_event_value(
                            f'plot_{i}_{file_id}', None)


In [15]:
def reset_rdf_data(plotting_rdf_data, file_id, c_id):
    plotting_rdf_data[file_id] = {}
    plotting_rdf_data[file_id]['c_id'] = c_id
    plotting_rdf_data[file_id]['point1'] = None
    plotting_rdf_data[file_id]['point2'] = None
    plotting_rdf_data[file_id]['x'] = np.array([])
    plotting_rdf_data[file_id]['y'] = np.array([])

In [None]:
def reset_fit_data(plotting_fit_data, selected_rows_dict, file_id, c_id, win):
    plotting_fit_data[file_id] = {}
    plotting_fit_data[file_id]['c_id'] = c_id
    selected_rows_dict[file_id] = []
    plotting_fit_data[file_id]['boundries'] = []
    plotting_fit_data[file_id]['lines'] = []
    plotting_fit_data[file_id]['table_rows'] = []
    plotting_fit_data[file_id]['intersection_line'] = None
    win[f'intersection_value_{file_id}'].update('')
    win[f'table_fit_{file_id}'].update(values=[])

In [16]:
def remove_red_point(plotting_rdf_data, fig, file_id, name):
        point = plotting_rdf_data[file_id][name]
        if point is not None:
            point = point.pop(0)
            point.remove()
            plotting_rdf_data[file_id][name] = None
            fig.canvas.draw()

In [17]:
def find_rdf_point(x, y, x_click, y_click, file_id, steps):
    x_np, y_np = find_nearest_point(x, y, x_click, y_click)
    index = np.where(x == x_np)[0][0]
    key = steps[index]
    key_list = list(rdf_dict[file_id].keys())
    key_list = [int(i) for i in key_list]
    key_list = np.array(key_list)
    key = find_nearest_value(key_list, key)
    new_index = np.where(steps == key)[0]

    rdf_data = np.array(rdf_dict[file_id][str(key)])
    x_np = x[new_index]
    y_np = y[new_index]

    return x_np, y_np, rdf_data

In [18]:
def get_rdf_col_number(value):
    if value != '':
        return int(value)
    else:
        return 1

In [19]:
def disconnect_on_click(fig, c_id):
    if c_id:
        fig.canvas.mpl_disconnect(c_id)

In [1]:
def choose_atoms(atoms, center, size):
    box_center = center / 100 * atoms.cell.lengths()
    x_min = box_center[0] - size / 100 * atoms.cell.lengths()[0] / 2
    x_max = box_center[0] + size / 100 * atoms.cell.lengths()[0] / 2
    y_min = box_center[1] - size / 100 * atoms.cell.lengths()[1] / 2
    y_max = box_center[1] + size / 100 * atoms.cell.lengths()[1] / 2
    z_min = box_center[2] - size / 100 * atoms.cell.lengths()[2] / 2
    z_max = box_center[2] + size / 100 * atoms.cell.lengths()[2] / 2
    indices = np.where(np.logical_and(atoms.positions[:, 0] < x_max, atoms.positions[:, 0] > x_min) & np.logical_and(atoms.positions[:, 1] < y_max, atoms.positions[:, 1] > y_min) & np.logical_and(atoms.positions[:, 2] < z_max, atoms.positions[:, 2] > z_min))[0]
    new_atoms = Atoms(
        positions=atoms.positions[indices] - box_center + [size /100 * atoms.cell.lengths()[0] / 2, size /100 * atoms.cell.lengths()[1] / 2, size /100 * atoms.cell.lengths()[2] / 2],
        numbers=atoms.get_atomic_numbers()[indices],
        cell=[size / 100 * atoms.cell.lengths()[0], size / 100 * atoms.cell.lengths()[1], size / 100 * atoms.cell.lengths()[2]]
    )

    return new_atoms