# Graph Classification Baselines

In [2]:
import os
import pickle
import sys

import numpy as np
import pandas as pd
import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_Mutag, GNN_Custom_NCI1, GNN_Custom_IsCyclic

In [3]:
DATASET = "NCI1" # OPTIONS: Mutagenicity, NCI1, IsCyclic
EVAL = "train" # OPTIONS: eval, train
#todo: MUTAG dataset is different from other baselines.

if DATASET == 'NCI1':
    EXPLANATION_FOLDER = "nci1_dc_top20"

## Data

In [3]:
explanations = dict()
PATH = f"../explanation/{EXPLANATION_FOLDER}"
for filename in os.listdir(PATH):
    if 'label' not in filename:
        continue
    graph_idx = ''.join(filter(lambda i: i.isdigit(), filename))
    explanations[int(graph_idx)] = pd.read_csv(f"{PATH}/{filename}", header=None).to_numpy()

In [4]:
len(explanations)

536

In [5]:
ckpt = torch.load(f"../data/{DATASET}/eval_as_{EVAL}.pt")
cg_dict = ckpt["cg"] # get computation graph
input_dim = cg_dict["feat"].shape[2]

## Model

In [6]:
model = GNN_Custom_Graph(in_features=input_dim, h_features=128)
model.load_state_dict(ckpt["model_state"])
model.eval()

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)

## Predictions

In [7]:
predictions = list()
labels = list()
for graph_idx in cg_dict['test_idx']:
    feat = cg_dict["feat"][graph_idx, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_idx].float().unsqueeze(0) # - explanations[graph_idx]
    label = cg_dict["label"][graph_idx].float().unsqueeze(0)
    proba = model(feat, adj)
    predictions.append(proba.round())
    labels.append(label)
predictions = torch.Tensor(predictions)
labels = torch.Tensor(labels)    

In [8]:
100 * (predictions == labels).sum() / len(labels)

tensor(100.)

## Fidelity & Explanation size

In [21]:
top_indices = dict() # They are from the Upper Triangular part only.
top_k = 20
for graph_id, graph in explanations.items():
    # triu: Upper Triangular
    # abs: These are indices of the flattended version, not of the 2D version.
    triu_abs_top_indices = (-np.triu(graph).flatten()).argsort()[:top_k]
    index_rows = triu_abs_top_indices // graph.shape[0]
    index_cols = triu_abs_top_indices % graph.shape[0]
    triu_top_k_indices = [(r,c) for r,c in zip(index_rows, index_cols)]
    top_indices[graph_id] = triu_top_k_indices

In [22]:
explanation_size_list = list()
cf_found_count = 0

for graph_id, graph in explanations.items():
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    label = cg_dict["label"][graph_id].long().unsqueeze(0)
    
    # get original prediction
    original_prediction = model(feat, adj).round()

    # set size_count to 0
    # make a copy of the original adjacency.
    size_count = 0
    new_adj = adj.clone()

    # work on correctly predicted label 1 nodes only.
    if label == 1 and original_prediction == 1:
        go = True
    else:
        go = False
    if not go:
        continue

    for index in top_indices[graph_id]:
        r1, c1 = index
        r2, c2 = c1, r1 # for the lower triangular part.

        # remove the edges
        new_adj[0][r1, c1] = 0.0
        new_adj[0][r2, c2] = 0.0

        # make the prediction
        new_prediction = model(feat, new_adj).round()


        # increase size_count by 1.
        size_count += 1

        # if the label flipped: stop
        if original_prediction != new_prediction:
            cf_found_count += 1
            explanation_size_list.append(size_count)
            break

In [23]:
fidelity = 1 - cf_found_count/len(explanations)
exp_size_mean, exp_size_std = np.mean(explanation_size_list), np.std(explanation_size_list)

print(f"Fidelity: {fidelity:.2f}")
print(f"Explanation size: mean={exp_size_mean:.2f}, std={exp_size_std:.2f}")

Fidelity: 0.20
Explanation size: mean=2.53, std=2.44


## Rough

In [6]:
label_108 = pd.read_csv(
    "../explanation/iscyclic_top20/graph_idx_108_label.csv",
    header=None
).to_numpy()

pred_108 = pd.read_csv(
    "../explanation/iscyclic_top20/graph_idx_108_pred.csv",
    header=None
).to_numpy()

In [8]:
label_108.sum()

31.653016686439514

In [11]:
label_108[0]

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.75364321, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [9]:
pred_108.sum()

103.88614672608674

In [10]:
pred_108[0]

array([0.024986  , 0.024986  , 0.024986  , 0.024986  , 0.024986  ,
       0.024986  , 0.024986  , 0.024986  , 0.024986  , 0.024986  ,
       0.024986  , 0.024986  , 0.024986  , 0.024986  , 0.024986  ,
       0.024986  , 0.024986  , 0.024986  , 0.024986  , 0.024986  ,
       0.024986  , 0.024986  , 0.024986  , 0.024986  , 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050399,
       0.00050399, 0.00050399, 0.00050399, 0.00050399, 0.00050