# Initial Experiment

In this experiment, I am running all three explanation methods over a set of 10 instances which include:

- '9606.ENSP00000269228' # NPC1 which is known to have involvement in the cholesterol homeostasis process
- '9606.ENSP00000289989' # Closest to SREBP2 and has label 1 (expected)
- '9606.ENSP00000415836' # Far from SREBP2 and has label 0 (expected)
- '9606.ENSP00000216180' # Closest to SREBP2 but has label 0 (unexpected)
- '9606.ENSP00000359398' # Far from SREBP2 but has label 1 (Unexpected)
- '9606.ENSP00000346046' # False Negative (lowest confidence)
- '9606.ENSP00000473036' # True Negative (lowest confidence)
- '9606.ENSP00000449270' # False Positive (highest confidence)
- '9606.ENSP00000270176' # True Positive (highest confidence)

From each method, I will measuring the metrics:
 
- Runtime
- Fidelity 
- Stability


In [46]:
# install necessary packages

%pip install torch_geometric # install pytorch geometric
%pip install torchvision #install torchvision
%pip install matplotlib #install matplotlib
%pip install graphviz # install graphviz

%matplotlib inline

from torch_geometric.data import Data, DataLoader
from torch_geometric.explain import GNNExplainer,Explainer,GraphMaskExplainer,PGExplainer
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv
from torch.nn import Linear,Softmax
import os
from tqdm import tqdm, trange
import pickle
from torch_geometric.explain.metric import fidelity, characterization_score

import matplotlib.pyplot as plt

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [9]:
# define GNN 

class SimpleGAT_1(torch.nn.Module):
    """
    A graph attention network with 4 graph layers and 2 linear layers.
    Uses v2 of graph attention that provides dynamic instead of static attention.
    The graph layer dimension and number of attention heads can be specified.
    """
    #torch.device('mps')
    def __init__(self, dataset, dim=8, num_heads=4):
        super(SimpleGAT_1, self).__init__()
        torch.manual_seed(seed=123)
        self.conv1 = GATv2Conv(in_channels=dataset.num_features, out_channels=dim, heads=num_heads,edge_dim=dataset.edge_attr.shape[1])
        self.conv2 = GATv2Conv(in_channels=dim * num_heads, out_channels=dim, heads=num_heads,edge_dim=dataset.edge_attr.shape[1])
        self.lin1 = Linear(dim * num_heads,dim)
        self.lin2 = Linear(dim,1)

    def forward(self, x, edge_index,edge_attr):
        h = self.conv1(x, edge_index,edge_attr).relu()
        h = self.conv2(h, edge_index,edge_attr).relu()
        h = self.lin1(h).relu()
        #print(h)
        h = F.dropout(h, p=0.1, training=self.training)
        out = self.lin2(h)[:,0]
        out = torch.sigmoid(out)
        return out

In [8]:
# load in data and fit GNN model 

data = torch.load('../SREBP2_0.pt')
model_path = '../SimpleGAT_1_model_lr_0.0001_dp_0.7.pth'
model = SimpleGAT_1(data,dim = 16)
model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
model.eval()

SimpleGAT_1(
  (conv1): GATv2Conv(165, 16, heads=4)
  (conv2): GATv2Conv(64, 16, heads=4)
  (lin1): Linear(in_features=64, out_features=16, bias=True)
  (lin2): Linear(in_features=16, out_features=1, bias=True)
)

In [15]:
# Use data to get feature labels and remove all elements of data that are not a feature label

feat_labels = list(data.to_dict().keys())
to_remove = ['edge_index', 
 'experimental',
 'database',
 'textmining',
 'combined_score',
 'binary_experimental',
 'binary_database',
 'binary_textmining',
 'num_nodes',
 'x',
 'y',
 'edge_attr',
 'num_classes',
 'label']

for element in to_remove:
    feat_labels.remove(element)

In [17]:
# extract node labels and testing nodes from node order 

""" 
- '9606.ENSP00000269228' # NPC1 which is known to have involvement in the cholesterol homeostasis process
- '9606.ENSP00000289989' # Closest to SREBP2 and has label 1 (expected)
- '9606.ENSP00000415836' # Far from SREBP2 and has label 0 (expected)
- '9606.ENSP00000216180' # Closest to SREBP2 but has label 0 (unexpected)
- '9606.ENSP00000359398' # Far from SREBP2 but has label 1 (Unexpected)
- '9606.ENSP00000346046' # False Negative (lowest confidence)
- '9606.ENSP00000473036' # True Negative (lowest confidence)
- '9606.ENSP00000449270' # False Positive (highest confidence)
- '9606.ENSP00000270176' # True Positive (highest confidence)
"""

with open('../node_order.pickle', 'rb') as f:
    node_order = pickle.load(f)

node_labels = list(node_order.keys())

test_nodes = [node_order['9606.ENSP00000269228'], # NPC1 which is known to have involvement in the cholesterol homeostasis process
                node_order['9606.ENSP00000289989'], # Closest to SREBP2 and has label 1 (expected)
                node_order['9606.ENSP00000415836'], # Far from SREBP2 and has label 0 (expected)
                node_order['9606.ENSP00000216180'], # Closest to SREBP2 but has label 0 (unexpected)
                node_order['9606.ENSP00000359398'], # Far from SREBP2 but has label 1 (Unexpected)
                node_order['9606.ENSP00000346046'], # False Negative (lowest confidence)
                node_order['9606.ENSP00000473036'], # True Negative (lowest confidence)
                node_order['9606.ENSP00000449270'], # False Positive (highest confidence)
                node_order['9606.ENSP00000270176'] ]# True Positive (highest confidence) 

# Experiments

In [58]:
# define explainers
gnn_attr = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

gnn_common = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='common_attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

pg = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

gm_attr = Explainer(
    model=model,
    algorithm=GraphMaskExplainer(2, epochs=5),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ), 
)

gm_comm = Explainer(
    model=model,
    algorithm=GraphMaskExplainer(2, epochs=5),
    explanation_type='model',
    node_mask_type='common_attributes',
    edge_mask_type='object',
    threshold_config=dict(
        threshold_type="topk",
        value=10 
    ),
    model_config=dict(
        mode='binary_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [47]:
def pos_fidelity_score(fids):
    return 1 - sum(fids)/len(fids)

def neg_fidelity_score(fids):
    return 1 - sum(fids)/len(fids)

In [71]:
import time
def trial(explainer):
    fid_pos = []
    fid_neg = []
    times = []
    explanations = []

    for node in test_nodes:
        start = time.time()
        explanation = explainer(data.x, data.edge_index, edge_attr=data.edge_attr, index = node)
        end = time.time()
        fid = fidelity(explainer, explanation)
        fid_pos.append(fid[0])
        fid_neg.append(fid[1])
        times.append(end - start)
        explanations.append(explanation)
        explanation.visualize_subgraph(path=f"trial_figs/{explainer}_{node}.pdf", backend='graphviz')

    return {'positive_fid' : fid_pos, 'negative_fid': fid_neg, 'time': times, 'explanations': explanations}

In [72]:
trial_data = {}
exps = [gnn_attr, gnn_common, pg, gm_attr, gm_comm]

for exp in exps:
    trial_data[exp] = trial(exp)

KeyboardInterrupt: 