In [None]:
from typing import Dict
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from monty.serialization import loadfn, dumpfn
import os
import pandas as pd
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogFormatter

In [None]:
data_path = "../data/oxi-mp_property_dataset.json.gz"

In [None]:
df = loadfn(data_path)
df.head()

In [None]:
def num_els(comp_dict):
    return len(comp_dict.keys())


df["num_els"] = df.composition.apply(num_els)
df["num_sp"] = df.oxi_composition.apply(num_els)
df.head()

In [None]:
df.query("num_els==4 & num_sp==6")

In [None]:
df["oxi_composition"][110018]

In [None]:
sns.set_theme(
    context="paper",
    style="ticks",
    font_scale=1.3,
)
fig, axes = plt.subplots(1, 2, figsize=(12, 0.7 * 9), sharey=True)

for ax, dat in zip(axes.flatten(), ["num_els", "num_sp"]):
    sns.countplot(df, x=dat, ax=ax)
    ax.bar_label(ax.containers[0])

# sns.scatterplot(df,x="num_els",y="num_sp",alpha=0.4,s=100,ax=axes[2])

axes[0].set_xlabel("Number of elements")
axes[1].set_xlabel("Number of species")
axes[0].set_ylabel("Number of compounds")
axes[1].set_ylabel("")
plt.tight_layout()
plt.savefig(
    "mp_dataset_distribution.pdf", dpi=300, bbox_inches="tight", transparent=True
)
plt.show()

In [None]:
# Create a layout with subplot2grid
fig = plt.figure(figsize=(12, 10))

# Define the grid for subplots
ax1 = plt.subplot2grid((2, 4), (0, 0), colspan=2)
ax2 = plt.subplot2grid((2, 4), (0, 2), colspan=2, sharey=ax1)
ax3 = plt.subplot2grid((2, 4), (1, 1), colspan=2)
axes = [ax1, ax2, ax3]
# Use Seaborn to enhance the plots
sns.scatterplot(df, x="num_els", y="num_sp", alpha=0.4, s=100, ax=ax3, rasterized=True)
for ax, dat in zip([ax1, ax2], ["num_els", "num_sp"]):
    sns.countplot(df, x=dat, ax=ax)
    ax.bar_label(ax.containers[0])

# Add labels, legends, etc. as needed
axes[0].set_xlabel("Number of elements")
axes[1].set_xlabel("Number of species")
axes[0].set_ylabel("Number of compounds")
axes[1].set_ylabel("")

axes[2].set_xlabel("Number of elements")
axes[2].set_ylabel("Number of species")
axes[2].set_yticks(range(0, 11))
axes[2].set_xticks(range(0, 10))
axes[2].plot(
    range(0, 11),
    range(0, 11),
    "k:",
    label="$N_{elements}=N_{species}$",
)
axes[2].legend()
fig.text(0.05, 1, "(a)", weight="bold")
fig.text(0.52, 1, "(b)", weight="bold")
fig.text(0.26, 0.46, "(c)", weight="bold")
# Adjust spacing
# plt.subplots_adjust(hspace=0.5)
plt.tight_layout()
# Show the plot
plt.savefig(
    "MP_skip_training_dataset.pdf", dpi=300, bbox_inches="tight", transparent=True
)
plt.show()

In [None]:
# Create a layout with subplot2grid
fig = plt.figure(figsize=(1.2 * 12, 1.2 * 10))

# Define the grid for subplots
ax1 = plt.subplot2grid((2, 4), (0, 0), colspan=2)
ax2 = plt.subplot2grid((2, 4), (0, 2), colspan=2, sharey=ax1)
ax3 = plt.subplot2grid((2, 4), (1, 1), colspan=2)
axes = [ax1, ax2, ax3]
# Use Seaborn to enhance the plots
for ax, dat in zip([ax1, ax2], ["num_els", "num_sp"]):
    sns.countplot(df, x=dat, ax=ax)
    ax.bar_label(ax.containers[0])

sns.heatmap(
    pivot,
    annot=True,
    cmap="Blues",
    fmt=".0f",
    robust=True,
    linewidth=0.5,
    cbar_kws={"label": "Number of structures"},
    norm=LogNorm(),
    ax=ax3,
)
cbar = ax3.collections[0].colorbar
cbar.set_ticks([1, 10, 100, 1000, 10000])  # Customize ticks based on your data
cbar.set_ticklabels(
    ["1", "10", "100", "1000", "10000"]
)  # Customize labels based on your data


# Add labels, legends, etc. as needed
axes[0].set_xlabel("Number of elements")
axes[1].set_xlabel("Number of species")
axes[0].set_ylabel("Number of compounds")
axes[1].set_ylabel("")

axes[2].set_xlabel("Number of elements")
axes[2].set_ylabel("Number of species")
axes[2].tick_params(axis="y", rotation=0)
fig.text(0.05, 1, "(a)", weight="bold")
fig.text(0.52, 1, "(b)", weight="bold")
fig.text(0.26, 0.46, "(c)", weight="bold")
# Adjust spacing
# plt.subplots_adjust(hspace=0.5)
plt.tight_layout()
# Show the plot
plt.savefig(
    "MP_skip_training_dataset_alt.pdf", dpi=300, bbox_inches="tight", transparent=True
)
plt.show()

In [None]:
pivot = df.groupby(["num_sp", "num_els"]).size().unstack().fillna(0)
pivot.head()

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
sns.heatmap(
    pivot,
    annot=True,
    cmap="Blues",
    fmt=".0f",
    robust=True,
    linewidth=0.5,
    cbar_kws={"label": "Number of structures"},
    norm=LogNorm(),
    ax=ax,
)
cbar = ax.collections[0].colorbar
cbar.set_ticks([1, 10, 100, 1000, 10000])  # Customize ticks based on your data
cbar.set_ticklabels(
    ["1", "10", "100", "1000", "10000"]
)  # Customize labels based on your data
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

In [None]:
print(f"Data set size: {df.shape[0]}")
print(f"Unique formula: {df.formula_pretty.nunique()}")
unique_dicts_count_comp = len(df["composition"].apply(str).unique())

unique_dicts_count_oxi_comp = len(df["oxi_composition"].apply(str).unique())
print(f"Unique elemental compositions: {unique_dicts_count_comp}")

print(f"Unique ionic compositions: {unique_dicts_count_oxi_comp}")

In [None]:
# Normalise compositions


def _get_fractional_composition(el_dict: str) -> Dict[str, float]:
    elamt = {}
    natoms = 0
    for el, v in el_dict.items():
        elamt[el] = v
        natoms += abs(v)
    return {el: elamt[el] / natoms for el in elamt}


_get_fractional_composition(df["oxi_composition"][0])

In [None]:
len(df["oxi_composition"].apply(_get_fractional_composition).apply(str).unique())

In [None]:
df["oxi_composition"].apply(_get_fractional_composition).apply(str).value_counts()

In [None]:
unique_df = df.sort_values(
    by=["formula_pretty", "formation_energy_per_atom"]
).drop_duplicates(subset="formula_pretty", ignore_index=True)
print(f"Data size with only lowest energy polymorphs: {len(unique_df)}")
unique_df.head()

In [None]:
unique_df.is_magnetic.sum()

In [None]:
unique_df.band_gap.mean()

In [None]:
unique_df.query("band_gap > 0").plot()

In [None]:
"is_metal".replace("_", " ").capitalize()

In [None]:
props = ["formation_energy_per_atom", "band_gap", "is_metal", "is_magnetic"]
class_props = ["is_metal", "is_magnetic"]
units = {"formation_energy_per_atom": "eV/atom", "band_gap": "eV"}
fig, axes = plt.subplots(2, 2, figsize=(12, 9))

for prop, ax in zip(props, axes.flatten()):
    if prop in class_props:
        sns.countplot(unique_df, x=prop, ax=ax)
        ax.set_xlabel(prop.replace("_", " ").capitalize())
    else:
        sns.histplot(unique_df, x=prop, ax=ax, rasterized=True)
        ax.set_xlabel(prop.replace("_", " ").capitalize() + f" [{units[prop]}]")
        ax.set_yscale("log")
    # ax.set_xlabel(prop.replace("_", " ").capitalize())
    ax.set_ylabel("Number of compounds")

fig.text(0.05, 1, "(a)", weight="bold")
fig.text(0.52, 1, "(b)", weight="bold")
fig.text(0.05, 0.52, "(c)", weight="bold")
fig.text(0.52, 0.52, "(d)", weight="bold")

plt.tight_layout()
plt.savefig("Property_dataset.pdf", dpi=300, bbox_inches="tight", transparent=True)
plt.show()

In [None]:
# Export the unique formula dataset
dumpfn(unique_df, fn="../data/oxi-mp_property_dataset_unique_formulas.json.gz")