# Produce Plots for Publication

Use self built surface classification library `surface_classify` available at https://github.com/FelixWodaczek/surface_classify.git.

This script imports some pre calculated values and does analysis on them to draw figure 1.

## Imports

In [None]:
import sys
sys.path.append("../py_src")

from glob import glob
import os

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import ticker
from matplotlib import colors as mcolors
from matplotlib import cm
from matplotlib.colors import ListedColormap
from adjustText import adjust_text

import sort_neigh

from ase.io import read as ase_read
from ase.neighborlist import natural_cutoffs, NeighborList
from dscribe.descriptors import LMBTR, SOAP

## Standard Values

In [None]:
simu_type = 'rh'
target_dir = "../test_data/230118_finalres_newlocalstruct_preclass/"
pd_target_dir = target_dir+'pd'
target_dir = target_dir + simu_type
n_particles = 1577
n_rhod = 15


r_cut=5.2
n_max=4
l_max=3
sigma=1.
gamma_kernel=1.

In [None]:
### Renaming Dict for better labels:
pretty_label = True
if pretty_label:
    prettylabel_dict = {
            "12": "12 neighbours",
            "11": "11 neighbours",
            "3_111_ad_atom": "(111) adatom",
            "4_111_ad_atom_pair": "(111) adatom pair",
            "4_100_ad_atom": "(100) adatom",
            "5_211_ad_atom": "(211) adatom" ,
            "5_110_ad_atom": "(110) adatom",
            "5_111_terrace": "(111) terrace",
            "6_100-110_interface": "(100)-(110)\ninterface",
            "6_100_terrace": "(100) terrace",
            "7_110": "(110)",
            "7_211": "(211)",
            "8_100": "(100)",
            "8_111_vacant_site": "(111) vacancy",
            "9_111": "(111)",
            "9_invalid": "Subsurface 1",
            "9_invalid_2": "Subsurface 2",
    }

## Subfigure: Single Snapshot of Nanoparticle

Analyse the available sites in a nanoparticle containing only Rh.

### Define Target

In this case a single nanoparticle in a .lammpstrj file.
It contains only Copper in an optimised configuration.
\

In [None]:
only_cu_dir = target_dir + "/cunanoparticle"
only_cu_path = only_cu_dir + "/cusingle.lammpstrj"

Build a classifier using previously set values.
This builds the SOAP dictionary from `../src/localstructures_final_mc` to later use for classifying sites.

In [None]:
rh_classifier = sort_neigh.NeighbourClassifier(
    local_structures_path=os.path.abspath("../src/localstructures_final_mc"),
    non_class_max=14
)
rh_classifier.load_identifiers(
    r_cut=r_cut, n_max=n_max, l_max=l_max, 
    sigma=sigma, gamma_kernel=gamma_kernel,
)


### Run Analysis via Classifier
Use the surface classifier to determine what sites are identified at which position.

In [None]:
full_particle = ase_read(only_cu_path)
at_pos = full_particle.get_positions()

In [None]:
from ase.visualize import view
mode="class_all"

# Define a neighbor list to get neighborhood of each copper atom
cut_off = natural_cutoffs(full_particle, mult=.98)# mult=0.98)
neighbour_list = NeighborList(cut_off, bothways=True, self_interaction=False)
neighbour_list.update(full_particle)

# Empty arrays for storing the categorisation of each copper atom
cu_cat_counter = np.zeros(shape=(rh_classifier.n_classes), dtype=np.int32)
categories = np.zeros((len(full_particle),), dtype=np.int32)
neighbours = np.zeros((len(full_particle),), dtype=np.int32)

# Go through every single copper atom 
ind_soaps = np.zeros((len(full_particle), rh_classifier.descr.get_number_of_features()))
for index in range(len(full_particle)):
    neighbour_indices, trash = neighbour_list.get_neighbors(index)
    neighbour_indices = np.append(np.array([index]), neighbour_indices, axis=0)
    neighbour_particle = full_particle[neighbour_indices]
    
    # Make center atom Rh
    neighbour_particle.symbols[:] = 'Cu'
    neighbour_particle.symbols[0] = 'Rh'
    
    # Create SOAPs of environment and use classifier
    ind_soaps[index] = rh_classifier.descr.create(neighbour_particle, centers=[0])
    n_neigh, class_id = rh_classifier.classify(neighbour_particle, mode=mode, ensure_position=False)

    cu_cat_counter[class_id] += 1
    neighbours[index] = int(n_neigh)
    categories[index] = int(class_id)

In [None]:
for ii in range(len(cu_cat_counter)):
    print(rh_classifier.id_to_cat(ii), ': %u'%cu_cat_counter[ii])

### Unsupervised Machine Learning

Use unsupervised ML to perform the same task and compare results afterwards.
Uses another classifier which instead of building a dictionary builds a machine learned classifier.

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, KernelPCA
from sklearn.cluster import KMeans, DBSCAN, Birch
from dscribe.descriptors import SOAP

ml_classifier = sort_neigh.USMLClassifier()
n_clust = ml_classifier._train_on_data(
    ind_soaps,
    dim_red=PCA(n_components=2), clusterer=Birch(n_clusters=8)
)
ml_classifier.descr = rh_classifier.descr

# soaps = ml_classifier.descr.create(full_particle)
reduced = ml_classifier.dim_red.transform(ind_soaps)

### Plotting
Create view of nanoparticle and reduced mapping to make sure everything classified correctly

In [None]:
def found_labels(categories, labels:list=None):
    """From a vector of int categories, find the labels that are actually in it.

    Args:
        categories (np.ndarray): Vector of integers
        labels (list, optional): Labels for each integer. If None is given, simply uses the numbers itself. Defaults to None.

    Returns:
        list: Labels in labels that were found within categories.
    """
    if labels is None:
        labels = np.arange(0, np.max(categories)+1).tolist()

    ret_labels = []
    for ii_label, label in enumerate(labels):
        if ii_label in categories:
            ret_labels.append(label)
    
    return ret_labels


In [None]:
%matplotlib inline
plt.rcParams.update(plt.rcParamsDefault)

cond = np.s_[...] # neighbours < 17

# Take classifications and use as colors for plotting on scatterplot
colors = categories[cond].copy()
# colors = n_clust

# Get labels from classifier
# c_labels = np.arange(np.min(colors), np.max(colors)+1)
# c_labels = [rh_classifier.id_to_cat(c_label) for c_label in c_labels]

c_labels = [rh_classifier.id_to_cat(id) for id in range(rh_classifier.n_classes)]
c_labels = found_labels(colors, c_labels)

# Move 11 and 12 neighbours to back
order_by_neighbors = True
if order_by_neighbors:
    c_labels.append(c_labels[0])
    c_labels.append(c_labels[1])
    c_labels.pop(0)
    c_labels.pop(0)

# Create a categorised colormap for found categories and save it in a dictionary
cat_clabels = c_labels.copy()
cat_cmap = cm.get_cmap('tab10', len(c_labels))
cat_norm = mcolors.Normalize(vmin=-0.5, vmax=len(c_labels)-0.5)
label_color_dict = {c_labels[ii_label]: cat_cmap(cat_norm(ii_label)) for ii_label in range(len(c_labels))}
label_color_dict['Non-surface'] = label_color_dict['12']

plot_colors = [label_color_dict[rh_classifier.id_to_cat(col)] for col in colors]

# Plot particle
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')

sc = ax.scatter(
    at_pos[cond, 0], at_pos[cond, 1], at_pos[cond, 2], c=plot_colors, alpha=1,
    s=800, edgecolors="k"
)

if pretty_label:
    plot_labels = [
        prettylabel_dict[c_label] if c_label in prettylabel_dict.keys() else c_label for c_label in c_labels
    ]
else:
    plot_labels = c_labels
    
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

# bar = fig.colorbar(sc, cmap=cat_cmap)
bar = fig.colorbar(cm.ScalarMappable(norm=cat_norm, cmap=cat_cmap))
bar.set_ticks(np.arange(len(c_labels)))
bar.set_ticklabels(plot_labels)
ax.set_axis_off()

plt.tight_layout()
# fig.savefig("threed_particle.pdf", format='pdf')
plt.show()

## Subfigure: Nanoparticle on PCA map

Add the nanoparticle onto a PCA map built from the SOAPs found on the nanoparticle.

In [None]:
soaps_from_classifier, labels = rh_classifier.get_soaps()
if pretty_label:
    p_labels = [
        prettylabel_dict[blabel] if blabel in prettylabel_dict.keys() else blabel for blabel in labels
    ]
print(labels)
print(soaps_from_classifier.shape)
soap_prediction = ml_classifier.dim_red.transform(soaps_from_classifier)

In [None]:
plt.rcParams.update(plt.rcParamsDefault)

compare_soaps = True
fontsize = 12

def base_pca_particle(fig, ax,
    particle_positions, particle_colors, # Particle and its colors
    pca_points, pca_colors, 
    pca_dict_vals, dict_labels, dict_prettylabels,
    label_color_dict, 
    ms, fontsize, **kwargs
):
    """Plotting function for creating a pca map of soaps for a particle.

    Args:
        fig (plt.Figure): Figure for plotting.
        ax (plt.axis): Axis for plotting.
        particle_positions (np.ndarray): (n_atoms, 3) array of atom positions
        particle_colors (np.ndarray): (n_atoms,) array of colors for plotting 3d particle
        pca_points (np.ndarray): (n_atoms, 2) PCA of SOAPs of atoms
        pca_colors (np.ndarray): (n_atoms,) colors on pca map of every atom
        pca_dict_vals (np.ndarray): PCA of SOAPs of dictionary
        dict_labels (list): labels of dictionary
        dict_prettylabels (list): prettified labels of dictionary
        label_color_dict (dict): color dictionary for dictionary values
        ms (int or float): markersize on PCA scatterplot
        fontsize (int or float): font size in figure

    Returns:
        tuple (plt.Figure, plt.axis, plt.axis): (figure, pca axis, inset 3d scatterplot axis)
    """
    # Plot pca scatterplot
    sc = ax.scatter(
        pca_points[:, 0], pca_points[:, 1], 
        c=pca_colors, s=ms, rasterized=True
    )

    # If soaps from classifier dictionary should be drawn, do that here
    if kwargs.get("compare_soaps", True):
        plot_cols = []
        for ii_label, label in enumerate(dict_labels):
            try:
                plot_cols.append(label_color_dict[label])
            except:
                plot_cols.append("grey")
        
        # Draw on pca scatterplot
        ax.scatter(pca_dict_vals[:, 0], pca_dict_vals[:, 1], c=plot_cols, s=ms*kwargs.get('ms_mult', 1.3), edgecolors='k')

        # Add labels to scatterplot
        texts = []
        for ii_label, label in enumerate(labels):
            texts.append(ax.text(
                pca_dict_vals[ii_label, 0], pca_dict_vals[ii_label, 1], 
                dict_prettylabels[ii_label], ha='right', fontsize=fontsize
            ))
        
        # Adjust labels using adjustText library, controllable via kwargs
        adjust_text(
            texts, pca_dict_vals[:, 0], pca_dict_vals[:, 1], ax=ax, 
            expand_text=kwargs.get('expand_text', (1.11, 1.25)), 
            expand_points=kwargs.get('expand_points', (1.4, 1.7)),
            force_text=kwargs.get('force_text', (0.01, 0.25)), 
            force_points=kwargs.get('force_points', (1.21, 1.21)),
            only_move={'text':'xy'}, 
            arrowprops=kwargs.get('arrowprops', dict(arrowstyle='-', lw=2, color='k', alpha=.5))
        )

    ax.tick_params(axis='x', labelsize=fontsize)
    ax.tick_params(axis='y', labelsize=fontsize)
    # ax.set_axis_off()

    axins = ax.inset_axes(bounds=kwargs.get('p_bounds', [0.05, 0.372, 0.5, 0.5]), projection='3d')

    axins.scatter(
        particle_positions[:, 0], particle_positions[:, 1], particle_positions[:, 2], c=particle_colors, alpha=kwargs.get("p_alpha", 1),
        s=kwargs.get('p_ssize', 200), edgecolors="k", rasterized=True
    )
    axins.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    axins.set_axis_off()
    x_lims = axins.get_xlim()
    x_span = x_lims[1] - x_lims[0]
    x_limmult = kwargs.get("p_xlimmult", -0.15)
    x_lims = (x_lims[0]-(x_span*x_limmult/2), x_lims[1]+(x_span*x_limmult/2))
    axins.set_xlim(x_lims)
    axins.set_ylim(x_lims)
    axins.set_zlim(x_lims)

    return fig, ax, axins

fig, ax = plt.subplots(1, 1, figsize=(9, 6))

fig, ax, axin = base_pca_particle(fig, ax,
    particle_positions=at_pos, particle_colors=plot_colors,
    pca_points=reduced, pca_colors=plot_colors,
    pca_dict_vals=soap_prediction, dict_labels=labels, dict_prettylabels=p_labels,
    label_color_dict=label_color_dict, ms=70, fontsize=fontsize
)
        
bar = fig.colorbar(cm.ScalarMappable(norm=cat_norm, cmap=cat_cmap))
bar.set_ticks(np.arange(len(c_labels)))
bar.set_ticklabels(plot_labels, fontsize=fontsize)

fig.savefig("fig1b_classification_pca.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()

## Subfigure: Development of Rh through Trajectory
After defining all the available positions in the nanoparticle, analyse where the Rh sit in each timestep.

### Run Evaluation on Files
Pre-existing evaluation is loaded from .txt files.
Use the classifier in `sort_neigh` for creating and saving these .txt files.

In [None]:
%matplotlib inline
plt.rcParams.update(plt.rcParamsDefault)

def load_all(target_dir):
    """Load all trajectory classifications in a directory structure built like:

    /target_dir
        /md
            x.lammpstrj # Trajectory (not strictly necessary, can also be blank files.)
            x_soap_sorted_counts.txt # Analysis for trajectory
            y.lammpstrj
            y_soap_sorted_counts.txt
            z.lammpstrj
            z_soap_sorted_counts.txt
        /mcmd
            x.lammpstrj
            x_soap_sorted_counts.txt
            y.lammpstrj
            y_soap_sorted_counts.txt
            z.lammpstrj
            z_soap_sorted_counts.txt

    Args:
        target_dir (string, path): top directory for saved files

    Returns:
        tuple (dict of dicts, int): (dictionary containing results ["x", "y", "z"] with "sorted_counts", "timesteps" and "sorted_cats", number of categories)
    """
    target_folders = [
        target_dir+"/mc",
        target_dir+"/mcmd"
    ]

    results_dict = {}
    n_cats = 0

    bulkify = True # Sum up 11 and 12 neighbours to "bulk"

    for target_folder in target_folders:
        for target_file in glob(target_folder+"/*.lammpstrj"):
            target_file = os.path.abspath(target_file)
            only_file = os.path.basename(target_file).split(".")[0]
            out_dir = os.path.join(os.path.dirname(target_file), only_file+"_soap_sorted_counts.txt")
            save_txt_path = os.path.join(os.path.dirname(target_file), only_file+"_soap_sorted_counts.txt")
            
            dir_name = save_txt_path.split("/")[-2]
            cur_key = '_'.join([dir_name, only_file])
            results_dict[cur_key] = {}
            
            sorted_counts, timesteps, sorted_cats = sort_neigh.NeighbourSort.load_sort_cat(save_txt_path)
            
            # print(dir_name)
            # sort_neigh.NeighbourSort.plot_dist(sorted_cats, sorted_counts)
            
            if bulkify:
                bulk_args = []
                for ii_scat, scat in enumerate(sorted_cats):
                    try:
                        if int(scat) >= 10 and int(scat) < 99:
                            bulk_args.append(ii_scat)
                    except:
                        pass
                bulk_args = np.array(bulk_args)
                bulk_slicer = np.s_[:, bulk_args]
                sorted_counts = np.append(
                    sorted_counts, np.sum(
                        sorted_counts[bulk_slicer], axis=-1
                    )[:, np.newaxis], axis=-1
                )
                sorted_cats.append('Non-surface')
                sorted_counts[bulk_slicer] = 0

            results_dict[cur_key]["sorted_counts"] = sorted_counts
            results_dict[cur_key]["timesteps"] = timesteps
            results_dict[cur_key]["sorted_cats"] = sorted_cats
            n_cats = max(n_cats, len(sorted_cats))
    
    return results_dict, n_cats

results_dict, n_cats = load_all(target_dir=target_dir)
pd_results_dict, pd_ncats = load_all(target_dir=pd_target_dir)

Reshape Data for better handling

In [None]:
def reshape_res_dict(results_dict):
    block_size = 1000
    md_counts = np.zeros((3 , n_cats), dtype=np.float64)
    md_errs = md_counts.copy()
    mcmd_counts = np.zeros((3 , n_cats), dtype=np.float64)
    mcmd_errs = mcmd_counts.copy()

    counter = 0
    for key, subdict in sorted(results_dict.items()):
        if not counter: # For first dictionary
            sorted_cats = subdict["sorted_cats"]
        if not "mcmd" in key:
            md_counts[counter%3, :] = np.sum(subdict["sorted_counts"], axis=0)
            md_counts[counter%3, :] = md_counts[counter%3, :] / (subdict["sorted_counts"].shape[0] * n_rhod)
            md_counts[counter%3, :] *= 100

            md_errs[counter%3, :] = sort_neigh.NeighbourSort.block_average(subdict["sorted_counts"], block_size=block_size)        
            md_errs[counter%3, :] = md_errs[counter%3, :] / (subdict["sorted_counts"].shape[0] * n_rhod)
            md_errs[counter%3, :] *= 100
        else:
            mcmd_counts[counter%3, :] = np.sum(subdict["sorted_counts"], axis=0)
            mcmd_counts[counter%3, :] = mcmd_counts[counter%3, :] / (subdict["sorted_counts"].shape[0] * n_rhod)
            mcmd_counts[counter%3, :] *= 100
            
            mcmd_errs[counter%3, :] = sort_neigh.NeighbourSort.block_average(subdict["sorted_counts"], block_size=block_size)        
            mcmd_errs[counter%3, :] = mcmd_errs[counter%3, :] / (subdict["sorted_counts"].shape[0] * n_rhod)
            mcmd_errs[counter%3, :] *= 100
        counter+=1

    return md_counts, md_errs, mcmd_counts, mcmd_errs, sorted_cats

md_counts, md_errs, mcmd_counts, mcmd_errs, sorted_cats = reshape_res_dict(results_dict=results_dict)
md_notzero = np.logical_not(np.logical_and.reduce(md_counts==0, axis=0))
mcmd_notzero = np.logical_not(np.logical_and.reduce(mcmd_counts==0, axis=0))

pd_md_counts, pd_md_errs, pd_mcmd_counts, pd_mcmd_errs, pd_sorted_cats = reshape_res_dict(results_dict=pd_results_dict)
pd_md_notzero = np.logical_not(np.logical_and.reduce(pd_md_counts==0, axis=0))
pd_mcmd_notzero = np.logical_not(np.logical_and.reduce(pd_mcmd_counts==0, axis=0))

In [None]:
plt.rcParams.update(plt.rcParamsDefault)

from matplotlib.lines import Line2D

def draw_rhpd_twoax(axes, rh_lines, rh_errs, pd_lines, pd_errs, all_not_zero, label_color_dict, **kwargs):
    """Make plot containing percentage of sites as lineplots.
    This function is not cleanly written and relies on globally defined variables.

    Args:
        axes (list of plt.axis): Figure axes.
        rh_lines (np.ndarray): Lines for left plot.
        rh_errs (np.ndarray): Errors for left plot.
        pd_lines (np.ndarray): Lines for right plot.
        pd_errs (np.ndarray): Errors for right plot.
        all_not_zero (np.ndarray of bool): Must be shaped (rh_lines.shape[1],), with True wherever any of the lines are non-zero
        label_color_dict (dict): Color dict for each category in lines.

    Returns:
        tuple (list of plt.axis, list of legend elements)
    """
    temperatures = [400, 500, 600]
    fontsize = kwargs.get('fontsize', 12)
    ms = kwargs.get('ms', 10)
    lw = kwargs.get('lw', 3)
    
    # Only plot entries that are non-zero in some line.
    nz_slicer = np.s_[:, all_not_zero]
    rh_lines = rh_lines[nz_slicer]
    rh_errs = rh_errs[nz_slicer]
    pd_lines = pd_lines[nz_slicer]
    pd_errs = pd_errs[nz_slicer]

    total_nonzero = np.sum(all_not_zero.astype(np.int8))
    x_range = np.arange(3)

    draw_cats = []
    for ii_cat, cat in enumerate(sorted_cats):
        if all_not_zero[ii_cat]:
            draw_cats.append(cat)
            # print(cat, md_counts[:, ii_cat])

    draw_colors = []
    for draw_cat in draw_cats:
        try:
            draw_colors.append(label_color_dict[draw_cat])
        except:
            draw_colors.append('gray')

    # Get labels from pretty label dict
    if pretty_label:
        plot_labels = [
            prettylabel_dict[nz_cat] if nz_cat in prettylabel_dict.keys() else nz_cat for nz_cat in draw_cats
        ]
    else:
        plot_labels = draw_cats

    legend_elements = []
    for ii_line in range(total_nonzero):
        nz_tf = rh_lines[:, ii_line] != 0
        axes[0].errorbar(
            x_range[nz_tf], rh_lines[:, ii_line][nz_tf], yerr=rh_errs[:, ii_line][nz_tf],
            color=draw_colors[ii_line], marker='o',
            markersize=ms, linewidth=lw,
            label=plot_labels[ii_line]
        )
        nz_tf = pd_lines[:, ii_line] != 0
        axes[1].errorbar(
            x_range[nz_tf], pd_lines[:, ii_line][nz_tf], yerr=pd_errs[:, ii_line][nz_tf],
            color=draw_colors[ii_line], marker='o', markersize=ms, linewidth=lw,
            label=plot_labels[ii_line]
        )
        legend_elements.append(Line2D(
            [0], [0],
            linestyle="",
            color=draw_colors[ii_line], fillstyle='left',
            marker='o', markersize=ms,
            lw=lw, label=plot_labels[ii_line]))

    for axis in axes:
        axis.set_yscale('log')

        axis.set_xticks(x_range)
        axis.set_xticklabels(temperatures, fontsize=fontsize*0.8)

    return axes, legend_elements

In [None]:
draw_rhcounts = mcmd_counts.copy()
draw_rherrs = mcmd_errs.copy()
draw_pdcounts = pd_mcmd_counts.copy()
draw_pderrs = pd_mcmd_errs.copy()

mcmd_all_not_zero = np.logical_or(mcmd_notzero, pd_mcmd_notzero)
total_nonzero = np.sum(mcmd_all_not_zero.astype(np.int8))
x_range = np.arange(3)

# add categories to bulk that are not initially in there
bulk_cats = ["9_invalid", "9_invalid_2"]
for ii_cat, cat in enumerate(sorted_cats):
    if cat in bulk_cats:
        print("Bulkified: ", cat)
        draw_rhcounts[:, -1] += draw_rhcounts[:, ii_cat]
        draw_rherrs[:, -1] += draw_rherrs[:, ii_cat]
        draw_pdcounts[:, -1] += draw_pdcounts[:, ii_cat]
        draw_pderrs[:, -1] += draw_pderrs[:, ii_cat]

        draw_rhcounts[:, ii_cat] = 0
        draw_rherrs[:, ii_cat] = 0
        draw_pdcounts[:, ii_cat] = 0
        draw_pderrs[:, ii_cat] = 0

        mcmd_all_not_zero[ii_cat] = False

draw_cats = []
for ii_cat, cat in enumerate(sorted_cats):
    if mcmd_all_not_zero[ii_cat]:
        draw_cats.append(cat)
        # print(draw_pdcounts[:, ii_cat], cat)

data_cat_cmap = cm.get_cmap('tab10', len(draw_cats))
data_cat_norm = mcolors.Normalize(vmin=-0.5, vmax=len(draw_cats)-0.5)
data_label_color_dict = {draw_cats[ii_label]: data_cat_cmap(data_cat_norm(ii_label)) for ii_label in range(len(draw_cats))}
data_label_color_dict['12'] = data_label_color_dict['Non-surface']
data_label_color_dict['11'] = data_label_color_dict['Non-surface']
data_label_color_dict['10'] = data_label_color_dict['Non-surface']
for bcat in bulk_cats:
    data_label_color_dict[bcat] = data_label_color_dict['Non-surface']

## Subfigure: PCA map of Trajectory

This builds a PCA projection of all SOAPs from a 600K trajectory of Pd as a dopant.
On it, a snapshot of the nanoparticle with classified Pd sites is to be plotted, which is imported here.
For this field to work, there needs to at least be a single trajectory available, currently this is test_data/230118_finalres_newlocalstruct_preclass/pd/600.lammpstrj.
If it should be re-run, set the `new_soaps` flag to `True`, otherwise existing trajectories will be loaded.

In [None]:
class VerboseSorter(sort_neigh.NeighbourSort):
    """Quick child class of sorter, which classifies but also returns a pca projection of each analysed site.
    The only element that changes is copy pasted except for also projecting onto a given dimensionality reduction.
    """
    def get_cats_dimred(self, dim_red, cutoff_mult=0.9, last_n=15, **kwargs):
        cat_ids = np.zeros((self.timesteps*last_n,), dtype=np.int8)
        pcas = np.zeros((self.timesteps*last_n, 2), dtype=np.float32)

        init_particle = self.particle_trajectory[0]
        particle_len = len(init_particle)
        particle_range = (particle_len - np.arange(last_n)[::-1])-1

        cut_off = natural_cutoffs(init_particle, mult=cutoff_mult)
        neighbour_list = NeighborList(cut_off, bothways=True, self_interaction=False)
        neighbour_list.update(init_particle)

        for step in self.progressbar(range(self.timesteps), "At Timestep:", size=40):
            cur_particle = self.particle_trajectory[step]
            neighbour_list.update(cur_particle)

            for ii_ind, index in enumerate(particle_range):
                neighbour_indices, dists = neighbour_list.get_neighbors(index)
                neighbour_indices = np.append(np.array([index]), neighbour_indices, axis=0)

                neighbour_particle = cur_particle[neighbour_indices]
                neighbour_particle.symbols[:] = 'Cu'
                neighbour_particle.symbols[0] = 'Rh'
                n_neighbours = len(neighbour_particle) - 1
                
                n_neigh, class_id = self.classifier.classify(neighbour_particle, **kwargs)
                soaps = self.classifier.descr.create(neighbour_particle, centers=[0])

                cat_ids[(step*last_n)+ii_ind] = class_id
                pcas[(step*last_n)+ii_ind, :] = dim_red.transform(soaps)[..., :2]

        return cat_ids, pcas

new_soaps = False
new_sorter = VerboseSorter(
    local_structures_path=os.path.abspath("../src/localstructures_final_mcmd_pd"),
    non_class_max=14, r_cut=r_cut, n_max=n_max, l_max=l_max,
    sigma=sigma, gamma_kernel=gamma_kernel,
)
traj_path = "../test_data/230118_finalres_newlocalstruct_preclass/pd/mcmd/600.lammpstrj"
new_sorter.load_particle(traj_path)
if new_soaps:
    example_catids, example_pca = new_sorter.get_cats_dimred(dim_red=ml_classifier.dim_red, last_n=n_rhod, mode='pre_group')
    np.save(os.path.join(os.path.dirname(traj_path), "example_catids.npy"), example_catids)
    np.save(os.path.join(os.path.dirname(traj_path), "example_pca.npy"), example_pca)
else:
    example_catids, example_pca = np.load(os.path.join(os.path.dirname(traj_path), "example_catids.npy")), np.load(os.path.join(os.path.dirname(traj_path), "example_pca.npy"))



In [None]:
example_colors = [data_label_color_dict[new_sorter.classifier.id_to_cat(catid)] for catid in example_catids]

example_ts = 3412
example_pos = new_sorter.particle_trajectory[example_ts].get_positions()
cu_col = 'peru'

# Go through particle, if a Rh is found get classification from previous classification
example_pcolors, example_sizes = [], []
rh_counter = 0
for ii_symb, symb in enumerate(new_sorter.particle_trajectory[example_ts].get_chemical_symbols()):
    if symb == 'Rh':
        example_pcolors.append(example_colors[(example_ts*n_rhod)+rh_counter])
        example_sizes.append(125)
        rh_counter += 1
    else:
        example_pcolors.append(mcolors.to_rgba(cu_col, alpha=1))
        example_sizes.append(0.5)

example_soaps_from_classifier, example_labels = new_sorter.classifier.get_soaps()
if pretty_label:
    example_plabels = [
        prettylabel_dict[blabel] if blabel in prettylabel_dict.keys() else blabel for blabel in example_labels
    ]
example_soap_prediction = ml_classifier.dim_red.transform(example_soaps_from_classifier)

## Plot Figure

Combine all the subfigures tested in the previous parts into one.

In [None]:
plt.rcParams.update(plt.rcParamsDefault)
if True:
    fontsize = 12
    font = {'family': 'serif', 'size': fontsize}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

fig = plt.figure(figsize=(11, 6))
fontsize = 12
ms = 7

marg = 0.02 # margin between figures
soap_vmarg = 0.05 # margin between soap plots
center_marg = 0.04 # margin between left and right side

rsw = 0.6 # width of right side
lsw = 1 - rsw - center_marg # width of left side

lmarg = 0.025 # margin between logplots
lth = 0.55 # 0.6 # total height of logplots
ltw = (lsw - (3*marg))/3

rate_axes = []

rate_axes.append(fig.add_axes((0, 0, ltw, lth))) # First logplot 
rate_axes.append(fig.add_axes((ltw+marg, 0, ltw, lth))) # Second logplot

# here is where the colorbar goes
cax = fig.add_axes(((ltw+marg)*2, lmarg, ltw/4, lth-lmarg*4))

soap_axes = []
soap_axes.append(fig.add_axes((lsw+center_marg, ((1-soap_vmarg)/2)+soap_vmarg, rsw, (1-soap_vmarg)/2)))
soap_axes.append(fig.add_axes((lsw+center_marg, 0, rsw, (1-soap_vmarg)/2)))

plt_ax, legend_elements = draw_rhpd_twoax([rate_axes[0], rate_axes[1]], 
    rh_lines=draw_rhcounts, rh_errs=draw_rherrs, 
    pd_lines=draw_pdcounts, pd_errs=draw_pderrs, 
    all_not_zero=mcmd_all_not_zero, 
    label_color_dict=data_label_color_dict, 
    fontsize=fontsize, ms=ms
)
rate_axes[0] = plt_ax[0]
rate_axes[1] = plt_ax[1]
    
rate_axes[1].set_yticklabels([], visible=False)
rate_axes[0].set_title('Rhodium', fontsize=fontsize, visible=True)
rate_axes[1].set_title('Palladium', fontsize=fontsize, visible=True)
rate_axes[0].set_ylabel(r'Percentage of Total Configurations (\%)', fontsize=fontsize, loc="bottom")
# rate_axes[0].yaxis.set_label_coords(-0.1,-0.3)
rate_axes[1].set_ylabel('')

for rax in rate_axes:
    rax.grid(axis='y', which='major', visible=True, color='k', alpha=0.8)
    rax.grid(axis='y', which='minor', visible=False) # , alpha=0.6, linestyle='--')
    rax.set_ylim([1e-3, 1.15e2])
    rax.set_xlabel('Temperature (K)', fontsize=fontsize)

pretty_cats = []
for draw_cat in draw_cats:
    pretty_cats.append(prettylabel_dict.get(draw_cat, draw_cat))

bar = fig.colorbar(cm.ScalarMappable(norm=data_cat_norm, cmap=data_cat_cmap), cax=cax)
bar.set_ticks(np.arange(len(draw_cats)))
bar.set_ticklabels(pretty_cats, fontsize=fontsize)

red_colors = [data_label_color_dict[rh_classifier.id_to_cat(col)] for col in categories]
fig, soap_axes[0], axin = base_pca_particle(fig, soap_axes[0],
    particle_positions=at_pos, particle_colors=red_colors,
    pca_points=reduced, pca_colors=red_colors,
    pca_dict_vals=soap_prediction, dict_labels=labels, dict_prettylabels=p_labels,
    label_color_dict=data_label_color_dict, 
    ms=50, fontsize=fontsize-2,
    p_ssize=130, p_xlimmult=-0.25,
    p_bounds=[0.0, 0.472, 0.5, 0.5],
    expand_text=(0.5, 0.9), expand_points=(1.6, 1.6), force_text=(0.6, 0.6), force_points=(0.4, 0.4)
)
fig, soap_axes[1], axin_1 = base_pca_particle(fig, soap_axes[1],
    particle_positions=example_pos, particle_colors=example_pcolors, # p_alpha=example_palphas,
    pca_points=example_pca, pca_colors=example_colors,
    pca_dict_vals=example_soap_prediction, dict_labels=example_labels, dict_prettylabels=example_plabels,
    label_color_dict=data_label_color_dict,
    ms=50, fontsize=fontsize-2,
    p_bounds=[0.0, 0.472, 0.5, 0.5],
    p_ssize=example_sizes, p_xlimmult=-0.25,
    expand_text=(0.8, 0.9), expand_points=(1.6, 1.6), force_text=(0.6, 0.6), force_points=(0.4, 0.4)
)

# soap_axes[0].set(yticklabels=[], xticklabels=[])
soap_axes[0].tick_params(axis='both', which='major', direction='in', labelleft=True, labelbottom=True)
soap_axes[1].tick_params(axis='both', which='major', direction='in', labelleft=True, labelbottom=True)
soap_axes[0].xaxis.set_major_locator(ticker.MaxNLocator(4))
soap_axes[0].yaxis.set_major_locator(ticker.MaxNLocator(5))
soap_axes[1].xaxis.set_major_locator(ticker.MaxNLocator(4))
soap_axes[1].yaxis.set_major_locator(ticker.MaxNLocator(5))

fig.savefig("fig1_classification_pca_nosplit.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()

### Plot additional snapshot of Pd

A single non-classified snapshot of Pd trajectory, differentiating between copper and palladium sites.

In [None]:
atom_brown = '#be7535ff'
atom_blue = '#197180ff'

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(projection='3d')

atom_symbols = new_sorter.particle_trajectory[example_ts].get_chemical_symbols()
atom_colors = [atom_brown if symbol=='Cu' else atom_blue for symbol in atom_symbols]

ax.scatter(
    example_pos[:, 0], example_pos[:, 1], example_pos[:, 2], c=atom_colors, alpha=1,
    s=2000, edgecolors="k", rasterized=True
)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.set_axis_off()
x_lims = ax.get_xlim()
x_span = x_lims[1] - x_lims[0]
x_limmult = -.15
x_lims = (x_lims[0]-(x_span*x_limmult/2), x_lims[1]+(x_span*x_limmult/2))
ax.set_xlim(x_lims)
ax.set_ylim(x_lims)
ax.set_zlim(x_lims)

fig.savefig("pd_snapshot_marked.pdf", format='pdf', dpi=300, bbox_inches='tight')

plt.show()

# SI Figures

In [None]:
md_draw_rhcounts = md_counts.copy()
md_draw_rherrs = md_errs.copy()
md_draw_pdcounts = pd_md_counts.copy()
md_draw_pderrs = pd_md_errs.copy()

md_all_not_zero = np.logical_or(md_notzero, pd_md_notzero)
md_total_nonzero = np.sum(md_all_not_zero.astype(np.int8))

# add categories to bulk that are not initially in there
bulk_cats = ["9_invalid", "9_invalid_2"]
for ii_cat, cat in enumerate(sorted_cats):
    if cat in bulk_cats:
        print("Bulkified: ", cat)
        md_draw_rhcounts[:, -1] += md_draw_rhcounts[:, ii_cat]
        md_draw_rherrs[:, -1] += md_draw_rherrs[:, ii_cat]
        md_draw_pdcounts[:, -1] += md_draw_pdcounts[:, ii_cat]
        md_draw_pderrs[:, -1] += md_draw_pderrs[:, ii_cat]

        md_draw_rhcounts[:, ii_cat] = 0
        md_draw_rherrs[:, ii_cat] = 0
        md_draw_pdcounts[:, ii_cat] = 0
        md_draw_pderrs[:, ii_cat] = 0

        md_all_not_zero[ii_cat] = False


In [None]:
plt.rcParams.update(plt.rcParamsDefault)
fontsize = 12
if True:
    font = {'family': 'serif', 'size': fontsize}
    plt.rc('font', **font)
    plt.rc('text', usetex=True)
    plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

fig = plt.figure(figsize=(11, 3))

marg = 0.05 # margin between figures

lsw = 1
lmarg = 0.025 # margin between logplots
lth = 1 # total height of logplots
ltw = (lsw - (5*marg))/5
lm = 3 # ratio of upper to lower logplot
luh = ((lth-lmarg)/(lm+1))*lm # height of upper logplot
llh = (lth-lmarg)/(lm+1) # height of lower logplot

rate_axes = []
rate_axes.append(fig.add_axes((0, llh+lmarg, ltw, luh))) # First logplot top
rate_axes.append(fig.add_axes((0, 0, ltw, llh))) # First logplot bottom
rate_axes.append(fig.add_axes((ltw+marg, llh+lmarg, ltw, luh))) # Second logplot top
rate_axes.append(fig.add_axes((ltw+marg, 0, ltw, llh))) # Second logplot bottom

# here is where the colorbar goes
cax = fig.add_axes(((ltw+marg)*2-(0.5*marg), lmarg, ltw/2, lth-lmarg*4))

rate_axes.append(fig.add_axes(((ltw+marg)*3, llh+lmarg, ltw, luh))) # Third logplot top
rate_axes.append(fig.add_axes(((ltw+marg)*3, 0, ltw, llh))) # Third logplot bottom
rate_axes.append(fig.add_axes(((ltw+marg)*4, llh+lmarg, ltw, luh))) # Fourth logplot top
rate_axes.append(fig.add_axes(((ltw+marg)*4, 0, ltw, llh))) # Fifth logplot bottom

plt_ax, legend_elements = draw_rhpd_twoax([rate_axes[0], rate_axes[2]], 
    rh_lines=draw_rhcounts, rh_errs=draw_rherrs, 
    pd_lines=draw_pdcounts, pd_errs=draw_pderrs, 
    all_not_zero=mcmd_all_not_zero, 
    label_color_dict=data_label_color_dict, 
    fontsize=fontsize, ms=ms
)
rate_axes[0] = plt_ax[0]
rate_axes[2] = plt_ax[1]
plt_ax, legend_elements = draw_rhpd_twoax([rate_axes[1], rate_axes[3]], 
    rh_lines=draw_rhcounts, rh_errs=draw_rherrs, 
    pd_lines=draw_pdcounts, pd_errs=draw_pderrs, 
    all_not_zero=mcmd_all_not_zero, 
    label_color_dict=data_label_color_dict, 
    fontsize=fontsize, ms=ms
)
rate_axes[1] = plt_ax[0]
rate_axes[3] = plt_ax[1]

plt_ax, legend_elements = draw_rhpd_twoax([rate_axes[4], rate_axes[6]], 
    rh_lines=md_draw_rhcounts, rh_errs=md_draw_rherrs, 
    pd_lines=md_draw_pdcounts, pd_errs=md_draw_pderrs, 
    all_not_zero=md_all_not_zero, 
    label_color_dict=data_label_color_dict, 
    fontsize=fontsize, ms=ms
)
rate_axes[4] = plt_ax[0]
rate_axes[6] = plt_ax[1]
plt_ax, legend_elements = draw_rhpd_twoax([rate_axes[5], rate_axes[7]], 
    rh_lines=md_draw_rhcounts, rh_errs=md_draw_rherrs, 
    pd_lines=md_draw_pdcounts, pd_errs=md_draw_pderrs, 
    all_not_zero=md_all_not_zero,  
    label_color_dict=data_label_color_dict, 
    fontsize=fontsize, ms=ms
)
rate_axes[5] = plt_ax[0]
rate_axes[7] = plt_ax[1]

bar = fig.colorbar(cm.ScalarMappable(norm=data_cat_norm, cmap=data_cat_cmap), cax=cax)
bar.set_ticks(np.arange(len(draw_cats)))
bar.set_ticklabels(pretty_cats, fontsize=fontsize)

### Beautification of logplots
for large_ind in range(0, 8, 2):
    rate_axes[large_ind].set_ylim([5e-1, 1.15e2])
    rate_axes[large_ind].set_xlabel('', visible=False)
    rate_axes[large_ind].set_title('', visible=False)
    rate_axes[large_ind].set_xticklabels('', visible=False)

for small_ind in range(1, 8, 2):
    rate_axes[small_ind].set_ylim([1e-4, 1e-1])
    rate_axes[small_ind].set_title('', visible=False)
    rate_axes[small_ind].set_xlabel('Temperature (K)', fontsize=fontsize)

for rax in rate_axes:
    rax.grid(axis='y', which='major', visible=True, color='k', alpha=0.8)
    rax.grid(axis='y', which='minor', visible=True, alpha=0.6, linestyle='--')

for right_ind in range(4, 8):
    rate_axes[right_ind].yaxis.tick_right()
    rate_axes[right_ind].yaxis.set_label_position("right")

for mid_ind in range(2, 6):
    rate_axes[mid_ind].set_yticklabels([], visible=False)

rate_axes[0].set_title('Rhodium MCMD', visible=True)
rate_axes[2].set_title('Palladium MCMD', visible=True)
rate_axes[4].set_title('Rhodium MC', visible=True)
rate_axes[6].set_title('Palladium MC', visible=True)

fig.savefig("sifig1_allclass.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()