In [None]:
import sys, os
import pandas as pd
import numpy as np
import cycler
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.font_manager import FontProperties
legend_font = FontProperties(family='Arial', style='normal', size=9)
import seaborn as sns
from sklearn.metrics import r2_score
from scipy.stats import gaussian_kde


base_path = "/home/tvanhout/oxides_ML/models/Experiments/RELAXED/tolerance_fixed/"
directory = "Db1_TiO2_base"

plt_path = os.path.join(base_path, directory, "figures")

# Load data
df_test_set = pd.read_csv(os.path.join(base_path, directory, "test_set.csv"))

df_uq = pd.read_csv(os.path.join(base_path, directory, "uq.csv"))
df_uq_test = df_uq[df_uq["split"] == "test"].sort_values("molecule_group")

df_test_set.info()

In [None]:
# Group MAE data
df_MAE = df_test_set.groupby("Molecule Group")[["Abs_error_eV"]].mean().reset_index()

# Create side-by-side subplots
fig, axes = plt.subplots(1, 2, figsize=(18/2.54, 9/2.54), dpi=300)
sns.color_palette("hls", 12)

params = {'mathtext.default': 'regular'}          
plt.rcParams.update(params)

# ---- Left subplot: Boxplot of error ----
sns.boxplot(
    data=df_uq_test, x="molecule_group", y="error", hue="molecule_group",
    linewidth=1, ax=axes[0], fliersize=1
)
axes[0].set_ylabel(r"$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV", fontsize=10)
axes[0].set_xlabel("Adsorbate Family", fontsize=10)
axes[0].set_title("Error Distribution by Adsorbate Family", fontsize=11)
axes[0].set_xticks([])  # Hide x tick labels
axes[0].set_ylim(-1.1, 1.1)
axes[0].yaxis.set_major_locator(MaxNLocator(5))
axes[0].tick_params(axis='y', labelsize=8)

# ---- Right subplot: Barplot of MAE ----
sns.barplot(
    data=df_MAE, x="Molecule Group", y="Abs_error_eV", hue="Molecule Group",
    linewidth=1, edgecolor="black", ax=axes[1]
)
axes[1].set_ylabel(r"MAE / eV", fontsize=10)
axes[1].set_xlabel("Adsorbate Family", fontsize=10)
axes[1].set_title("Mean Absolute Error by Adsorbate Family", fontsize=11)
axes[1].set_xticks([])  # Hide x tick labels
ymax = df_MAE["Abs_error_eV"].max() + 0.05
axes[1].set_ylim(0, ymax)
axes[1].yaxis.set_major_locator(MaxNLocator(5))
axes[1].tick_params(axis='y', labelsize=8)

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
# First figure: Boxplot of error
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

params = {'mathtext.default': 'regular'}
plt.rcParams.update(params)

sns.boxplot(
    data=df_uq_test, x="molecule_group", y="error", hue="molecule_group",
    linewidth=1, ax=ax, fliersize=1, legend=True
)

ax.set_ylabel(r"$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV", fontsize=10)
ax.set_xlabel("Adsorbate Family", fontsize=10)
ax.set_title("Error Distribution by Adsorbate Family", fontsize=11)
ax.set_xticks([])  # Hide x tick labels
ax.set_ylim(-1.1, 1.1)
ax.yaxis.set_major_locator(MaxNLocator(5))
ax.tick_params(axis='y', labelsize=8)

# Legend formatting
handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles, labels,
    loc='center left',
    bbox_to_anchor=(1.02, 0.5),  # Position legend just outside to the right
    title="", fontsize=9,
    ncol=1, columnspacing=0.4, handletextpad=0.2, borderpad=0.35, framealpha=1
)
for text in ax.get_legend().get_texts():
    text.set_text(text.get_text().capitalize())
ax.get_legend().get_frame().set_linewidth(0.5)
ax.get_legend().get_frame().set_edgecolor("black")


# plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "boxplot_error_distribution.png"), dpi=300, bbox_inches='tight' )
fig.savefig(os.path.join(plt_path, "boxplot_error_distribution.svg"), dpi=300, bbox_inches='tight')


In [None]:
# Group MAE data
df_MAE = df_test_set.groupby("Molecule Group")[["Abs_error_eV"]].mean().reset_index()

# Second figure: Barplot of MAE
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

params = {'mathtext.default': 'regular'}
plt.rcParams.update(params)

sns.barplot(
    data=df_MAE, x="Molecule Group", y="Abs_error_eV", hue="Molecule Group",
    linewidth=1, edgecolor="black", ax=ax
)

ax.set_ylabel(r"MAE / eV", fontsize=10)
ax.set_xlabel("Adsorbate Family", fontsize=10)
ax.set_title("Mean Absolute Error by Adsorbate Family", fontsize=11)
ax.set_xticks([])  # Hide x tick labels
ymax = df_MAE["Abs_error_eV"].max() + 0.05
ax.set_ylim(0, ymax)
ax.yaxis.set_major_locator(MaxNLocator(5))
ax.tick_params(axis='y', labelsize=8)

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "barplot_MAE.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "barplot_MAE.svg"), dpi=300)

In [None]:
# First figure: Boxplot of error
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

params = {'mathtext.default': 'regular'}
plt.rcParams.update(params)

sns.boxplot(
    data=df_uq_test, x="material", y="error", hue="material",
    linewidth=1, ax=ax, fliersize=1, legend=True
)

ax.set_ylabel(r"$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV", fontsize=10)
ax.set_xlabel("Adsorbate Family", fontsize=10)
ax.set_title("Error Distribution by Metal Oxide", fontsize=11)
ax.set_xticks([])  # Hide x tick labels
ax.set_ylim(-1.1, 1.1)
ax.yaxis.set_major_locator(MaxNLocator(5))
ax.tick_params(axis='y', labelsize=8)

# Legend formatting
handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles, labels, loc="lower left", title="", fontsize=9,
    ncol=2, columnspacing=0.4, handletextpad=0.2, borderpad=0.35, framealpha=1
)
for text in ax.get_legend().get_texts():
    text.set_text(text.get_text().capitalize())
ax.get_legend().get_frame().set_linewidth(0.5)
ax.get_legend().get_frame().set_edgecolor("black")

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "material_error_distribution.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "material_error_distribution.svg"), dpi=300)


In [None]:
# print statistics
mae = df_test_set["Abs_error_eV"].mean()
rmse = np.sqrt((df_test_set["Error_eV"]**2).mean())
r2 = r2_score(df_test_set["True_eV"], df_test_set["Prediction_eV"])
mean = df_test_set["Prediction_eV"].mean()
median = df_test_set["Prediction_eV"].median()
std = df_test_set["Prediction_eV"].std()
n = len(df_test_set["Prediction_eV"])

print("Mean: {:.2f} eV".format(mean))
print("Median: {:.2f} eV".format(median))
print("Std: {:.2f} eV".format(std))
print("MAE: {:.2f} eV".format(mae))
print("RMSE: {:.2f} eV".format(rmse))
print("R2: {:.2f}".format(r2))
print("N: {}".format(n))

In [None]:
df_count = df_uq.groupby("molecule_group")["molecule"].count().reset_index()

# subplot with first two figures
fig, ax = plt.subplots(1, 2, figsize=(18/2.54,9/2.54), dpi=300)
sns.color_palette("hls", 12)

sns.barplot(
    data=df_count, x="molecule_group", y="molecule", hue="molecule_group",
    linewidth=1, edgecolor="black", ax=ax[0]
)
ax[0].set_ylabel(r"Count", fontsize=10)
ax[0].set_xlabel("Adsorbate Family", fontsize=10)
ax[0].set_title("Count by Adsorbate Family", fontsize=11)
ax[0].set_xticks([])  # Hide x tick labels
ymax = df_count["molecule"].max() + 10
ax[0].set_ylim(0, ymax)
ax[0].yaxis.set_major_locator(MaxNLocator(5))
ax[0].tick_params(axis='y', labelsize=8)


sns.kdeplot(data=df_test_set, x="Prediction_eV", hue="Molecule Group", fill=True, ax=ax[1], alpha=0.5, multiple="stack", linewidth=0)
ax[1].spines['right'].set_linewidth(1.0)
ax[1].spines['right'].set_color('black')
ax[1].spines['top'].set_linewidth(1.0)
ax[1].spines['top'].set_color('black')
ax[1].spines['left'].set_linewidth(1.0)
ax[1].spines['left'].set_color('black')
ax[1].spines['bottom'].set_linewidth(1.0)
ax[1].spines['bottom'].set_color('black')
plt.tick_params("both")
plt.xlabel("$\mathit{E}_{ads}^{GNN}$ / eV")  
plt.ylabel("Density")
plt.title("Predicted Energy distribution by Adsorbate Family")
plt.xlim(-5, 1)
ylim = 0.6
plt.ylim(0, ylim)
plt.vlines(np.mean(df_test_set["Prediction_eV"]), 0, ylim, colors='r', linestyles='dashed', label='mean')
plt.vlines(np.median(df_test_set["Prediction_eV"]), 0, ylim, colors='g', linestyles='dashed', label='median')
plt.legend(fontsize=9)
plt.text(0.03, 0.95, "mean = {:.2f}\nmedian = {:.2f}\nstd = {:.2f}".format(mean, median, std),
 transform=plt.gca().transAxes, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=1.0, edgecolor='black'), fontsize=9)
plt.tight_layout()

In [None]:
# Barplot of molecule count
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

sns.barplot(
    data=df_count, x="molecule_group", y="molecule", hue="molecule_group",
    linewidth=1, edgecolor="black", ax=ax
)
ax.set_ylabel(r"Count", fontsize=10)
ax.set_xlabel("Adsorbate Family", fontsize=10)
ax.set_title("Count by Adsorbate Family", fontsize=11)
ax.set_xticks([])  # Hide x tick labels
ymax = df_count["molecule"].max() + 10
ax.set_ylim(0, ymax)
ax.yaxis.set_major_locator(MaxNLocator(5))
ax.tick_params(axis='y', labelsize=8)

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "barplot_count.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "barplot_count.svg"), dpi=300)


In [None]:
# KDE plot of prediction energy
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

sns.kdeplot(
    data=df_test_set, x="Prediction_eV", hue="Molecule Group",
    fill=True, ax=ax, alpha=0.5, multiple="stack", linewidth=0
)

# Format axes and spines
for spine in ['right', 'top', 'left', 'bottom']:
    ax.spines[spine].set_linewidth(1.0)
    ax.spines[spine].set_color('black')

ax.set_xlabel("$\mathit{E}_{ads}^{GNN}$ / eV")
ax.set_ylabel("Density")
ax.set_title("Predicted Energy Distribution by Adsorbate Family")
ax.set_xlim(-5, 1)
ylim = 0.6
ax.set_ylim(0, ylim)

# Add mean and median lines
mean = df_test_set["Prediction_eV"].mean()
median = df_test_set["Prediction_eV"].median()
std = df_test_set["Prediction_eV"].std()

ax.axvline(mean, 0, 1, color='r', linestyle='dashed', label='mean')
ax.axvline(median, 0, 1, color='g', linestyle='dashed', label='median')
ax.legend(fontsize=9)

# Add annotation box
ax.text(
    0.03, 0.95,
    "mean = {:.2f}\nmedian = {:.2f}\nstd = {:.2f}".format(mean, median, std),
    transform=ax.transAxes, va='top',
    bbox=dict(boxstyle='round', facecolor='white', alpha=1.0, edgecolor='black'),
    fontsize=9
)

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "energy_distribution.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "energy_distribution.svg"), dpi=300)

In [None]:
# print statistics
mae = df_test_set["Abs_error_eV"].mean()
rmse = np.sqrt((df_test_set["Error_eV"]**2).mean())
r2 = r2_score(df_test_set["True_eV"], df_test_set["Prediction_eV"])
mean = df_test_set["Error_eV"].mean()
median = df_test_set["Error_eV"].median()
std = df_test_set["Error_eV"].std()
n = len(df_test_set["Error_eV"])

print("Mean: {:.2f} eV".format(mean))
print("Median: {:.2f} eV".format(median))
print("Std: {:.2f} eV".format(std))
print("MAE: {:.2f} eV".format(mae))
print("RMSE: {:.2f} eV".format(rmse))
print("R2: {:.2f}".format(r2))
print("N: {}".format(n))

In [None]:
# subplot with first two figures
fig, ax = plt.subplots(1, 2, figsize=(18/2.54,9/2.54), dpi=300)
sns.color_palette("hls", 12)
sns.scatterplot(x="True_eV", y="Prediction_eV", hue="Material", data=df_test_set, ax=ax[0], ec="k", s=15)
params = {'mathtext.default': 'regular'}          
plt.rcParams.update(params)
ax[0].set_ylabel('$\mathit{E}_{ads}^{GNN}$ / eV')
ax[0].set_xlabel('$\mathit{E}_{ads}^{DFT}$ / eV')
ax[0].set_title("Parity plot")
l_min = -6
l_max = 1.5
ax[0].set_xlim(l_min, l_max)
ax[0].set_ylim(l_min, l_max)
ax[0].plot([l_min, l_max], [l_min, l_max], c="k", zorder=-1)
ax[0].xaxis.set_major_locator(MaxNLocator(5))
ax[0].yaxis.set_major_locator(MaxNLocator(5))
plt.tight_layout()
text = "MAE = {:.2f} eV\nRMSE = {:.2f} eV\n$\mathit{{R}}^{{2}}$ = {:.2f}\nN = {}".format(mae, rmse, r2, n)
props = dict(boxstyle='round', facecolor='white', edgecolor='black')
ax[0].text(0.05, 0.95, text, transform=ax[0].transAxes, fontsize=9,
        verticalalignment='top', bbox=props)
handles, labels = ax[0].get_legend_handles_labels()
ax[0].legend(handles, labels, loc="lower left", title="", fontsize=9, ncol=2, columnspacing=0.4, handletextpad=0.2, borderpad=0.35, framealpha=1)
# Capitalize legend labels
for text in ax[0].get_legend().get_texts():
    text.set_text(text.get_text().capitalize())
# Reduce spacing between the two columns of the legend
ax[0].get_legend().get_frame().set_linewidth(0.5)
ax[0].get_legend().get_frame().set_edgecolor("black")


sns.kdeplot(df_test_set["Error_eV"], fill=True, ax=ax[1], alpha=0.5)
ax[1].spines['right'].set_linewidth(1.0)
ax[1].spines['right'].set_color('black')
ax[1].spines['top'].set_linewidth(1.0)
ax[1].spines['top'].set_color('black')
ax[1].spines['left'].set_linewidth(1.0)
ax[1].spines['left'].set_color('black')
ax[1].spines['bottom'].set_linewidth(1.0)
ax[1].spines['bottom'].set_color('black')
plt.tick_params("both")
plt.xlabel("$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV")  
plt.ylabel("Density")
plt.title("Error distribution")
plt.xlim(-1.5, 1.5)
ylim = 4
plt.ylim(0, ylim)
plt.vlines(np.mean(df_test_set["Error_eV"]), 0, ylim, colors='r', linestyles='dashed', label='mean')
plt.vlines(np.median(df_test_set["Error_eV"]), 0, ylim, colors='g', linestyles='dashed', label='median')
plt.legend(fontsize=9)
plt.text(0.03, 0.95, "mean = {:.2f}\nmedian = {:.2f}\nstd = {:.2f}".format(mean, median, std),
 transform=plt.gca().transAxes, va='top', bbox=dict(boxstyle='round', facecolor='white', alpha=1.0, edgecolor='black'), fontsize=9)
plt.tight_layout()

In [None]:
# Parity plot
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

sns.scatterplot(
    x="True_eV", y="Prediction_eV", hue="Material", data=df_test_set,
    ax=ax, ec="k", s=15
)

params = {'mathtext.default': 'regular'}          
plt.rcParams.update(params)

ax.set_ylabel('$\mathit{E}_{ads}^{GNN}$ / eV')
ax.set_xlabel('$\mathit{E}_{ads}^{DFT}$ / eV')
ax.set_title("Parity plot")

l_min = -6
l_max = 1.5
ax.set_xlim(l_min, l_max)
ax.set_ylim(l_min, l_max)

# Diagonal line
ax.plot([l_min, l_max], [l_min, l_max], c="k", zorder=-1)

# Ticks
ax.xaxis.set_major_locator(MaxNLocator(5))
ax.yaxis.set_major_locator(MaxNLocator(5))

# Metrics box
text = "MAE = {:.2f} eV\nRMSE = {:.2f} eV\n$\mathit{{R}}^{{2}}$ = {:.2f}\nN = {}".format(mae, rmse, r2, n)
props = dict(boxstyle='round', facecolor='white', edgecolor='black')
ax.text(0.05, 0.95, text, transform=ax.transAxes, fontsize=9,
        verticalalignment='top', bbox=props)

# Legend formatting
handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles, labels, loc="lower left", title="", fontsize=9,
    ncol=2, columnspacing=0.4, handletextpad=0.2, borderpad=0.35, framealpha=1
)
for text in ax.get_legend().get_texts():
    text.set_text(text.get_text().capitalize())
ax.get_legend().get_frame().set_linewidth(0.5)
ax.get_legend().get_frame().set_edgecolor("black")

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "parity_plot.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "parity_plot.svg"), dpi=300)


In [None]:
# KDE plot of prediction error
fig, ax = plt.subplots(figsize=(12/2.54, 8/2.54), dpi=300)
sns.color_palette("hls", 12)

sns.kdeplot(df_test_set["Error_eV"], fill=True, ax=ax, alpha=0.5)

# Spine formatting
for spine in ['right', 'top', 'left', 'bottom']:
    ax.spines[spine].set_linewidth(1.0)
    ax.spines[spine].set_color('black')

plt.tick_params("both")
ax.set_xlabel("$\mathit{E}_{ads}^{DFT} - \mathit{E}_{ads}^{GNN}$ / eV")  
ax.set_ylabel("Density")
ax.set_title("Error distribution")

ax.set_xlim(-1, 1)
ylim = 4
ax.set_ylim(0, ylim)

# Mean and median lines
mean = df_test_set["Error_eV"].mean()
median = df_test_set["Error_eV"].median()
std = df_test_set["Error_eV"].std()

ax.axvline(mean, 0, 1, color='r', linestyle='dashed', label='mean')
ax.axvline(median, 0, 1, color='g', linestyle='dashed', label='median')
ax.legend(fontsize=9)

# Text box
ax.text(
    0.03, 0.95,
    "mean = {:.2f}\nmedian = {:.2f}\nstd = {:.2f}".format(mean, median, std),
    transform=ax.transAxes, va='top',
    bbox=dict(boxstyle='round', facecolor='white', alpha=1.0, edgecolor='black'),
    fontsize=9
)

plt.tight_layout()
plt.show()

# Save the figure
fig.savefig(os.path.join(plt_path, "error_distribution.png"), dpi=300)
fig.savefig(os.path.join(plt_path, "error_distribution.svg"), dpi=300)
