In [None]:
import pandas as pd
import os
import json
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

sns.set_theme(context="paper", style="ticks", font_scale=1.3)
pd.options.display.float_format = "{:.4f}".format

In [None]:
skipatom_sumdir = "../out_final/species_prop/skipatom/induced/summary"
skipspecies_induced_sumdir = "../out_final/species_prop/skipspecies/induced/summary"
skipspecies_sumdir = "../out_final/species_prop/skipspecies/non_induced/summary"

dir_list = [skipatom_sumdir, skipspecies_induced_sumdir, skipspecies_sumdir]


def json_loader(path: str) -> dict:
    with open(path) as f:
        return json.load(f)


results = []

for d in dir_list:
    summaries = [s for s in os.listdir(d) if s.endswith(".json")]

    [results.append(json_loader(f"{d}/{s}")) for s in summaries]

In [None]:
df = pd.DataFrame(results)
df.head()

In [None]:
mapper = {
    "skipatominduced": "SkipAtom induced",
    "skipspecies": "SkipSpecies",
    "skipspecies_induced": "SkipSpecies induced",
}
df["r"] = df["representation"].map(mapper)
df.head()

In [None]:
def heatmap_summary(df, pool, classification, ax=None, reg_task=None):
    mapper = {
        "skipatominduced": "SkipAtom\nInduced",
        "skipspecies": "SkipSpecies",
        "skipspecies_induced": "SkipSpecies\nInduced",
    }
    df["r"] = df["representation"].map(mapper)

    d = df.query(f"pooling == '{pool}'").reset_index(drop=True)
    if not ax:
        fig, ax = plt.subplots(figsize=(12, 8))
    if classification:
        pivot = d.pivot_table(
            index="r", columns="dimension", values=["mean_auc", "std_auc"]
        )
        # annots = pivot.apply(lambda x: f'{ufloat(x[["mean_auc","std_auc"]]):.1uS}', axis=1)
        sns.heatmap(
            pivot["mean_auc"],
            annot=pivot["mean_auc"].applymap(lambda x: f"{x:.3f}"),
            # annot=pivot["mean_auc"].applymap(lambda x:f"{x:.3f}")+'\n±'+pivot["std_auc"].applymap(lambda x:f"{x:.3f}"),
            # annot=pivot["u_auc"].apply_map(lambda x: f"{x:.1uS}"),
            fmt="",
            ax=ax,
            cmap="Blues",
            cbar_kws={
                "label": "Mean AUC",  # 'fontsize':13
            },
            annot_kws={"size": 13},
        )
    else:
        pivot = d.pivot_table(
            index="r", columns="dimension", values=["mean_mae", "std_mae"]
        )
        sns.heatmap(
            pivot["mean_mae"],
            annot=pivot["mean_mae"].applymap(lambda x: f"{x:.3f}"),
            # annot=pivot["mean_mae"].applymap(lambda x:f"{x:.3f}")+'\n±'+pivot["std_mae"].applymap(lambda x:f"{x:.3f}"),
            fmt="",
            ax=ax,
            cmap="Blues_r",
            cbar_kws={
                "label": "Mean MAE [eV]"
                if reg_task.lower() == "band gap"
                else "Mean MAE [eV/atom]",
                # 'fontsize':13
            },
            annot_kws={"size": 13},
        )

In [None]:
def df_to_latex(df: pd.DataFrame):
    # Sort the dataframe
    dframe = df.sort_values(by=["dimension", "representation", "pooling"])
    # Get the task name
    task = df.iloc[0]["task"]

    # Remove irrelevant columns
    cols_clf = ["representation", "pooling", "dimension", "mean_auc", "std_auc"]
    cols_reg = ["representation", "pooling", "dimension", "mean_mae", "std_mae"]
    if df.iloc[0].classification:
        dframe["AUC"] = dframe[
            ["mean_auc", "std_auc"].apply(
                lambda x: "±".join(round(x, 3).astype(str)), 1
            )
        ]
        dframe = dframe[["representation", "pooling", "dimension", "AUC"]]
    else:
        dframe["MAE"] = dframe[["mean_mae", "std_mae"]].apply(
            lambda x: "±".join(round(x, 3).astype(str)), 1
        )
        dframe = dframe[["representation", "pooling", "dimension", "MAE"]]
    latex_table = dframe.to_latex(index=False, escape=False)

    return dframe, latex_table

In [None]:
# Get task specific dataframes
df_band_gap = df.query('task == "band_gap"').reset_index(drop=True)
df_formation_energy_per_atom = df.query(
    'task == "formation_energy_per_atom"'
).reset_index(drop=True)
df_is_metal = df.query('task == "is_metal"').reset_index(drop=True)
df_is_magnetic = df.query('task == "is_magnetic"').reset_index(drop=True)

In [None]:
dfs = [df_band_gap, df_formation_energy_per_atom, df_is_metal, df_is_magnetic]
tasks = ["Band gap", "$E_{form}$", "Metallic classification", "Magnetic classification"]
fig, axes = plt.subplots(2, 2, figsize=(12.5, 9))
for df, task, ax in zip(dfs, tasks, axes.flatten()):
    pool = "sum"
    # ax=axes
    if task in ["Metallic classification", "Magnetic classification"]:
        heatmap_summary(df, pool=pool, classification=True, ax=ax)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)
    else:
        heatmap_summary(df, pool=pool, classification=False, ax=ax, reg_task=task)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)

    ax.set_xlabel("Dimension", fontsize=13)
    ax.set_ylabel("Representation", fontsize=13)
fig.text(0.01, 0.99, "(a)", weight="bold")
fig.text(0.51, 0.99, "(b)", weight="bold")
fig.text(0.01, 0.49, "(c)", weight="bold")
fig.text(0.51, 0.49, "(d)", weight="bold")
plt.tight_layout()
plt.savefig(
    "../plots/Sumpool_heatmap_publication.pdf",
    bbox_inches="tight",
    transparent=True,
    dpi=600,
)
plt.show()

In [None]:
dfs = [df_band_gap, df_formation_energy_per_atom, df_is_metal, df_is_magnetic]
tasks = ["Band gap", "$E_{form}$", "Metallic classification", "Magnetic classification"]
fig, axes = plt.subplots(2, 2, figsize=(12.5, 9))
for df, task, ax in zip(dfs, tasks, axes.flatten()):
    pool = "max"
    # ax=axes
    if task in ["Metallic classification", "Magnetic classification"]:
        heatmap_summary(df, pool=pool, classification=True, ax=ax)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)
    else:
        heatmap_summary(df, pool=pool, classification=False, ax=ax, reg_task=task)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)

    ax.set_xlabel("Dimension", fontsize=13)
    ax.set_ylabel("Representation", fontsize=13)

plt.tight_layout()
plt.savefig(
    "../plots/maxpool_heatmap.pdf", bbox_inches="tight", transparent=True, dpi=300
)
plt.show()

In [None]:
dfs = [df_band_gap, df_formation_energy_per_atom, df_is_metal, df_is_magnetic]
tasks = ["Band gap", "$E_{form}$", "Metallic classification", "Magnetic classification"]
fig, axes = plt.subplots(2, 2, figsize=(12.5, 9))
for df, task, ax in zip(dfs, tasks, axes.flatten()):
    pool = "mean"
    # ax=axes
    if task in ["Metallic classification", "Magnetic classification"]:
        heatmap_summary(df, pool=pool, classification=True, ax=ax)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)
    else:
        heatmap_summary(df, pool=pool, classification=False, ax=ax, reg_task=task)
        ax.set_title(f"{task} {pool}-pooling", fontsize=13)

    ax.set_xlabel("Dimension", fontsize=13)
    ax.set_ylabel("Representation", fontsize=13)

plt.tight_layout()
plt.savefig(
    "../plots/meanpool_heatmap.pdf", bbox_inches="tight", transparent=True, dpi=300
)
plt.show()

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(16, 16))
for ax, pool in zip(axes.flatten(), ["max", "sum", "mean"]):
    heatmap_summary(df_is_metal, pool=pool, classification=True, ax=ax)
    ax.set_title(f"Metal classifications {pool}-pooling")

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(16, 16))
for ax, pool in zip(axes.flatten(), ["max", "sum", "mean"]):
    heatmap_summary(df_is_metal, pool=pool, classification=True, ax=ax)
    ax.set_title(f"$E_{{f}}$ predictions {pool}-pooling")