In [1]:
import os
from typing import Union
from math import sqrt

import torch
from torch import Tensor
from torch.nn.functional import mse_loss
from torch_geometric.utils.loop import add_remaining_self_loops
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing

from dig.version import debug
from dig.xgraph.models.utils import subgraph
#from dig.xgraph.method.utils import symmetric_edge_mask_indirect_graph
from dig.xgraph.evaluation import XCollector

EPS = 1e-15

import import_ipynb
from influenceDataset import get_dataloader, influenceDataset
from gnnNets import get_gnnNets
from base_explainer import ExplainerBase

importing Jupyter notebook from influenceDataset.ipynb
importing Jupyter notebook from gnnNets.ipynb
importing Jupyter notebook from base_explainer.ipynb


In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "3" 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [11]:
def cross_entropy_with_logit(y_pred: torch.Tensor, y_true: torch.Tensor, **kwargs):
    return cross_entropy(y_pred, y_true.long(), **kwargs)


class GNNExplainer(ExplainerBase):
    r"""The GNN-Explainer model from the `"GNNExplainer: Generating
    Explanations for Graph Neural Networks"
    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
    structures and small subsets node features that play a crucial role in a
    GNN’s node-predictions.
    .. note:: For an example, see `benchmarks/xgraph
        <https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph>`_.
    Args:
        model (torch.nn.Module): The GNN module to explain.
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        explain_graph (bool, optional): Whether to explain graph classification model
            (default: :obj:`False`)
        indirect_graph_symmetric_weights (bool, optional): If `True`, then the explainer
            will first realize whether this graph input has indirect edges, 
            then makes its edge weights symmetric. (default: :obj:`False`)
    """
    def __init__(self,
                 model: torch.nn.Module,
                 epochs: int = 100,
                 lr: float = 0.01,
                 coff_edge_size: float = 0.001,
                 coff_edge_ent: float = 0.001,
                 coff_node_feat_size: float = 1.0,
                 coff_node_feat_ent: float = 0.1,
                 explain_graph: bool = False,
                 indirect_graph_symmetric_weights: bool = False):
        super(GNNExplainer, self).__init__(model, epochs, lr, explain_graph)
        self.coff_node_feat_size = coff_node_feat_size
        self.coff_node_feat_ent = coff_node_feat_ent
        self.coff_edge_size = coff_edge_size
        self.coff_edge_ent = coff_edge_ent
        self._symmetric_edge_mask_indirect_graph: bool = indirect_graph_symmetric_weights   

    
    def __loss__(self, raw_preds: Tensor, x_label: Union[Tensor, int]):
        loss = mse_loss(raw_preds, torch.tensor([[x_label]]).to(device))
        """
        if self.explain_graph:
            loss = cross_entropy_with_logit(raw_preds, x_label)
        else:
            loss = cross_entropy_with_logit(raw_preds[self.node_idx].reshape(1, -1), x_label)
        """
        
        m = self.edge_mask.sigmoid()
        loss = loss + self.coff_edge_size * m.sum()
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        loss = loss + self.coff_edge_ent * ent.mean()

        if self.mask_features:
            m = self.node_feat_mask.sigmoid()
            loss = loss + self.coff_node_feat_size * m.sum()
            ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
            loss = loss + self.coff_node_feat_ent * ent.mean()

        return loss

    def gnn_explainer_alg(self,
                          x: Tensor,
                          edge_index: Tensor,
                          edge_attr: Tensor,
                          ex_label: Tensor,
                          prediction,
                          mask_features: bool = False,
                          **kwargs
                          ) -> Tensor:

        # initialize a mask
        self.to(x.device)
        self.mask_features = mask_features

        # train to get the mask
        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        for epoch in range(1, self.epochs + 1):

            if mask_features:
                h = x * self.node_feat_mask.view(1, -1).sigmoid()
            else:
                h = x
            raw_preds = self.model(x=h, edge_index=edge_index, edge_attr=edge_attr, **kwargs)
            loss = self.__loss__(raw_preds, prediction)
            if epoch % 20 == 0 and debug:
                print(f'Loss:{loss.item()}')

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=2.0)
            optimizer.step()

        return self.edge_mask.data

    def forward(self, x, edge_index, edge_attr, mask_features=False, target_label=None, prediction=None, **kwargs):
        r"""
        Run the explainer for a specific graph instance.
        Args:
            x (torch.Tensor): The graph instance's input node features.
            edge_index (torch.Tensor): The graph instance's edge index.
            mask_features (bool, optional): Whether to use feature mask. Not recommended.
                (Default: :obj:`False`)
            target_label (torch.Tensor, optional): if given then apply optimization only on this label
            **kwargs (dict):
                :obj:`node_idx` （int, list, tuple, torch.Tensor): The index of node that is pending to be explained.
                (for node classification)
                :obj:`sparsity` (float): The Sparsity we need to control to transform a
                soft mask to a hard mask. (Default: :obj:`0.7`)
                :obj:`num_classes` (int): The number of task's classes.
        :rtype: (None, list, list)
        .. note::
            (None, edge_masks, related_predictions):
            edge_masks is a list of edge-level explanation for each class;
            related_predictions is a list of dictionary for each class
            where each dictionary includes 4 type predicted probabilities.
        """
        super().forward(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)
        self.model.eval()

        self_loop_edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=self.num_nodes)

        """
        # Only operate on a k-hop subgraph around `node_idx`.
        # Get subgraph and relabel the node, mapping is the relabeled given node_idx.
        if not self.explain_graph:
            self.node_idx = node_idx = kwargs.get('node_idx')
            assert node_idx is not None, 'An node explanation needs kwarg node_idx, but got None.'
            if isinstance(node_idx, torch.Tensor) and not node_idx.dim():
                node_idx = node_idx.to(self.device).flatten()
            elif isinstance(node_idx, (int, list, tuple)):
                node_idx = torch.tensor([node_idx], device=self.device, dtype=torch.int64).flatten()
            else:
                raise TypeError(f'node_idx should be in types of int, list, tuple, '
                                f'or torch.Tensor, but got {type(node_idx)}')
            self.subset, _, _, self.hard_edge_mask = subgraph(
                node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True,
                num_nodes=None, flow=self.__flow__())
            self.new_node_idx = torch.where(self.subset == node_idx)[0]
        """

        if kwargs.get('edge_masks'):
            edge_masks = kwargs.pop('edge_masks')
            self.__set_masks__(x, self_loop_edge_index)

        else:
            # Assume the mask we will predict
            labels = tuple(i for i in range(kwargs.get('num_classes')))
            ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels)

            # Calculate mask
            edge_masks = []
            for ex_label in ex_labels:
                if target_label is None or ex_label.item() == target_label.item():
                    self.__clear_masks__()
                    self.__set_masks__(x, self_loop_edge_index)
                    edge_mask = self.gnn_explainer_alg(x, edge_index, edge_attr, ex_label, prediction).sigmoid()
                    
                    #if self._symmetric_edge_mask_indirect_graph:
                    #    edge_mask = symmetric_edge_mask_indirect_graph(self_loop_edge_index, edge_mask)

                    edge_masks.append(edge_mask)

        hard_edge_masks = [self.control_sparsity(mask, sparsity=kwargs.get('sparsity')).sigmoid().to(self.device)
                           for mask in edge_masks]  # 원래 코드에는 .sigmoid()가 없다. 그러나 없으면 오류나고 다른 method들에는 있길래 내가 추가 
        
        with torch.no_grad():
            related_preds = self.eval_related_pred(x, edge_index, edge_attr, hard_edge_masks, **kwargs)

        self.__clear_masks__()

        return edge_masks, hard_edge_masks, related_preds

    def __repr__(self):
        return f'{self.__class__.__name__}()'

In [12]:
def pipeline():
    dataset = influenceDataset('/data/URP','/data/URP/graphs')
        
    model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})  # 내가 다른 코드에서 add_self_loop를 지웠을 수 있다. 조심
    model.load_state_dict(torch.load('/data/URP/model/bestmodel.pt'))
    model.to(device) 
    
    gnn_explainer = GNNExplainer(model,
                                 epochs=100,
                                 lr=0.01,
                                 explain_graph=True)
    gnn_explainer.device = device

    index = 0
    x_collector = XCollector()
    for i, data in enumerate(dataset):
        index += 1
        data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
        data.to(device)
        pred = model(data) 
        prediction = pred.argmax(-1).item()

        edge_masks, hard_edge_masks, related_preds = \
            gnn_explainer(data.x, data.edge_index, data.edge_attr,
                          sparsity=0.1,
                          num_classes=dataset.num_classes,
                         prediction=pred.item())
        edge_masks = [edge_mask.to('cpu') for edge_mask in edge_masks]
        x_collector.collect_data(hard_edge_masks, related_preds, label=prediction)

    print(f'Fidelity: {x_collector.fidelity:.4f}\n'
          f'Fidelity_inv: {x_collector.fidelity_inv: .4f}\n'
          f'Sparsity: {x_collector.sparsity:.4f}')

In [13]:
pipeline()

Fidelity: 228.0257
Fidelity_inv:  0.0138
Sparsity: 0.1000


In [32]:
# test code
"""
from torch_geometric.data import Data
x = torch.tensor([[1], [0], [1], [1], [0]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 3, 3], [1, 2, 1, 1, 4]], dtype=torch.long)
edge_attr1 = torch.tensor([[0], [0.8], [1], [0.5], [0.1]], dtype=torch.float)
edge_attr2 = torch.tensor([[1], [0.2], [0], [0.8], [0.1]], dtype=torch.float)

data11 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data12 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data21 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr2)

model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})
model.load_state_dict(torch.load('../model/bestmodel.pt'))

gnngi_explainer = GNNExplainer(model, epochs=10, lr=0.01, explain_graph=True)

index = 0
for i, data in enumerate([data11,data12,data21]):
    index += 1
    data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
    pred = model(data)
    walks, edge_masks, related_preds = \
        gnngi_explainer(data.x, data.edge_index, data.edge_attr,
                        sparsity=0.5,
                        num_classes=1, prediction=pred.item())
    print(index)
    print(related_preds)
    print(edge_masks)
    print()
"""
# 같은 데이터로도 결과가 다른 것은 GNN에서 데이터마다 모델을 학습시킬 떄의 randomness때문인듯

1
[{'zero': tensor([0.3027]), 'masked': tensor([360.6974]), 'maskout': tensor([360.6983]), 'origin': tensor([360.6983]), 'sparsity': tensor(0.5000)}]
[tensor([1., 0., 0., 0., 0., 1., 1., 1., 0., 1.])]
tensor([[360.6983]], grad_fn=<AddmmBackward0>)

2
[{'zero': tensor([0.3027]), 'masked': tensor([360.6974]), 'maskout': tensor([360.6982]), 'origin': tensor([360.6983]), 'sparsity': tensor(0.5000)}]
[tensor([0., 0., 1., 1., 0., 0., 0., 1., 1., 1.])]
tensor([[360.6983]], grad_fn=<AddmmBackward0>)

3
[{'zero': tensor([0.3027]), 'masked': tensor([360.6983]), 'maskout': tensor([360.6982]), 'origin': tensor([575.8438]), 'sparsity': tensor(0.5000)}]
[tensor([1., 1., 1., 1., 0., 1., 0., 0., 0., 0.])]
tensor([[575.8438]], grad_fn=<AddmmBackward0>)

