Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added XGNN implementation to PyG #1

Merged
merged 126 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
daef679
initial commit
Nov 7, 2023
f58195f
added base files
Nov 7, 2023
c67a3c2
a
Nov 7, 2023
f8dd136
a
Nov 7, 2023
38597cf
a
Nov 7, 2023
f7e36e7
a
Nov 7, 2023
d168a79
XGNNExplainer example
BP4769 Nov 7, 2023
30b5a20
XGNNExplanation example
BP4769 Nov 7, 2023
027e201
Expaliner expanded to handle GenerativeExplanation
BP4769 Nov 7, 2023
8ce4ff9
changes
Nov 7, 2023
956d7e5
changes
Nov 7, 2023
82af31b
changes
Nov 7, 2023
5d1b38a
changes
Nov 7, 2023
3f4bf75
Updated explanation documentation
amadejp Nov 7, 2023
9cd21b0
modified: torch_geometric/explain/algorithm/xgnn_explainer.py
Nov 7, 2023
0c126dc
testing
BP4769 Nov 7, 2023
c830ba3
testing
BP4769 Nov 7, 2023
c856513
adapted generativeExplanaton
Nov 7, 2023
b4b6b1f
a
Nov 7, 2023
949b287
Commit to pull
BP4769 Nov 14, 2023
b47450a
Code cleaned
BP4769 Nov 14, 2023
bed350c
XGNNGenerator documentation update
amadejp Nov 14, 2023
6c97973
a
Nov 14, 2023
25d3bd1
class name reformat fix
amadejp Nov 14, 2023
c4d0862
XGNNExplainer docs update
amadejp Nov 15, 2023
9bc2f53
masks in GraphGenerator
BP4769 Nov 30, 2023
afafc2e
XGNNTrainer name fixed
BP4769 Nov 30, 2023
e5423fb
added datasets and classes, pretrained model
Nov 30, 2023
3af3d48
adjustment
Nov 30, 2023
9c86ddd
added pretrained model to example
Nov 30, 2023
ee675ee
test XGNNTrainer names fixed
BP4769 Nov 30, 2023
e49cf84
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Nov 30, 2023
897a636
calculate reward function
amadejp Nov 30, 2023
4710d93
bla
Nov 30, 2023
e1e6b04
commit
BP4769 Nov 30, 2023
e5539f0
commit
BP4769 Nov 30, 2023
3a5acaf
calculate reward function
amadejp Nov 30, 2023
1e7b813
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
amadejp Nov 30, 2023
1c9b0db
train_generative_model development
amadejp Nov 30, 2023
88ad513
new ExplanationType added
BP4769 Nov 30, 2023
af5c9c5
fixing semantic/other errors
amadejp Nov 30, 2023
5d17d86
debugging rl example
amadejp Dec 5, 2023
8019da1
dev
Dec 5, 2023
2448c0c
debugging
BP4769 Dec 5, 2023
e07939c
Update xgnn_explainer.py
amadejp Dec 5, 2023
8b5659a
dev
Dec 5, 2023
88f63e8
Merge branch 'dev' of github.com:SimonBele/pytorch_geometric into dev
Dec 5, 2023
6ede98f
a
Dec 5, 2023
597135c
debugging GraphGenerator and RL example in general
amadejp Dec 5, 2023
d221bc9
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
amadejp Dec 5, 2023
01fac18
debigging
BP4769 Dec 7, 2023
125b5cf
refactored
Dec 7, 2023
a427234
mutag added, seperate directory created
Dec 7, 2023
45a1eaf
debugging
BP4769 Dec 7, 2023
a007d52
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 7, 2023
026465a
fixes
Dec 7, 2023
4ddfb25
a
Dec 7, 2023
d587682
objectview
Dec 7, 2023
d8bfaca
Update xgnn_explainer.py
amadejp Dec 7, 2023
5126563
candidate set extracted from data
BP4769 Dec 7, 2023
b74fd8d
debugging
amadejp Dec 7, 2023
321e4a0
Update xgnn_explainer.py
amadejp Dec 7, 2023
889d08d
Update xgnn_explainer.py
amadejp Dec 7, 2023
7502d10
debugging RL example
amadejp Dec 7, 2023
5add1fd
froze pretrained
Dec 7, 2023
2f20331
custom softmax function added
amadejp Dec 7, 2023
ae5b32d
simon hated doing this
BP4769 Dec 7, 2023
a9dcdf5
another model added
BP4769 Dec 11, 2023
ddccc22
code made prettier
BP4769 Dec 11, 2023
bf7af4d
debugging
BP4769 Dec 11, 2023
bf8f0bb
added weight initialization
Dec 12, 2023
ff0c21d
better training
Dec 12, 2023
160e71b
added pretrained model
Dec 12, 2023
95439e4
Fix XGNNExplainer to handle missing candidate_set argument
BP4769 Dec 12, 2023
351855f
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 12, 2023
490f95e
debugging + Add initial_node_type parameter to RLGenExplainer constru…
BP4769 Dec 12, 2023
787553f
refactored xgnnexplainer
SimonBele Dec 12, 2023
76c7618
refactor fix
amadejp Dec 12, 2023
b2a09b1
debugging
BP4769 Dec 12, 2023
111b0ad
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 12, 2023
b8c5aeb
a
SimonBele Dec 12, 2023
758e741
Merge branch 'dev' of github.com:SimonBele/pytorch_geometric into dev
SimonBele Dec 12, 2023
a3e38fc
fixed validate
SimonBele Dec 12, 2023
694483a
xgnn model fixing
amadejp Dec 12, 2023
a7a7280
model args and .pth updated
amadejp Dec 12, 2023
c60bd05
single graph prediction support in pretrained
SimonBele Dec 12, 2023
64c243d
rl example gnn_output updated
amadejp Dec 12, 2023
a621ee0
quickfix
amadejp Dec 12, 2023
ebf365f
Fix bug in login functionality
BP4769 Dec 12, 2023
4fcd85d
added generatemodel base class
SimonBele Dec 12, 2023
f272748
changed generativeExplanation
SimonBele Dec 12, 2023
2b7eba1
merge conflict resolved + minor changes
BP4769 Dec 12, 2023
c9a4509
added base class
SimonBele Dec 13, 2023
a95e6a0
Merge branch 'dev' of github.com:SimonBele/pytorch_geometric into dev
SimonBele Dec 13, 2023
dcfaa50
fixed imports
SimonBele Dec 13, 2023
3f9ed5c
changed generative explanation, explanation set sampler
SimonBele Dec 13, 2023
30259da
a
SimonBele Dec 13, 2023
3e6aa0e
aaa
SimonBele Dec 13, 2023
427cd7e
updated graphGenerator forward logic
amadejp Dec 13, 2023
17861cb
a
SimonBele Dec 13, 2023
c15efd8
no list anymore
BP4769 Dec 13, 2023
e736c31
aAaaaa
SimonBele Dec 13, 2023
20c1abc
aAaaaa
SimonBele Dec 13, 2023
4195c2d
small hotfixes
amadejp Dec 13, 2023
b173463
Update xgnn_explainer.py
amadejp Dec 13, 2023
3576feb
Added graph visualisations
BP4769 Dec 13, 2023
d5089e2
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 13, 2023
1c5e73f
Refactor graph sampling and training process
amadejp Dec 13, 2023
41ab8e4
something works :)
amadejp Dec 13, 2023
b64a03b
Graph visualisations improved
BP4769 Dec 13, 2023
fd6d8ca
merge
BP4769 Dec 13, 2023
82d6852
Code cleanup + example almost completed with multiple graphs of diffe…
BP4769 Dec 13, 2023
4946a90
Refactor xgnn_explainer.py: Clean up code, remove unused functions, a…
amadejp Dec 13, 2023
ff17e40
Delete unnecessary code and unused imports
amadejp Dec 13, 2023
bbc377d
Directory cleanup
amadejp Dec 13, 2023
73d7aaa
Better visualisation
BP4769 Dec 13, 2023
301c499
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 13, 2023
08cc25f
prettier code
SimonBele Dec 13, 2023
0299668
Visualisation finnished
BP4769 Dec 13, 2023
ceaf166
Merge branch 'dev' of https://github.com/SimonBele/pytorch_geometric …
BP4769 Dec 13, 2023
953b28d
Class & Function descriptions reformat
amadejp Dec 13, 2023
58167b2
code & dir cleanup
amadejp Dec 14, 2023
1530062
Added pytest for generative explanation
amadejp Dec 14, 2023
6094353
PEP8 reformat
amadejp Dec 14, 2023
6165fd5
Update CHANGELOG.md
amadejp Dec 14, 2023
d6a24a2
Merge branch 'master' into dev
SimonBele Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,5 @@ examples/**/*.png
examples/**/*.pdf
benchmark/results/
.mypy_cache/

!torch_geometric/data/
!test/data/
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196))
- Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032))
- Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
- Added XGNN implementation for graph explanation to `explain` module

### Changed

Expand Down
4 changes: 4 additions & 0 deletions docs/source/modules/explain.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ Explanations
:show-inheritance:
:members:

.. autoclass:: torch_geometric.explain.GenerativeExplanation
:show-inheritance:
:members:

Explainer Algorithms
--------------------

Expand Down
Binary file added examples/explain/xgnn/mutag_model.pth
Binary file not shown.
539 changes: 539 additions & 0 deletions examples/explain/xgnn/xgnn_explainer.py

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions examples/explain/xgnn/xgnn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from torch_geometric.nn import GCNConv
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch.nn.parameter import Parameter
import math

### GCN to predict graph property
class GCN_Graph(torch.nn.Module):
def __init__(self, input_dim, output_dim, dropout, emb = False):
super(GCN_Graph, self).__init__()

self.input_dim = input_dim
self.output_dim = output_dim

self.dropout = dropout
self.convs = torch.nn.ModuleList([GCNConv(in_channels = input_dim, out_channels = 32),
GCNConv(in_channels = 32, out_channels = 48),
GCNConv(in_channels = 48, out_channels = 64)])

self.pool = global_mean_pool # global averaging to obtain graph representation

self.fc1 = torch.nn.Linear(64, 32)
self.fc2 = torch.nn.Linear(32, output_dim)

self.loss = torch.nn.BCEWithLogitsLoss()
self.reset_parameters()

def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
stdv = 1. / math.sqrt(conv.lin.weight.size(1))
torch.nn.init.uniform_(conv.lin.weight, -stdv, stdv)

conv.bias = Parameter(torch.FloatTensor(conv.out_channels))
conv.bias.data.uniform_(-stdv, stdv)

self.fc1.reset_parameters()
self.fc2.reset_parameters()

def forward(self, data):
# Extract important attributes of our mini-batch
x, edge_index = data.x, data.edge_index

for i in range(len(self.convs)):
x = F.relu(self.convs[i](x, edge_index))
if i < len(self.convs) - 1: # do not apply dropout on last layer
x = F.dropout(x, p=self.dropout, training=self.training)

# Check if 'batch' attribute is present
if hasattr(data, 'batch'):
batch = data.batch
else:
# For a single graph, use a zero tensor as the batch vector,
# where its size equals the number of nodes.
batch = torch.zeros(data.num_nodes, dtype=torch.long, device=x.device)

x = self.pool(x, batch)

x = F.relu(self.fc1(x))
x = F.dropout(x, self.dropout, training=self.training)
x = self.fc2(x)
#x = F.sigmoid(x)
#x = F.softmax(x, dim=1)
return x
125 changes: 125 additions & 0 deletions examples/explain/xgnn/xgnn_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from torch_geometric.data import Batch
from torch_geometric.datasets import TUDataset
import torch
import torch.optim as optim
import numpy as np
from tqdm import trange
import copy
from tqdm.auto import trange
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from xgnn_model import GCN_Graph

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

def create_single_batch(dataset):
data_list = [data for data in dataset]
batched_data = Batch.from_data_list(data_list)
return batched_data

def test(test_dataset, model):
model.eval()
with torch.no_grad():
logits = model(test_dataset).squeeze() # Logits for each graph
probabilities = torch.sigmoid(logits) # Convert logits to probabilities
predictions = probabilities > 0.5 # Convert probabilities to binary predictions
correct = (predictions == test_dataset.y).float() # Assumes labels are 0 or 1
accuracy = correct.mean()

return accuracy


def train(dataset, args, train_indices, val_indices, test_indices):
# Split dataset into training and testing (validation is not used here)
train_dataset = create_single_batch([dataset[i] for i in train_indices]).to(device)
test_dataset = create_single_batch([dataset[i] for i in test_indices]).to(device)

# Model initialization
model = GCN_Graph(args.input_dim, output_dim=1, dropout=args.dropout).to(device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #

# Training loop
losses = []
test_accs = []
best_acc = 0
best_model = None
for epoch in trange(args.epochs, desc="Training", unit="Epoch"):
model.train()
opt.zero_grad()

pred = model(train_dataset)
label = train_dataset.y.float()
loss = model.loss(pred.squeeze(), label)
loss.backward()
opt.step()
total_loss = loss.item()
losses.append(total_loss)

# Test accuracy
if epoch % 10 == 0:
test_acc = test(test_dataset, model)

test_accs.append(test_acc)
if test_acc > best_acc:
best_acc = test_acc
best_model = copy.deepcopy(model)
else:
test_accs.append(test_accs[-1])



return test_accs, losses, best_model, best_acc

class objectview(object):
def __init__(self, d):
self.__dict__ = d

device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = {'device': device,
'dropout': 0.1,
'epochs': 5000,
'input_dim' : 7,
'opt': 'adam',
'opt_restart': 0,
'weight_decay': 1e-4,
'lr': 0.007}

args = objectview(args)

dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
num_graphs = len(dataset)

# Define split percentages
train_percentage = 0.7
val_percentage = 0.0

# Calculate split sizes
train_size = int(num_graphs * train_percentage)
val_size = int(num_graphs * val_percentage)
test_size = num_graphs - train_size - val_size

# Create shuffled indices
indices = np.random.permutation(num_graphs)
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

test_accs, losses, best_model, best_acc = train(dataset, args, train_indices, val_indices, test_indices)

try:
torch.save(best_model.state_dict(), 'examples/explain/xgnn/mutag_model.pth')
print("Model saved successfully.")
except Exception as e:
print("Error saving model:", e)

print("Maximum test set accuracy: {0}".format(max(test_accs)))
print("Minimum loss: {0}".format(min(losses)))

plt.title(dataset.name)
plt.plot(losses, label="training loss")
plt.plot(test_accs, label="test accuracy")
plt.legend()
plt.show()
40 changes: 40 additions & 0 deletions test/explain/algorithm/test_xgnn_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import torch
from torch_geometric.explain import XGNNExplainer, GenerativeExplanation
from abc import abstractmethod

# Mock subclass of XGNNExplainer for testing
class MockXGNNExplainer(XGNNExplainer):
def train_generative_model(self, model, for_class, **kwargs):
return None

@pytest.fixture
def model():
return torch.nn.Linear(3, 2)

def test_xgnn_explainer_initialization():
explainer = MockXGNNExplainer(epochs=200, lr=0.005)
assert explainer.epochs == 200
assert explainer.lr == 0.005

def test_xgnn_explainer_forward(model):
explainer = MockXGNNExplainer()
x = torch.rand(10, 3)
edge_index = torch.randint(0, 10, (2, 30))
target = torch.randint(0, 2, (10,))

explanation = explainer(model, x, edge_index, target=target, for_class=1)
assert isinstance(explanation, GenerativeExplanation)

# Test ValueError for missing 'for_class' argument
with pytest.raises(ValueError):
explainer(model, x, edge_index, target=target)

def test_xgnn_explainer_abstract_method():
class IncompleteExplainer(XGNNExplainer):
pass
explainer = IncompleteExplainer()

# Ensure that instantiation fails due to the unimplemented abstract method
with pytest.raises(NotImplementedError):
explainer.train_generative_model(None, for_class=0)
3 changes: 2 additions & 1 deletion test/explain/test_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_forward(data, target, explanation_type):
assert isinstance(explanation, Explanation)
assert 'x' in explanation
assert 'edge_index' in explanation
assert 'target' in explanation
if explanation_type != ExplanationType.generative: # target is not used for generative explanation
assert 'target' in explanation
assert 'node_mask' in explanation.available_explanations
assert explanation.node_mask.size() == data.x.size()

Expand Down
96 changes: 96 additions & 0 deletions test/explain/test_generative_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.explain import Explainer, XGNNExplainer


# Mock model for testing
class MLP_Graph(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP_Graph, self).__init__()
self.fc1 = nn.Linear(input_dim, 8)
self.fc2 = nn.Linear(8, output_dim)

def forward(self, x):
# Flatten the graph representation
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x


# Mock explainer algorithm
class ExampleExplainer(XGNNExplainer):
def __init__(self, epochs, lr, candidate_set, validity_args):
super(ExampleExplainer, self).__init__()
self.epochs = epochs
self.lr = lr
self.candidate_set = candidate_set
self.validity_args = validity_args

def train_generative_model(self, model_to_explain, for_class):
# For simplicity, this example does not include actual training logic

for epoch in range(self.epochs):
# Placeholder for training logic
pass

return Data()


# Mock graph generator
class ExampleGraphGenerator():
def __init__(self, graph):
self.graph = graph

def sample(self):
# has to return a list of Data objects
return [Data(), Data(), Data()]


# Fixture for setting up XGNNExplainer
@pytest.fixture
def setup_xgnn_explainer():
mock_model = MLP_Graph(input_dim=7, output_dim=1)

explainer = Explainer(
model = mock_model,
algorithm = ExampleExplainer(epochs = 10,
lr = 0.01,
candidate_set={'C': torch.tensor([1, 0, 0, 0, 0, 0, 0])}, # Simplified candidate set
validity_args={'C': 4}),
explanation_type = 'generative',
model_config = dict(
mode = 'binary_classification',
task_level = 'graph',
return_type = 'probs',
)
)

class_index = 1
x = torch.tensor([])
edge_index = torch.tensor([[], []])

return explainer, x, edge_index, class_index


# Test output of XGNNExplainer
def test_explainer_output(setup_xgnn_explainer):
explainer, x, edge_index, class_index = setup_xgnn_explainer
explanation = explainer(x, edge_index, for_class=class_index)

# Check if explanation is of type Data
assert isinstance(explanation, Data), "Explanation is not of type Data"


# Test output of ExampleExplainer
def test_sampler_output():
sampled_graphs = ExampleGraphGenerator(Data()).sample()

# Check if sampled_graphs is a list of Data objects
assert isinstance(sampled_graphs, list), "Sampled graphs is not a list"
assert all(isinstance(graph, Data) for graph in sampled_graphs), "Sampled graphs is not a list of Data objects"


4 changes: 3 additions & 1 deletion torch_geometric/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .config import ExplainerConfig, ModelConfig, ThresholdConfig
from .explanation import Explanation, HeteroExplanation
from .explanation import Explanation, HeteroExplanation, GenerativeExplanation, ExplanationSetSampler
from .algorithm import * # noqa
from .explainer import Explainer
from .metric import * # noqa
Expand All @@ -11,4 +11,6 @@
'Explanation',
'HeteroExplanation',
'Explainer',
'GenerativeExplanation',
'ExplanationSetSampler',
]
2 changes: 2 additions & 0 deletions torch_geometric/explain/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .pg_explainer import PGExplainer
from .attention_explainer import AttentionExplainer
from .graphmask_explainer import GraphMaskExplainer
from .xgnn_explainer import XGNNExplainer

__all__ = classes = [
'ExplainerAlgorithm',
Expand All @@ -14,4 +15,5 @@
'PGExplainer',
'AttentionExplainer',
'GraphMaskExplainer',
'XGNNExplainer'
]
Loading