In [1]:
import os
from math import sqrt
from typing import Any, Callable, List, Tuple, Union, Dict

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import DataLoader
from torch_geometric.utils.loop import add_remaining_self_loops

from dig.xgraph.models.utils import subgraph
from dig.xgraph.evaluation import XCollector
from dig.xgraph.models.ext.deeplift.layer_deep_lift import DeepLift

import import_ipynb
from influenceDataset import get_dataloader, influenceDataset
from gnnNets import get_gnnNets
from base_explainer import WalkBase

importing Jupyter notebook from influenceDataset.ipynb
importing Jupyter notebook from gnnNets.ipynb
importing Jupyter notebook from base_explainer.ipynb


In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "3" 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
"""
이것도 self_loop를 추가한다. 왜?

deeplift의 정확한 원리를 모르고 edge_attr를 사용하도록 수정했다. 제대로 수정한건지 확인 필요
"""
class DeepLIFT(WalkBase):
    r"""
    An implementation of DeepLIFT on graph in
    `Learning Important Features Through Propagating Activation Differences <https://arxiv.org/abs/1704.02685>`_.
    Args:
        model (torch.nn.Module): The target model prepared to explain.
        explain_graph (bool, optional): Whether to explain graph classification model.
            (default: :obj:`False`)
    .. note:: For node classification model, the :attr:`explain_graph` flag is False.
        For an example, see `benchmarks/xgraph
        <https://github.com/divelab/DIG/tree/dig/benchmarks/xgraph>`_.
    """

    def __init__(self, model: nn.Module, explain_graph: bool = False):
        super().__init__(model=model, explain_graph=explain_graph)

    def forward(self,
                x: Tensor,
                edge_index: Tensor,
                edge_attr: Tensor,
                **kwargs
                ):
        r"""
        Run the explainer for a specific graph instance.
        Args:
            x (torch.Tensor): The graph instance's input node features.
            edge_index (torch.Tensor): The graph instance's edge index.
            **kwargs (dict): :obj:`node_idx` （int): The index of node that is pending to be explained.
                (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a
                soft mask to a hard mask. (Default: :obj:`0.7`)
        :rtype: (None, list, list)
        .. note::
            (None, edge_masks, related_predictions):
            edge_masks is a list of edge-level explanation for each class;
            related_predictions is a list of dictionary for each class
            where each dictionary includes 4 type predicted probabilities.
        """

        # --- run the model once ---
        super().forward(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)
        self.model.eval()
        self_loop_edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=self.num_nodes)
        
        """
        if not self.explain_graph:
            node_idx = kwargs.get('node_idx')
            if not node_idx.dim():
                node_idx = node_idx.reshape(-1)
            node_idx = node_idx.to(self.device)
            assert node_idx is not None
            self.subset, _, _, self.hard_edge_mask = subgraph(
                node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True,
                num_nodes=None, flow=self.__flow__())
            self.new_node_idx = torch.where(self.subset == node_idx)[0]
        """

        # --- add shap calculation hook ---
        shap = DeepLift(self.model)
        self.model.apply(shap._register_hooks)

        inp_with_ref = torch.cat([x, torch.zeros(x.shape, device=self.device, dtype=torch.float)], dim=0).requires_grad_(True)
        edge_index_with_ref = torch.cat([edge_index, edge_index + x.shape[0]], dim=1)
        edge_attr_with_ref = torch.cat([edge_attr, edge_attr])
        batch = torch.arange(2, dtype=torch.long, device=self.device).view(2, 1).repeat(1, x.shape[0]).reshape(-1)
        out = self.model(inp_with_ref, edge_index_with_ref, edge_attr_with_ref, batch)

        labels = tuple(i for i in range(kwargs.get('num_classes')))
        ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels)

        if kwargs.get('edge_masks'):
            edge_masks = kwargs.pop('edge_masks')
            hard_edge_masks = [self.control_sparsity(mask, kwargs.get('sparsity')).sigmoid() for mask in edge_masks]

        else:
            edge_masks = []
            hard_edge_masks = []
            for ex_label in ex_labels:

                if self.explain_graph:
                    f = torch.unbind(out[:, ex_label])
                else:
                    f = torch.unbind(out[[node_idx, node_idx + x.shape[0]], ex_label])

                (m, ) = torch.autograd.grad(outputs=f, inputs=inp_with_ref, retain_graph=True)
                inp, inp_ref = torch.chunk(inp_with_ref, 2)
                attr_wo_relu = (torch.chunk(m, 2)[0] * (inp - inp_ref)).sum(1)

                mask = attr_wo_relu.squeeze()
                score_mask = (mask[self_loop_edge_index[0]] + mask[self_loop_edge_index[1]]) / 2
                edge_masks.append(score_mask.detach())
                mask = self.control_sparsity(score_mask, kwargs.get('sparsity'))
                mask = mask.sigmoid()
                hard_edge_masks.append(mask.detach())

        # Store related predictions for further evaluation.
        shap._remove_hooks()

        with torch.no_grad():
            with self.connect_mask(self):
                related_preds = self.eval_related_pred(x, edge_index, edge_attr, hard_edge_masks, **kwargs)
        return edge_masks, hard_edge_masks, related_preds

In [4]:
def pipeline():
    dataset = influenceDataset('/data/URP','/data/URP/graphs')
        
    model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})
    model.load_state_dict(torch.load('/data/URP/model/bestmodel.pt'))
    model.to(device)
    
    deep_lift = DeepLIFT(model, explain_graph=True)
    
    index = 0
    x_collector = XCollector()
    for i, data in enumerate(dataset):
        index += 1
        data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
        data.to(device)
        edge_masks, hard_edge_masks, related_preds = \
                    deep_lift(data.x, data.edge_index, data.edge_attr,
                              sparsity=0.1,
                              num_classes=dataset.num_classes)
        edge_masks = [edge_mask.to('cpu') for edge_mask in edge_masks]
        
        model.eval()
        prediction = model(data).argmax(-1).item()
        x_collector.collect_data(edge_masks, related_preds, label=prediction)
    
    print(f'Fidelity: {x_collector.fidelity:.4f}\n'
          f'Fidelity_inv: {x_collector.fidelity_inv: .4f}\n'
          f'Sparsity: {x_collector.sparsity:.4f}')

In [5]:
pipeline()

Fidelity: 3.8500
Fidelity_inv:  0.0173
Sparsity: 0.1000


In [4]:
# text
"""
from torch_geometric.data import Data
x = torch.tensor([[1], [0], [1], [1], [0]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2, 3, 3], [1, 2, 1, 1, 4]], dtype=torch.long)
edge_attr1 = torch.tensor([[0], [0.8], [1], [0.5], [0.1]], dtype=torch.float)
edge_attr2 = torch.tensor([[1], [0.2], [1], [0.8], [1]], dtype=torch.float)
data11 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data12 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr1)
data21 = Data(x=x, edge_index=edge_index, edge_attr=edge_attr2)

model = get_gnnNets(1, 1, {'gnn_latent_dim':[128,128,128], 'add_self_loop':False})
model.load_state_dict(torch.load('../model/bestmodel.pt'))

gnngi_explainer = DeepLIFT(model, explain_graph=True)

index = 0
for i, data in enumerate([data11,data12,data21]):
    index += 1
    data.edge_index, data.edge_attr = add_remaining_self_loops(data.edge_index, edge_attr=data.edge_attr, num_nodes=data.num_nodes)
    pred = model(data)
    walks, edge_masks, related_preds = \
        gnngi_explainer(data.x, data.edge_index, data.edge_attr,
                        sparsity=0.5,
                        num_classes=1)
    print(index)
    print(related_preds)
    print(edge_masks)
    print()
"""

1
[{'zero': tensor([0.3027]), 'masked': tensor([360.6983]), 'maskout': tensor([156.1122]), 'origin': tensor([360.6983]), 'sparsity': tensor(0.5000)}]
[tensor([1., 0., 0., 1., 1., 1., 0., 0., 1., 0.])]

2
[{'zero': tensor([0.3027]), 'masked': tensor([360.6983]), 'maskout': tensor([156.1122]), 'origin': tensor([360.6983]), 'sparsity': tensor(0.5000)}]
[tensor([1., 0., 0., 1., 1., 1., 0., 0., 1., 0.])]

3
[{'zero': tensor([0.3027]), 'masked': tensor([360.6983]), 'maskout': tensor([65.3457]), 'origin': tensor([631.2159]), 'sparsity': tensor(0.5000)}]
[tensor([1., 0., 0., 1., 0., 1., 0., 1., 1., 0.])]

