In [1]:
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch_geometric.transforms as T
from torch_geometric.datasets import BAShapes
from torch_geometric.nn import GCN, GNNExplainer
from torch_geometric.utils import k_hop_subgraph

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = BAShapes(transform=T.GCNNorm())
data = dataset[0]

idx = torch.arange(data.num_nodes)
train_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

model = GCN(
    data.num_node_features, 
    hidden_channels=20, 
    num_layers=3,
    out_channels=dataset.num_classes, 
    normalize=False
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)

In [3]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, edge_weight=data.edge_weight)
    loss = F.cross_entropy(out[train_idx], data.y[train_idx])
    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index, edge_weight=data.edge_weight).argmax(dim=-1)

    train_correct = int((pred[train_idx] == data.y[train_idx]).sum())
    train_acc = train_correct / train_idx.size(0)

    test_correct = int((pred[test_idx] == data.y[test_idx]).sum())
    test_acc = test_correct / test_idx.size(0)

    return train_acc, test_acc


for epoch in range(1, 2001):
    loss = train()
    if epoch % 200 == 0:
        train_acc, test_acc = test()
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 0200, Loss: 1.2757, Train: 0.4286, Test: 0.4286
Epoch: 0400, Loss: 1.2551, Train: 0.4286, Test: 0.4286
Epoch: 0600, Loss: 1.2107, Train: 0.4286, Test: 0.4286
Epoch: 0800, Loss: 1.1368, Train: 0.4286, Test: 0.4286
Epoch: 1000, Loss: 1.0079, Train: 0.6250, Test: 0.6214
Epoch: 1200, Loss: 0.8570, Train: 0.7982, Test: 0.8286
Epoch: 1400, Loss: 0.7577, Train: 0.8750, Test: 0.8786
Epoch: 1600, Loss: 0.6964, Train: 0.8768, Test: 0.8786
Epoch: 1800, Loss: 0.6390, Train: 0.8839, Test: 0.8786
Epoch: 2000, Loss: 0.5966, Train: 0.8857, Test: 0.8786


In [4]:
model.eval()
targets, preds = [], []
expl = GNNExplainer(model, epochs=300, return_type='raw', log=False)

# Explanation ROC AUC over all test nodes:
loop_mask = data.edge_index[0] != data.edge_index[1]
for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()):
    _, expl_edge_mask = expl.explain_node(
        node_idx, data.x, data.edge_index, edge_weight=data.edge_weight
    )
    subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index)
    expl_edge_mask = expl_edge_mask[loop_mask]
    subgraph_edge_mask = subgraph[3][loop_mask]
    targets.append(data.edge_label[subgraph_edge_mask].cpu())
    preds.append(expl_edge_mask[subgraph_edge_mask].cpu())

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [01:07<00:00,  1.12s/it]


In [5]:
auc = roc_auc_score(torch.cat(targets), torch.cat(preds))

print(f'ROC AUC: {auc:.4f}')

ROC AUC: 0.4330
