In [1]:
!pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.0.4.tar.gz (407 kB)
[K     |████████████████████████████████| 407 kB 8.0 MB/s 
Building wheels for collected packages: torch-geometric
  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone
  Created wheel for torch-geometric: filename=torch_geometric-2.0.4-py3-none-any.whl size=616603 sha256=e7f95147b4904d44b3ee9cf88f971cb1b8798c08fae30614006a36196f025eee
  Stored in directory: /root/.cache/pip/wheels/18/a6/a4/ca18c3051fcead866fe7b85700ee2240d883562a1bc70ce421
Successfully built torch-geometric
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.0.4


In [2]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

1.11.0+cu113
[K     |████████████████████████████████| 7.9 MB 8.9 MB/s 
[K     |████████████████████████████████| 3.5 MB 8.2 MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [3]:
import os
import torch
import matplotlib.pyplot as plt

from torch_geometric.nn import GCNConv
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import NormalizeFeatures
import torch_geometric.transforms as T

transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()])

dataset = TUDataset(root="data/MUTAG", name="MUTAG", transform=None)
print(dataset)
data = dataset[0]
print(data)

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip


MUTAG(188)
Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])


Extracting data/MUTAG/MUTAG/MUTAG.zip
Processing...
Done!


In [4]:
print(f'Number of classes: {dataset.num_classes}')

Number of classes: 2


In [5]:
import numpy as np
torch.manual_seed(12345)
dataset = dataset.shuffle()

split = int(0.8 * len(dataset))
train_dataset, test_dataset = dataset[:split], dataset[split:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 150
Number of test graphs: 38


In [6]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2636], x=[1188, 7], edge_attr=[2636, 4], y=[64], batch=[1188], ptr=[65])

Step 2:
Number of graphs in the current batch: 64
DataBatch(edge_index=[2, 2506], x=[1139, 7], edge_attr=[2506, 4], y=[64], batch=[1139], ptr=[65])

Step 3:
Number of graphs in the current batch: 22
DataBatch(edge_index=[2, 852], x=[387, 7], edge_attr=[852, 4], y=[22], batch=[387], ptr=[23])



In [7]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [8]:
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GCN(hidden_channels=64)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 31):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

<IPython.core.display.Javascript object>

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 008, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 009, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 010, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 011, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 012, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 013, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 014, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 015, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 016, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 018, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 019, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 020, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 021, Train Acc: 0.6467, Test Acc:

In [9]:
for data in test_loader:
  out = model(data.x, data.edge_index, data.batch) 
  print(F.softmax(out[0], dim=0)[0] > 2) 
  pred = out.argmax(dim=1) 
  print(pred[0].item()) 
  print(len(pred))
  break

tensor(False)
1
38


In [10]:
import os.path as osp
from re import sub
from sklearn import neighbors
import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GCNConv, GNNExplainer

from torch_geometric.loader import DataLoader

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils import subgraph
import numpy as np

import networkx as nx
from torch_geometric import utils

def expand(starts, ends, target, max_depth=1, prev=[]):
    node_neighbors = np.array([ends[idx] for idx, node in enumerate(starts) if node == target and ends[idx] != target])
    prev.append(target)
    if max_depth > 1:
        for n in node_neighbors:
            node_neighbors = np.concatenate((node_neighbors, expand(starts, ends, target=n, max_depth=max_depth-1, prev=prev)), axis=0)
    indices = np.unique(node_neighbors, return_index=True)[1]
    return np.array([node_neighbors[i] for i in sorted(indices)])


def process_one_graph(data):
    used = []
    num_nodes = data.x.shape[0]
    subgraph_sizes = [int(num_nodes/n) for n in range(2, 5)]
    start_nodes, end_nodes = np.array(data.edge_index)
    sub_graphs = []
    # make a grow from each node
    for target_idx in range(num_nodes):
        nodes_to_keep = expand(starts=start_nodes, ends=end_nodes, target=target_idx, max_depth=3, prev=[])
        if nodes_to_keep.shape[0] == 0:
            continue
        for size in subgraph_sizes:
            # select the grows based on size
            _subset = nodes_to_keep[:size]
            # remove repetitive stuff
            if set(_subset) in used:
                continue
            else:
                used.append(set(_subset))
            _subset = torch.from_numpy(np.array(_subset))
            sub_graphs.append((target_idx, data.subgraph(subset=_subset)))
    return sub_graphs



dataset = 'MUTAG'
path = osp.join(osp.dirname(osp.realpath("__file__")), '..', 'data', 'TUDataset')
transform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()])
dataset = TUDataset(path, dataset, transform=None)

pred_list = []
all_subgraphs = []
for idx, data in enumerate(tqdm.tqdm(dataset)):
    out = model(data.x, data.edge_index, data.batch)  
    pred = out.argmax(dim=1) 
    pred_list.append(pred)

    subgraphs = process_one_graph(data)
    all_subgraphs.append(subgraphs)
    

# for target_idx, graph in tqdm.tqdm(all_subgraphs):
#     # nx.draw_networkx(utils.to_networkx(graph, remove_self_loops=True))
#     # plt.show()


Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting /data/TUDataset/MUTAG/MUTAG.zip
Processing...
Done!
100%|██████████| 188/188 [00:09<00:00, 20.72it/s]


In [11]:
print(type(all_subgraphs))
for data in all_subgraphs:
  print(data[1])
  break



<class 'list'>
(0, Data(edge_index=[2, 8], x=[5, 7], edge_attr=[8, 4], y=[1]))


In [12]:
print(type(test_dataset))
for test_data in test_dataset:
  print(test_data)
  break

<class 'torch_geometric.datasets.tu_dataset.TUDataset'>
Data(edge_index=[2, 50], x=[22, 7], edge_attr=[50, 4], y=[1])


In [23]:
min_sufficient_explanation = [] #list of lists, each inner list correspond to one full graph in TUDataset Protein,
#containing each full graph's MSE's
threshold = 0.58
for idx, data_collection in enumerate(tqdm.tqdm(all_subgraphs)):
  sub_mse = [] #sub_mse for this collection, correspond to one original full graph
  for data in data_collection:
    out = model(data[1].x, data[1].edge_index, data[1].batch) 
    #use confidence 
    
    confidence = F.softmax(out[0], dim=0)
    pred = out.argmax(dim=1)
    #print(confidence)
    #print((confidence[pred.item()] >= threshold).item())
    
      
    # if the subgraph's predicted label matches the original full graph's predicted, add to its min_suff_expl list
    if torch.eq(pred, pred_list[idx]).item() and (confidence[pred.item()] >= threshold).item():
      sub_mse.append(data)
  
  min_sufficient_explanation.append(sub_mse)

  

100%|██████████| 188/188 [00:13<00:00, 14.19it/s]


In [22]:
for i in range(10):
  print(len(min_sufficient_explanation[i]))
  print(len(all_subgraphs[i]))
  print(len(min_sufficient_explanation[i]) < len(all_subgraphs[i]))
for data in all_subgraphs:
  print(data)
  break

for data in min_sufficient_explanation:
  print(len(data))
  break

#for example, the first three lines of outputs
#means 35 of the 43 subgraphs of the 0-th full graph
#are min sufficient explanations

35
43
True
18
32
True
21
32
True
38
44
True
16
28
True
33
56
True
30
37
True
35
46
True
8
29
True
35
43
True
[(0, Data(edge_index=[2, 14], x=[7, 7], edge_attr=[14, 4], y=[1])), (0, Data(edge_index=[2, 8], x=[5, 7], edge_attr=[8, 4], y=[1])), (0, Data(edge_index=[2, 6], x=[4, 7], edge_attr=[6, 4], y=[1])), (1, Data(edge_index=[2, 14], x=[7, 7], edge_attr=[14, 4], y=[1])), (1, Data(edge_index=[2, 8], x=[5, 7], edge_attr=[8, 4], y=[1])), (2, Data(edge_index=[2, 16], x=[8, 7], edge_attr=[16, 4], y=[1])), (2, Data(edge_index=[2, 6], x=[4, 7], edge_attr=[6, 4], y=[1])), (3, Data(edge_index=[2, 8], x=[5, 7], edge_attr=[8, 4], y=[1])), (3, Data(edge_index=[2, 2], x=[4, 7], edge_attr=[2, 4], y=[1])), (4, Data(edge_index=[2, 14], x=[8, 7], edge_attr=[14, 4], y=[1])), (4, Data(edge_index=[2, 8], x=[5, 7], edge_attr=[8, 4], y=[1])), (4, Data(edge_index=[2, 2], x=[4, 7], edge_attr=[2, 4], y=[1])), (5, Data(edge_index=[2, 6], x=[4, 7], edge_attr=[6, 4], y=[1])), (6, Data(edge_index=[2, 14], x=[8, 7]

In [15]:
#input: 
# pred_list: [tensor, tensor, tensor ...], len = len(TUDataset/Protein)
#   each element is a one-element tensor 
#   representing the predicted label of one graph in TUDataset/Protein
# all_subgraph: [[subgraph, subgraph..], [subgraph, subgraph..]...], len = len(TUDataset/Protein)
#   each element is a list of subgraphs obtained from one graph in TUDataset/Protein
#   each subgraph is a tuple in the form (k, Data(edge_index=[2, 24], x=[8, 0], y=[1]))
#   where k is the index of the node that this subgraph expanded from
#output:
# min_sufficient_explanation: [[subgraph, subgraph,..], [subgraph, subgraph]...]
#   len = len(TUDataset/Protein)
#   each element is a list of subgraphs which are the MSEs
#   each subgraph is a tuple in the form (k, Data(edge_index=[2, 24], x=[8, 0], y=[1]))
#   where k is the index of the node that this subgraph expanded from