In [None]:
import sys, os
import pandas as pd
import numpy as np
import cycler
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.font_manager import FontProperties
legend_font = FontProperties(family='Arial', style='normal', size=9)
import seaborn as sns
from sklearn.metrics import r2_score
from scipy.stats import gaussian_kde
from io import StringIO
from matplotlib.font_manager import FontProperties

# Define custom font
arial_font = FontProperties(family='Arial', style='normal', size=11)

In [None]:
# Base path to dataset and figures
base_path = "/home/tvanhout/oxides_ML/models/NCV/"

# Directory containing the dataset
directory = "Db1"

plt_path = os.path.join(base_path, directory, "figures")

# Load data
df = pd.read_csv(os.path.join(base_path, directory, "summary.csv"))

# The column name is long, so alias it
data_col = df.columns[0]  # this is the "System,Material,...Abs_error_eV" column

# For each row, split the string into a list
all_dfs = []

for _, row in df.iterrows():
    row_data = row[data_col].split(",")  # split on commas
    run_id = row["run"]

    all_dfs.append(row_data + [run_id])  # append data row + run as last column

# Convert to DataFrame
df_summary = pd.DataFrame(all_dfs, columns=[
    "System", "Material", "Surface", "Molecule Group", "Molecule",
    "State", "Dissociation", "True_eV", "Prediction_eV",
    "Error_eV", "Abs_error_eV", "run"
])

# Convert numeric columns
numeric_cols = ["True_eV", "Prediction_eV", "Error_eV", "Abs_error_eV"]
df_summary[numeric_cols] = df_summary[numeric_cols].astype(float)

# Done!
df_summary.head()

In [None]:
# Calculate per-run MAE (correct for nested cross-validation)
mae_per_run = df_summary.groupby("run")["Abs_error_eV"].mean()
mae_nested = mae_per_run.mean()

# Global stats (on all data, which is usually okay for RMSE, R², etc.)
rmse = np.sqrt((df_summary["Error_eV"] ** 2).mean())
r2 = r2_score(df_summary["True_eV"], df_summary["Prediction_eV"])
mean = df_summary["Error_eV"].mean()
median = df_summary["Error_eV"].median()
std = df_summary["Error_eV"].std()
n = len(df_summary)

# Print statistics
print("Statistics based on nested cross-validation results:\n")
print("Mean Error: {:.2f} eV".format(mean))
print("Median Error: {:.2f} eV".format(median))
print("Standard Deviation: {:.2f} eV".format(std))
print("Nested MAE (mean of per-run MAEs): {:.2f} eV".format(mae_nested))
print("RMSE: {:.2f} eV".format(rmse))
print("R² Score: {:.2f}".format(r2))
print("Total Predictions (N): {}".format(n))

In [None]:
# Set figure and axes
fig, axs = plt.subplots(2, 2, figsize=(18/2.54, 14/2.54), dpi=300)  # Wider and taller to fit legends
plt.subplots_adjust(hspace=0.4, wspace=0.3)

# === TOP LEFT: Parity plot ===
sns.scatterplot(x="True_eV", y="Prediction_eV", hue="Material", data=df_summary.sort_values("Material"), ax=axs[0, 0], ec="k", s=15)
axs[0, 0].set_ylabel('$\mathit{E}_{ads}^{GNN}$ / eV')
axs[0, 0].set_xlabel('$\mathit{E}_{ads}^{DFT}$ / eV')
axs[0, 0].set_title("Parity plot")
l_min, l_max = -8, 4
axs[0, 0].set_xlim(l_min, l_max)
axs[0, 0].set_ylim(l_min, l_max)
axs[0, 0].plot([l_min, l_max], [l_min, l_max], c="k", zorder=-1)
axs[0, 0].xaxis.set_major_locator(MaxNLocator(5))
axs[0, 0].yaxis.set_major_locator(MaxNLocator(5))
text = "MAE = {:.2f} eV\nRMSE = {:.2f} eV\n$\mathit{{R}}^2$ = {:.2f}\nN = {}".format(mae_nested, rmse, r2, n)
props = dict(boxstyle='round', facecolor='white', edgecolor='black')
axs[0, 0].text(0.05, 0.95, text, transform=axs[0, 0].transAxes, fontsize=9,
               verticalalignment='top', bbox=props)
# Legend
handles, labels = axs[0, 0].get_legend_handles_labels()
legend = axs[0, 0].legend(handles, labels, loc="lower left", fontsize=9, ncol=2)
for text in legend.get_texts():
    text.set_text(text.get_text().capitalize())
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor("black")

# === TOP RIGHT: KDE of error ===
sns.kdeplot(df_summary["Error_eV"], fill=True, ax=axs[0, 1], alpha=0.5)
axs[0, 1].set_xlabel("$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV")
axs[0, 1].set_ylabel("Density")
axs[0, 1].set_title("Error distribution")
axs[0, 1].set_xlim(-1.5, 1.5)
ylim = 4
axs[0, 1].set_ylim(0, ylim)
axs[0, 1].vlines(np.mean(df_summary["Error_eV"]), 0, ylim, colors='r', linestyles='dashed', label='mean')
axs[0, 1].vlines(np.median(df_summary["Error_eV"]), 0, ylim, colors='g', linestyles='dashed', label='median')
axs[0, 1].legend(fontsize=9)
axs[0, 1].text(0.03, 0.95, "std = {:.2f}".format(std),
               transform=axs[0, 1].transAxes, va='top',
               bbox=dict(boxstyle='round', facecolor='white', alpha=1.0, edgecolor='black'),
               fontsize=9)

# === BOTTOM LEFT: Boxplot by molecule group ===
sns.boxplot(data=df_summary.sort_values("Molecule Group"), x="Molecule Group", y="Error_eV", hue="Molecule Group",
            linewidth=1, ax=axs[1, 0], fliersize=0, legend=False)
axs[1, 0].set_ylabel(r"$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV")
axs[1, 0].set_xlabel("Adsorbate Family")
axs[1, 0].set_title("Error Distribution by Adsorbate Family")
axs[1, 0].set_xticks([])
axs[1, 0].set_ylim(-1.1, 1.1)
axs[1, 0].yaxis.set_major_locator(MaxNLocator(5))
# # Legend
# handles, labels = axs[1, 0].get_legend_handles_labels()
# legend = axs[1, 0].legend(handles, labels, loc='center left', bbox_to_anchor=(1.02, 0.5), fontsize=9)
# for text in legend.get_texts():
#     text.set_text(text.get_text().capitalize())
# legend.get_frame().set_linewidth(0.5)
# legend.get_frame().set_edgecolor("black")

# === BOTTOM RIGHT: Boxplot by material ===
sns.boxplot(data=df_summary.sort_values("Material"), x="Material", y="Error_eV", hue="Material",
            linewidth=1, ax=axs[1, 1], fliersize=0, legend=True)
axs[1, 1].set_ylabel(r"$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV")
axs[1, 1].set_xlabel("Metal Oxide")
axs[1, 1].set_title("Error Distribution by Metal Oxide")
axs[1, 1].set_xticks([])
axs[1, 1].set_ylim(-1.5, 1.5)
axs[1, 1].yaxis.set_major_locator(MaxNLocator(5))
# Legend
handles, labels = axs[1, 1].get_legend_handles_labels()
legend = axs[1, 1].legend(handles, labels, loc="lower left", fontsize=9, ncol=2)
for text in legend.get_texts():
    text.set_text(text.get_text().capitalize())
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor("black")

# === Save or show ===
plt.tight_layout()
# Save
fig.savefig(os.path.join(plt_path, "combined_2x2_plots.svg"), dpi=300, bbox_inches="tight")
fig.savefig(os.path.join(plt_path, "combined_2x2_plots.png"), dpi=300, bbox_inches="tight")
plt.show()
