# Referee Comments

## Imports

In [None]:
import os
import sys
import pickle
sys.path.append("../py_src")

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import ticker
from matplotlib import colors as mcolors
from matplotlib import cm
from matplotlib.lines import Line2D
from matplotlib.colors import ListedColormap
from adjustText import adjust_text

import sort_neigh

from ase import Atoms
from ase.io import read as ase_read
from ase.neighborlist import natural_cutoffs, NeighborList

In [None]:
### Renaming Dict for better labels:
pretty_label = True
if pretty_label:
    prettylabel_dict = {
            "12": "12 neighbours",
            "11": "11 neighbours",
            "3_111_ad_atom": "(111) adatom",
            "4_111_ad_atom_pair": "(111) adatom pair",
            "4_100_ad_atom": "(100) adatom",
            "5_211_ad_atom": "(211) adatom" ,
            "5_110_ad_atom": "(110) adatom",
            "5_111_terrace": "(111) terrace",
            "6_100-110_interface": "(100)-(110)\ninterface",
            "6_100_terrace": "(100) terrace",
            "7_110": "(110)",
            "7_211": "(211)",
            "8_100": "(100)",
            "8_111_vacant_site": "(111) vacancy",
            "9_111": "(111)",
            "9_invalid": "Subsurface 1",
            "9_invalid_2": "Subsurface 2",
    }

# Get colorbar from fig 1
filename = 'fig1_cbar.pkl'

with open(filename, 'rb') as file:
    data_cat_cmap, data_cat_norm, data_label_color_dict, draw_cats = pickle.load(file)
    file.close()

pretty_draw_cats = []
for draw_cat in draw_cats:
    pretty_draw_cats.append(prettylabel_dict.get(draw_cat, draw_cat))


## Analyse 800K Trajectory

In [None]:
def draw_rhbars(ax, plot_lines, plot_errs, all_not_zero, label_color_dict, sorted_cats, **kwargs):
    """Make plot containing percentage of sites as lineplots.
    This function is not cleanly written and relies on globally defined variables.

    Args:
        axes (list of plt.axis): Figure axes.
        rh_lines (np.ndarray): Lines for left plot.
        rh_errs (np.ndarray): Errors for left plot.
        pd_lines (np.ndarray): Lines for right plot.
        pd_errs (np.ndarray): Errors for right plot.
        all_not_zero (np.ndarray of bool): Must be shaped (rh_lines.shape[1],), with True wherever any of the lines are non-zero
        label_color_dict (dict): Color dict for each category in lines.

    Returns:
        tuple (list of plt.axis, list of legend elements)
    """
    temperatures = [800]
    fontsize = kwargs.get('fontsize', 12)
    ms = kwargs.get('ms', 10)
    lw = kwargs.get('lw', 3)
    
    # Only plot entries that are non-zero in some line.
    nz_slicer = np.s_[:, all_not_zero]
    rh_lines = plot_lines[nz_slicer]
    rh_errs = plot_errs[nz_slicer]

    total_nonzero = np.sum(all_not_zero.astype(np.int8))
    x_range = np.arange(len(temperatures))

    draw_cats = []
    for ii_cat, cat in enumerate(sorted_cats):
        if all_not_zero[ii_cat]:
            draw_cats.append(cat)
            # print(cat, plot_lines[:, ii_cat])

    draw_colors = []
    for draw_cat in draw_cats:
        try:
            draw_colors.append(label_color_dict[draw_cat])
        except:
            draw_colors.append('gray')

    # Get labels from pretty label dict
    if pretty_label:
        plot_labels = [
            prettylabel_dict[nz_cat] if nz_cat in prettylabel_dict.keys() else nz_cat for nz_cat in draw_cats
        ]
    else:
        plot_labels = draw_cats

    legend_elements = []
    for ii_line in range(total_nonzero):
        nz_tf = rh_lines[:, ii_line] != 0
        ax.errorbar(
            x_range[nz_tf], rh_lines[:, ii_line][nz_tf], yerr=rh_errs[:, ii_line][nz_tf],
            color=draw_colors[ii_line], marker='o',
            markersize=ms, linewidth=lw,
            label=plot_labels[ii_line]
        )
        legend_elements.append(Line2D(
            [0], [0],
            linestyle="",
            color=draw_colors[ii_line], fillstyle='left',
            marker='o', markersize=ms,
            lw=lw, label=plot_labels[ii_line]))

    # for axis in axes:
    ax.set_yscale('log')

    ax.set_xticks(x_range)
    ax.set_xticklabels(temperatures, fontsize=fontsize*0.8)

    return ax, legend_elements

In [None]:
r_cut=5.2 # 4.2 # 2.7 
n_max=4
l_max=3
sigma=1.
gamma_kernel=1.

n_rhod = 15

target_path = "../test_data/230523_refreply_800.lammpstrj"

sorter = sort_neigh.NeighbourSort(
    local_structures_path=os.path.abspath("../src/localstructures_final_mcmd_rh"),
    r_cut=r_cut, n_max=n_max, l_max=l_max, sigma=sigma, gamma_kernel=gamma_kernel
)

fname = os.path.basename(target_path).split('.')[0]
save_txt_path = os.path.join(os.path.dirname(target_path), fname+"_soap_sorted_counts.txt")

In [None]:
new_trajectory = False
if new_trajectory:
    sorter.load_particle(target_path)
    cat_counter = sorter.create_local_structure(last_n=n_rhod, create_subfolders=False, mode="class_all", cutoff_mult=0.98)

    file_name = sorter.sort_save_cat(
        file_name=save_txt_path, cat_counter=cat_counter
    )

In [None]:
sorted_counts, timesteps, sorted_cats = sort_neigh.NeighbourSort.load_sort_cat(save_txt_path)

sorted_errs = sort_neigh.NeighbourSort.block_average(sorted_counts, block_size=1000)

rh_totals = np.sum(sorted_counts, axis=0) / (sorted_counts.shape[0] * n_rhod) * 100
plot_errs = sorted_errs / (sorted_counts.shape[0] * n_rhod) * 100

# Bulkify but also remove some misclassifications (1-5 events in trajectory)
bulk_cats = ["9_invalid", "9_invalid_2", "10", "11", "13", "3_111_ad_atom"]
for ii_cat, cat in enumerate(sorted_cats):
    if cat in bulk_cats:
        print("Bulkified: ", cat)
        rh_totals[-3] += rh_totals[ii_cat]
        plot_errs[-3] += plot_errs[ii_cat]

        rh_totals[ii_cat] = 0
        plot_errs[ii_cat] = 0

In [None]:
plt.rcParams.update(plt.rcParamsDefault)
fontsize = 12
if True:
    font = {'family': 'serif', 'size': fontsize}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

fig, ax = plt.subplots(1, 1, figsize=(2, 4))


ax, legend_elements = draw_rhbars(ax,
    plot_lines=rh_totals[np.newaxis, ...], plot_errs=plot_errs[np.newaxis, ...],
    all_not_zero=rh_totals!=0, 
    label_color_dict=data_label_color_dict, sorted_cats=sorted_cats,
    fontsize=12, ms=7
)

bar = fig.colorbar(cm.ScalarMappable(norm=data_cat_norm, cmap=data_cat_cmap), ax=ax)
bar.set_ticks(np.arange(len(draw_cats)))
bar.set_ticklabels(pretty_draw_cats, fontsize=fontsize)

ax.set_title('Rhodium', fontsize=fontsize, visible=True)
ax.set_xlabel("Temperature (K)")
ax.set_ylabel(r'Percentage of Total Configurations (\%)', fontsize=fontsize, loc="bottom")

fig.savefig("refreply_800K.pdf", format='pdf', dpi=300, bbox_inches='tight')

plt.show()

## Classify Higher Index Surfaces

This shows that the nearest neighbours of higher index surfaces behave like lower index surfaces.

In [None]:
def cu_classify(full_particle: Atoms, neighbour_list: NeighborList, rh_classifier: sort_neigh.NeighbourClassifier, mode="class_all"):
    # Empty arrays for storing the categorisation of each copper atom
    cu_cat_counter = np.zeros(shape=(rh_classifier.n_classes), dtype=np.int32)
    categories = np.zeros((len(full_particle),), dtype=np.int32)
    neighbours = np.zeros((len(full_particle),), dtype=np.int32)

    # Go through every single copper atom 
    ind_soaps = np.zeros((len(full_particle), rh_classifier.descr.get_number_of_features()))
    for index in range(len(full_particle)):
        neighbour_indices, trash = neighbour_list.get_neighbors(index)
        neighbour_indices = np.append(np.array([index]), neighbour_indices, axis=0)
        neighbour_particle = full_particle[neighbour_indices]
        
        # Make center atom Rh
        neighbour_particle.symbols[:] = 'Cu'
        neighbour_particle.symbols[0] = 'Rh'
        
        # Create SOAPs of environment and use classifier
        ind_soaps[index] = rh_classifier.descr.create(neighbour_particle, centers=[0])
        n_neigh, class_id = rh_classifier.classify(neighbour_particle, mode=mode, ensure_position=False)

        cu_cat_counter[class_id] += 1
        neighbours[index] = int(n_neigh)
        categories[index] = int(class_id)

    return cu_cat_counter, neighbours, categories, ind_soaps

In [None]:
rh_classifier = sort_neigh.NeighbourClassifier(
    local_structures_path=os.path.abspath("../src/localstructures_final_mc"),
    non_class_max=14
)
rh_classifier.load_identifiers(
    r_cut=r_cut, n_max=n_max, l_max=l_max, 
    sigma=sigma, gamma_kernel=gamma_kernel,
)

In [None]:
%matplotlib inline
high_index_211 = ase_read('supercell-211.xyz')
high_index_311 = ase_read('supercell-311.xyz')

hi_catcounts = []
hi_neighbours = []
hi_cats = []
hi_soaps = []
hi_pcas = []
hi_plot_colors = []

for high_index_cell in [high_index_211, high_index_311]:
    # Define a neighbor list to get neighborhood of each copper atom
    cut_off = natural_cutoffs(high_index_cell, mult=.98, pbc=[True, True, False])# mult=0.98)
    high_index_nl = NeighborList(cut_off, bothways=True, self_interaction=False)
    high_index_nl.update(high_index_cell)

    high_index_cu_cat_counter, high_index_neighbours, high_index_categories, high_index_ind_soaps = cu_classify(high_index_cell, high_index_nl, rh_classifier, mode='pre_group')

    hi_catcounts.append(high_index_cu_cat_counter)
    hi_neighbours.append(high_index_neighbours)
    hi_cats.append(high_index_categories)
    hi_soaps.append(high_index_ind_soaps)
    
    if False:
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(projection='3d')
        pos = high_index_cell.get_scaled_positions()

        ax.scatter(
            pos[:, 0], pos[:, 1], pos[:, 2], c=high_index_categories, alpha=1,
            s=500, edgecolors="k", rasterized=True
        )
        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        ax.set_axis_off()
        plt.show()

    # hi_plot_colors.append([label_color_dict[rh_classifier.id_to_cat(col)] for col in high_index_categories])


In [None]:
xy_cut = 0.2
borders = np.array([[xy_cut, 1-xy_cut], [xy_cut, 1-xy_cut], [0.58, 1]], dtype=np.float32).T

lpos = high_index_211.get_scaled_positions(wrap=True)
lcond = np.logical_and.reduce(np.logical_and(borders[0, :]<lpos, lpos<borders[1, :]), axis=-1)
rpos = high_index_311.get_scaled_positions(wrap=True)
rcond = np.logical_and.reduce(np.logical_and(borders[0, :]<rpos, rpos<borders[1, :]), axis=-1)

# The 7 neighbour structures are hard enough to classify as it is, but this is just for demonstration so simply switch entries here for clarity
plot_label_color_dict = data_label_color_dict.copy()
plot_label_color_dict['7_110'] = plot_label_color_dict['7_211']

lcol = [plot_label_color_dict[rh_classifier.id_to_cat(col)] for ii_col, col in enumerate(hi_cats[0]) if lcond[ii_col]]
rcol = [plot_label_color_dict[rh_classifier.id_to_cat(col)] for ii_col, col in enumerate(hi_cats[1]) if rcond[ii_col]]

In [None]:
plt.rcParams.update(plt.rcParamsDefault)
fontsize = 12
if True:
    font = {'family': 'serif', 'size': fontsize}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

fig = plt.figure(figsize=(9, 4))

ms = 1500

cb_x = 0.05
cb_y = 0.6
l_ax = fig.add_axes((0, 0, 0.4, 1), projection='3d')
l_ax.view_init(elev=30, azim=45, roll=15)
cb_ax = fig.add_axes((0.4, (1-cb_y)/2, cb_x, cb_y))
r_ax = fig.add_axes((0.6, 0, 0.4, 1), projection='3d')
r_ax.view_init(elev=30, azim=20, roll=0)

l_ax.scatter(
    lpos[lcond, 0], lpos[lcond, 1], lpos[lcond, 2], c=lcol, alpha=1,
    s=ms, edgecolors="k", rasterized=True
)
r_ax.scatter(
    rpos[rcond, 0], rpos[rcond, 1], rpos[rcond, 2], c=rcol, alpha=1,
    s=ms, edgecolors="k", rasterized=True
)

bar = fig.colorbar(cm.ScalarMappable(norm=data_cat_norm, cmap=data_cat_cmap), cax=cb_ax)
bar.set_ticks(np.arange(len(draw_cats)))
bar.set_ticklabels(pretty_draw_cats, fontsize=fontsize)

for ax in [l_ax, r_ax]:
    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    ax.set_axis_off()

l_ax.set_title('211')
r_ax.set_title('311')

fig.savefig("refreply_highindex.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()