In [None]:
from math import sqrt
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.optim as optim

from torch_geometric.explain import ExplainerConfig, Explanation, ModelConfig
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel


class PGExplainer_onlyconc(ExplainerAlgorithm):
    r"""The GNN-Explainer model from the `"GNNExplainer: Generating
    Explanations for Graph Neural Networks"
    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
    structures and node features that play a crucial role in the predictions
    made by a GNN.

    .. note::

        For an example of using :class:`GNNExplainer`, see
        `examples/explain/gnn_explainer.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,
        `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/
        pyg-team/pytorch_geometric/blob/master/examples/
        explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/
        gnn_explainer_link_pred.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.

    Args:
        epochs (int, optional): The number of epochs to train.
            (default: :obj:`100`)
        lr (float, optional): The learning rate to apply.
            (default: :obj:`0.01`)
        **kwargs (optional): Additional hyper-parameters to override default
            settings in
            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
    """

    coeffs = {
        'edge_size': 0.005,
        'edge_reduction': 'sum',
        'node_feat_size': 1.0,
        'temp': [5.0, 2.0],
        'bias': 0.0,
        'node_feat_reduction': 'mean',
        'edge_ent': 1.0,
        'node_feat_ent': 0.1,
        'EPS': 1e-15,
    }

    def __init__(self, epochs = 100, lr = 0.01, del_edge_num=10, size_coeff=1, ent_coeff=1, lr_gamma=1, **kwargs):
        super().__init__()
        self.epochs = epochs
        self.lr = lr
        self.coeffs.update(kwargs)
        self.del_edge_num = del_edge_num

        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None
        self.edge_weight = None

        self.size_coeff = size_coeff
        self.ent_coeff = ent_coeff
        self.lr_gamma = lr_gamma

    def forward(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ) -> Explanation:
        if isinstance(x, dict) or isinstance(edge_index, dict):
            raise ValueError(f"Heterogeneous graphs not yet supported in "
                             f"'{self.__class__.__name__}'")

        self._train(model, x, edge_index, target=target, index=index, **kwargs)

        
        
        self.edge_mask = self.edge_weight  

        node_mask = self._post_process_mask(
            self.node_mask,
            self.hard_node_mask,
            apply_sigmoid=True,
        )
        edge_mask = self._post_process_mask(
            self.edge_mask,
            self.hard_edge_mask,
            apply_sigmoid=True,
        )

        self._clean_model(model)

        return Explanation(node_mask=node_mask, edge_mask=-edge_mask)
        

    def supports(self) -> bool:
        return True

    def _train(
        self,
        model: torch.nn.Module,
        x: Tensor,
        edge_index: Tensor,
        *,
        target: Tensor,
        index: Optional[Union[int, Tensor]] = None,
        **kwargs,
    ):
        with torch.no_grad(): self.infl_orig = sum(model(x, edge_index, **kwargs)).item()
        self.seed_size = sum(x).item()
        
        self._initialize_masks(x, edge_index)

        parameters = []
        parameters.append(self.edge_weight)

        optimizer = torch.optim.Adam(parameters, lr=self.lr)
        scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda epoch: self.lr_gamma)

        for i in range(self.epochs):
            clear_masks(model)
            temperature = self._get_temperature(i)
            self.edge_mask = self._concrete_sample(self.edge_weight, temperature)
            set_masks(model, self.edge_mask, edge_index, apply_sigmoid=True)
            
            optimizer.zero_grad()

            y_hat, y = model(x, edge_index, **kwargs), target

            if index is not None:
                y_hat, y = y_hat[index], y[index]

            loss, pred, size, ent = self._loss(y_hat, y)
            if i%100==0: print(f'{i} {pred.item():.2f} {size.item():.2f} {ent.item():.4f} {loss.item():.2f}')

            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, 0.01)
            optimizer.step()

            
            
            
            if i == 0 and self.node_mask is not None:
                if self.node_mask.grad is None:
                    raise ValueError("Could not compute gradients for node "
                                     "features. Please make sure that node "
                                     "features are used inside the model or "
                                     "disable it via `node_mask_type=None`.")
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                if self.edge_weight.grad is None:
                    raise ValueError("Could not compute gradients for edges. "
                                     "Please make sure that edges are used "
                                     "via message passing inside the model or "
                                     "disable it via `edge_mask_type=None`.")
                self.hard_edge_mask = self.edge_weight.grad != 0.0

            

    def _initialize_masks(self, x: Tensor, edge_index: Tensor):
        node_mask_type = self.explainer_config.node_mask_type
        edge_mask_type = self.explainer_config.edge_mask_type

        device = x.device
        (N, F), E = x.size(), edge_index.size(1)
        self.m = E
        self.n = N

        std = 0.1
        if node_mask_type is None:
            self.node_mask = None
        elif node_mask_type == MaskType.object:
            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
        elif node_mask_type == MaskType.attributes:
            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
        elif node_mask_type == MaskType.common_attributes:
            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
        else:
            assert False

        if edge_mask_type is None:
            self.edge_mask = None
        elif edge_mask_type == MaskType.object:
            self.edge_weight = torch.nn.Parameter(torch.randn(E, requires_grad=True, device=device))
            temperature = self._get_temperature(0)
            self.edge_mask = self._concrete_sample(self.edge_weight, temperature)

            
            
        else:
            assert False

    def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
        pred = y_hat[0]

        mask = self.edge_mask[self.hard_edge_mask].sigmoid()
        size = self.m-mask.sum()

        ent = -mask * torch.log(mask + self.coeffs['EPS']) - (1 - mask) * torch.log(1 - mask + self.coeffs['EPS'])
        ent = ent.mean()

        loss = -(self.infl_orig-pred)/(self.infl_orig-self.seed_size) + self.size_coeff*(size/self.del_edge_num-1)**2 + self.ent_coeff*ent
        
        
        

        return loss, pred, size, ent

    def _clean_model(self, model):
        clear_masks(model)
        self.node_mask = self.hard_node_mask = None
        self.edge_mask = self.hard_edge_mask = None

    def _get_temperature(self, epoch: int) -> float:
        temp = self.coeffs['temp']
        temp = temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)
        return temp

    def _concrete_sample(self, logits: Tensor,
                         temperature: float = 1.0) -> Tensor:
        bias = self.coeffs['bias']
        eps = (1 - 2 * bias) * torch.rand_like(logits) + bias
        return (eps.log() - (1 - eps).log() + logits) / temperature