In [None]:
from torch_geometric.data import Data
import torch
from torch import load, save, tensor

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict

import pubchempy as pcp
import sys
import re
import os

# Add src folder to the sys.path
src_path = "../src"
sys.path.insert(0, src_path)

from oxides_ml.dataset import OxidesGraphDataset
from oxides_ml.graph_tools import graph_plotter

# ============================================================================
# CONFIGURATION - Replace these paths with your data locations
# Or set environment variables: VASP_DATA_DIR, GRAPH_DATASET_DIR
# ============================================================================

vasp_directory = os.environ.get(
    "VASP_DATA_DIR",
    "/path/to/VASP/database_3/oxide_adsorbates"
)

graph_dataset_dir = os.environ.get(
    "GRAPH_DATASET_DIR",
    "./models/test_graph_datasets"
)

graph_params = {
    "structure": {"tolerance": 0.3, "scaling_factor": 1.25, "surface_order": 2},
    "features": {"adsorbate": False, "radical": False, "valence": False, "cn": False, "magnetization": False, "ads_height": False},
    "target": "adsorption_energy"
}

print(f"VASP directory: {vasp_directory}")
print(f"Graph dataset directory: {graph_dataset_dir}")

dataset = OxidesGraphDataset(vasp_directory, graph_dataset_dir, graph_params, initial_state=False, force_reload=True, augment=False)
len(dataset)

In [None]:
dataset[0].keys()

In [None]:
for graph in range(len(dataset)):
    if dataset[graph].adsorbate_name == "Acetylene":
        graph_plotter(dataset[graph])

In [None]:
# Keys to exclude from the DataFrame
exclude_keys = {'edge_index', 'edge_attr', 'node_feats', 'x', 'elem', 'idx', 'adsorbate_indices', 'target', 'facet' }  # Add or remove keys as needed

rows = []
for data in dataset:
    row = {}

    for key in data.keys():
        if key in exclude_keys:
            continue
        
        value = data[key]
        
        if isinstance(value, torch.Tensor):
            if value.ndim == 0 or (value.ndim == 1 and len(value) == 1):
                row[key] = value.item()
            elif key in ['adsorbate_indices', 'elem']:
                row[key] = value.tolist()
            else:
                continue
        else:
            row[key] = value

    rows.append(row)

df = pd.DataFrame(rows)

def flag_physisorbed(df, height_threshold=3, energy_threshold=-0.5):
    """
    Add a 'physisorbed' column to the dataframe based on thresholds.

    Args:
        df (pd.DataFrame): DataFrame with 'ads_height' and 'ads_energy' columns.
        height_threshold (float): Min height for physisorption [Å].
        energy_threshold (float): Max adsorption energy for physisorption [eV].

    Returns:
        pd.DataFrame: Modified dataframe with new 'physisorbed' column.
    """
    # df['physisorbed'] = (df['ads_height'] >= height_threshold) & (df['ads_energy'] >= energy_threshold)
    return df

df = flag_physisorbed(df)
df.head()

df[df['material'] == "IrO2"].count()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 6))  # 1 row, 3 columns

# First plot
sns.histplot(df, x = "ads_energy", bins=30, kde=True, palette="Set1", edgecolor='black', zorder=10, ax=axes[0])
axes[0].set_title("Global Adsorption Energy Distribution")
axes[0].set_xlabel("Adsorption Energy (eV)")
axes[0].set_ylabel("Frequency")
axes[0].grid(True)

# Second plot
sns.histplot(df, x = "ads_energy", hue = 'physisorbed', bins=30, kde=True, palette="Set1" , edgecolor='black', zorder=10, ax=axes[1])
axes[1].set_title("Global Adsorption Energy Distribution by Binding Type")
axes[1].set_xlabel("Adsorption Energy (eV)")
axes[1].set_ylabel("Frequency")
axes[1].grid(True)

# Adjust layout
plt.tight_layout()

plt.show()

In [None]:
df_RuO2 = df[df['material'] == 'RuO2'].sort_values(by=['adsorbate_group']).reset_index(drop=True)

fig, axes = plt.subplots(1, 3, figsize=(20, 6))  # 1 row, 3 columns

# First plot
sns.histplot(df_RuO2, x = "ads_energy", bins=30, kde=True, palette="Set1", edgecolor='black', zorder=10, ax=axes[0])
axes[0].set_title("RuO2 Adsorption Energy Distribution")
axes[0].set_xlabel("Adsorption Energy (eV)")
axes[0].set_ylabel("Frequency")
axes[0].grid(True)

# Second plot
sns.histplot(df_RuO2, x = "ads_energy", hue = 'physisorbed' , bins=30, kde=True, palette="Set1" , edgecolor='black', zorder=10, ax=axes[1])
axes[1].set_title("RuO2 Adsorption Energy Distribution by Binding Type")
axes[1].set_xlabel("Adsorption Energy (eV)")
axes[1].set_ylabel("Frequency")
axes[1].grid(True)

# Third plot
sns.boxplot(df_RuO2, x="physisorbed", y="ads_energy", hue = "physisorbed", palette="Set1", legend=False, ax=axes[2])
axes[2].set_title("RuO2 Adsorption Energy Variation by Binding Type")
axes[2].set_ylabel("Adsorption Energy (eV)")
axes[2].set_xlabel("Final Site")
axes[2].grid(True)

# Adjust layout
plt.tight_layout()

plt.show()

In [None]:
df_IrO2 = df[df['material'] == 'IrO2'].sort_values(by=['adsorbate_group']).reset_index(drop=True)

fig, axes = plt.subplots(1, 3, figsize=(20, 6))  # 1 row, 3 columns

# First plot
sns.histplot(df_IrO2, x = "ads_energy", bins=30, kde=True, palette="Set1", edgecolor='black', zorder=10, ax=axes[0])
axes[0].set_title("IrO2 Adsorption Energy Distribution")
axes[0].set_xlabel("Adsorption Energy (eV)")
axes[0].set_ylabel("Frequency")
axes[0].grid(True)

# Second plot
sns.histplot(df_IrO2, x = "ads_energy", hue = 'physisorbed' , bins=30, kde=True, palette="Set1", edgecolor='black', zorder=10, ax=axes[1])
axes[1].set_title("IrO2 Adsorption Energy Distribution by Binding Type")
axes[1].set_xlabel("Adsorption Energy (eV)")
axes[1].set_ylabel("Frequency")
axes[1].grid(True)

# Third plot
sns.boxplot(df_IrO2, x="physisorbed", y="ads_energy", hue = "physisorbed", palette="Set1", legend=False, ax=axes[2])
axes[2].set_title("IrO2 Adsorption Energy Variation by Binding Type")
axes[2].set_ylabel("Adsorption Energy (eV)")
axes[2].set_xlabel("Final Site")
axes[2].grid(True)

# Adjust layout
plt.tight_layout()

plt.show()