In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(context="paper", style="ticks", font_scale=1.3)

In [None]:
species_induced_dir = "../out_final/species_prop/skipspecies/induced/results"
species_dir = "../out_final/species_prop/skipspecies/non_induced/results"
atom_dir = "../out_final/species_prop/skipatom/induced/results"

In [None]:
def get_task(path, rep, pool, task, dim):
    files = [f for f in os.listdir(path) if f.endswith(".csv")]
    fname = f"{rep}_{pool}_dim{dim}_MP_{task}"
    task_csv = sorted([f for f in files if fname in f])

    return task_csv


def plot_losses(df, ax=None, classification=False, label=None):
    if classification:
        avg_val = df.groupby("epoch")["val_auc"].mean().reset_index()
        std_val = df.groupby("epoch")["val_auc"].std().reset_index()

    else:
        avg_val = df.groupby("epoch")["val_mae"].mean().reset_index()
        std_val = df.groupby("epoch")["val_mae"].std().reset_index()

    if not ax:
        fig, ax = plt.subplots()
    if classification:
        ax.plot(avg_val["epoch"], avg_val["val_auc"], label=label)
        ax.fill_between(
            avg_val["epoch"],
            avg_val["val_auc"] - std_val["val_auc"],
            avg_val["val_auc"] + std_val["val_auc"],
            alpha=0.3,
        )
    else:
        ax.plot(avg_val["epoch"], avg_val["val_mae"], label=label)
        ax.fill_between(
            avg_val["epoch"],
            avg_val["val_mae"] - std_val["val_mae"],
            avg_val["val_mae"] + std_val["val_mae"],
            alpha=0.3,
        )


def plot_tasks(task, ax, classification=False):
    ss_task = get_task(species_dir, "skipspecies", "sum", task, 200)[-1]
    ssi_task = get_task(species_induced_dir, "skipspecies", "sum", task, 200)[-1]
    sa_task = get_task(atom_dir, "skipatominduced", "sum", task, 200)[-1]
    df_ss = pd.read_csv(f"{species_dir}/{ss_task}")
    df_ssi = pd.read_csv(f"{species_induced_dir}/{ssi_task}")
    df_sa = pd.read_csv(f"{atom_dir}/{sa_task}")

    plot_losses(df_ss, ax, classification=classification, label="SkipSpecies")
    plot_losses(df_ssi, ax, classification=classification, label="SkipSpecies Induced")
    plot_losses(df_sa, ax, classification=classification, label="SkipAtom")

In [None]:
supported_tasks = ["band_gap", "formation_energy_per_atom", "is_metal", "is_magnetic"]
task_to_string = {
    "band_gap": "Band gap",
    "formation_energy_per_atom": "$E_{form}$",
    "is_metal": "Metallic classification",
    "is_magnetic": "Magnetic classification",
}
regression_tasks = ["band_gap", "formation_energy_per_atom"]
classification_tasks = ["is_metal", "is_magnetic"]
y_axis_labels = {"band_gap": "MAE [eV]", "formation_energy_per_atom": "MAE [eV/atom]"}
fs = 13
fig, axes = plt.subplots(2, 2, figsize=(12, 7.5))
for task, ax in zip(supported_tasks, axes.flatten()):
    if task in classification_tasks:
        plot_tasks(task, ax, classification=True)
        ax.set_ylabel("AUC", fontsize=fs)

    else:
        plot_tasks(task, ax)
        ax.set_ylabel(y_axis_labels[task], fontsize=fs)

    ax.set_xlabel("Epoch", fontsize=fs)
    ax.set_title(f"{task_to_string[task]}", fontsize=fs)
    ax.tick_params(axis="both", labelsize=13)

    ax.legend(frameon=False, fontsize=13)
# plt.legend(frameon=False)
fig.text(0.05, 0.98, "(a)", fontsize=fs, weight="bold")
fig.text(0.55, 0.98, "(b)", fontsize=fs, weight="bold")

fig.text(0.05, 0.5, "(c)", fontsize=fs, weight="bold")
fig.text(0.55, 0.5, "(d)", fontsize=fs, weight="bold")
plt.tight_layout()
plt.savefig(
    "../plots/Validation_losses_12x7_5_publication.pdf",
    dpi=600,
    bbox_inches="tight",
    transparent=True,
)
plt.show()