# Initial Experiment

In this experiment, I am running all three explanation methods over a set of 10 instances which include:

- '9606.ENSP00000269228' # NPC1 which is known to have involvement in the cholesterol homeostasis process
- '9606.ENSP00000289989' # Closest to SREBP2 and has label 1 (expected)
- '9606.ENSP00000415836' # Far from SREBP2 and has label 0 (expected)
- '9606.ENSP00000216180' # Closest to SREBP2 but has label 0 (unexpected)
- '9606.ENSP00000359398' # Far from SREBP2 but has label 1 (Unexpected)
- '9606.ENSP00000346046' # False Negative (lowest confidence)
- '9606.ENSP00000473036' # True Negative (lowest confidence)
- '9606.ENSP00000449270' # False Positive (highest confidence)
- '9606.ENSP00000270176' # True Positive (highest confidence)

From each method, I will measuring the metrics:
 
- Runtime
- Fidelity 
- Stability


In [1]:
# install necessary packages

%pip install torch_geometric # install pytorch geometric
%pip install torchvision #install torchvision
%pip install matplotlib #install matplotlib
%pip install graphviz # install graphviz

%matplotlib inline

from torch_geometric.data import Data, DataLoader
from torch_geometric.explain import GNNExplainer,Explainer,GraphMaskExplainer,PGExplainer
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv
from torch.nn import Linear,Softmax
import os
from tqdm import tqdm, trange
import pickle
from torch_geometric.explain.metric import fidelity, characterization_score

import matplotlib.pyplot as plt

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# define GNN 

class SimpleGAT_1(torch.nn.Module):
    """
    A graph attention network with 4 graph layers and 2 linear layers.
    Uses v2 of graph attention that provides dynamic instead of static attention.
    The graph layer dimension and number of attention heads can be specified.
    """
    #torch.device('mps')
    def __init__(self, dataset, dim=8, num_heads=4):
        super(SimpleGAT_1, self).__init__()
        torch.manual_seed(seed=123)
        self.conv1 = GATv2Conv(in_channels=dataset.num_features, out_channels=dim, heads=num_heads,edge_dim=dataset.edge_attr.shape[1])
        self.conv2 = GATv2Conv(in_channels=dim * num_heads, out_channels=dim, heads=num_heads,edge_dim=dataset.edge_attr.shape[1])
        self.lin1 = Linear(dim * num_heads,dim)
        self.lin2 = Linear(dim,1)

    def forward(self, x, edge_index,edge_attr):
        h = self.conv1(x, edge_index,edge_attr).relu()
        h = self.conv2(h, edge_index,edge_attr).relu()
        h = self.lin1(h).relu()
        #print(h)
        h = F.dropout(h, p=0.1, training=self.training)
        out = self.lin2(h)[:,0]
        out = torch.sigmoid(out)
        return out

In [3]:
# load in data and fit GNN model 

data = torch.load('../SREBP2_0.pt')
model_path = '../SimpleGAT_1_model_lr_0.0001_dp_0.7.pth'
model = SimpleGAT_1(data,dim = 16)
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
model.eval()

SimpleGAT_1(
  (conv1): GATv2Conv(165, 16, heads=4)
  (conv2): GATv2Conv(64, 16, heads=4)
  (lin1): Linear(in_features=64, out_features=16, bias=True)
  (lin2): Linear(in_features=16, out_features=1, bias=True)
)

In [4]:
# Use data to get feature labels and remove all elements of data that are not a feature label

feat_labels = list(data.to_dict().keys())
to_remove = ['edge_index', 
 'experimental',
 'database',
 'textmining',
 'combined_score',
 'binary_experimental',
 'binary_database',
 'binary_textmining',
 'num_nodes',
 'x',
 'y',
 'edge_attr',
 'num_classes',
 'label']

for element in to_remove:
    feat_labels.remove(element)

In [5]:
# extract node labels and testing nodes from node order 

""" 
- '9606.ENSP00000269228' # NPC1 which is known to have involvement in the cholesterol homeostasis process
- '9606.ENSP00000289989' # Closest to SREBP2 and has label 1 (expected)
- '9606.ENSP00000415836' # Far from SREBP2 and has label 0 (expected)
- '9606.ENSP00000216180' # Closest to SREBP2 but has label 0 (unexpected)
- '9606.ENSP00000359398' # Far from SREBP2 but has label 1 (Unexpected)
- '9606.ENSP00000346046' # False Negative (lowest confidence)
- '9606.ENSP00000473036' # True Negative (lowest confidence)
- '9606.ENSP00000449270' # False Positive (highest confidence)
- '9606.ENSP00000270176' # True Positive (highest confidence)
"""

with open('../node_order.pickle', 'rb') as f:
    node_order = pickle.load(f)

node_labels = list(node_order.keys())

test_nodes = [node_order['9606.ENSP00000269228'], # NPC1 which is known to have involvement in the cholesterol homeostasis process
                node_order['9606.ENSP00000289989'], # Closest to SREBP2 and has label 1 (expected)
                node_order['9606.ENSP00000415836'], # Far from SREBP2 and has label 0 (expected)
                node_order['9606.ENSP00000216180'], # Closest to SREBP2 but has label 0 (unexpected)
                node_order['9606.ENSP00000359398'], # Far from SREBP2 but has label 1 (Unexpected)
                node_order['9606.ENSP00000346046'], # False Negative (lowest confidence)
                node_order['9606.ENSP00000473036'], # True Negative (lowest confidence)
                node_order['9606.ENSP00000449270'], # False Positive (highest confidence)
                node_order['9606.ENSP00000270176'] ]# True Positive (highest confidence) 

# Experiments

In [6]:
# define explainers
gnn_attr = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

gnn_common = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='common_attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

pg = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

gm_attr = Explainer(
    model=model,
    algorithm=GraphMaskExplainer(2, epochs=5),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ), 
)

gm_comm = Explainer(
    model=model,
    algorithm=GraphMaskExplainer(2, epochs=5),
    explanation_type='model',
    node_mask_type='common_attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [7]:
def pos_fidelity_score(fids):
    return 1 - sum(fids)/len(fids)

def neg_fidelity_score(fids):
    return 1 - sum(fids)/len(fids)

In [12]:
import time
def trial(explainer):
    fid_pos = []
    fid_neg = []
    times = []
    explanations = []

    for node in test_nodes:
        start = time.time()
        explanation = explainer(data.x, data.edge_index, edge_attr=data.edge_attr, index = node)
        end = time.time()
        fid = fidelity(explainer, explanation)
        fid_pos.append(fid[0])
        fid_neg.append(fid[1])
        times.append(end - start)
        explanations.append(explanation)
        # explanation.visualize_subgraph(path=f"trial_figs/{explainer}_{node}.pdf", backend='graphviz')

    return {'positive_fid' : fid_pos, 'negative_fid': fid_neg, 'time': times, 'explanations': explanations}

In [14]:
%%time 

gnn_exp_attr = trial(gnn_attr)

CPU times: user 5h 7min 29s, sys: 6h 6min 10s, total: 11h 13min 39s
Wall time: 2h 39min 39s


In [19]:
%%time

gnn_exp_common = trial(gnn_common)

CPU times: user 5h 4min 29s, sys: 6h 28min 33s, total: 11h 33min 2s
Wall time: 2h 37min 16s


In [21]:
def pg_trial(explainer):
    fid_pos = []
    fid_neg = []
    times = []
    explanations = []

    for node in test_nodes:
        start = time.time()
        for epoch in range(30):
            loss = explainer.algorithm.train(epoch, model, data.x, data.edge_index,
                                         target=data.y, index=node, edge_attr = data.edge_attr)

        explanation = explainer(data.x, data.edge_index, target=data.y, index=node, edge_attr = data.edge_attr)
        end = time.time()
        fid = fidelity(explainer, explanation)
        fid_pos.append(fid[0])
        fid_neg.append(fid[1])
        times.append(end - start)
        explanations.append(explanation)

    return {'positive_fid' : fid_pos, 'negative_fid': fid_neg, 'time': times, 'explanations': explanations}


In [22]:
%%time 

pg_exp = pg_trial(pg)

CPU times: user 1h 10min 48s, sys: 1h 41min 55s, total: 2h 52min 43s
Wall time: 1h 8min 1s


In [23]:
%%time 

gm_exp_attr = trial(gm_attr)

Python(44796) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Train explainer for node(s) tensor([6179]) with layer 1: 100%|██████████| 5/5 [01:26<00:00, 17.32s/it]
Train explainer for node(s) tensor([6179]) with layer 0: 100%|██████████| 5/5 [03:52<00:00, 46.56s/it]
Explain: 100%|██████████| 2/2 [00:10<00:00,  5.25s/it]
Train explainer for node(s) tensor([17397]) with layer 1: 100%|██████████| 5/5 [01:30<00:00, 18.06s/it]
Train explainer for node(s) tensor([17397]) with layer 0: 100%|██████████| 5/5 [03:43<00:00, 44.79s/it]
Explain: 100%|██████████| 2/2 [00:05<00:00,  2.53s/it]
Train explainer for node(s) tensor([15315]) with layer 1: 100%|██████████| 5/5 [01:47<00:00, 21.43s/it]
Train explainer for node(s) tensor([15315]) with layer 0: 100%|██████████| 5/5 [04:01<00:00, 48.25s/it]
Explain: 100%|██████████| 2/2 [00:11<00:00,  5.90s/it]
Train explainer for node(s) tensor([9021]) with layer 1: 100%|██████████| 5/5 [01:30<00:00, 18.06s/it]
Train explai

CPU times: user 31min 46s, sys: 1h 2min 19s, total: 1h 34min 5s
Wall time: 49min 18s


In [24]:
gm_exp_common = trial(gm_comm)

Train explainer for node(s) tensor([6179]) with layer 1: 100%|██████████| 5/5 [01:08<00:00, 13.61s/it]
Train explainer for node(s) tensor([6179]) with layer 0: 100%|██████████| 5/5 [04:09<00:00, 49.97s/it]
Explain: 100%|██████████| 2/2 [00:09<00:00,  4.75s/it]
Train explainer for node(s) tensor([17397]) with layer 1: 100%|██████████| 5/5 [01:16<00:00, 15.30s/it]
Train explainer for node(s) tensor([17397]) with layer 0: 100%|██████████| 5/5 [04:07<00:00, 49.44s/it]
Explain: 100%|██████████| 2/2 [00:07<00:00,  3.66s/it]
Train explainer for node(s) tensor([15315]) with layer 1: 100%|██████████| 5/5 [01:07<00:00, 13.51s/it]
Train explainer for node(s) tensor([15315]) with layer 0: 100%|██████████| 5/5 [03:10<00:00, 38.13s/it]
Explain: 100%|██████████| 2/2 [00:05<00:00,  2.61s/it]
Train explainer for node(s) tensor([9021]) with layer 1: 100%|██████████| 5/5 [01:11<00:00, 14.24s/it]
Train explainer for node(s) tensor([9021]) with layer 0: 100%|██████████| 5/5 [03:16<00:00, 39.35s/it]
Explain

In [32]:
trial_exps = [gnn_exp_attr, gnn_exp_common, pg_exp, gm_exp_common, gm_exp_attr]
trial_names = ["GNNExp with Attr", "GNNExp with Common Attr", "PGExplainer", "GraphMask with Attr", "GraphMask with Common Attr"]

In [76]:
import pandas as pd
import numpy as np
trials = []
for i in range(len(trial_exps)):
    trial = trial_exps[i]
    trial_name = trial_names[i]
    pos_fid = pos_fidelity_score(trial['positive_fid'])
    neg_fid = neg_fidelity_score(trial['negative_fid'])
    avg_time = np.mean(trial['time']) / 60
    std_time = np.std(trial['time']) / 60
    d = {'explainer_type': trial_name, 'positive_fidelity': pos_fid, 'negative_fidelity': neg_fid, 'average_time_mins': avg_time, 'std_time': std_time}
    trials.append(d)


trials_df = pd.DataFrame(trials)
trials_df

Unnamed: 0,explainer_type,positive_fidelity,negative_fidelity,average_time_mins,std_time
0,GNNExp with Attr,0.666667,0.444444,17.702863,4.288051
1,GNNExp with Common Attr,0.333333,0.444444,17.436035,0.291227
2,PGExplainer,1.0,0.333333,7.509203,0.262723
3,GraphMask with Attr,0.333333,0.333333,4.812597,0.407326
4,GraphMask with Common Attr,0.333333,0.333333,5.376516,0.369671


In [59]:
def get_k_features(explanation, k, feat_labels):
    node_mask = explanation.get('node_mask')
    if node_mask is None:
        raise ValueError("The attribute 'node_mask' is not available ")
    if node_mask.dim() != 2 or node_mask.size(1) <= 1:
        raise ValueError(f"Cannot compute feature importance for "
                            f"object-level 'node_mask' "
                            f"(got shape {node_mask.size()})")
    score = node_mask.sum(dim=0)
    score = score.cpu().numpy()
    df = pd.DataFrame({'score': score}, index=feat_labels)
    df = df.sort_values('score', ascending=False)
    df = df.round(decimals=3)
    return df.head(k)
    

In [75]:
top_features = pd.DataFrame()


for trial in trial_exps:
    for exps in trial['explanations']:
        try:
            top_features = pd.concat([top_features, get_k_features(exps, 10, feat_labels)])
        except:
            continue

top_features['feature_name'] = top_features.index
top_features = top_features.reset_index()
top_features = top_features.groupby('feature_name').agg('count').reset_index()
top_features = top_features.sort_values('score', ascending=False)
top_features.head(10)


Unnamed: 0,feature_name,index,score
19,RNA line ab,13,13
36,cytoplasm,10,10
127,train_mask,10,10
49,go_feature_12,10,10
41,endomembrane system,8,8
117,plasma membrane,8,8
97,golgi apparatus,8,8
98,has_path,7,7
129,vesicles,7,7
42,endoplasmic reticulum,6,6
