In [1]:
import os
import math

import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.utils.loop import add_remaining_self_loops
from torch_geometric.data import DataLoader

from dig.xgraph.models.utils import subgraph
from dig.xgraph.method import GNN_GI
from dig.xgraph.evaluation import XCollector

import import_ipynb
from influenceDataset import get_dataloader, influenceDataset
from gnnNets import get_gnnNets, GNNPool
from base_explainer import WalkBase

  from .autonotebook import tqdm as notebook_tqdm


importing Jupyter notebook from influenceDataset.ipynb
importing Jupyter notebook from gnnNets.ipynb
importing Jupyter notebook from base_explainer.ipynb


In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "3" 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [2]:
class GNN_GI(WalkBase):
    r"""
    An implementation of GNN-GI in
    `Higher-Order Explanations of Graph Neural Networks via Relevant Walks <https://arxiv.org/abs/2006.03589>`_.

    Args:
        model (torch.nn.Module): The target model prepared to explain.
        explain_graph (bool, optional): Whether to explain graph classification model.
            (default: :obj:`False`)

    .. note:: For node classification model, the :attr:`explain_graph` flag is False.

    """

    def __init__(self, model: nn.Module, explain_graph: bool = False):
        super().__init__(model=model, explain_graph=explain_graph)

    def forward(self,
                x: Tensor,
                edge_index: Tensor,
                edge_attr: Tensor,
                **kwargs
                ):
        r"""
        Run the explainer for a specific graph instance.

        Args:
            x (torch.Tensor): The graph instance's input node features.
            edge_index (torch.Tensor): The graph instance's edge index.
            **kwargs (dict):
                :obj:`node_idx` （int): The index of node that is pending to be explained.
                (for node classification)
                :obj:`sparsity` (float): The Sparsity we need to control to transform a
                soft mask to a hard mask. (Default: :obj:`0.7`)
                :obj:`num_classes` (int): The number of task's classes.

        :rtype: (dict, list, list)

        .. note::
            (walks, edge_masks, related_predictions):
            walks is a dictionary including walks' edge indices and corresponding explained scores;
            edge_masks is a list of edge-level explanation for each class;
            related_predictions is a list of dictionary for each class
            where each dictionary includes 4 type predicted probabilities.

        """
        super().forward(x, edge_index, edge_attr, **kwargs)
        self.model.eval()
        self_loop_edge_index, _ = add_remaining_self_loops(edge_index, edge_attr=edge_attr, num_nodes=self.num_nodes)

        walk_steps, fc_step = self.extract_step(x, edge_index, edge_attr, detach=False)
        labels = tuple(i for i in range(kwargs.get('num_classes')))
        
        """
        if not self.explain_graph:
            node_idx = kwargs.get('node_idx')
            if not node_idx.dim():
                node_idx = node_idx.reshape(-1)
            node_idx = node_idx.to(self.device)
            assert node_idx is not None
            self.subset, _, _, self.hard_edge_mask = subgraph(
                node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True,
                num_nodes=None, flow=self.__flow__())
            self.new_node_idx = torch.where(self.subset == node_idx)[0]
        """

        if kwargs.get('walks'):
            walks = kwargs.pop('walks')

        else:
            def compute_walk_score(adjs, r, allow_edges, walk_idx=[]):
                if not adjs:
                    walk_indices.append(walk_idx)
                    walk_scores.append(r.detach())
                    return
                (grads,) = torch.autograd.grad(outputs=r, inputs=adjs[0], create_graph=True)
                for i in allow_edges:
                    allow_edges = torch.where(self_loop_edge_index[1] == self_loop_edge_index[0][i])[0].tolist()
                    new_r = grads[i] * adjs[0][i]
                    compute_walk_score(adjs[1:], new_r, allow_edges, [i] + walk_idx)

            walk_scores_tensor_list = [None for i in labels]
            for label in labels:
                if self.explain_graph:
                    f = torch.unbind(fc_step['output'][0, label].unsqueeze(0))
                    allow_edges = [i for i in range(self_loop_edge_index.shape[1])]
                else:
                    f = torch.unbind(fc_step['output'][node_idx, label].unsqueeze(0))
                    allow_edges = torch.where(self_loop_edge_index[1] == node_idx)[0].tolist()

                adjs = [walk_step['module'][0].edge_weight for walk_step in walk_steps]
                reverse_adjs = adjs.reverse()
                walk_indices = []
                walk_scores = []

                compute_walk_score(adjs, f, allow_edges)
                walk_scores_tensor_list[label] = torch.stack(walk_scores, dim=0).view(-1, 1)

            walks = {'ids': torch.tensor(walk_indices, device=self.device),
                     'score': torch.cat(walk_scores_tensor_list, dim=1)}

        # --- Apply edge mask evaluation ---
        with torch.no_grad():
            with self.connect_mask(self):
                ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels)
                edge_masks = []
                hard_edge_masks = []
                for ex_label in ex_labels:
                    edge_attr = self.explain_edges_with_loop(x, walks, ex_label)
                    edge_mask = edge_attr.detach()
                    valid_mask = (edge_mask != - math.inf)
                    edge_mask[edge_mask == - math.inf] = edge_mask[valid_mask].min() - 1  # replace the negative inf
                    edge_masks.append(edge_mask)
                    hard_edge_masks.append(self.control_sparsity(edge_attr, kwargs.get('sparsity')).sigmoid())

                related_preds = self.eval_related_pred(x, edge_index, edge_attr, hard_edge_masks, **kwargs)

        return walks, edge_masks, related_preds

In [4]:
def pipeline():
    dataset = influenceDataset('/data/URP','/data/URP/graphs')
    # data는 test것만 사용해야한다. 나중에 수정하기
        
    model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})
    model.load_state_dict(torch.load('../model/bestmodel.pt'))
    model.to(device)
    
    gnngi_explainer = GNN_GI(model, explain_graph=True)
    
    index = 0
    x_collector = XCollector()
    for i, data in enumerate(dataset):
        print(index)
        index += 1
        data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
        data.to(device)
        walks, edge_masks, related_preds = \
            gnngi_explainer(data.x, data.edge_index, data.edge_attr,
                            sparsity=0.5,
                            num_classes=dataset.num_classes)
        #walks = {k: v.to('cpu') for k, v in walks.items()}
        
        model.eval()
        prediction = model(data).argmax(-1).item()
        x_collector.collect_data(edge_masks, related_preds, label=prediction)
    
    print(f'Fidelity: {x_collector.fidelity:.4f}\n'
          f'Fidelity_inv: {x_collector.fidelity_inv: .4f}\n'
          f'Sparsity: {x_collector.sparsity:.4f}')

In [5]:
pipeline()

0


KeyboardInterrupt: 

In [17]:
# test
"""
from torch_geometric.data import Data
x = torch.tensor([[1], [0], [1], [1], [0]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 3, 3], [1, 2, 1, 1, 4]], dtype=torch.long)
edge_attr1 = torch.tensor([[0], [0.8], [1], [0.5], [0.1]], dtype=torch.float)
edge_attr2 = torch.tensor([[1], [0.2], [1], [0.8], [1]], dtype=torch.float)
data11 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data12 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data21 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr2)

model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})
model.load_state_dict(torch.load('../model/bestmodel.pt'))

gnngi_explainer = GNN_GI(model, explain_graph=True)

index = 0
for i, data in enumerate([data11,data12,data21]):
    index += 1
    data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
    pred = model(data)
    walks, edge_masks, related_preds = \
        gnngi_explainer(data.x, data.edge_index, data.edge_attr,
                        sparsity=0.5,
                        num_classes=1)
    print(index)
    print(related_preds)
    print(edge_masks)
    print()
"""

1
[{'zero': tensor([0.3027]), 'masked': tensor([360.6984]), 'maskout': tensor([0.3027]), 'origin': tensor([360.6984]), 'sparsity': tensor(0.5000)}]
[tensor([  0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 551.1729,   0.0000,
          0.0000, 551.1729,   0.0000])]

2
[{'zero': tensor([0.3027]), 'masked': tensor([360.6984]), 'maskout': tensor([0.3027]), 'origin': tensor([360.6984]), 'sparsity': tensor(0.5000)}]
[tensor([  0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 551.1729,   0.0000,
          0.0000, 551.1729,   0.0000])]

3
[{'zero': tensor([0.3027]), 'masked': tensor([360.6984]), 'maskout': tensor([10.4759]), 'origin': tensor([539.5463]), 'sparsity': tensor(0.5000)}]
[tensor([259.4229,  22.4284, 199.1229, 207.5383,   0.0000, 426.5988, 197.8501,
        276.7264, 341.2790,   0.0000])]

