# Tree Grids Dataset

In [1]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GCNSynthetic

## Data

In [2]:
filename = "../../cfgnnexplainer/data/gnn_explainer/syn5.pickle"
with open(filename, "rb") as file:
    treegrids_cfgnn = pickle.load(file)

In [3]:
for key, val in treegrids_cfgnn.items():
    print(f"{key:<10}: {type(val)}")

adj       : <class 'numpy.ndarray'>
feat      : <class 'numpy.ndarray'>
labels    : <class 'numpy.ndarray'>
train_idx : <class 'list'>
test_idx  : <class 'list'>


## Model

In [4]:
model = GCNSynthetic(
    nfeat=10,
    nhid=20,
    nout=20,
    nclass=2,
    dropout=0.0
)

In [5]:
state_dict = torch.load("../cfgnn_model_weights/gcn_3layer_syn5.pt")
model.load_state_dict(state_dict)
model.eval()

GCNSynthetic(
  (gc1): GraphConvolution (10 -> 20)
  (gc2): GraphConvolution (20 -> 20)
  (gc3): GraphConvolution (20 -> 20)
  (lin): Linear(in_features=60, out_features=2, bias=True)
)

## Predictions

In [6]:
feat = torch.from_numpy(treegrids_cfgnn['feat']).float()
adj = torch.from_numpy(treegrids_cfgnn['adj']).float()
preds = model(feat, adj).detach().numpy()

## Checkpoint

In [7]:
ckpt_bashapes = torch.load("../data/syn1/eval_as_eval.pt")
bashapes = ckpt_bashapes['cg']

In [8]:
ckpt_treegrids = ckpt_bashapes.copy()
ckpt_treegrids['model_state'] = state_dict
ckpt_treegrids['cg'] = treegrids_cfgnn.copy()
ckpt_treegrids['cg']['pred'] = preds # This key is missing in treegrids_cfgnn.
ckpt_treegrids['cg']['label'] = ckpt_treegrids['cg'].pop('labels') # Key mismatch.

## Eval set

In [9]:
with open("../../eval_set.pkl", "rb") as file:
    eval_set = pickle.load(file)
KEY = "syn5/tree-grid"

### Our eval set as part of the training set

In [10]:
treegrids_cfgnn.keys()

dict_keys(['adj', 'feat', 'labels', 'train_idx', 'test_idx'])

In [11]:
train_set_indices = [range(treegrids_cfgnn['labels'].shape[1])]
test_set_indices = eval_set[KEY]

In [12]:
ckpt_treegrids["cg"]["train_idx"] = train_set_indices
ckpt_treegrids["cg"]["test_idx"] = test_set_indices

In [13]:
os.makedirs(f"../data/syn5", exist_ok=True)
torch.save(ckpt_treegrids, f"../data/syn5/eval_as_train.pt")

### Our eval set as the validation set

In [14]:
train_set_indices = [i for i in range(treegrids_cfgnn['labels'].shape[1]) if i not in eval_set[KEY]]
test_set_indices = eval_set[KEY]

In [15]:
ckpt_treegrids["cg"]["train_idx"] = train_set_indices
ckpt_treegrids["cg"]["test_idx"] = test_set_indices

In [16]:
os.makedirs(f"../data/syn5", exist_ok=True)
torch.save(ckpt_treegrids, f"../data/syn5/eval_as_eval.pt")

## Rough

In [17]:
for k1, k2 in zip(
    sorted(list(ckpt_treegrids['cg'])),
    sorted(list(ckpt_bashapes['cg']))
):
    if k1 != k2:
        print("MISMATCH!")
        print(k1, k2)

In [18]:
import numpy as np
labels = ckpt_treegrids['cg']['label']
preds = np.argmax(ckpt_treegrids['cg']['pred'], axis=-1)

In [19]:
acc = np.sum(labels == preds)/labels.shape[1]
print(f"Accuracy: {100 * acc :.2f} %")

Accuracy: 84.97 %
