In [1]:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import SNAPDataset

torch.manual_seed(0)

# Load the dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

orig_transform = T.Compose(
    [
        T.ToDevice(device),
        T.RemoveIsolatedNodes(),
    ]
)

transform = T.Compose(
    [
        T.ToDevice(device),
        T.RemoveIsolatedNodes(),
        T.RandomLinkSplit(
            num_val=0.05,
            num_test=0.1,
            is_undirected=True,
            add_negative_train_samples=False,
        ),
    ]
)

dataset = SNAPDataset(
    root="./data/SNAPDataset", name="ego-facebook", transform=transform
)
train_data, val_data, test_data = dataset[0]

  from .autonotebook import tqdm as notebook_tqdm
Downloading https://snap.stanford.edu/data/facebook.tar.gz
Extracting data\SNAPDataset\ego-facebook\raw\facebook.tar.gz
Processing...
Done!


In [3]:
import warnings

import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, roc_auc_score
from torch import nn
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling

warnings.filterwarnings("ignore", category=UserWarning)


class SimpleNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def forward(self, x, edge_index, edge_label_index, data=None):
        z = self.encode(x, edge_index)
        out = self.decode(z, edge_label_index)
        return torch.hstack((-out, out)).T


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # TODO: look into SAGEConv, GATConv, GINConv, comparison between
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

        self.W1 = nn.Linear(out_channels * 2, out_channels)
        self.W2 = nn.Linear(out_channels, 1)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        z1 = torch.cat((z[edge_label_index[0]], z[edge_label_index[1]]), dim=1)
        out1 = self.W2(F.relu(self.W1(z1)).squeeze())

        z2 = torch.cat((z[edge_label_index[1]], z[edge_label_index[0]]), dim=1)
        out2 = self.W2(F.relu(self.W1(z2)).squeeze())

        return (out1 + out2) / 2

    def forward(self, x, edge_index, edge_label_index, data=None):
        z = self.encode(x, edge_index)
        out = self.decode(z, edge_label_index)
        return torch.hstack((-out, out)).T


def train(model, optimizer, criterion, data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=data.edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.edge_label_index.shape[1],
        method="sparse",
    )

    edge_label_index = torch.cat([data.edge_label_index, neg_edge_index], dim=-1)
    edge_label = torch.cat(
        [data.edge_label, data.edge_label.new_zeros(neg_edge_index.size(1))], dim=0
    )

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss


@torch.no_grad()
def test(model, data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    a, b = data.edge_label.cpu().numpy(), out.cpu().numpy()
    c = (out > 0.5).float().cpu().numpy()
    return roc_auc_score(a, b), accuracy_score(a, c)


def train_simple_model(epochs):
    simple_model = SimpleNet(dataset.num_features, 128, 32).to(device)
    simple_optimizer = torch.optim.Adam(params=simple_model.parameters(), lr=3e-3)
    simple_criterion = nn.BCEWithLogitsLoss()

    best_val_auc = final_test_auc = final_test_acc = 0
    best_model_dict = simple_model.state_dict()
    for epoch in range(1, epochs + 1):
        loss = train(simple_model, simple_optimizer, simple_criterion, train_data)
        val_auc, val_acc = test(simple_model, val_data)
        test_auc, test_acc = test(simple_model, test_data)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc
            final_test_acc = test_acc
            best_model_dict = simple_model.state_dict()
        if epoch % 50 == 0:
            print(
                f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}"
            )

    print(f"Final Test: {final_test_auc:.4f} {final_test_acc:.4f}")
    print()

    return simple_model, best_model_dict


def train_model(epochs):
    model = Net(dataset.num_features, 128, 32).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-3)
    criterion = torch.nn.BCEWithLogitsLoss()

    best_val_auc = final_test_auc = final_test_acc = 0
    best_model_dict = model.state_dict()
    for epoch in range(1, epochs + 1):
        loss = train(model, optimizer, criterion, train_data)
        val_auc, val_acc = test(model, val_data)
        test_auc, test_acc = test(model, test_data)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            final_test_auc = test_auc
            final_test_acc = test_acc
            best_model_dict = model.state_dict()
        if epoch % 50 == 0:
            print(
                f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f} {val_acc:.4f}, Test: {test_auc:.4f} {test_acc:.4f}"
            )

    print(f"Final Test: {final_test_auc:.4f} {final_test_acc:.4f}")
    print()

    return model, best_model_dict

epochs = 1000

simple_model, _ = train_simple_model(epochs)
model, _ = train_model(epochs)

Epoch: 050, Loss: 0.4899, Val: 0.9342 0.7887, Test: 0.9022 0.7491
Epoch: 100, Loss: 0.4601, Val: 0.9398 0.8099, Test: 0.9011 0.7439
Epoch: 150, Loss: 0.4547, Val: 0.9446 0.7852, Test: 0.9059 0.7439
Epoch: 200, Loss: 0.4549, Val: 0.9267 0.8063, Test: 0.8937 0.7404
Epoch: 250, Loss: 0.4351, Val: 0.9289 0.8169, Test: 0.8961 0.7544
Epoch: 300, Loss: 0.4271, Val: 0.9305 0.8099, Test: 0.8965 0.7544
Epoch: 350, Loss: 0.4293, Val: 0.9234 0.8063, Test: 0.8909 0.7491
Epoch: 400, Loss: 0.4257, Val: 0.9271 0.7923, Test: 0.8909 0.7404
Epoch: 450, Loss: 0.4275, Val: 0.9159 0.8063, Test: 0.8822 0.7491
Epoch: 500, Loss: 0.4137, Val: 0.9168 0.8169, Test: 0.8813 0.7614
Epoch: 550, Loss: 0.4199, Val: 0.9149 0.8028, Test: 0.8810 0.7491
Epoch: 600, Loss: 0.4223, Val: 0.9051 0.8169, Test: 0.8756 0.7509
Epoch: 650, Loss: 0.4117, Val: 0.9117 0.8099, Test: 0.8785 0.7649
Epoch: 700, Loss: 0.4144, Val: 0.9142 0.8099, Test: 0.8825 0.7667
Epoch: 750, Loss: 0.4168, Val: 0.9085 0.8063, Test: 0.8746 0.7667
Epoch: 800

In [4]:
class Explainer:
    def __init__(self, pred_model, x, edge_index):
        self.pred_model = pred_model
        self.x = x
        self.edge_index = edge_index

    def explain_edge(self, node_idx_1, node_idx_2):
        raise NotImplementedError

In [5]:
import numpy as np
import torch
from numpy import ndarray
from torch_geometric.utils import k_hop_subgraph


def edge_centered_subgraph(node_idx_1, node_idx_2, x, edge_index, num_hops):
    num_nodes = x.size(0)

    subset_1, _, _, edge_mask_1 = k_hop_subgraph(
        node_idx_1, num_hops, edge_index, num_nodes=num_nodes
    )
    subset_2, _, _, edge_mask_2 = k_hop_subgraph(
        node_idx_2, num_hops, edge_index, num_nodes=num_nodes
    )

    # Combines two node-centered subgraphs
    temp_node_idx = edge_index[0].new_full((num_nodes,), -1)  # full size
    edge_mask = edge_mask_1 | edge_mask_2
    edge_index = edge_index[:, edge_mask]  # filters out edges
    subset = torch.cat((subset_1, subset_2)).unique()
    temp_node_idx[subset] = torch.arange(subset.size(0), device=edge_index.device)
    edge_index = temp_node_idx[edge_index]  # maps edge_index to [0, n]
    x = x[subset]  # filters out nodes
    mapping = torch.tensor(
        [
            (subset == node_idx_1).nonzero().item(),
            (subset == node_idx_2).nonzero().item(),
        ]
    )

    return x, edge_index, mapping, subset, edge_mask


def get_neighbors(edge_index, node_idx_1: int, node_idx_2: int) -> ndarray:
    node_1_neighbors = set(
        edge_index[:, (edge_index[0] == node_idx_1)][1].cpu().numpy()
    )
    node_2_neighbors = set(
        edge_index[:, (edge_index[0] == node_idx_2)][1].cpu().numpy()
    )
    neighbors = np.array(list(node_1_neighbors.union(node_2_neighbors)))
    return neighbors


def mask_nodes(x, mask):
    new_x = x.clone()
    new_x[~mask] = 0
    return new_x

In [24]:
from math import sqrt

import torch
from torch_geometric.nn import GNNExplainer as PyG_GNNExplainer
from torch_geometric.nn.models.explainer import set_masks

EPS = 1e-15


class _GNNExplainer(PyG_GNNExplainer):
    coeffs = {
        "edge_size": 0.10,
        "edge_reduction": "sum",
        "edge_ent": 1,
    }

    def _initialize_masks(self, x, edge_index, sub_edge_mask=None):
        (N, F), E = x.size(), edge_index.size(1)
        self.node_feat_mask = torch.nn.Parameter(100 * torch.ones(1, F))

        std = torch.nn.init.calculate_gain("relu") * sqrt(2.0 / (2 * N))
        if sub_edge_mask is None:
            self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)
        else:
            E_1, mask = sub_edge_mask.sum(), 100 * torch.ones(E)
            mask[sub_edge_mask] = torch.randn(E_1) * std
            self.edge_mask = torch.nn.Parameter(mask)

    """
    def _loss(self, log_logits, prediction, node_idx=None):
        error_loss = -log_logits[prediction]

        m = self.edge_mask[self.sub_edge_mask].sigmoid()
        edge_reduce = getattr(torch, self.coeffs["edge_reduction"])
        edge_size_loss = edge_reduce(m)
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        edge_ent_loss = ent.mean()

        loss = (
            error_loss
            + self.coeffs["edge_size"] * edge_size_loss
            + self.coeffs["edge_ent"] * edge_ent_loss
        )
        
        print(
            round(error_loss.item(), 4), "  \t", 
            round(edge_size_loss.item(), 4), "  \t", 
            round(edge_ent_loss.item(), 4), "  \t",
            round(loss.item(), 4), "  \t"
        )

        return loss
    """
    
    def _loss(self, log_logits, prediction, node_idx = None):
        error_loss = -log_logits[prediction].clip(-6, 6)
                
        m = self.edge_mask[self.sub_edge_mask].sigmoid()
        edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
        edge_size_loss = edge_reduce(m)
        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
        edge_ent_loss = ent.mean()

        loss = -error_loss * (1 - torch.mean(m)) - 1 * edge_ent_loss
        
        print(
            round(error_loss.item(), 4), "  \t", 
            round(edge_size_loss.item(), 4), "  \t", 
            round(edge_ent_loss.item(), 4), "  \t",
            round(loss.item(), 4), "  \t"
        )
        
        return loss

    def explain_edge(self, node_idx_1, node_idx_2, x, edge_index):
        self.model.eval()
        self._clear_masks()

        num_edges = edge_index.size(1)

        # Only operate on a k-hop subgraph around `node_idx_1` and `node_idx_2.
        (
            x,
            edge_index,
            mapping,
            _,
            hard_edge_mask,
        ) = edge_centered_subgraph(node_idx_1, node_idx_2, x, edge_index, self.num_hops)

        # Only optimizes the edges from neighbors to node_1/node_2, other direction not needed for prediction
        self.sub_edge_mask = (edge_index[1] == mapping[0]) | (
            edge_index[1] == mapping[1]
        )
        edge_label_index = mapping.unsqueeze(1)

        # Get the initial prediction
        prediction = self.get_initial_prediction(
            x, edge_index, edge_label_index=edge_label_index
        )

        self._initialize_masks(x, edge_index, self.sub_edge_mask)
        self.to(x.device)

        set_masks(self.model, self.edge_mask, edge_index, apply_sigmoid=True)
        optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)

        for epoch in range(1, self.epochs + 1):
            optimizer.zero_grad()
            out = self.model(
                x=x, edge_index=edge_index, edge_label_index=edge_label_index
            )
            loss = self.get_loss(out, prediction, mapping).mean()
            loss.backward()
            optimizer.step()

        edge_mask = self.edge_mask.new_zeros(num_edges)
        edge_mask[hard_edge_mask] = self.edge_mask.detach().sigmoid()

        self._clear_masks()

        return edge_mask


class GNNExplainer(Explainer):
    def __init__(self, pred_model, x, edge_index, epochs=200, lr=1e-2):
        super().__init__(pred_model, x, edge_index)
        self.explainer = _GNNExplainer(pred_model, epochs=epochs, lr=lr)

    def explain_edge(self, node_idx_1, node_idx_2):
        edge_mask = self.explainer.explain_edge(
            node_idx_1, node_idx_2, self.x, self.edge_index
        )

        output = {}
        edge_filter = (self.edge_index[1] == node_idx_1) | (
            self.edge_index[1] == node_idx_2
        )
        temp_edge_index = self.edge_index[0, edge_filter].cpu().numpy()
        temp_edge_mask = edge_mask[edge_filter].cpu().numpy()
        for node_idx, weight in zip(temp_edge_index, temp_edge_mask):
            output[node_idx] = (output.get(node_idx, weight) + weight) / 2

        return output

In [25]:
def sample_gnnexplainer(model, x, edge_index, node_idx_1, node_idx_2):
    # GNNExplainer, 200 queries per explanation
    gnnexplainer = GNNExplainer(model, x, edge_index, epochs=200, lr=0.1)
    output = gnnexplainer.explain_edge(node_idx_1, node_idx_2)

    print("GNNExplainer Output")
    for node_idx, weight in sorted(output.items(), key=lambda x: -x[1]):
        print(node_idx, "\t", round(weight, 4))
    print()

    return output

x, edge_index = test_data.x, test_data.edge_index 
node_idx_1, node_idx_2 = 24, 187

print("Nodes:", node_idx_1, node_idx_2)
print(
    "Neighorhood Size:", get_neighbors(edge_index, node_idx_1, node_idx_2).shape[0]
)
print()


sample_gnnexplainer(model, x, edge_index, node_idx_1, node_idx_2);

Nodes: 24 187
Neighorhood Size: 81

0.5718   	 51.7357   	 0.6925   	 -0.9798   	
0.8111   	 49.2917   	 0.6911   	 -1.1177   	
1.0504   	 46.8579   	 0.6871   	 -1.2643   	
1.2891   	 44.4966   	 0.681   	 -1.4185   	
1.5253   	 42.1776   	 0.6725   	 -1.5792   	
1.7579   	 39.9029   	 0.6619   	 -1.7453   	
1.9857   	 37.6796   	 0.6492   	 -1.9154   	
2.2075   	 35.5158   	 0.6345   	 -2.0882   	
2.4218   	 33.4199   	 0.6181   	 -2.2617   	
2.6284   	 31.3996   	 0.6002   	 -2.4351   	
2.8251   	 29.4617   	 0.581   	 -2.6058   	
2.978   	 27.6165   	 0.5608   	 -2.748   	
3.1015   	 25.8753   	 0.5402   	 -2.87   	
3.1905   	 24.2541   	 0.5197   	 -2.9661   	
3.2551   	 22.7587   	 0.4997   	 -3.0425   	
3.3184   	 21.398   	 0.4807   	 -3.1163   	
3.3799   	 20.1619   	 0.4626   	 -3.1873   	
3.4391   	 19.0401   	 0.4455   	 -3.255   	
3.4921   	 18.0228   	 0.4293   	 -3.3163   	
3.539   	 17.1133   	 0.4144   	 -3.371   	
3.5806   	 16.3086   	 0.4008   	 -3.4199   	
3.6199  

4.8541   	 14.1997   	 0.1571   	 -4.3484   	
4.8551   	 14.2145   	 0.1569   	 -4.3485   	
4.8553   	 14.2106   	 0.1568   	 -4.3486   	
4.8553   	 14.2025   	 0.1566   	 -4.3488   	
4.8552   	 14.1907   	 0.1563   	 -4.3491   	
4.856   	 14.1963   	 0.1562   	 -4.3493   	
4.8563   	 14.1969   	 0.156   	 -4.3494   	
4.8575   	 14.2137   	 0.1559   	 -4.3495   	
4.8577   	 14.2112   	 0.1557   	 -4.3496   	
4.8582   	 14.2123   	 0.1555   	 -4.3498   	
4.8587   	 14.2165   	 0.1554   	 -4.3499   	
4.8584   	 14.2025   	 0.1552   	 -4.3501   	
4.8571   	 14.1722   	 0.1549   	 -4.3501   	
4.8568   	 14.1613   	 0.1547   	 -4.3502   	
4.8576   	 14.1682   	 0.1546   	 -4.3505   	
4.858   	 14.1698   	 0.1545   	 -4.3506   	
4.8582   	 14.1667   	 0.1543   	 -4.3507   	
4.8593   	 14.1805   	 0.1542   	 -4.3509   	
4.8601   	 14.1883   	 0.154   	 -4.3511   	
4.8606   	 14.1905   	 0.1539   	 -4.3512   	
4.8618   	 14.2091   	 0.1538   	 -4.3513   	
GNNExplainer Output
346 	 0.9928
341 	