In [12]:
import torch
from torch_geometric.explain import GNNExplainer
#from torch_geometric.explain import CFExplainer
from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
# Import other necessary modules as needed


In [15]:
%matplotlib inline

from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.nn import GCNConv, Set2Set
from torch_geometric.explain import GNNExplainer
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

In [16]:
#Load the dataset
dataset = 'cora'
path = os.path.join(os.getcwd(), 'data', 'Planetoid')
train_dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

# Since the dataset is comprised of a single huge graph, we extract that graph by indexing 0.
data = train_dataset[0]

# Since there is only 1 graph, the train/test split is done by masking regions of the graph. We split the last 500+500 nodes as val and test, and use the rest as the training data.
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[:data.num_nodes - 1000] = 1
data.val_mask = None
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[data.num_nodes - 500:] = 1

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [17]:
node_idx = 10
x, edge_index = data.x, data.edge_index
explainer = GNNExplainer(model, epochs=200)
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)

TypeError: __init__() got multiple values for argument 'epochs'

In [None]:
epochs = 200
dim = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(num_features=train_dataset.num_features, dim=dim, num_classes=train_dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)
def test(model, data):
    model.eval()
    logits, accs = model(data.x, data.edge_index, data), []
    for _, mask in data('train_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs
loss = 999.0
train_acc = 0.0
test_acc = 0.0

t = trange(epochs, desc="Stats: ", position=0)

for epoch in t:

    model.train()
    
    loss = 0

    data = data.to(device)
    optimizer.zero_grad()
    log_logits = model(data.x, data.edge_index, data)

    # Since the data is a single huge graph, training on the training set is done by masking the nodes that are not in the training set.
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    # validate
    train_acc, test_acc = test(model, data)
    train_loss = loss
    
    t.set_description('[Train_loss:{:.6f} Train_acc: {:.4f}, Test_acc: {:.4f}]'.format(loss, train_acc, test_acc))

In [13]:
# Step 1: Load the benchmark dataset
dataset = TUDataset(root='data', name='PROTEINS')
data = dataset[0]  # Choose the first graph in the dataset for simplicity

# Step 2: Prepare the data
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Step 3: Define a GNN model
class GNNModel(torch.nn.Module):
    def __init__(self):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
        self.explainer = GNNExplainer(self.conv2)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Step 4: Train the GNN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNNModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

for data in loader:
    data = data.to(device)
    optimizer.zero_grad()
    output = model(data.x, data.edge_index)
    loss = F.nll_loss(output, data.y)
    loss.backward()
    optimizer.step()

# Step 5: Explain node importance with GNN-Explainer
model.eval()
explainer = GNNExplainer(model)
node_idx = 0  # Choose the index of the node to explain
node_feat_mask = explainer.explain_node(data.to(device), node_idx)

# Use the node_feat_mask for visualization or further analysis
print(node_feat_mask)


ValueError: Expected input batch_size (5) to match target batch_size (1).

In [None]:
from math import sqrt
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F


from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
from torch_geometric.utils import to_dense_adj
from torch_geometric.utils import dense_adjacency


class CFExplainer(ExplainerAlgorithm):
    r"""The CF-Explainer model from the `"CF-GNNExplainer: Counterfactual Explanations for Graph Neural
Networks"
    <https://arxiv.org/abs/2102.03322>`_ paper for generating CF explanations for GNNs: 
    the minimal perturbation to the input (graph) data such that the prediction changes.

    .. note::

        For an example of using :class:`GNNExplainer`, see
        `examples/explain/gnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,
        `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/
        pyg-team/pytorch_geometric/blob/master/examples/
        explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/
        gnn_explainer_link_pred.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.

    Args:
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to override default
            settings in
            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
    """

    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'beta' : .001,
        'node_feat_size': 1.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs: int = 100, lr: float = 0.01, cf_optimizer = "SGD", n_momentum = 0, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.cf_optimizer = cf_optimizer
        self.n_momentum = n_momentum
        self.coeffs.update(kwargs)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None
        self.best_cf_example = None
        self.best_loss = np.inf
        
    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, target=target, index=index, **kwargs)

        # node_mask = self._post_process_mask(
        #     self.best_cf_example[0],
        #     self.hard_node_mask,
        #     apply_sigmoid=True,
        # )
        edge_mask = self._post_process_mask(
            self.best_cf_example,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=edge_mask)

    def supports(self) -> bool:
        return True
    
    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ):
        self._initialize_masks(x, edge_index)

        parameters = []
        if self.node_mask is not None:
            parameters.append(self.node_mask)
        if self.edge_mask is not None:
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters.append(self.edge_mask)
    
        if self.cf_optimizer == "SGD" and n_momentum == 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr)
        elif self.cf_optimizer == "SGD" and n_momentum != 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr, nesterov=True, momentum=n_momentum)
        elif self.cf_optimizer == "Adadelta":
            optimizer = torch.optim.Adadelta(parameters, lr=self.lr)
        else:
            raise Exception("Optimizer is not currently supported.")
        
        num_cf_examples = 0
        original_prediction  = model(x, edge_index, **kwargs)
        for i in range(self.epochs):
            optimizer.zero_grad()
            h = x if self.node_mask is None else x * self.node_mask.sigmoid()
            discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask)>=0.5, 1, 0)
            set_masks(model, discrete_edge_mask, edge_index, apply_sigmoid=False)
            y_hat, y = model(h, edge_index, **kwargs), original_prediction
            y_hat_discrete, y_discrete = y_hat.argmax(dim=1), y.argmax(dim=1)
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)

            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss = self._loss(y_hat, y, edge_index)

            if loss.item() < best_loss:
                best_loss = loss
                self.best_cf_example = self.edge_mask
            
            loss.backward()
            optimizer.step()

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_mask.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_mask.grad != 0.0

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
            self.edge_mask = Parameter(torch.randn(E, device=device) * std)
        else:
            assert False

    # def _initialize_masks(self, x: Tensor, edge_index: Tensor):
    #     node_mask_type = self.explainer_config.node_mask_type
    #     edge_mask_type = self.explainer_config.edge_mask_type

    #     device = x.device
    #     (N, F), E = x.size(), edge_index.size(1)

    #     if node_mask_type is None:
    #         self.node_mask = None
    #     elif node_mask_type == MaskType.object:
    #         self.node_mask = Parameter(torch.ones(N, 1, device=device))
    #     elif node_mask_type == MaskType.attributes:
    #         self.node_mask = Parameter(torch.ones(N, F, device=device))
    #     elif node_mask_type == MaskType.common_attributes:
    #         self.node_mask = Parameter(torch.ones(1, F, device=device))
    #     else:
    #         assert False


    #     if edge_mask_type is None:
    #         self.edge_mask = None
    #     elif edge_mask_type == MaskType.object:
    #         self.edge_mask = Parameter(torch.ones(E, device=device))
    #     else:
    #         assert False


    def _loss(self, y_hat: Tensor, y: Tensor, edge_index) -> Tensor:
        y_hat_discrete = y_hat.argmax(dim=1)
        y_discrete = y.argmax(dim=1)

        pred_same = (y_hat_discrete == y_discrete).float()
        
        # if self.model_config.mode == ModelMode.binary_classification:
        #     loss = self._loss_binary_classification(y_hat, y)
        # elif self.model_config.mode == ModelMode.multiclass_classification:
        #     loss = self._loss_multiclass_classification(y_hat, y)
        # elif self.model_config.mode == ModelMode.regression:
        #     loss = self._loss_regression(y_hat, y)
        # else:
        #     assert False
        # Want negative in front to maximize loss instead of minimizing it to find CFs
        discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask)>=0, 1, 0)

        loss_pred = - F.nll_loss(y_hat, y_discrete)
        adj = dense_adjacency(edge_index, edge_attr=None, num_nodes=None)
        discrete_adj = torch.where(torch.sigmoid(self.edge_mask) >= 0, 1, 0)
        loss_graph_dist = torch.sum(torch.abs(adj - discrete_adj)) / 2

        #loss_graph_dist = sum(sum(abs(to_dense_adj(edge_index) - to_dense_adj(discrete_edge_mask)))) / 2      # Number of edges changed (symmetrical)

		# Zero-out loss_pred with pred_same if prediction flips
        loss_total = pred_same * loss_pred + self.coeff['beta'] * loss_graph_dist


        # if self.hard_edge_mask is not None:
        #     assert self.edge_mask is not None
        #     m = self.edge_mask[self.hard_edge_mask].sigmoid()
        #     edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
        #     loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
        #     ent = -m * torch.log(m + self.coeffs['EPS']) - (
        #         1 - m) * torch.log(1 - m + self.coeffs['EPS'])
        #     loss = loss + self.coeffs['edge_ent'] * ent.mean()

        # if self.hard_node_mask is not None:
        #     assert self.node_mask is not None
        #     m = self.node_mask[self.hard_node_mask].sigmoid()
        #     node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
        #     loss16 = loss + self.coeffs['node_feat_size'] * node_reduce(m)
        #     ent = -m * torch.log(m + self.coeffs['EPS']) - (
        #         1 - m) * torch.log(1 - m + self.coeffs['EPS'])
        #     loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss_total

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None