# Generalized Graph Dataset Analysis

This notebook provides a flexible framework for analyzing PyTorch Geometric graph datasets with configurable parameters for graph construction (tolerance, scaling factor, surface order) and node features.

## Configuration

Set your dataset parameters below. Modify these variables to analyze different datasets.

In [None]:
import os
import sys
import json

from torch_geometric.data import Data
import torch
from torch import load, save, tensor

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict

import pubchempy as pcp
import numpy as np

# Add src folder to the sys.path
src_path = "../src"
sys.path.insert(0, src_path)

from oxides_ml.dataset import OxidesGraphDataset
from oxides_ml.graph_tools import graph_plotter

### Dataset Configuration Parameters

In [None]:
# ============================================================================
# CONFIGURATION SECTION - Modify these parameters to analyze different datasets
# ============================================================================

# Dataset source and output directories
# Modify these paths to point to your VASP data and desired output directory
VASP_DIRECTORY = os.environ.get(
    "VASP_DATA_DIR",
    "/path/to/your/VASP/data"  # Replace with your VASP data directory
)

GRAPH_DATASET_DIR = os.environ.get(
    "GRAPH_DATASET_DIR",
    "./graph_datasets"  # Replace with your desired output directory
)

# Graph structure parameters
TOLERANCE = 0.3  # Nearest-neighbor distance cutoff (Å)
SCALING_FACTOR = 1.25  # Unit cell scaling factor
SURFACE_ORDER = 2  # Order of nearest neighbors for surface classification (1, 2, 3, or custom)

# Node features to include in the dataset
INCLUDE_ADSORBATE_FLAG = False  # Flag indicating adsorbate vs surface atoms
INCLUDE_RADICAL = False  # Radical electron indicators
INCLUDE_VALENCE = False  # Valence electron counts
INCLUDE_CN = False  # Coordination number (highly recommended)
INCLUDE_MAGNETIZATION = False  # Spin polarization from DFT
INCLUDE_ADS_HEIGHT = False  # Adsorbate height above surface

# Target property and data options
TARGET_PROPERTY = "adsorption_energy"  # Property to predict (e.g., "adsorption_energy", "formation_energy")
USE_INITIAL_STATE = False  # Use initial structures (True) or relaxed structures (False)
USE_AUGMENTATION = False  # Data augmentation: include both initial and relaxed structures
FORCE_RELOAD = False  # Force reprocessing of graphs (slow, use for regenerating datasets)

NUM_CORES = os.cpu_count()  # Number of cores for multiprocessing

# ============================================================================
# NOTE: Set environment variables for different systems:
# export VASP_DATA_DIR=/path/to/your/data
# export GRAPH_DATASET_DIR=/path/to/output
# ============================================================================

print(f"VASP Directory: {VASP_DIRECTORY}")
print(f"Graph Dataset Directory: {GRAPH_DATASET_DIR}")

### Load and Display Dataset

In [None]:
# Build graph parameters from configuration
graph_params = {
    "structure": {
        "tolerance": TOLERANCE,
        "scaling_factor": SCALING_FACTOR,
        "surface_order": SURFACE_ORDER
    },
    "features": {
        "adsorbate": INCLUDE_ADSORBATE_FLAG,
        "radical": INCLUDE_RADICAL,
        "valence": INCLUDE_VALENCE,
        "cn": INCLUDE_CN,
        "magnetization": INCLUDE_MAGNETIZATION,
        "ads_height": INCLUDE_ADS_HEIGHT
    },
    "target": TARGET_PROPERTY
}

print("Graph Parameters:")
print(json.dumps(graph_params, indent=2))

# Load the dataset
dataset = OxidesGraphDataset(
    VASP_DIRECTORY,
    GRAPH_DATASET_DIR,
    graph_params,
    ncores=NUM_CORES,
    initial_state=USE_INITIAL_STATE,
    augment=USE_AUGMENTATION,
    force_reload=FORCE_RELOAD
)

print(f"\nDataset loaded successfully!")
print(f"Total number of graphs: {len(dataset)}")

## Dataset Inspection

In [None]:
# Inspect the first graph
first_graph = dataset[0]

print("First graph properties:")
print(f"Keys: {list(first_graph.keys())}")
print(f"\nNumber of nodes: {first_graph.num_nodes}")
print(f"Number of edges: {first_graph.num_edges}")
print(f"Node feature dimension: {first_graph.x.shape[1] if hasattr(first_graph, 'x') else 'N/A'}")
print(f"\nFormula: {first_graph.formula}")
print(f"Material: {first_graph.material}")
print(f"Adsorbate: {first_graph.adsorbate_name}")
print(f"Type: {first_graph.type}")
print(f"State: {first_graph.state}")
print(f"Target ({TARGET_PROPERTY}): {first_graph.target.item():.3f}")
print(f"Node features: {first_graph.node_feats}")

## Dataset Statistics

In [None]:
# Collect dataset statistics
exclude_keys = {'edge_index', 'edge_attr', 'node_feats', 'x', 'elem', 'idx', 'adsorbate_indices', 'target', 'facet'}

rows = []
for data in dataset:
    row = {}
    for key in data.keys():
        if key in exclude_keys:
            continue
        value = data[key]
        if isinstance(value, torch.Tensor):
            row[key] = value.item()
        else:
            row[key] = value
    rows.append(row)

df = pd.DataFrame(rows)

print(f"Dataset shape: {df.shape}")
print(f"\nDataset summary:")
print(df.describe())

## Filtering and Analysis by Material

In [None]:
# Count graphs by material
material_counts = df['material'].value_counts()
print("Graphs by material:")
print(material_counts)

# Count graphs by type
type_counts = df['type'].value_counts()
print("\nGraphs by type:")
print(type_counts)

# Count graphs by state
state_counts = df['state'].value_counts()
print("\nGraphs by state:")
print(state_counts)

## Target Property Statistics

In [None]:
# Extract target values
targets = [data.target.item() for data in dataset]

print(f"Target property: {TARGET_PROPERTY}")
print(f"Mean: {np.mean(targets):.4f} eV")
print(f"Median: {np.median(targets):.4f} eV")
print(f"Std Dev: {np.std(targets):.4f} eV")
print(f"Min: {np.min(targets):.4f} eV")
print(f"Max: {np.max(targets):.4f} eV")

# Histogram of target values
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(targets, bins=50, edgecolor='black')
ax.set_xlabel(f'{TARGET_PROPERTY} (eV)')
ax.set_ylabel('Count')
ax.set_title(f'Distribution of {TARGET_PROPERTY}')
plt.tight_layout()
plt.show()

## Graph Visualization

In [None]:
# Function to find graphs by criteria
def find_graphs_by_adsorbate(dataset, adsorbate_name):
    """Find all graphs with a specific adsorbate"""
    indices = [i for i in range(len(dataset)) if dataset[i].adsorbate_name == adsorbate_name]
    return indices

def find_graphs_by_material(dataset, material):
    """Find all graphs for a specific material"""
    indices = [i for i in range(len(dataset)) if dataset[i].material == material]
    return indices

def find_graphs_by_criteria(dataset, material=None, adsorbate=None, energy_range=None):
    """Find graphs matching multiple criteria"""
    indices = list(range(len(dataset)))
    
    if material:
        indices = [i for i in indices if dataset[i].material == material]
    
    if adsorbate:
        indices = [i for i in indices if dataset[i].adsorbate_name == adsorbate]
    
    if energy_range:
        min_e, max_e = energy_range
        indices = [i for i in indices if min_e <= dataset[i].target.item() <= max_e]
    
    return indices

# Example: Find and visualize a specific graph
# Uncomment and modify to visualize graphs
# idx = find_graphs_by_adsorbate(dataset, "Acetylene")[0]
# graph_plotter(dataset[idx])

## Adsorbate Distribution

In [None]:
# Count adsorbates
adsorbate_counts = df['adsorbate_name'].value_counts()
print(f"\nNumber of unique adsorbates: {len(adsorbate_counts)}")
print(f"\nTop 20 most frequent adsorbates:")
print(adsorbate_counts.head(20))

# Visualize top adsorbates
fig, ax = plt.subplots(figsize=(12, 8))
adsorbate_counts.head(20).plot(kind='barh', ax=ax)
ax.set_xlabel('Count')
ax.set_title('Top 20 Most Frequent Adsorbates')
plt.tight_layout()
plt.show()