diff --git a/.gitignore b/.gitignore index c38a9ed6d2d1..70acc22727f5 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,5 @@ examples/**/*.png examples/**/*.pdf benchmark/results/ .mypy_cache/ - !torch_geometric/data/ !test/data/ diff --git a/CHANGELOG.md b/CHANGELOG.md index ab342b3e598a..7a35a1af9ecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196)) - Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032)) - Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) +- Added XGNN implementation for graph explanation to `explain` module ### Changed diff --git a/docs/source/modules/explain.rst b/docs/source/modules/explain.rst index 27128c2a137d..9b8a10e4dedc 100644 --- a/docs/source/modules/explain.rst +++ b/docs/source/modules/explain.rst @@ -55,6 +55,10 @@ Explanations :show-inheritance: :members: +.. autoclass:: torch_geometric.explain.GenerativeExplanation + :show-inheritance: + :members: + Explainer Algorithms -------------------- diff --git a/examples/explain/xgnn/mutag_model.pth b/examples/explain/xgnn/mutag_model.pth new file mode 100644 index 000000000000..d8c54c99e4d5 Binary files /dev/null and b/examples/explain/xgnn/mutag_model.pth differ diff --git a/examples/explain/xgnn/xgnn_explainer.py b/examples/explain/xgnn/xgnn_explainer.py new file mode 100644 index 000000000000..60b9e6019362 --- /dev/null +++ b/examples/explain/xgnn/xgnn_explainer.py @@ -0,0 +1,539 @@ +import copy +import os.path as osp +import os + +import torch +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.explain import Explainer, XGNNExplainer, ExplanationSetSampler +from torch_geometric.nn import GCNConv +from torch_geometric.datasets import TUDataset +from torch_geometric.utils import to_networkx +import networkx as nx +import matplotlib.pyplot as plt +from tqdm import trange +from xgnn_model import GCN_Graph + +def masked_softmax(vector, mask): + """ + Apply softmax to only selected elements of the vector, as indicated by the mask. + The output will be a probability distribution where unselected elements are 0. + + Args: + vector (torch.Tensor): Input vector. + mask (torch.Tensor): Mask indicating which elements to softmax. + + Returns: + torch.Tensor: Softmaxed vector. + """ + mask = mask.bool() + masked_vector = vector.masked_fill(~mask, float('-inf')) + softmax_result = F.softmax(masked_vector, dim=0) + + return softmax_result + +class GraphGenerator(torch.nn.Module, ExplanationSetSampler): + """ + Graph generator that generates a new graph state from a given graph state. + + Inherits: + torch.nn.Module: Base class for all neural network modules. + ExplanationSetSampler: Base class for sampling from an explanation set. + + Args: + candidate_set (dict): Set of candidate nodes for graph generation. + dropout (float): Dropout rate for regularization. + initial_node_type (str, optional): Initial node type for graph initialization. + """ + def __init__(self, candidate_set, dropout, initial_node_type = None): + super(GraphGenerator, self).__init__() + # TODO: Check + self.candidate_set = candidate_set + self.initial_node_type = initial_node_type + self.dropout = dropout + num_node_features = len(next(iter(self.candidate_set.values()))) + self.gcn_layers = torch.nn.ModuleList([ + GCNConv(num_node_features, 16), + GCNConv(16, 24), + GCNConv(24, 32), + ]) + + self.mlp_start_node = torch.nn.Sequential( + torch.nn.Linear(32, 16), + torch.nn.Linear(16, 1), + torch.nn.ReLU6() + ) + self.mlp_end_node = torch.nn.Sequential( + torch.nn.Linear(32, 24), + torch.nn.Linear(24, 1), + torch.nn.ReLU6() + ) + + + def initialize_graph_state(self, graph_state): + r""" + Initializes the graph state with a single node. + + Args: + graph_state (torch_geometric.data.Data): The graph state to initialize. + + Returns: + torch_geometric.data.Data: The initialized graph state. + """ + if self.initial_node_type is None: + keys = list(self.candidate_set.keys()) + self.initial_node_type = keys[torch.randint(len(keys), (1,)).item()] + + feature = candidate_set[self.initial_node_type].unsqueeze(0) + edge_index = torch.tensor([], dtype=torch.long).view(2, -1) + node_type = [self.initial_node_type,] + # update graph state + graph_state.x = feature + graph_state.edge_index = edge_index + graph_state.node_type = node_type + + def forward(self, graph_state): + """ + Generates a new graph state from the given graph state. + + Args: + graph_state (torch_geometric.data.Data): The graph state to generate a new graph state from. + + Returns: + ((torch.Tensor, torch.Tensor), (torch.Tensor, torch.Tensor)): The logits and one hot encodings for the start and end node. + torch_geometric.data.Data: The new graph state. + """ + graph_state = copy.deepcopy(graph_state) + if graph_state.x.shape[0] == 0: + self.initialize_graph_state(graph_state) # initialize graph state if it is empty + + # contatenate graph_state features with candidate_set features + node_features_graph = graph_state.x.detach().clone() + candidate_features = torch.stack(list(self.candidate_set.values())) + node_features = torch.cat((node_features_graph, candidate_features), dim=0).float() + node_edges = graph_state.edge_index.detach().clone() + + # compute node encodings with GCN layers + node_encodings = node_features + for gcn_layer in self.gcn_layers: + node_encodings = F.relu6(gcn_layer(node_encodings, node_edges)) + node_encodings = F.dropout(node_encodings, self.dropout, training=self.training) + + # get start node probabilities and mask out candidates + start_node_logits = self.mlp_start_node(node_encodings) + + candidate_set_mask = torch.ones_like(start_node_logits) + candidate_set_indices = torch.arange(node_features_graph.shape[0], node_encodings.shape[0]) + # set candidate set probabilities to 0 + candidate_set_mask[candidate_set_indices] = 0 + start_node_probs = masked_softmax(start_node_logits, candidate_set_mask).squeeze() + + # sample start node + p_start = torch.distributions.Categorical(start_node_probs) + start_node = p_start.sample() + + # get end node probabilities and mask out start node + end_node_logits = self.mlp_end_node(node_encodings) + start_node_mask = torch.ones_like(end_node_logits) + start_node_mask[start_node] = 0 + end_node_probs = masked_softmax(end_node_logits, start_node_mask).squeeze() + + # sample end node + end_node = torch.distributions.Categorical(end_node_probs).sample() + num_nodes_graph = graph_state.x.shape[0] + if end_node >= num_nodes_graph: + # add new node features to graph state + graph_state.x = torch.cat([graph_state.x, node_features[end_node].unsqueeze(0).float()], dim=0) + graph_state.node_type.append(list(self.candidate_set.keys())[end_node - num_nodes_graph]) + new_edge = torch.tensor([[start_node], [num_nodes_graph]]) + else: + new_edge = torch.tensor([[start_node], [end_node]]) + graph_state.edge_index = torch.cat((graph_state.edge_index, new_edge), dim=1) + + # one hot encoding of start and end node + start_node_one_hot = torch.eye(start_node_probs.shape[0])[start_node] + end_node_one_hot = torch.eye(end_node_probs.shape[0])[end_node] + + return ((start_node_logits.squeeze(), start_node_one_hot), (end_node_logits.squeeze(), end_node_one_hot)), graph_state + + def sample(self, num_samples: int, **kwargs): + """ + Samples a number of graphs from the generator. + + Args: + num_samples (int): The number of graphs to sample. + **kwargs: Additional keyword arguments. + + Raises: + ValueError: If neither num_nodes nor max_steps is specified. + + Returns: + List[torch_geometric.data.Data]: The list of sampled graphs. + """ + # extract num_nodes and max_steps from kwargs or set them to None + num_nodes = kwargs.get('num_nodes', None) + max_steps = kwargs.get('max_steps', None) + + # check that either num_nodes or max_steps is not None + if num_nodes is None and max_steps is None: + raise ValueError("Either num_nodes or max_steps must be specified") + + # create empty graph state + empty_graph = Data(x=torch.tensor([]), edge_index=torch.tensor([]), node_type=[]) + current_graph_state = copy.deepcopy(empty_graph) + + # sample graphs + sampled_graphs = [] + + max_steps_reached = False + num_nodes_reached = False + self.eval() + for _ in range(num_samples): + step = 0 + while not max_steps_reached and not num_nodes_reached: + G = copy.deepcopy(current_graph_state) + ((p_start, a_start), (p_end, a_end)), current_graph_state = self.forward(G) + step += 1 + # check if max_steps is reached (if max_steps is None, this will never be True) + max_steps_reached = max_steps is not None and step >= max_steps + # check if num_nodes is reached + num_nodes_reached = num_nodes is not None and current_graph_state.x.shape[0] > num_nodes + # add sampled graph to list + sampled_graphs.append(G) + # reset current graph state + current_graph_state = copy.deepcopy(empty_graph) + # reset max_steps_reached and num_nodes_reached + max_steps_reached = False + num_nodes_reached = False + return sampled_graphs + +class RLGenExplainer(XGNNExplainer): + """ RL-based generator for graph explanations using XGNN. + + Inherits: + XGNNExplainer: Base class for explanation generation using XGNN method. + + Args: + epochs (int): Number of training epochs. + lr (float): Learning rate. + candidate_set (dict): Set of candidate nodes for graph generation. + validity_args (dict): Arguments for graph validity check. + initial_node_type (str, optional): Initial node type for graph initialization. + """ + def __init__(self, epochs, lr, candidate_set, validity_args, initial_node_type = None): + super(RLGenExplainer, self).__init__(epochs, lr) + self.candidate_set = candidate_set + self.graph_generator = GraphGenerator(candidate_set, 0.1, initial_node_type) + self.max_steps = 10 + self.lambda_1 = 1 + self.lambda_2 = 1 + self.num_classes = 2 + self.validity_args = validity_args + + + def reward_tf(self, pre_trained_gnn, graph_state, num_classes): + r""" + Computes the reward for the given graph state by evaluating the graph with the pre-trained GNN. + + Args: + pre_trained_gnn (torch.nn.Module): The pre-trained GNN to use for computing the reward. + graph_state (torch_geometric.data.Data): The graph state to compute the reward for. + num_classes (int): The number of classes in the dataset. + + Returns: + torch.Tensor: The reward for the given graph state. + """ + with torch.no_grad(): + gnn_output = pre_trained_gnn(graph_state) + probability_of_target_class = torch.sigmoid(gnn_output).squeeze() + + return probability_of_target_class - 1 / num_classes + + + def rollout_reward(self, intermediate_graph_state, pre_trained_gnn, target_class, num_classes, num_rollouts=5): + r""" + Computes the rollout reward for the given graph state. + + Args: + intermediate_graph_state (torch_geometric.data.Data): The intermediate graph state to compute the rollout reward for. + pre_trained_gnn (torch.nn.Module): The pre-trained GNN to use for computing the reward. + target_class (int): The target class to explain. + num_classes (int): The number of classes in the dataset. + num_rollouts (int): The number of rollouts to perform. + + Returns: + float: The average rollout reward for the given graph state. + """ + final_rewards = [] + for _ in range(num_rollouts): + # make copy of intermediate graph state + intermediate_graph_state_copy = copy.deepcopy(intermediate_graph_state) + _, final_graph = self.graph_generator(intermediate_graph_state_copy) + # Evaluate the final graph + reward = self.reward_tf(pre_trained_gnn, final_graph, num_classes) + final_rewards.append(reward) + + del intermediate_graph_state_copy # delete intermediate graph state copy + + average_final_reward = sum(final_rewards) / len(final_rewards) + return average_final_reward + + + def evaluate_graph_validity(self, graph_state): + r""" + Evaluates the validity of the given graph state. Dataset specific graph rules are implemented here. + + Args: + graph_state (torch_geometric.data.Data): The graph state to evaluate. + + Returns: + int: The graph validity score. 0 if the graph is valid, -1 otherwise. + """ + # For mutag, node degrees cannot exceed valency + degrees = torch.bincount(graph_state.edge_index.flatten(), minlength=graph_state.num_nodes) + node_type_valencies = torch.tensor([self.validity_args[type_] for type_ in graph_state.node_type]) + if torch.any(degrees > node_type_valencies): + return -1 + return 0 + + + def calculate_reward(self, graph_state, pre_trained_gnn, target_class, num_classes): + r""" + Calculates the reward for the given graph state. + + Args: + graph_state (torch_geometric.data.Data): The graph state to compute the reward for. + pre_trained_gnn (torch.nn.Module): The pre-trained GNN to use for computing the reward. + target_class (int): The target class to explain. + num_classes (int): The number of classes in the dataset. + + Returns: + torch.Tensor: The final reward for the given graph state. + """ + intermediate_reward = self.reward_tf(pre_trained_gnn, graph_state, num_classes) + final_graph_reward = self.rollout_reward(graph_state, pre_trained_gnn, target_class, num_classes) + # Compute graph validity score (R_tr), based on the specific graph rules of the dataset + graph_validity_score = self.evaluate_graph_validity(graph_state) + reward = intermediate_reward + self.lambda_1 * final_graph_reward + self.lambda_2 * graph_validity_score + return reward + + def train_generative_model(self, model_to_explain, for_class): + """ + Trains the generative model for the given number of epochs. We use RL approach to train the generative model. + + Args: + model_to_explain (_type_): The model to explain. + for_class (_type_): The class to explain. + + Returns: + torch_geometric.data.Data: The trained generative model. + """ + optimizer = torch.optim.Adam(self.graph_generator.parameters(), lr = self.lr, betas=(0.9, 0.99)) + losses = [] + for epoch in trange(self.epochs): + + # create empty graph state + empty_graph = Data(x=torch.tensor([]), edge_index=torch.tensor([]), node_type=[]) + current_graph_state = empty_graph + + for step in range(self.max_steps): + model.train() + optimizer.zero_grad() + + new_graph_state = copy.deepcopy(current_graph_state) + ((p_start, a_start), (p_end, a_end)), new_graph_state = self.graph_generator(new_graph_state) + reward = self.calculate_reward(new_graph_state, model_to_explain, for_class, self.num_classes) + + LCE_start = F.cross_entropy(p_start, a_start) + LCE_end = F.cross_entropy(p_end, a_end) + loss = -reward * (LCE_start + LCE_end) + + loss.backward() + optimizer.step() + + if reward > 0: + current_graph_state = new_graph_state + + losses.append(loss.item()) + + return self.graph_generator + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# Load pretrained model + +args = {'device': device, + 'dropout': 0.1, + 'epochs': 1000, + 'input_dim' : 7, + 'opt': 'adam', + 'opt_scheduler': 'none', + 'opt_restart': 0, + 'weight_decay': 5e-5, + 'lr': 0.001} +class objectview(object): + def __init__(self, d): + self.__dict__ = d + +args = objectview(args) + +model = GCN_Graph(args.input_dim, output_dim=1, dropout=args.dropout).to(device) + +# Assume 'model_to_freeze' is the model you want to freeze +for param in model.parameters(): + param.requires_grad = False + +# depending on os change path +path = "examples/explain/xgnn/mutag_model.pth" +if os.name == 'nt': + path = "examples\\explain\\xgnn\\mutag_model.pth" + +model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) +model.to(device) + +# Train generative model + +# extract features for the candidate set +dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG') +all_features = torch.cat([data.x for data in dataset], dim=0) + +# graph validity check for number of edges depending on atom +max_valency = {'C': 4, 'N': 5, 'O': 2, 'F': 1, 'I': 7, 'Cl': 7, 'Br': 5} + +# node type map that maps node type to a one hot vector encoding torch tensor +candidate_set = {'C': torch.tensor([1, 0, 0, 0, 0, 0, 0]), + 'N': torch.tensor([0, 1, 0, 0, 0, 0, 0]), + 'O': torch.tensor([0, 0, 1, 0, 0, 0, 0]), + 'F': torch.tensor([0, 0, 0, 1, 0, 0, 0]), + 'I': torch.tensor([0, 0, 0, 0, 1, 0, 0]), + 'Cl': torch.tensor([0, 0, 0, 0, 0, 1, 0]), + 'Br': torch.tensor([0, 0, 0, 0, 0, 0, 1])} + +explainer = Explainer( + model = model, + algorithm = RLGenExplainer(epochs = 100, + lr = 0.01, + candidate_set=candidate_set, + validity_args = max_valency, + initial_node_type = 'C'), + explanation_type = 'generative', + model_config = dict( + mode = 'binary_classification', + task_level = 'graph', + return_type = 'probs', + ), +) + +# choose target class +class_index = 1 + +# empty x and edge_index tensors, since we are not explaining one specific graph but existing explainer requires these +x = torch.tensor([]) +edge_index = torch.tensor([[], []]) + +explanation_mutagenic = explainer(x, edge_index, for_class=class_index) +explanation_set_mutagenic = explanation_mutagenic.explanation_set + +########################### +### SAMPLE SINGLE GRAPH ### +########################### +sampled_graph = explanation_set_mutagenic.sample(num_samples=1, num_nodes=10)[0] +# visualize sampled graph with DEFAULT inbuild method +explanation_mutagenic.visualize_explanation_graph(sampled_graph, path='examples/explain/xgnn/sample_graph', backend='networkx') + + +############################## +### SAMPLE MULTIPLE GRAPHS ### +############################## +node_color_dict = {'C': '#0173B2', + 'N': '#DE8F05', + 'O': '#029E73', + 'F': '#D55E00', + 'I': '#CC78BC', + 'Cl': '#CA9161', + 'Br': '#FBAFE4'} # colorblind palette + +# generate graphs of multiple sizes for mutagenic class + +sampled_graphs = [] +scores = [] +for i in range(3, 11): + sampled_graphs_i = explanation_set_mutagenic.sample(num_samples=5, num_nodes=i) + scores_i = [] + for sampled_graph in sampled_graphs_i: + score = model(sampled_graph) + probability_score = torch.sigmoid(score) + scores_i.append(probability_score.item()) + + # choose graph with best score + best_graph_index = scores_i.index(max(scores_i)) + scores.append(scores_i[best_graph_index]) + sampled_graphs.append(sampled_graphs_i[best_graph_index]) + +# visualize sampled graphs with CUSTOM method +fig, axes = plt.subplots(2, 4, figsize=(16, 8)) +axes = axes.flatten() + +for i, (sampled_graph, score) in enumerate(zip(sampled_graphs, scores)): + G = to_networkx(sampled_graph, to_undirected=True) + node_type_dict = dict(enumerate(sampled_graph.node_type)) + nx.set_node_attributes(G, node_type_dict, 'node_type') + labels = nx.get_node_attributes(G, 'node_type') + node_color = [node_color_dict[key] for key in nx.get_node_attributes(G, 'node_type').values()] + pos = nx.spring_layout(G) + axes[i].set_title("Max_num_nodes = {}\nprobability = {:.5f}".format(i+3, score), loc="center", fontsize=10) + # set subtitle to score + axes + nx.draw(G, pos=pos, ax=axes[i], cmap=plt.get_cmap('coolwarm'), node_color=node_color, labels=labels, font_color='white') + +# plt.savefig('examples/explain/xgnn/sample_graphs_custom.png') +fig.suptitle("Sampled explanation graphs for mutagenic class", fontsize=16) +plt.show() + +plt.close() + + +# generate graphs of multiple sizes for non-mutagenic class +class_index = 0 + +explanation_non_mutagenic = explainer(x, edge_index, for_class=class_index) +explanation_set_non_mutagenic = explanation_non_mutagenic.explanation_set + +sampled_graphs = [] +scores = [] +for i in range(3, 11): + sampled_graphs_i = explanation_set_non_mutagenic.sample(num_samples=5, num_nodes=i) + scores_i = [] + for sampled_graph in sampled_graphs_i: + score = model(sampled_graph) + probability_score = 1 - torch.sigmoid(score) + scores_i.append(probability_score.item()) + + # choose graph with best score + best_graph_index = scores_i.index(max(scores_i)) + scores.append(scores_i[best_graph_index]) + sampled_graphs.append(sampled_graphs_i[best_graph_index]) + +# visualize sampled graphs with CUSTOM method +fig, axes = plt.subplots(2, 4, figsize=(16, 8)) +axes = axes.flatten() + +for i, (sampled_graph, score) in enumerate(zip(sampled_graphs, scores)): + G = to_networkx(sampled_graph, to_undirected=True) + node_type_dict = dict(enumerate(sampled_graph.node_type)) + nx.set_node_attributes(G, node_type_dict, 'node_type') + labels = nx.get_node_attributes(G, 'node_type') + node_color = [node_color_dict[key] for key in nx.get_node_attributes(G, 'node_type').values()] + pos = nx.spring_layout(G) + axes[i].set_title("Max_num_nodes = {}\nprobability = {:.5f}".format(i+3, score), loc="center", fontsize=10) + # plot graph with scroe as title + nx.draw(G, pos=pos, ax=axes[i], cmap=plt.get_cmap('coolwarm'), node_color=node_color, labels=labels, font_color='white') + +# plt.savefig('examples/explain/xgnn/sample_graphs_custom.png') +fig.suptitle("Sampled explanation graphs for non-mutagenic class", fontsize=16) +plt.show() + + + + + diff --git a/examples/explain/xgnn/xgnn_model.py b/examples/explain/xgnn/xgnn_model.py new file mode 100644 index 000000000000..43aaba09318a --- /dev/null +++ b/examples/explain/xgnn/xgnn_model.py @@ -0,0 +1,65 @@ +from torch_geometric.nn import GCNConv +import torch +import torch.nn.functional as F +from torch_geometric.nn import global_mean_pool +from torch.nn.parameter import Parameter +import math + +### GCN to predict graph property +class GCN_Graph(torch.nn.Module): + def __init__(self, input_dim, output_dim, dropout, emb = False): + super(GCN_Graph, self).__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + + self.dropout = dropout + self.convs = torch.nn.ModuleList([GCNConv(in_channels = input_dim, out_channels = 32), + GCNConv(in_channels = 32, out_channels = 48), + GCNConv(in_channels = 48, out_channels = 64)]) + + self.pool = global_mean_pool # global averaging to obtain graph representation + + self.fc1 = torch.nn.Linear(64, 32) + self.fc2 = torch.nn.Linear(32, output_dim) + + self.loss = torch.nn.BCEWithLogitsLoss() + self.reset_parameters() + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + stdv = 1. / math.sqrt(conv.lin.weight.size(1)) + torch.nn.init.uniform_(conv.lin.weight, -stdv, stdv) + + conv.bias = Parameter(torch.FloatTensor(conv.out_channels)) + conv.bias.data.uniform_(-stdv, stdv) + + self.fc1.reset_parameters() + self.fc2.reset_parameters() + + def forward(self, data): + # Extract important attributes of our mini-batch + x, edge_index = data.x, data.edge_index + + for i in range(len(self.convs)): + x = F.relu(self.convs[i](x, edge_index)) + if i < len(self.convs) - 1: # do not apply dropout on last layer + x = F.dropout(x, p=self.dropout, training=self.training) + + # Check if 'batch' attribute is present + if hasattr(data, 'batch'): + batch = data.batch + else: + # For a single graph, use a zero tensor as the batch vector, + # where its size equals the number of nodes. + batch = torch.zeros(data.num_nodes, dtype=torch.long, device=x.device) + + x = self.pool(x, batch) + + x = F.relu(self.fc1(x)) + x = F.dropout(x, self.dropout, training=self.training) + x = self.fc2(x) + #x = F.sigmoid(x) + #x = F.softmax(x, dim=1) + return x \ No newline at end of file diff --git a/examples/explain/xgnn/xgnn_train.py b/examples/explain/xgnn/xgnn_train.py new file mode 100644 index 000000000000..94c883a9e855 --- /dev/null +++ b/examples/explain/xgnn/xgnn_train.py @@ -0,0 +1,125 @@ +from torch_geometric.data import Batch +from torch_geometric.datasets import TUDataset +import torch +import torch.optim as optim +import numpy as np +from tqdm import trange +import copy +from tqdm.auto import trange +import matplotlib.pyplot as plt +from torch.optim.lr_scheduler import ReduceLROnPlateau +from xgnn_model import GCN_Graph + +seed = 42 +np.random.seed(seed) +torch.manual_seed(seed) + +def create_single_batch(dataset): + data_list = [data for data in dataset] + batched_data = Batch.from_data_list(data_list) + return batched_data + +def test(test_dataset, model): + model.eval() + with torch.no_grad(): + logits = model(test_dataset).squeeze() # Logits for each graph + probabilities = torch.sigmoid(logits) # Convert logits to probabilities + predictions = probabilities > 0.5 # Convert probabilities to binary predictions + correct = (predictions == test_dataset.y).float() # Assumes labels are 0 or 1 + accuracy = correct.mean() + + return accuracy + + +def train(dataset, args, train_indices, val_indices, test_indices): + # Split dataset into training and testing (validation is not used here) + train_dataset = create_single_batch([dataset[i] for i in train_indices]).to(device) + test_dataset = create_single_batch([dataset[i] for i in test_indices]).to(device) + + # Model initialization + model = GCN_Graph(args.input_dim, output_dim=1, dropout=args.dropout).to(device) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # + + # Training loop + losses = [] + test_accs = [] + best_acc = 0 + best_model = None + for epoch in trange(args.epochs, desc="Training", unit="Epoch"): + model.train() + opt.zero_grad() + + pred = model(train_dataset) + label = train_dataset.y.float() + loss = model.loss(pred.squeeze(), label) + loss.backward() + opt.step() + total_loss = loss.item() + losses.append(total_loss) + + # Test accuracy + if epoch % 10 == 0: + test_acc = test(test_dataset, model) + + test_accs.append(test_acc) + if test_acc > best_acc: + best_acc = test_acc + best_model = copy.deepcopy(model) + else: + test_accs.append(test_accs[-1]) + + + + return test_accs, losses, best_model, best_acc + +class objectview(object): + def __init__(self, d): + self.__dict__ = d + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +args = {'device': device, + 'dropout': 0.1, + 'epochs': 5000, + 'input_dim' : 7, + 'opt': 'adam', + 'opt_restart': 0, + 'weight_decay': 1e-4, + 'lr': 0.007} + +args = objectview(args) + +dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG') +num_graphs = len(dataset) + +# Define split percentages +train_percentage = 0.7 +val_percentage = 0.0 + +# Calculate split sizes +train_size = int(num_graphs * train_percentage) +val_size = int(num_graphs * val_percentage) +test_size = num_graphs - train_size - val_size + +# Create shuffled indices +indices = np.random.permutation(num_graphs) +train_indices = indices[:train_size] +val_indices = indices[train_size:train_size + val_size] +test_indices = indices[train_size + val_size:] + +test_accs, losses, best_model, best_acc = train(dataset, args, train_indices, val_indices, test_indices) + +try: + torch.save(best_model.state_dict(), 'examples/explain/xgnn/mutag_model.pth') + print("Model saved successfully.") +except Exception as e: + print("Error saving model:", e) + +print("Maximum test set accuracy: {0}".format(max(test_accs))) +print("Minimum loss: {0}".format(min(losses))) + +plt.title(dataset.name) +plt.plot(losses, label="training loss") +plt.plot(test_accs, label="test accuracy") +plt.legend() +plt.show() \ No newline at end of file diff --git a/test/explain/algorithm/test_xgnn_explainer.py b/test/explain/algorithm/test_xgnn_explainer.py new file mode 100644 index 000000000000..a2a8f0bfce79 --- /dev/null +++ b/test/explain/algorithm/test_xgnn_explainer.py @@ -0,0 +1,40 @@ +import pytest +import torch +from torch_geometric.explain import XGNNExplainer, GenerativeExplanation +from abc import abstractmethod + +# Mock subclass of XGNNExplainer for testing +class MockXGNNExplainer(XGNNExplainer): + def train_generative_model(self, model, for_class, **kwargs): + return None + +@pytest.fixture +def model(): + return torch.nn.Linear(3, 2) + +def test_xgnn_explainer_initialization(): + explainer = MockXGNNExplainer(epochs=200, lr=0.005) + assert explainer.epochs == 200 + assert explainer.lr == 0.005 + +def test_xgnn_explainer_forward(model): + explainer = MockXGNNExplainer() + x = torch.rand(10, 3) + edge_index = torch.randint(0, 10, (2, 30)) + target = torch.randint(0, 2, (10,)) + + explanation = explainer(model, x, edge_index, target=target, for_class=1) + assert isinstance(explanation, GenerativeExplanation) + + # Test ValueError for missing 'for_class' argument + with pytest.raises(ValueError): + explainer(model, x, edge_index, target=target) + +def test_xgnn_explainer_abstract_method(): + class IncompleteExplainer(XGNNExplainer): + pass + explainer = IncompleteExplainer() + + # Ensure that instantiation fails due to the unimplemented abstract method + with pytest.raises(NotImplementedError): + explainer.train_generative_model(None, for_class=0) diff --git a/test/explain/test_explainer.py b/test/explain/test_explainer.py index c8eaf7cb3909..07aeed1effa8 100644 --- a/test/explain/test_explainer.py +++ b/test/explain/test_explainer.py @@ -60,7 +60,8 @@ def test_forward(data, target, explanation_type): assert isinstance(explanation, Explanation) assert 'x' in explanation assert 'edge_index' in explanation - assert 'target' in explanation + if explanation_type != ExplanationType.generative: # target is not used for generative explanation + assert 'target' in explanation assert 'node_mask' in explanation.available_explanations assert explanation.node_mask.size() == data.x.size() diff --git a/test/explain/test_generative_explanation.py b/test/explain/test_generative_explanation.py new file mode 100644 index 000000000000..f99229e7cb52 --- /dev/null +++ b/test/explain/test_generative_explanation.py @@ -0,0 +1,96 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.explain import Explainer, XGNNExplainer + + +# Mock model for testing +class MLP_Graph(nn.Module): + def __init__(self, input_dim, output_dim): + super(MLP_Graph, self).__init__() + self.fc1 = nn.Linear(input_dim, 8) + self.fc2 = nn.Linear(8, output_dim) + + def forward(self, x): + # Flatten the graph representation + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +# Mock explainer algorithm +class ExampleExplainer(XGNNExplainer): + def __init__(self, epochs, lr, candidate_set, validity_args): + super(ExampleExplainer, self).__init__() + self.epochs = epochs + self.lr = lr + self.candidate_set = candidate_set + self.validity_args = validity_args + + def train_generative_model(self, model_to_explain, for_class): + # For simplicity, this example does not include actual training logic + + for epoch in range(self.epochs): + # Placeholder for training logic + pass + + return Data() + + +# Mock graph generator +class ExampleGraphGenerator(): + def __init__(self, graph): + self.graph = graph + + def sample(self): + # has to return a list of Data objects + return [Data(), Data(), Data()] + + +# Fixture for setting up XGNNExplainer +@pytest.fixture +def setup_xgnn_explainer(): + mock_model = MLP_Graph(input_dim=7, output_dim=1) + + explainer = Explainer( + model = mock_model, + algorithm = ExampleExplainer(epochs = 10, + lr = 0.01, + candidate_set={'C': torch.tensor([1, 0, 0, 0, 0, 0, 0])}, # Simplified candidate set + validity_args={'C': 4}), + explanation_type = 'generative', + model_config = dict( + mode = 'binary_classification', + task_level = 'graph', + return_type = 'probs', + ) + ) + + class_index = 1 + x = torch.tensor([]) + edge_index = torch.tensor([[], []]) + + return explainer, x, edge_index, class_index + + +# Test output of XGNNExplainer +def test_explainer_output(setup_xgnn_explainer): + explainer, x, edge_index, class_index = setup_xgnn_explainer + explanation = explainer(x, edge_index, for_class=class_index) + + # Check if explanation is of type Data + assert isinstance(explanation, Data), "Explanation is not of type Data" + + +# Test output of ExampleExplainer +def test_sampler_output(): + sampled_graphs = ExampleGraphGenerator(Data()).sample() + + # Check if sampled_graphs is a list of Data objects + assert isinstance(sampled_graphs, list), "Sampled graphs is not a list" + assert all(isinstance(graph, Data) for graph in sampled_graphs), "Sampled graphs is not a list of Data objects" + + \ No newline at end of file diff --git a/torch_geometric/explain/__init__.py b/torch_geometric/explain/__init__.py index bca9b5d51f9c..e97c26e8b78f 100644 --- a/torch_geometric/explain/__init__.py +++ b/torch_geometric/explain/__init__.py @@ -1,5 +1,5 @@ from .config import ExplainerConfig, ModelConfig, ThresholdConfig -from .explanation import Explanation, HeteroExplanation +from .explanation import Explanation, HeteroExplanation, GenerativeExplanation, ExplanationSetSampler from .algorithm import * # noqa from .explainer import Explainer from .metric import * # noqa @@ -11,4 +11,6 @@ 'Explanation', 'HeteroExplanation', 'Explainer', + 'GenerativeExplanation', + 'ExplanationSetSampler', ] diff --git a/torch_geometric/explain/algorithm/__init__.py b/torch_geometric/explain/algorithm/__init__.py index a462a5777edb..e1c86719a254 100644 --- a/torch_geometric/explain/algorithm/__init__.py +++ b/torch_geometric/explain/algorithm/__init__.py @@ -5,6 +5,7 @@ from .pg_explainer import PGExplainer from .attention_explainer import AttentionExplainer from .graphmask_explainer import GraphMaskExplainer +from .xgnn_explainer import XGNNExplainer __all__ = classes = [ 'ExplainerAlgorithm', @@ -14,4 +15,5 @@ 'PGExplainer', 'AttentionExplainer', 'GraphMaskExplainer', + 'XGNNExplainer' ] diff --git a/torch_geometric/explain/algorithm/xgnn_explainer.py b/torch_geometric/explain/algorithm/xgnn_explainer.py new file mode 100644 index 000000000000..3c3f2261d5ac --- /dev/null +++ b/torch_geometric/explain/algorithm/xgnn_explainer.py @@ -0,0 +1,111 @@ +from abc import abstractmethod +from typing import Optional, Union, Dict +import torch +from torch import Tensor +from torch_geometric.explain.algorithm import ExplainerAlgorithm +from torch_geometric.explain import (GenerativeExplanation, + ExplanationSetSampler) + + +class XGNNExplainer(ExplainerAlgorithm): + r"""The XGNN-Explainer model from the `"XGNN: Towards Model-Level + Explanations of Graph Neural Networks" `_ + paper for training a graph generator so that the generated graph patterns + maximize a certain prediction of the model. + + XGNN-Explainer interprets GNN models by training a graph generator, which + creates graph patterns that maximize the predictive outcome for a specific + class. This approach provides a model-level explanation, offering insights + into what input patterns lead to certain predictions by GNNs. + + .. note:: + + For an example of using :class:`XGNNExplainer`, see + `examples/explain/xgnn_explainer.py `_. + + The explainer trains a generative model, iterating over steps to + progressively build a graph that maximizes the class score of the target + class. The generative model can be customized and is a key component of the + explanation process. + + Args: + epochs (int, optional): The number of epochs for training the + generative model. + (default: :obj:`100`) + lr (float, optional): The learning rate for training the + generative model. + (default: :obj:`0.01`) + **kwargs (optional): Additional hyper-parameters to configure the + training + process of the generative model. + """ + + def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs): + super().__init__() + self.epochs = epochs + self.lr = lr + self.config = kwargs + + def forward( + self, + model: torch.nn.Module, + x: Union[Tensor, Dict[str, Tensor]], + edge_index: Union[Tensor, Dict[str, Tensor]], + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ) -> GenerativeExplanation: + r"""Computes the generative explanation for each class. + + Args: + model (torch.nn.Module): The model to explain. + x (Union[Tensor, Dict[str, Tensor]]): The input node features. + edge_index (Union[Tensor, Dict[str, Tensor]]): The edge indices. + target (Tensor): The target tensor for the explanation. + index (Optional[Union[int, Tensor]], optional): The index of the + node or graph to explain, the index does not matter for this + explainer, and is only here for the sake of integration with + other classes. + **kwargs: Additional keyword arguments. + + Returns: + Union[Explanation, HeteroExplanation]: The explanation result. + """ + # Validate 'for_class' argument + for_class = kwargs.pop('for_class', None) + if for_class is None: + raise ValueError("The 'for_class' argument must be provided") + + if isinstance(x, dict) or isinstance(edge_index, dict): + raise ValueError(f"Heterogeneous graphs not yet supported in " + f"'{self.__class__.__name__}'") + + if index is not None: + raise ValueError( + f"Index not supported in '{self.__class__.__name__}'") + + generative_model_t = self.train_generative_model(model, + for_class=for_class, + **kwargs) + + return GenerativeExplanation(model=model, + explanation_set=generative_model_t) + + @abstractmethod + def train_generative_model(self, model, for_class, + **kwargs) -> ExplanationSetSampler: + r""" Abstract method to train the generative model. + Must be implemented in subclasses. + + Args: + model: The model to explain. + for_class: The class for which the explanation is generated. + """ + raise NotImplementedError( + "The method train_generative_model must be implemented in subclasses" + ) + + def supports(self) -> bool: + return True diff --git a/torch_geometric/explain/config.py b/torch_geometric/explain/config.py index 21f16fdf9b12..9a1909da9548 100644 --- a/torch_geometric/explain/config.py +++ b/torch_geometric/explain/config.py @@ -9,6 +9,7 @@ class ExplanationType(Enum): """Enum class for the explanation type.""" model = 'model' phenomenon = 'phenomenon' + generative = 'generative' class MaskType(Enum): @@ -100,8 +101,9 @@ def __init__( f"'object' (got '{edge_mask_type.value}')") if node_mask_type is None and edge_mask_type is None: - raise ValueError("Either 'node_mask_type' or 'edge_mask_type' " - "must be provided") + if ExplanationType(explanation_type) is not ExplanationType.generative: + raise ValueError("Either 'node_mask_type' or 'edge_mask_type' " + "must be provided") self.explanation_type = ExplanationType(explanation_type) self.node_mask_type = node_mask_type diff --git a/torch_geometric/explain/explainer.py b/torch_geometric/explain/explainer.py index 63ec5ad4c0c5..5b4b1bfbaadb 100644 --- a/torch_geometric/explain/explainer.py +++ b/torch_geometric/explain/explainer.py @@ -8,6 +8,8 @@ ExplainerAlgorithm, Explanation, HeteroExplanation, + GenerativeExplanation, + ) from torch_geometric.explain.algorithm.utils import ( clear_masks, @@ -142,17 +144,18 @@ def get_masked_prediction( out = self.get_prediction(x, edge_index, **kwargs) clear_masks(self.model) + return out def __call__( self, - x: Union[Tensor, Dict[NodeType, Tensor]], - edge_index: Union[Tensor, Dict[EdgeType, Tensor]], + x: Optional[Union[Tensor, Dict[NodeType, Tensor]]], + edge_index: Optional[Union[Tensor, Dict[EdgeType, Tensor]]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs, - ) -> Union[Explanation, HeteroExplanation]: + ) -> Union[Explanation, HeteroExplanation, GenerativeExplanation]: r"""Computes the explanation of the GNN for the given inputs and target. @@ -195,7 +198,10 @@ def __call__( f"type '{self.explanation_type.value}'") prediction = self.get_prediction(x, edge_index, **kwargs) target = self.get_target(prediction) - + elif self.explanation_type == ExplanationType.generative: + pass + + if isinstance(index, int): index = torch.tensor([index]) @@ -213,6 +219,9 @@ def __call__( self.model.train(training) + if isinstance(explanation, GenerativeExplanation): + return explanation + # Add explainer objectives to the `Explanation` object: explanation._model_config = self.model_config explanation.prediction = prediction diff --git a/torch_geometric/explain/explanation.py b/torch_geometric/explain/explanation.py index 8897a32d166a..fd58dedb5dc0 100644 --- a/torch_geometric/explain/explanation.py +++ b/torch_geometric/explain/explanation.py @@ -1,6 +1,3 @@ -import copy -from typing import Dict, List, Optional, Union - import torch from torch import Tensor @@ -10,6 +7,10 @@ from torch_geometric.typing import EdgeType, NodeType from torch_geometric.visualization import visualize_graph +from torch_geometric.data.batch import Batch +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union +import copy class ExplanationMixin: @property @@ -399,3 +400,90 @@ def _visualize_score( plt.show() plt.close() + + +class ExplanationSetSampler(ABC): + r""" + Serves as a base class for sampling from an "Explanation Set" of a neural network. + This set comprises data points that maximize the network's activation. It can be + extended by various generative models or fixed size datasets to perform sampling. + """ + + @abstractmethod + def sample(self, num_samples: int, **kwargs): + r""" + Abstract method to sample data points from the Explanation Set. + + Args: + num_samples (int): The number of samples to generate. + """ + raise NotImplementedError("The method sample must be implemented in subclasses") + + +class GenerativeExplanation(Data): + r"""Holds all the obtained explanations of a homogeneous graph. + + The generative explanation object is a :obj:`~torch_geometric.data.Data` object and + holds the explanation set. + + Args: + explanation_set (ExplanationSetSampler, required): + The explanation set used to explain NN activations, can be a finite set, + generative model or anything that can sample from the abstract explanation set. + is_finite (bool, required): Indicates whether the explanation set is finite. Should be set appropriately in subclasses. + """ + + def validate(self, raise_on_error: bool = True) -> bool: + r"""Validates the correctness of the `GenerativeExplanation` object.""" + status = super().validate(raise_on_error) + explanation_set = self.get("explanation_set") + is_finite = self.get('is_finite') + + if explanation_set is None or is_finite is None: + if raise_on_error: + raise ValueError("Both 'explanation_set' and 'is_finite' must be set.") + status = False + return status + + def is_finite(self) -> bool: + r"""Check if the Explanation Set is finite.""" + is_finite = self.get('is_finite') + return is_finite + + def get_explanation_set(self, **kwargs): + r""" + Retrieves the Explanation Set. If the set is not finite, expects 'num_samples' in kwargs to + be provided for the sake of sampling a finite subset of the explanation set. + + Args: + **kwargs: Key arguments, expected to contain 'num_samples' for infinite sets. + + Raises: + ValueError: If the Explanation Set is infinite and 'num_samples' is not provided. + """ + explanation_set = self.get("explanation_set") + if not isinstance(explanation_set, ExplanationSetSampler): + raise TypeError("'explanation_set' must extend ExplanationSetSampler") + + if not self.is_finite(): + if 'num_samples' not in kwargs: + raise ValueError("Expected 'num_samples' argument for an infinite Explanation Set.") + + return explanation_set.sample(**kwargs) + + def visualize_explanation_graph(self, graph_state, path: Optional[str] = None, + backend: Optional[str] = None): + r""" + Visualizes the explanation graph with edge weights set to be equal. + + Args: + graph_state: The state of the graph to be visualized. + path (Optional[str]): The path to where the plot is saved. + If set to `None`, will visualize the plot on-the-fly. + backend (Optional[str]): The graph drawing backend to use for + visualization (`"graphviz"`, `"networkx"`). + If set to `None`, will use the most appropriate + visualization backend based on available system packages. + """ + edge_index = graph_state.edge_index + visualize_graph(edge_index, path=path, backend=backend) \ No newline at end of file