In [1]:
from typing import List, Tuple, Dict
from math import sqrt
import numpy as np

import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils.loop import add_self_loops, remove_self_loops
from torch_geometric.data import Data, Batch

from dig.xgraph.models.utils import subgraph

import import_ipynb
from gnnNets import GNNPool

  from .autonotebook import tqdm as notebook_tqdm


importing Jupyter notebook from gnnNets.ipynb


In [2]:
"""
__edge_mask__ -> _edge_mask

edge_attr를 생성하도록 수정
__subgraph__() : kargs에서 자연스럽게 처리. visualize나 subgraphX에 사용. 난 사용할 일 없긴 할 듯
eval_related_pred() : 파라미터에 edge_attr추가
forward() : 파라미터에 edge_attr 추가
batch_input() : batch를 생성하는것 같은데 flowx에서만 사용된다. 수정안함.
visualize_graph() : 안쓸듯. 수정안함.
batch_input(), visualize_graph()는 일단 삭제.
"""
class ExplainerBase(nn.Module):

    def __init__(self, model: nn.Module, epochs: int = 0, lr: float = 0, explain_graph: bool = False,
                 molecule: bool = False):
        super().__init__()
        self.model = model
        self.lr = lr
        self.epochs = epochs
        self.explain_graph = explain_graph
        self.molecule = molecule
        self.mp_layers = [module for module in self.model.modules() if isinstance(module, MessagePassing)]
        self.num_layers = len(self.mp_layers)

        self.ori_pred = None
        self.ex_labels = None
        self.edge_mask = None
        self.hard_edge_mask = None

        self.num_edges = None
        self.num_nodes = None
        self.device = None

    def __set_masks__(self, x: Tensor, edge_index: Tensor, init="normal"):
        (N, F), E = x.size(), edge_index.size(1)

        self.node_feat_mask = torch.nn.Parameter(torch.randn(F, requires_grad=True, device=self.device) * 0.1)

        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
        self.edge_mask = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=self.device) * std)

        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module._explain = True
                module._edge_mask = self.edge_mask

    def __clear_masks__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                module._explain = False
                module._edge_mask = None
        self.node_feat_masks = None
        self.edge_mask = None

    @property
    def __num_hops__(self):
        if self.explain_graph:
            return -1
        else:
            return self.num_layers

    def __flow__(self):
        for module in self.model.modules():
            if isinstance(module, MessagePassing):
                return module.flow
        return 'source_to_target'

    def __subgraph__(self, node_idx: int, x: Tensor, edge_index: Tensor, **kwargs):
        assert 'edge_attr' in kwargs
        
        num_nodes, num_edges = x.size(0), edge_index.size(1)

        subset, edge_index, mapping, edge_mask = subgraph(
            node_idx, self.__num_hops__, edge_index, relabel_nodes=True,
            num_nodes=num_nodes, flow=self.__flow__())

        x = x[subset]
        for key, item in kwargs.items():
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                item = item[subset]
            elif torch.is_tensor(item) and item.size(0) == num_edges:
                item = item[edge_mask]
            kwargs[key] = item

        return x, edge_index, mapping, edge_mask, kwargs

    def forward(self,
                x: Tensor,
                edge_index: Tensor,
                edge_attr: Tensor,
                **kwargs
                ):
        self.num_edges = edge_index.shape[1]
        self.num_nodes = x.shape[0]
        self.device = x.device

    def control_sparsity(self, mask: Tensor, sparsity=None, **kwargs):
        r"""
        :param mask: mask that need to transform
        :param sparsity: sparsity we need to control i.e. 0.7, 0.5
        :return: transformed mask where top 1 - sparsity values are set to inf.
        """
        if sparsity is None:
            sparsity = 0.7

        if not self.explain_graph:
            assert self.hard_edge_mask is not None
            mask_indices = torch.where(self.hard_edge_mask)[0]
            sub_mask = mask[self.hard_edge_mask]
            mask_len = sub_mask.shape[0]
            _, sub_indices = torch.sort(sub_mask, descending=True)
            split_point = int((1 - sparsity) * mask_len)
            important_sub_indices = sub_indices[: split_point]
            important_indices = mask_indices[important_sub_indices]
            unimportant_sub_indices = sub_indices[split_point:]
            unimportant_indices = mask_indices[unimportant_sub_indices]
            trans_mask = mask.clone()
            trans_mask[:] = - float('inf')
            trans_mask[important_indices] = float('inf')
        else:
            _, indices = torch.sort(mask, descending=True)
            mask_len = mask.shape[0]
            split_point = int((1 - sparsity) * mask_len)
            important_indices = indices[: split_point]
            unimportant_indices = indices[split_point:]
            trans_mask = mask.clone()
            trans_mask[important_indices] = float('inf')
            trans_mask[unimportant_indices] = - float('inf')

        return trans_mask
    
    

    def eval_related_pred(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, edge_masks: List[Tensor], **kwargs):

        node_idx = kwargs.get('node_idx')
        node_idx = 0 if node_idx is None else node_idx  # graph level: 0, node level: node_idx
        related_preds = []

        # change the mask from -inf ~ +inf into 0 ~ 1
        for ex_label, edge_mask in enumerate(edge_masks):
            if self.hard_edge_mask is not None:
                sparsity = 1.0 - (edge_mask[self.hard_edge_mask] != 0).sum() / edge_mask[self.hard_edge_mask].size(0)
            else:
                sparsity = 1.0 - (edge_mask != 0).sum() / edge_mask.size(0)

            self.edge_mask.data = torch.ones(edge_mask.size(), device=self.device)
            ori_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            self.edge_mask.data = edge_mask
            masked_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            # mask out important elements for fidelity calculation
            self.edge_mask.data = 1.0 - edge_mask  # keep Parameter's id
            maskout_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            # zero_mask
            self.edge_mask.data = torch.zeros(edge_mask.size(), device=self.device)
            zero_mask_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            related_preds.append({'zero': zero_mask_pred[node_idx],
                                  'masked': masked_pred[node_idx],
                                  'maskout': maskout_pred[node_idx],
                                  'origin': ori_pred[node_idx],
                                  'sparsity': sparsity})

            """
            # Adding proper activation function to the models' outputs.
            tmp_result_dict = {}
            for key, pred in related_preds[ex_label].items():
                if key in ['sparsity']:
                    tmp_result_dict[key] = pred.item()
                else:
                    tmp_result_dict[key] = pred.reshape(-1).softmax(0)[ex_label].item()
            related_preds[ex_label] = tmp_result_dict
            """

        self.__clear_masks__()
        return related_preds

In [3]:
"""
edge_attr를 사용하도록 수정
extract_step() : argument에 edge_attr 추가. model 호출에 edge_attr 추가
eval_related_pred() : 파라미터에 edge_attr 추가

"""
class WalkBase(ExplainerBase):

    def __init__(self, model: nn.Module, epochs: int = 0, lr: float = 0, explain_graph: bool = False, molecule: bool = False):
        super().__init__(model, epochs, lr, explain_graph, molecule)
   
        
    def extract_step(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, detach: bool = True, split_fc: bool = False):

        layer_extractor = []
        hooks = []

        def register_hook(module: nn.Module):
            if not list(module.children()) or isinstance(module, MessagePassing):
                hooks.append(module.register_forward_hook(forward_hook))

        def forward_hook(module: nn.Module, input: Tuple[Tensor], output: Tensor):
            # input contains x and edge_index
            if detach:
                layer_extractor.append((module, input[0].clone().detach(), output.clone().detach()))
            else:
                layer_extractor.append((module, input[0], output))

        # --- register hooks ---
        self.model.apply(register_hook)

        pred = self.model(x, edge_index, edge_attr)

        for hook in hooks:
            hook.remove()

        # --- divide layer sets ---

        walk_steps = []
        fc_steps = []
        pool_flag = False
        step = {'input': None, 'module': [], 'output': None}
        for layer in layer_extractor:
            if isinstance(layer[0], MessagePassing) or isinstance(layer[0], GNNPool):
                if isinstance(layer[0], GNNPool):
                    pool_flag = True
                if step['module'] and step['input'] is not None:
                    walk_steps.append(step)
                step = {'input': layer[1], 'module': [], 'output': None}
            if pool_flag and split_fc and isinstance(layer[0], nn.Linear):
                if step['module']:
                    fc_steps.append(step)
                step = {'input': layer[1], 'module': [], 'output': None}
            step['module'].append(layer[0])
            step['output'] = layer[2]

        for walk_step in walk_steps:
            if hasattr(walk_step['module'][0], 'nn') and walk_step['module'][0].nn is not None:
                # We don't allow any outside nn during message flow process in GINs
                walk_step['module'] = [walk_step['module'][0]]

        if split_fc:
            if step['module']:
                fc_steps.append(step)
            return walk_steps, fc_steps
        else:
            fc_step = step

        return walk_steps, fc_step

    def walks_pick(self,
                   edge_index: Tensor,
                   pick_edge_indices: List,
                   walk_indices: List=[],
                   num_layers=0
                   ):
        walk_indices_list = []
        for edge_idx in pick_edge_indices:

            # Adding one edge
            walk_indices.append(edge_idx)
            _, new_src = src, tgt = edge_index[:, edge_idx]
            next_edge_indices = np.array((edge_index[0, :] == new_src).nonzero().view(-1))

            # Finding next edge
            if len(walk_indices) >= num_layers:
                # return one walk
                walk_indices_list.append(walk_indices.copy())
            else:
                walk_indices_list += self.walks_pick(edge_index, next_edge_indices, walk_indices, num_layers)

            # remove the last edge
            walk_indices.pop(-1)

        return walk_indices_list

    def eval_related_pred(self, x, edge_index, edge_attr, masks, **kwargs):
        # place to add accuracy
        node_idx = kwargs.get('node_idx')
        pred_label = kwargs.get('pred_label')
        node_idx = 0 if node_idx is None else node_idx  # graph level: 0, node level: node_idx

        related_preds = []

        for label, edge_mask in enumerate(masks):
            if self.hard_edge_mask is not None:
                sparsity = 1.0 - (edge_mask[self.hard_edge_mask] != 0).sum() / edge_mask[self.hard_edge_mask].size(0)
            else:
                sparsity = 1.0 - (edge_mask != 0).sum() / edge_mask.size(0)

            # origin pred
            for mask in self.edge_mask:
                mask.data = torch.ones(edge_mask.size(), device=self.device)

            ori_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            for mask in self.edge_mask:
                mask.data = edge_mask
            masked_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            # mask out important elements for fidelity calculation
            for mask in self.edge_mask:
                mask.data = 1.0 - edge_mask
            maskout_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            # zero_mask
            for mask in self.edge_mask:
                mask.data = torch.zeros(edge_mask.size(), device=self.device)
            zero_mask_pred = self.model(x=x, edge_index=edge_index, edge_attr=edge_attr, **kwargs)

            # Store related predictions for further evaluation.
            related_preds.append({'zero': zero_mask_pred[node_idx],
                                  'masked': masked_pred[node_idx],
                                  'maskout': maskout_pred[node_idx],
                                  'origin': ori_pred[node_idx],
                                  'sparsity': sparsity})
            """
            # Adding proper activation function to the models' outputs.
            if pred_label:
                label = pred_label
            tmp_result_dict = {}
            for key, pred in related_preds[label].items():
                if key in ['sparsity']:
                    tmp_result_dict[key] = pred.item()
                else:
                    tmp_result_dict[key] = pred.reshape(-1).softmax(0)[label].item()
            related_preds[label] = tmp_result_dict
            """

        return related_preds

    def explain_edges_with_loop(self, x: Tensor, walks: Dict[Tensor, Tensor], ex_label):

        walks_ids = walks['ids']
        walks_score = walks['score'][:walks_ids.shape[0], ex_label].reshape(-1)
        if walks_ids.max() <= self.num_edges - 1:  # num_edges includes the self-loop
            idx_ensemble = torch.cat([(walks_ids == i).int().sum(dim=1).unsqueeze(0) for i in range(self.num_edges)], dim=0)
        else:
            idx_ensemble = torch.cat([(walks_ids == i).int().sum(dim=1).unsqueeze(0) for i in range(self.num_edges + self.num_nodes)], dim=0)
        hard_edge_attr_mask = (idx_ensemble.sum(1) > 0).long()
        hard_edge_attr_mask_value = torch.tensor([float('inf'), 0], dtype=torch.float, device=self.device)[hard_edge_attr_mask]
        edge_attr = (idx_ensemble * (walks_score.unsqueeze(0))).sum(1)
        # idx_ensemble1 = torch.cat(
        #     [(walks_ids == i).int().sum(dim=1).unsqueeze(1) for i in range(self.num_edges + self.num_nodes)], dim=1)
        # edge_attr1 = (idx_ensemble1 * (walks_score.unsqueeze(1))).sum(0)

        return edge_attr - hard_edge_attr_mask_value

    class connect_mask(object):

        def __init__(self, cls):
            self.cls = cls

        def __enter__(self):

            self.cls.edge_mask = [nn.Parameter(torch.randn(self.cls.x_batch_size * (self.cls.num_edges + self.cls.num_nodes))) for _ in
                             range(self.cls.num_layers)] if hasattr(self.cls, 'x_batch_size') else \
                                 [nn.Parameter(torch.randn(1 * (self.cls.num_edges + self.cls.num_nodes))) for _ in
                             range(self.cls.num_layers)]

            for idx, module in enumerate(self.cls.mp_layers):
                module._explain = True
                module._edge_mask = self.cls.edge_mask[idx]

        def __exit__(self, *args):
            for idx, module in enumerate(self.cls.mp_layers):
                module._explain = False

    class temp_mask(object):

        def __init__(self, cls, temp_edge_mask):
            self.cls = cls
            self.temp_edge_mask = temp_edge_mask

        def __enter__(self):

            for idx, module in enumerate(self.cls.mp_layers):
                module.__explain_flow__ = True
                module.layer_edge_mask = self.temp_edge_mask[idx]

        def __exit__(self, *args):
            for idx, module in enumerate(self.cls.mp_layers):
                module.__explain_flow__ = False