In [None]:
from elementembeddings.core import SpeciesEmbedding
from elementembeddings.plotter import dimension_plotter, heatmap_plotter
import seaborn as sns
import pandas as pd
from elementembeddings.utils.species import get_sign, parse_species
from pymatviz.ptable import ptable_heatmap
import matplotlib.pyplot as plt
import smact

In [None]:
# The SkipSpecies file contains some neutral species so a function is needed to remove them
# Year_final/skipspecies/data/skipspecies_2022_10_28/csv_files/induced/skipspecies_2022_10_28_induced_dim200.csv
skipspecies200induced_path = "../data/skipspecies_2022_10_28/csv_files/induced/skipspecies_2022_10_28_induced_dim200.csv"
spec = SpeciesEmbedding.from_csv(skipspecies200induced_path)


def remove_neutral_species(embedding):
    neutral_species = [s for s in embedding.species_list if "0" in s]
    return embedding.remove_species(neutral_species)


spec = remove_neutral_species(spec)

In [None]:
# SMACT is used to get the elements
# A dictionary for element_symbol:[oxidation_states] is created
d = {}
for el in smact.element_dictionary().keys():
    try:
        d[el] = spec.get_element_oxi_states(el)
    except:
        pass
# Convert dictionary to dictionary of element: number of oxidation states
d_num = {k: len(v) for k, v in d.items()}

# Pymatviz is used to make a periodic table heatmap
ptable_heatmap(
    d_num,
    cbar_title="Number of oxidation states",
    # text_color="black",
    colorscale="Blues",
)
plt.savefig("../plots/num_oxi_states_periodic_table.pdf", dpi=600)

In [None]:
# Load the species occurrence data
spec_occur_df = pd.read_csv("species_occurence.csv")


# List for top 100 species
top100_list = (
    spec_occur_df.sort_values(by="Count", ascending=False).head(100)["Species"].tolist()
)


# Plot the t-SNE and UMAP plots of the embedding
sns.set_theme(context="paper", style="ticks", font_scale=1.3)
fig, axes = plt.subplots(1, 2, figsize=(16, 8), dpi=600)
scatter_params = {"s": 80}
tsne_reducer_params = {"random_state": 42, "perplexity": 10}
umap_reducer_params = {"random_state": 42}
dimension_plotter(
    spec,
    reducer="umap",
    adjusttext=1,
    include_species=top100_list,
    reducer_params=umap_reducer_params,
    ax=axes[0],
    scatter_params=scatter_params,
)
dimension_plotter(
    spec,
    reducer="tsne",
    adjusttext=1,
    include_species=top100_list,
    reducer_params=tsne_reducer_params,
    ax=axes[1],
    scatter_params=scatter_params,
)
for ax in axes.flatten():
    ax.set_title("")
    ax.legend("")
fig.text(0.1, 0.82, "(a)", weight="bold")
fig.text(0.52, 0.82, "(b)", weight="bold")

handles, labels = ax.get_legend_handles_labels()
legend = axes[0].legend(
    handles=handles[1:10] + handles[11:],
    labels=labels[1:10] + labels[11:],
    loc="upper center",
    bbox_to_anchor=(1, 1.2),
    ncols=6,
    frameon=False,
)
plt.subplots_adjust(top=0.8)
# plt.tight_layout()
plt.savefig(
    "../plots/tSNE_UMAP_skipspecies_2022_10_28_dim200_top100.pdf",
    dpi=600,
    bbox_inches="tight",
    transparent=True,
)
plt.show()