In [None]:
# Imports
import seaborn as sns
import psutil
import multiprocessing as mp
from pymatgen.analysis.structure_matcher import StructureMatcher
import smact
import pandas as pd
from pymatgen.core import Composition
import matplotlib.pyplot as plt
import re
from pymatgen.ext.matproj import MPRester
import os
from pymatgen.analysis.local_env import CrystalNN
from typing import Tuple
from monty.serialization import loadfn, dumpfn


def parse_species(species: str) -> Tuple[str, int]:
    """
    Parses a species string into its atomic symbol and oxidation state.

    :param species: the species string
    :return: a tuple of the atomic symbol and oxidation state

    """

    ele = re.match(r"[A-Za-z]+", species).group(0)

    charge_match = re.search(r"\d+", species)
    ox_state = int(charge_match.group(0)) if charge_match else 0

    if "-" in species:
        ox_state *= -1

    # Handle cases of X+ or X- (instead of X1+ or X1-)
    if "+" in species and ox_state == 0:
        ox_state = 1

    if ox_state == 0 and "-" in species:
        ox_state = -1
    return ele, ox_state


API_KEY = os.environ.get("MP_API_KEY")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def co_ordination_num(structure):
    cnn = CrystalNN()
    cn = cnn.get_cn(structure, 0)
    return cn

## CSP

We aim to test the embeddings for the task of predicting the structure of binary compounds in four different structure types:
* CsCl
* Rocksalt
* Zincblende
* Wurtzite

Let's first get example structures for each of these types from the Materials Project.

For CsCl, we can use the structure of CsCl (mp-22865, ) itself. For rocksalt, we can use the structure of NaCl (mp-22862, ). For zincblende, we can use the structure of ZnS (mp-10695, ). For wurtzite, we can use the structure of ZnS (mp-560588, ).

In [None]:
with MPRester(API_KEY) as mpr:
    cscl_struct = mpr.get_structure_by_material_id("mp-22865")
    rock_salt_struct = mpr.get_structure_by_material_id("mp-22862")
    zinc_blende_struct = mpr.get_structure_by_material_id("mp-10695")
    wurtzite_struct = mpr.get_structure_by_material_id("mp-560588")

## Binary materials dataset

We will obtain a dataset of binary, non-metal materials from the Materials Project. We will use the pymatgen library to obtain the structures of these materials.

In [None]:
with MPRester(API_KEY) as mpr:
    docs = mpr.summary.search(
        # num_elements=2,
        formula="*1*1",
        theoretical=False,
        is_metal=False,
        fields=[
            "material_id",
            "formula_pretty",
            "spacegroup.symbol",
            "crystal_system",
            "energy_above_hull",
            "database_IDs",
            "possible_species",
            "structure",
        ],
    )

docs_list = [doc.dict(exclude={"fields_not_requested"}) for doc in docs]
df = pd.DataFrame(docs_list)
# dumpfn(df, 'df.json')
df.head()

In [None]:
df.formula_pretty.value_counts()

In [None]:
SM = StructureMatcher(attempt_supercell=True)


def determine_structure_type(structure):
    if SM.fit_anonymous(structure, cscl_struct):
        return "cscl"
    elif SM.fit_anonymous(structure, rock_salt_struct):
        return "rock salt"
    elif SM.fit_anonymous(structure, zinc_blende_struct):
        return "zinc blende"
    elif SM.fit_anonymous(structure, wurtzite_struct):
        return "wurtzite"
    else:
        return "other"

In [None]:
with mp.Pool(processes=psutil.cpu_count()) as pool:
    df["structure_type"] = pool.map(determine_structure_type, df.structure)

In [None]:
dumpfn(df, "df_structure_types.json")

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
g = sns.barplot(
    x=df.structure_type.value_counts().index, y=df.structure_type.value_counts(), ax=ax
)
ax.set_xlabel("Structure Type")
ax.set_ylabel("Count")
ax.set_title("Structure Type Distribution")
ax.bar_label(g.containers[0], fmt="%d")
fig.show()

In [None]:
df2 = df.copy()
df2 = df2.sort_values(by=["formula_pretty", "energy_above_hull"])
df2.reset_index(inplace=True, drop=True)
df2.drop_duplicates(subset=["formula_pretty"], keep="first", inplace=True)
df2.reset_index(inplace=True, drop=True)
print(
    f"By removing the duplicate formulae, we have reduced the dataset from {len(df)} to {len(df2)} entries."
)

structure_df = df2[df2.structure_type != "other"]
structure_df.reset_index(inplace=True, drop=True)
print(
    f"By removing the 'other' structure types, we have reduced the dataset from {len(df2)} to {len(structure_df)} entries."
)
structure_df.head()

In [None]:
structure_df = structure_df[
    structure_df.possible_species.map(lambda x: len(x)) > 0
].reset_index(drop=True)
print(len(structure_df))
structure_df.head()

In [None]:
def get_cation(species_list):
    cation_list = []
    for species in species_list:
        if "+" in species:
            cation_list.append(parse_species(species))
    return cation_list[0]


def get_anion(species_list):
    anion_list = []
    for species in species_list:
        if "-" in species:
            anion_list.append(parse_species(species))
    return anion_list[0]


structure_df["cation"] = structure_df.possible_species.apply(get_cation)
structure_df["anion"] = structure_df.possible_species.apply(get_anion)


structure_df.head()

In [None]:
# Calculate the coordination number using CrystalNN
structure_df["cn"] = structure_df.structure.apply(co_ordination_num)


# Calculate the radius ratio
def radius_ratio(cation, anion):
    cat = smact.Species(cation[0], cation[1])
    an = smact.Species(anion[0], anion[1])

    radius_ratio = cat.average_ionic_radius / an.average_ionic_radius
    return radius_ratio


structure_df["radius_ratio"] = structure_df.apply(
    lambda x: radius_ratio(x["cation"], x["anion"]), axis=1
)
dumpfn(structure_df, fn="df_final_structure.json")
structure_df.head()

In [None]:
# Remove compositions with no radius ratio
rr_df = structure_df[structure_df.radius_ratio.notnull()].reset_index(drop=True)
print(
    f"Number of unique compositions with guessible oxidation states and radius ratio: {len(rr_df)}"
)
rr_df.head()

In [None]:
def rr_predict_cn(x):
    if x <= 0.155:
        return 2
    elif 0.155 < x <= 0.225:
        return 3
    elif 0.225 < x <= 0.414:
        return 4
    elif 0.414 < x <= 0.732:
        return 6
    elif 0.732 < x <= 1.0:
        return 8
    else:
        return 12


rr_df["rr_predict_cn"] = rr_df.radius_ratio.apply(rr_predict_cn)


def compare_cn(x):
    if x["cn"] == x["rr_predict_cn"]:
        return True
    else:
        return False


rr_df["compare_cn"] = rr_df.apply(compare_cn, axis=1)
rr_df.head()

In [None]:
rr_df.compare_cn.value_counts(normalize=True)
fig, ax = plt.subplots()
a = ax.bar(
    rr_df.compare_cn.value_counts().index,
    rr_df.compare_cn.value_counts(normalize=True) * 100,
)
ax.set_xticks([0, 1])
ax.set_xticklabels(["False", "True"])
ax.set_ylabel("Percentage of compositions")
ax.set_xlabel("Comparison of coordination number")
ax.bar_label(a, fmt="%.1f")
fig.show()

In [None]:
d = {
    "Radius Ratio Upper": [1.0, 1.0, 0.717, 0.717, 0.326, 0.225],
    "Radius Ratio Lower": [1.0, 0.717, 0.326, 0.326, 0.225, 0.155],
    "Coordination Number": [12, 8, 6, 4, 4, 3],
    "Coordination": [
        "CCP/HCCP",
        "Cubic",
        "Octahedral",
        "Square Planar",
        "Tetrahedral",
        "Triangular",
    ],
}

RR_table = pd.DataFrame(data=d)
RR_table