In [None]:
import os
import numpy as np
import time

import torch
from torch_geometric.data import Data
from torch_geometric.explain import Explainer, CaptumExplainer

import import_ipynb
from dataset import get_data
from constants import *
from utils import *

In [None]:
def RumorGuard_G(adj_list, seed_idx, prob, del_edge_num, model_name=None, gnn_latent_dim=[128,128,128,128,128,128], gpu_num='cpu', **kwargs):
    time_start = time.time()
    time_alg = []
    
    device = set_gpu(gpu_num)
    model = load_model(model_name, device, gnn_latent_dim=gnn_latent_dim)

    # convert adj_list to Data
    data = adj2Data(adj_list, seed_idx)
    data = data.to(device)
    edge_index = data.edge_index
    edge_attr = data.edge_attr
    
    mask = []
    for _ in range(del_edge_num):
        pred_min = float('inf')
        idx_min = None
        for i in range(len(edge_attr)):
            u = edge_index[0,i]
            v = edge_index[1,i]
            p = edge_attr[i][0].item()
            if p==0: continue
            edge_attr[i][0] = 0
            pred = model(data)[0].item()
            edge_attr[i][0] = p
            if pred<pred_min: pred_min = pred; idx_min=(u,v,i,p)
        mask.append([idx_min[0],idx_min[1],idx_min[3]])
        edge_attr[idx_min[2]]=0

        time_alg.append(time.time()-time_start)
        
    return mask, time_alg

In [None]:
# RumorGuard_O is implemented based on GNNExplainer of PyG

from math import sqrt
from typing import Optional, Tuple, Union
import numpy as np

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.optim as optim

from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel


class RumorGuard_O_class(ExplainerAlgorithm):

    def __init__(self, epochs=100, lr=0.01, del_edge_num=10, size_coeff=1, ent_coeff=1, lr_gamma=1, log_filename=None, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.eps = 1e-15
        self.del_edge_num = del_edge_num

        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None

        self.epochs = epochs
        self.size_coeff = size_coeff
        self.ent_coeff = ent_coeff
        self.lr_gamma = lr_gamma
        self.log_filename = log_filename

    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, target=target, index=index, **kwargs)

        node_mask = self._post_process_mask(
            self.node_mask,
            self.hard_node_mask,
            apply_sigmoid=True,
        )
        edge_mask = self._post_process_mask(
            self.edge_mask,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=-edge_mask)

    def supports(self) -> bool:
        return True

    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ):
        with torch.no_grad(): self.infl_orig = sum(model(x, edge_index, **kwargs)).item()
        self.seed_size = sum(x).item()

        record = []
        
        self._initialize_masks(x, edge_index)

        parameters = []
        if self.node_mask is not None:
            parameters.append(self.node_mask)
        if self.edge_mask is not None:
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters.append(self.edge_mask)

        optimizer = torch.optim.Adam(parameters, lr=self.lr)
        scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: self.lr_gamma)

        for i in range(1,self.epochs+1):
            optimizer.zero_grad()

            h = x if self.node_mask is None else x * self.node_mask.sigmoid()
            y_hat, y = model(h, edge_index, **kwargs), target

            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss, pred, size, ent = self._loss(y_hat, y, i)
            record.append((i,loss.item(),pred.item(),size.item(),ent.item()))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, 0.01)
            optimizer.step()

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_mask.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_mask.grad != 0.0

            scheduler.step()

        if self.log_filename :
            with open(self.log_filename,'w') as f:
                f.write('seed infl_orig\n')
                f.write(f'{self.seed_size} {self.infl_orig}\n')
                f.write('epoch pred size ent loss\n')
                for epoch, loss, pred, size, ent in record:
                    f.write(f'{epoch} {pred} {size} {ent} {loss}\n')
            mask = self.edge_mask.cpu().detach().sigmoid().numpy()
            with open(self.log_filename[:-4]+'.npy', 'wb') as f: np.save(f, mask)

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)
        self.m = E
        self.n = N

        std = 0.1
        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            self.edge_mask = torch.nn.Parameter(torch.ones(E, requires_grad=True, device=device) * 3)
        else:
            assert False

    def _loss(self, y_hat: Tensor, y: Tensor, epoch) -> Tensor:
        pred = y_hat[0]

        mask = self.edge_mask[self.hard_edge_mask].sigmoid()
        size = self.m-mask.sum()

        ent = -mask * torch.log(mask + self.eps) - (1 - mask) * torch.log(1 - mask + self.eps)
        ent = ent.mean()

        loss = -(self.infl_orig-pred)/(self.infl_orig-self.seed_size) + self.size_coeff*(size/self.del_edge_num-1)**2 + self.ent_coeff*ent

        return loss, pred, size, ent

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None


def RumorGuard_O(adj_list, seed_idx, prob, del_edge_num, model_name=None, gnn_latent_dim=[128,128,128,128,128,128], gpu_num='cpu', **kwargs):
    time_alg = -time.time()
    kwargs['del_edge_num'] = del_edge_num
    
    device = set_gpu(gpu_num)
    model = load_model(model_name, device, gnn_latent_dim=gnn_latent_dim)

    explainer = Explainer(
        model=model,
        algorithm=RumorGuard_O_class(**kwargs),
        explanation_type='phenomenon',
        edge_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
            return_type='raw'
        )
    )

    n = len(adj_list)
    is_seed = np.zeros(n, dtype=int)
    is_seed[seed_idx] = 1
    edge_index = [[],[]]
    edge_attr = []
    edge = []
    for u in range(n):
        for v,p in adj_list[u]:
            edge_index[0].append(u)
            edge_index[1].append(v)
            edge_attr.append([p])
            edge.append((u,v,p))
    edge_index = torch.tensor(edge_index)
    edge_attr = torch.tensor(edge_attr)
    seed = torch.from_numpy(np.expand_dims(is_seed,axis=-1)).float()
    prob = torch.from_numpy(np.expand_dims(prob,axis=-1)).float()
    data = Data(x=seed, edge_index=edge_index, edge_attr=edge_attr, y=prob)
    edge = np.array(edge)

    data = data.to(device)
    explanation = explainer(data.x, data.edge_index, edge_attr=data.edge_attr, target=torch.unsqueeze(torch.sum(data.y,dim=0),0))
    mask = explanation.edge_mask.cpu()

    _, indices = torch.sort(mask, descending=True)
    edge = edge[indices.numpy()]

    time_alg += time.time()
    return edge[:del_edge_num], time_alg

In [None]:
def RumorGuard_I(adj_list, seed_idx, prob, del_edge_num, model_name=None, gnn_latent_dim=[128,128,128,128,128,128], gpu_num='cpu', **kwargs):
    time_alg = -time.time()
    kwargs['del_edge_num'] = del_edge_num
    
    device = set_gpu(gpu_num)
    model = load_model(model_name, device, gnn_latent_dim=gnn_latent_dim)

    explainer = Explainer(
        model=model,
        algorithm=CaptumExplainer('Saliency'),
        explanation_type='phenomenon',
        edge_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
            return_type='raw',
        )
    )

    n = len(adj_list)
    is_seed = np.zeros(n, dtype=int)
    is_seed[seed_idx] = 1
    edge_index = [[],[]]
    edge_attr = []
    edge = []
    for u in range(n):
        for v,p in adj_list[u]:
            edge_index[0].append(u)
            edge_index[1].append(v)
            edge_attr.append([p])
            edge.append((u,v,p))
    edge_index = torch.tensor(edge_index)
    edge_attr = torch.tensor(edge_attr)
    seed = torch.from_numpy(np.expand_dims(is_seed,axis=-1)).float()
    prob = torch.from_numpy(np.expand_dims(prob,axis=-1)).float()
    data = Data(x=seed, edge_index=edge_index, edge_attr=edge_attr, y=prob)
    edge = np.array(edge)

    data = data.to(device)
    explanation = explainer(data.x, data.edge_index, edge_attr=data.edge_attr, target=torch.unsqueeze(torch.sum(data.y,dim=0),0))
    mask = explanation.edge_mask.cpu()

    if 'prod_p' in kwargs and kwargs['prod_p']==True:
        mask = mask * edge_attr.cpu().squeeze()

    _, indices = torch.sort(mask, descending=True)
    edge = edge[indices.numpy()]

    time_alg += time.time()
    return edge[:del_edge_num], time_alg