# Element Similarity

This notebook is used to reproduce the plots shown in the paper.

In [None]:
# Imports
from AtomicEmbeddings.core import Embedding
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple
from AtomicEmbeddings.core import Embedding, data_directory
from AtomicEmbeddings.plotter import multi_heatmap_plotter
import pandas as pd
import os
import seaborn as sns

sns.set(font_scale=1.5)

## Introduction

Let's set up the Embedding classes and load the data

In [None]:
# Load the embeddings
cbfvs = [
    "magpie",
    "mat2vec",
    "megnet16",
    "oliynyk",
    "mod_petti",
    "random_200",
    "skipatom",
]
element_embedddings = {cbfv: Embedding.load_data(cbfv) for cbfv in cbfvs}

We can reproduce some of the information in table I from the paper by running the following code:

In [None]:
# Let's find the dimensionality of all of the CBFVs that we have loaded

element_embedddings_dim = {cbfv: [element_embedddings[cbfv].dim] for cbfv in cbfvs}

dim_df = pd.DataFrame.from_dict(
    element_embedddings_dim, orient="index", columns=["dimension"]
)
print(dim_df)

## II.B Similarity measures

Let's set up the Embedding classes for our analysis

In [None]:
# Get our four embeddings to compare
cbfvs_to_keep = ["magpie", "mat2vec", "megnet16", "random_200"]
element_vectors = {cbfv: element_embedddings[cbfv] for cbfv in cbfvs_to_keep}

# Keep the first 83 elements

# Get the ordered symbols file
symbols_path = os.path.join(data_directory, "element_data", "ordered_periodic.txt")
with open(symbols_path) as f:
    symbols = f.read().splitlines()

# Get the first 83 elements
symbols = symbols[:83]

for cbfv in cbfvs_to_keep:
    # Get the keys of the atomic embeddings object
    elements = set(element_vectors[cbfv].element_list)
    el_symbols_set = set(symbols)

    # Get the element symbols we want to remove
    els_to_remove = list(elements - el_symbols_set)

    # Iteratively delete the elements with atomic number
    # greater than 83 from our embeddings
    for el in els_to_remove:
        del element_vectors[cbfv].embeddings[el]

    # Verify that we have 83 elements
    print(len(element_vectors[cbfv].element_list))

In [None]:
def heatmap_plotter(
    embedding: Embedding,
    metric: str,
    distance: bool = True,
    correlation: bool = False,
    figsize: Tuple[int, int] = (36, 24),
    ax: Optional[plt.axes] = None,
    filename: Optional[str] = None,
    show: bool = True,
    **kwargs,
):
    """
    Plot a heatmap of the embedding.

    Parameters
    ----------
    embedding : Embedding
        The embedding to be plotted.
    metric : str
        The distance/correlation metric to be used.
    distance : bool, optional
        Whether to plot a distance heatmap, by default True
    correlation : bool, optional
        Whether to plot a correlation heatmap, by default False
    figsize : Tuple[int,int], optional
        The size of the figure, by default (36, 24)
    filename : Optional[str], optional
        The filename to save the figure to, by default None
    show : bool, optional
        Whether to show the figure, by default True
    **kwargs
        Additional keyword arguments to pass to seaborn.heatmap

    """
    if not ax:
        fig, ax = plt.subplots(figsize=figsize)
    if correlation:
        pivot = embedding.pearson_pivot_table()

    elif distance:
        pivot = embedding.distance_pivot_table(metric=metric)
    xlabels = [i[1] for i in pivot.index]
    ylabels = [i[1] for i in pivot.columns]
    sns.heatmap(
        pivot,
        cmap="bwr",
        square="True",
        linecolor="k",
        ax=ax,
        cbar_kws={"shrink": 0.5},
        xticklabels=True,
        yticklabels=True,
        **kwargs,
    )
    ax.title.set_text(embedding.embedding_name)
    ax.set_xticklabels(
        xlabels,
    )
    ax.set_yticklabels(ylabels)
    ax.set_xlabel("")
    ax.set_ylabel("")

In [None]:
def heatmap_plotter2(
    embedding: Embedding,
    metric: str,
    sortaxisby: str = "mendeleev",
    filename: Optional[str] = None,
    ax: Optional[plt.axes] = None,
    show_axislabels: bool = True,
    show_plot: bool = True,
    **kwargs,
):
    """
    Plot multiple heatmaps of the embeddings.

    Parameters
    ----------
    embeddings : List[Embedding]
        The embeddings to be plotted.
    nrows : int
        The number of rows in the figure.
    ncols : int
        The number of columns in the figure.
    metric : bool, optional
        Whether to plot a metric distance heatmap, by default False
    sortaxisby : str, optional
        The attribute to sort the axis by, by default "mendeleev_number".
        Options are "mendeleev_number", "atomic_number"
    figsize : Tuple[int,int], optional
        The size of the figure, by default (36, 36)
    filename : Optional[str], optional
        The filename to save the figure to, by default None
    show_axislabels : bool, optional
        Whether to show the axis, by default True
    show_plot : bool, optional
        Whether to show the figure, by default True
    **kwargs
        Additional keyword arguments to pass to seaborn.heatmap

    """
    if not ax:
        fig, ax = plt.subplots()

    correlation_metrics = ["spearman", "pearson", "cosine_similarity"]
    distance_metrics = [
        "euclidean",
        "manhattan",
        "cosine_distance",
        "chebyshev",
        "wasserstein",
        "energy",
    ]
    if metric in correlation_metrics:
        p = embedding.pearson_pivot_table()

    elif metric in distance_metrics:
        p = embedding.distance_pivot_table(metric=metric, sortby=sortaxisby)
    xlabels = [i[1] for i in p.index]
    ylabels = [i[1] for i in p.columns]
    sns.heatmap(
        p,
        cmap="bwr",
        square="True",
        linecolor="k",
        ax=ax,
        cbar_kws={
            "shrink": 0.5,
        },
        xticklabels=True,
        yticklabels=True,
        **kwargs,
    )
    ax.set_title(
        embedding.embedding_name,
        fontdict={
            # "fontsize": 30,
            "fontweight": "bold"
        },
    )
    if not show_axislabels:
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xticklabels(
            xlabels,
        )
        ax.set_yticklabels(ylabels)
    ax.set_xlabel("")
    ax.set_ylabel("")

In [None]:
fig, (axes) = plt.subplots(2, 2, figsize=(10, 10))

heatmap_plotter2(
    embedding=element_vectors["magpie"],
    metric="euclidean",
    sortaxisby="atomic_number",
    show_axislabels=False,
    ax=axes[0, 0],
)
heatmap_plotter2(
    embedding=element_vectors["megnet16"],
    metric="euclidean",
    sortaxisby="atomic_number",
    show_axislabels=False,
    ax=axes[0, 1],
)
heatmap_plotter2(
    embedding=element_vectors["mat2vec"],
    metric="euclidean",
    sortaxisby="atomic_number",
    show_axislabels=False,
    ax=axes[1, 0],
)

heatmap_plotter2(
    embedding=element_vectors["random_200"],
    metric="euclidean",
    sortaxisby="atomic_number",
    show_axislabels=False,
    ax=axes[1, 1],
)
plt.tight_layout()
# plt.savefig("test.pdf")
plt.show()

# Euclidean distances


\begin{equation}
d_E(\textbf{A,B}) = 
\sqrt{
(A_1 - B_1)^2 
+ \cdots
+ (A_n - B_n)^2 }
\end{equation}

We can use the Euclidean distance to compare the similarity of two elements. The following code will plot the distribution of the Euclidean distances between all pairs of elements in the embedding space.

In [None]:
multi_heatmap_plotter(
    element_vectors.values(),
    nrows=2,
    ncols=2,
    metric="euclidean",
    sortaxisby="atomic_number",
    show_axislabels=False,
    show_plot=True,
    figsize=(10, 10),
    # filename="1_euclidean.pdf",
)

### Manhattan distances

\begin{equation}
d_M(\textbf{A,B}) = 
\sum_{i=1}^n |A_i - B_i|
\end{equation}

We can use the Manhattan distance to compare the similarity of two elements. The following code will plot the distribution of the Manhattan distances between all pairs of elements in the embedding space.


In [None]:
multi_heatmap_plotter(
    element_vectors.values(),
    nrows=2,
    ncols=2,
    metric="manhattan",
    sortaxisby="atomic_number",
    show_axislabels=False,
    show_plot=True,
    figsize=(10, 10),
    # filename="2_manhattan.pdf",
)