In [1]:
import torch
from torch_geometric.explain import GNNExplainer
#from torch_geometric.explain import CFExplainer
from torch_geometric.nn import GCNConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
# Import other necessary modules as needed


In [2]:
# !pip install jupyterthemes
# !jt -t chesterish
#!jt -r


In [3]:
%matplotlib inline

from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.nn import GCNConv, Set2Set
from torch_geometric.explain import GNNExplainer
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

In [4]:
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.nn import GCNConv

dataset = 'Cora'
path = os.path.join(os.getcwd(), 'data', 'Planetoid')
dataset = Planetoid(path, dataset)
data = dataset[0]


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    #print(loss)
    loss.backward()
    optimizer.step()

In [5]:
torch.argmax(model(data.x, data.edge_index)[83])
data.y[83]

tensor(2)

In [6]:
torch.tensor(3)*torch.tensor(3)

tensor(9)

In [7]:
for name, param in model.named_parameters():
    print(f"Parameter '{name}': requires_grad={param.requires_grad}")

Parameter 'conv1.bias': requires_grad=True
Parameter 'conv1.lin.weight': requires_grad=True
Parameter 'conv2.bias': requires_grad=True
Parameter 'conv2.lin.weight': requires_grad=True


In [11]:
from math import sqrt
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F


from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
from torch_geometric.utils import to_dense_adj
#from torch_geometric.utils import dense_adjacency


class CFExplainer(ExplainerAlgorithm):
    r"""The CF-Explainer model from the `"CF-GNNExplainer: Counterfactual Explanations for Graph Neural
Networks"
    <https://arxiv.org/abs/2102.03322>`_ paper for generating CF explanations for GNNs: 
    the minimal perturbation to the input (graph) data such that the prediction changes.

    .. note::

        For an example of using :class:`GNNExplainer`, see
        `examples/explain/gnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,
        `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/
        pyg-team/pytorch_geometric/blob/master/examples/
        explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/
        gnn_explainer_link_pred.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.

    Args:
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to override default
            settings in
            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
    """

    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'beta' : .001,
        'node_feat_size': 1.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs: int = 100, lr: float = 0.01, cf_optimizer = "SGD", n_momentum = 0, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.cf_optimizer = cf_optimizer
        self.n_momentum = n_momentum
        self.coeffs.update(kwargs)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None
        self.best_cf_example = None
        self.best_loss = np.inf
        
    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        index: int = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, index=index, **kwargs)

        # node_mask = self._post_process_mask(
        #     self.best_cf_example[0],
        #     self.hard_node_mask,
        #     apply_sigmoid=True,
        # )
        node_mask = self._post_process_mask(
            self.node_mask,
            self.hard_node_mask,
            apply_sigmoid=True,
        )
        edge_mask = self._post_process_mask(
            self.edge_mask,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=edge_mask)

    def supports(self) -> bool:
        return True
    
    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: int = None,
        **kwargs,
    ):
        self._initialize_masks(x, edge_index)
        parameters = []
        #if self.node_mask is not None:
        #    parameters.append(self.node_mask)
        if self.edge_mask is not None:
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters.append(self.edge_mask)
    
        if self.cf_optimizer == "SGD" and self.n_momentum == 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr)
        elif self.cf_optimizer == "SGD" and self.n_momentum != 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr, nesterov=True, momentum=n_momentum)
        elif self.cf_optimizer == "Adadelta":
            optimizer = torch.optim.Adadelta(parameters, lr=self.lr)
        else:
            raise Exception("Optimizer is not currently supported.")
        
        original_prediction  = model(x, edge_index, **kwargs)
        print('org pred', original_prediction[index])
        print('target', target[index])
        for i in range(self.epochs):
            optimizer.zero_grad()
            h = x # if self.node_mask is None else x * self.node_mask.sigmoid()
            #discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask)>=0.5, 1, 0)
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            y_hat, y = model(h, edge_index, **kwargs), target
            
            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss = self._loss(y_hat, y)
            
            loss.backward()
            optimizer.step()

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_mask.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_mask.grad != 0.0
                print("self.hard_edge_mask", torch.sum(self.hard_edge_mask))

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = None
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)

        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.ones(N, 1, device=device))
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.ones(N, F, device=device))
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.ones(1, F, device=device))
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            self.edge_mask = Parameter(torch.ones(E, device=device))
        else:
            assert False



    def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
        print('y_hat', y_hat)
#        # Calculate the sigmoid of self.edge_mask
#         edge_mask_sigmoid = torch.sigmoid(self.edge_mask)

#         # Calculate the absolute difference between 1.0 and edge_mask_sigmoid
#         edge_mask_penalty = torch.abs(1.0 - edge_mask_sigmoid)

#         # Threshold edge_mask_penalty elements at 0.5 using element-wise multiplication
#         threshold = 0.5
#         edge_mask_penalty = edge_mask_penalty * (edge_mask_sigmoid >= threshold).float()

#         # Calculate the negative log-likelihood loss
#         log_probs = torch.log(edge_mask_sigmoid)
#         loss_nll = -log_probs.mean()

#         # Additional term: penalize similarity between edge_mask_sigmoid and 1.0
#         separation_term = torch.abs(edge_mask_sigmoid - 1.0)

#         # Combine all loss terms
#         loss = loss_nll + separation_term.mean() + edge_mask_penalty.mean()

        if self.model_config.mode == ModelMode.binary_classification:
            loss = self._loss_binary_classification(y_hat, y)
        elif self.model_config.mode == ModelMode.multiclass_classification:
            loss = self._loss_multiclass_classification(y_hat, y)
        elif self.model_config.mode == ModelMode.regression:
            loss = self._loss_regression(y_hat, y)
        else:
            assert False
        if self.hard_edge_mask is not None:
            assert self.edge_mask is not None
            m = self.edge_mask[self.hard_edge_mask].sigmoid()
            edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
            loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
            ent = -m * torch.log(m + self.coeffs['EPS']) - (
                1 - m) * torch.log(1 - m + self.coeffs['EPS'])
            loss = loss + self.coeffs['edge_ent'] * ent.mean()

#         if self.hard_node_mask is not None:
#             assert self.node_mask is not None
#             m = self.node_mask[self.hard_node_mask].sigmoid()
#             node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
#             loss = loss + self.coeffs['node_feat_size'] * node_reduce(m)
#             ent = -m * torch.log(m + self.coeffs['EPS']) - (
#                 1 - m) * torch.log(1 - m + self.coeffs['EPS'])
#             loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
        print(loss)
        return loss
#     def _loss(self, y_hat: Tensor, y: Tensor, edge_index, index) -> Tensor:
#         if self.model_config.mode == ModelMode.binary_classification:
#             loss = -1 * self._loss_binary_classification(y_hat, y)
#         elif self.model_config.mode == ModelMode.multiclass_classification:
#             loss = -1 * self._loss_multiclass_classification(y_hat, y)
#         elif self.model_config.mode == ModelMode.regression:
#             loss = -1 * self._loss_regression(y_hat, y)
#         else:
#             assert False
#         #         y_hat_discrete = torch.argmax(y_hat, dim=-1)  # Compute argmax along the class dimension
# #         y_discrete = torch.argmax(y, dim=-1)  # Compute argmax along the class dimension
        
# #         print("y_hat.requires_grad:", y_hat.requires_grad)
# #         print("y.requires_grad:", y.requires_grad)
# #         print("edge_index.requires_grad:", edge_index.requires_grad)
# #         #pred_same_indicator = torch.where(y_hat_discrete == y_discrete, torch.tensor(-1.0), torch.tensor(-1.0))
# #         loss_pred = - F.nll_loss(F.log_softmax(y_hat, dim=-1), y_discrete.long(), reduction='none')
        
# #         print("loss_pred.requires_grad:", loss_pred.requires_grad)
# #         print("self.edge_mask.requires_grad:", self.edge_mask.requires_grad)
# #         #loss_pred = F.nll_loss(F.log_softmax(y_hat, dim=-1), y_discrete.long(), reduction)
# # #         if self.model_config.mode == ModelMode.binary_classification:
# # #             loss_pred = self._loss_binary_classification(y_hat, y.long())
# # #         elif self.model_config.mode == ModelMode.multiclass_classification:
# # #             loss_pred = self._loss_multiclass_classification(y_hat, y.long())
# # # #         elif self.model_config.mode == ModelMode.regression:
# # # #             loss_pred = - self._loss_regression(y_hat, y)
# # #         else:
# # #             assert False
# #         print("self.edge mask:", self.edge_mask)
# #         discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask) > 0.5, torch.tensor(1.0), torch.tensor(0.0))
# #         print("discrete_edge_mask.requires_grad:", discrete_edge_mask.requires_grad)
# #         discrete_edge_mask.requires_grad = False
# #         print("discrete_edge_mask.requires_grad:", discrete_edge_mask.requires_grad)
# #         print("discrete edge mask", discrete_edge_mask)
# #         new_edge_index = edge_index * discrete_edge_mask

# #         loss_graph_dist = torch.sum(torch.abs(new_edge_index - edge_index)) / 2
# #         #print("loss_graph_dist", loss_graph_dist)
# #         loss_total = loss_pred
# #         #loss_total = pred_same_indicator * loss_pred + self.coeffs['beta'] * loss_graph_dist
# #         print("loss_total", loss_total)
#         loss_total = loss
#         print(loss_total)
#         return loss_total

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None


In [12]:


explainer = Explainer(
    model=model,
    algorithm=CFExplainer(epochs=20),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)
node_index =83
print('True value', data.y[node_index])
explanation = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')
# path = 'feature_importance.png'
# explanation.visualize_feature_importance(path, top_k=10)
# print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")

for idx, i in enumerate(explanation['edge_mask']):
    if i > 0:
        print ("Weight", idx, i)

True value tensor(2)
org pred tensor([-7.0778, -5.5675, -0.0099, -6.1459, -7.2232, -6.2567, -7.8279],
       grad_fn=<SelectBackward0>)
target tensor(2)
y_hat tensor([-7.0778, -5.5675, -0.0099, -6.1459, -7.2232, -6.2567, -7.8279],
       grad_fn=<SelectBackward0>)
tensor(0.0099, grad_fn=<NllLossBackward0>)
self.hard_edge_mask tensor(8)
y_hat tensor([-7.0779, -5.5675, -0.0099, -6.1460, -7.2233, -6.2567, -7.8280],
       grad_fn=<SelectBackward0>)
tensor(0.6213, grad_fn=<AddBackward0>)
y_hat tensor([-7.0785, -5.5680, -0.0099, -6.1465, -7.2239, -6.2573, -7.8287],
       grad_fn=<SelectBackward0>)
tensor(0.6213, grad_fn=<AddBackward0>)
y_hat tensor([-7.0791, -5.5684, -0.0099, -6.1470, -7.2245, -6.2578, -7.8294],
       grad_fn=<SelectBackward0>)
tensor(0.6212, grad_fn=<AddBackward0>)
y_hat tensor([-7.0797, -5.5688, -0.0099, -6.1475, -7.2251, -6.2583, -7.8301],
       grad_fn=<SelectBackward0>)
tensor(0.6212, grad_fn=<AddBackward0>)
y_hat tensor([-7.0803, -5.5693, -0.0099, -6.1480, -7.2257,

In [42]:
for i in explanation['edge_mask']:
    if i > 0:
        print(i)

tensor(0.5606)
tensor(0.5632)
tensor(0.5487)
tensor(0.5514)
tensor(0.5621)
tensor(0.5524)
tensor(0.5479)
tensor(0.5542)


In [38]:
from math import sqrt
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn.parameter import Parameter

from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel


class GNNExplainer(ExplainerAlgorithm):
    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'node_feat_size': 1.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.coeffs.update(kwargs)

        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None

    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, target=target, index=index, **kwargs)

        node_mask = self._post_process_mask(
            self.node_mask,
            self.hard_node_mask,
            apply_sigmoid=True,
        )
        edge_mask = self._post_process_mask(
            self.edge_mask,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=edge_mask)

    def supports(self) -> bool:
        return True

    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ):
        self._initialize_masks(x, edge_index)

        parameters = []
        if self.node_mask is not None:
            parameters.append(self.node_mask)
        if self.edge_mask is not None:
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters.append(self.edge_mask)

        optimizer = torch.optim.Adam(parameters, lr=self.lr)

        for i in range(self.epochs):
            optimizer.zero_grad()

            h = x if self.node_mask is None else x * self.node_mask.sigmoid()
            y_hat, y = model(h, edge_index, **kwargs), target

            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss = self._loss(y_hat, y)

            loss.backward()
            optimizer.step()

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_mask.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_mask.grad != 0.0

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
            self.edge_mask = Parameter(torch.randn(E, device=device) * std)
        else:
            assert False

    def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
        if self.model_config.mode == ModelMode.binary_classification:
            loss = self._loss_binary_classification(y_hat, y)
        elif self.model_config.mode == ModelMode.multiclass_classification:
            loss = self._loss_multiclass_classification(y_hat, y)
        elif self.model_config.mode == ModelMode.regression:
            loss = self._loss_regression(y_hat, y)
        else:
            assert False

        if self.hard_edge_mask is not None:
            assert self.edge_mask is not None
            m = self.edge_mask[self.hard_edge_mask].sigmoid()
            edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
            loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
            ent = -m * torch.log(m + self.coeffs['EPS']) - (
                1 - m) * torch.log(1 - m + self.coeffs['EPS'])
            loss = loss + self.coeffs['edge_ent'] * ent.mean()

        if self.hard_node_mask is not None:
            assert self.node_mask is not None
            m = self.node_mask[self.hard_node_mask].sigmoid()
            node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
            loss = loss + self.coeffs['node_feat_size'] * node_reduce(m)
            ent = -m * torch.log(m + self.coeffs['EPS']) - (
                1 - m) * torch.log(1 - m + self.coeffs['EPS'])
            loss = loss + self.coeffs['node_feat_ent'] * ent.mean()
        print(loss)
        return loss

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None

In [34]:
from math import sqrt
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F


from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
from torch_geometric.utils import to_dense_adj
#from torch_geometric.utils import dense_adjacency


class CFExplainer(ExplainerAlgorithm):
    r"""The CF-Explainer model from the `"CF-GNNExplainer: Counterfactual Explanations for Graph Neural
Networks"
    <https://arxiv.org/abs/2102.03322>`_ paper for generating CF explanations for GNNs: 
    the minimal perturbation to the input (graph) data such that the prediction changes.

    .. note::

        For an example of using :class:`GNNExplainer`, see
        `examples/explain/gnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,
        `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/
        pyg-team/pytorch_geometric/blob/master/examples/
        explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/
        gnn_explainer_link_pred.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.

    Args:
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to override default
            settings in
            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
    """

    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'beta' : .001,
        'node_feat_size': 1.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs: int = 100, lr: float = 0.01, cf_optimizer = "SGD", n_momentum = 0, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.cf_optimizer = cf_optimizer
        self.n_momentum = n_momentum
        self.coeffs.update(kwargs)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None
        self.best_cf_example = None
        self.best_loss = np.inf
        
    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, target=target, index=index, **kwargs)

        # node_mask = self._post_process_mask(
        #     self.best_cf_example[0],
        #     self.hard_node_mask,
        #     apply_sigmoid=True,
        # )
        edge_mask = self._post_process_mask(
            self.best_cf_example,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=edge_mask)

    def supports(self) -> bool:
        return True
    
    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ):
        self._initialize_masks(x, edge_index)

        parameters = []
        if self.node_mask is not None:
            parameters.append(self.node_mask)
        if self.edge_mask is not None:
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            parameters.append(self.edge_mask)
    
        if self.cf_optimizer == "SGD" and self.n_momentum == 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr)
        elif self.cf_optimizer == "SGD" and self.n_momentum != 0.0:
            optimizer = torch.optim.SGD(parameters, lr=self.lr, nesterov=True, momentum=n_momentum)
        elif self.cf_optimizer == "Adadelta":
            optimizer = torch.optim.Adadelta(parameters, lr=self.lr)
        else:
            raise Exception("Optimizer is not currently supported.")
        
        num_cf_examples = 0
        original_prediction  = model(x, edge_index, **kwargs)
        for i in range(self.epochs):
            optimizer.zero_grad()
            h = x if self.node_mask is None else x * self.node_mask.sigmoid()
            discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask)>=0.5, 1, 0)
            set_masks(model, discrete_edge_mask, edge_index, apply_sigmoid=False)
            y_hat, y = model(h, edge_index, **kwargs), original_prediction
            y_hat_discrete, y_discrete = y_hat.argmax(dim=1), y.argmax(dim=1)
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)

            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss = self._loss(y_hat, y, edge_index)

            if loss.item() < best_loss:
                best_loss = loss
                self.best_cf_example = self.edge_mask
            
            loss.backward()
            optimizer.step()

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_mask.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_mask.grad != 0.0

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)

        std = 0.1
        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
            self.edge_mask = Parameter(torch.randn(E, device=device) * std)
        else:
            assert False

    # def _initialize_masks(self, x: Tensor, edge_index: Tensor):
    #     node_mask_type = self.explainer_config.node_mask_type
    #     edge_mask_type = self.explainer_config.edge_mask_type

    #     device = x.device
    #     (N, F), E = x.size(), edge_index.size(1)

    #     if node_mask_type is None:
    #         self.node_mask = None
    #     elif node_mask_type == MaskType.object:
    #         self.node_mask = Parameter(torch.ones(N, 1, device=device))
    #     elif node_mask_type == MaskType.attributes:
    #         self.node_mask = Parameter(torch.ones(N, F, device=device))
    #     elif node_mask_type == MaskType.common_attributes:
    #         self.node_mask = Parameter(torch.ones(1, F, device=device))
    #     else:
    #         assert False


    #     if edge_mask_type is None:
    #         self.edge_mask = None
    #     elif edge_mask_type == MaskType.object:
    #         self.edge_mask = Parameter(torch.ones(E, device=device))
    #     else:
    #         assert False


    def _loss(self, y_hat: Tensor, y: Tensor, edge_index) -> Tensor:
        y_hat_discrete = y_hat.argmax(dim=1)
        y_discrete = y.argmax(dim=1)

        pred_same = (y_hat_discrete == y_discrete).float()
        
        # if self.model_config.mode == ModelMode.binary_classification:
        #     loss = self._loss_binary_classification(y_hat, y)
        # elif self.model_config.mode == ModelMode.multiclass_classification:
        #     loss = self._loss_multiclass_classification(y_hat, y)
        # elif self.model_config.mode == ModelMode.regression:
        #     loss = self._loss_regression(y_hat, y)
        # else:
        #     assert False
        # Want negative in front to maximize loss instead of minimizing it to find CFs
        discrete_edge_mask = torch.where(torch.sigmoid(self.edge_mask)>=0, 1, 0)

        loss_pred = - F.nll_loss(y_hat, y_discrete)
        adj = dense_adjacency(edge_index, edge_attr=None, num_nodes=None)
        discrete_adj = torch.where(torch.sigmoid(self.edge_mask) >= 0, 1, 0)
        loss_graph_dist = torch.sum(torch.abs(adj - discrete_adj)) / 2

        #loss_graph_dist = sum(sum(abs(to_dense_adj(edge_index) - to_dense_adj(discrete_edge_mask)))) / 2      # Number of edges changed (symmetrical)

		# Zero-out loss_pred with pred_same if prediction flips
        loss_total = pred_same * loss_pred + self.coeff['beta'] * loss_graph_dist


        # if self.hard_edge_mask is not None:
        #     assert self.edge_mask is not None
        #     m = self.edge_mask[self.hard_edge_mask].sigmoid()
        #     edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
        #     loss = loss + self.coeffs['edge_size'] * edge_reduce(m)
        #     ent = -m * torch.log(m + self.coeffs['EPS']) - (
        #         1 - m) * torch.log(1 - m + self.coeffs['EPS'])
        #     loss = loss + self.coeffs['edge_ent'] * ent.mean()

        # if self.hard_node_mask is not None:
        #     assert self.node_mask is not None
        #     m = self.node_mask[self.hard_node_mask].sigmoid()
        #     node_reduce = getattr(torch, self.coeffs['node_feat_reduction'])
        #     loss16 = loss + self.coeffs['node_feat_size'] * node_reduce(m)
        #     ent = -m * torch.log(m + self.coeffs['EPS']) - (
        #         1 - m) * torch.log(1 - m + self.coeffs['EPS'])
        #     loss = loss + self.coeffs['node_feat_ent'] * ent.mean()

        return loss_total

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None
