# Graph Classification Baselines

In [15]:
import os
import pickle
import sys

import numpy as np
import pandas as pd
import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_Graph

In [16]:
DATASET = "NCI1" # OPTIONS: Mutagenicity, NCI1, IsCyclic
EVAL = "eval"
#todo: MUTAG dataset is different from other baselines.

if DATASET == 'NCI1':
    EXPLANATION_FOLDER = "nci1_dc_top20"

## Data

In [17]:
explanations = dict()
PATH = f"../explanation/{EXPLANATION_FOLDER}"
for filename in os.listdir(PATH):
    if 'label' not in filename:
        continue
    graph_idx = ''.join(filter(lambda i: i.isdigit(), filename))
    explanations[int(graph_idx)] = pd.read_csv(f"{PATH}/{filename}", header=None).to_numpy()

In [18]:
ckpt = torch.load(f"../data/{DATASET}/eval_as_{EVAL}.pt")
cg_dict = ckpt["cg"] # get computation graph
input_dim = cg_dict["feat"].shape[2]

## Model

In [19]:
model = GNN_Custom_Graph(in_features=input_dim, h_features=128)
model.load_state_dict(ckpt["model_state"])
model.eval()

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)

## Predictions

In [20]:
predictions = list()
labels = list()
for graph_idx in cg_dict['test_idx']:
    feat = cg_dict["feat"][graph_idx, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_idx].float().unsqueeze(0) # - explanations[graph_idx]
    label = cg_dict["label"][graph_idx].float().unsqueeze(0)
    proba = model(feat, adj)
    predictions.append(proba.round())
    labels.append(label)
predictions = torch.Tensor(predictions)
labels = torch.Tensor(labels)    

## Fidelity

In [38]:
misclassifications = (predictions != labels).sum()
fidelity = 1 - misclassifications/len(predictions)
print("\n===============")
print(f"Fidelity: {fidelity:.2f}")


Fidelity: 0.49


## Explanation size

In [None]:
exp_size = list()
for graph in explanations.values():
    exp_size.append((graph ))

## Rough

In [65]:
cg_dict.keys()

dict_keys(['adj', 'feat', 'label', 'gid', 'pred', 'train_idx', 'val_idx', 'test_idx'])

In [77]:
(
    cg_dict['label'][torch.Tensor(cg_dict['val_idx']).squeeze().long()]
    == torch.Tensor(cg_dict['pred']).squeeze()[
        torch.Tensor(cg_dict['val_idx']).squeeze().long()
    ].round().squeeze()
).sum() / cg_dict['label'].size(0)

tensor(0.0665)