In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

sys.path.append(project_root)

In [2]:
import torch

from src.models.cgcnn import CGCNN
from src.xai.wrappers import CGCNNWrapper
from src.data.loader import get_loaders
from src.utils.io import load_model
from src.xai.gnn_explainer import explain_prediction
from src.utils.config import load_config

In [3]:
yaml_pth = "../configs/test.yaml"
cfg = load_config(yaml_pth)

In [None]:
train_loader, val_loader, test_loader = get_loaders("../data/processed",
                                                    target=["formation_energy_per_atom"],
                                                    batch_size=16,
                                                    num_workers=0,
                                                    train_ratio=0.8,
                                                    val_ratio=0.1,
                                                    seed=42)

model = CGCNN(
        node_input_dim=cfg.model.node_input_dim,
        edge_input_dim=cfg.model.edge_input_dim,
        hidden_dim=cfg.model.hidden_dim,
        num_layers=cfg.model.num_layers,
        output_dim=cfg.model.output_dim,
    )

model = load_model(model, "../checkpoints/cgcnn_test_best.pth", device="cuda")

In [None]:
data_iter = iter(test_loader)
first_batch = next(data_iter)
data = first_batch[0]
print(data)

In [None]:
from src.xai.gnn_explainer import build_explainer

wrapped_model = CGCNNWrapper(model)

data = data.to("cuda")

explainer = build_explainer(wrapped_model, epochs=100)
explanation = explainer(
    x=data.x,
    edge_index=data.edge_index,
    edge_attr=data.edge_attr,
    target=data.y
)

# 결과 확인
edge_mask = explanation.edge_mask.detach().cpu()
print("Edge importance:", edge_mask[:10])

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer

explainer = build_explainer(wrapped_model, epochs=100)
explanation = explainer(
    x=sample.x,
    edge_index=sample.edge_index,
    edge_attr=sample.edge_attr,
    batch=sample.batch if hasattr(sample, "batch") else torch.zeros(sample.x.size(0), dtype=torch.long).to(sample.x.device),
    target=sample.y
)

node_index = 10
explanation = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')

path = 'feature_importance.png'
explanation.visualize_feature_importance(path, top_k=10)
print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")

In [None]:
explain_prediction(
    model=CGCNNWrapper(model),
    data=data,
    save_path="../explain_sample0.png"
)