# Cross check labels

In [1]:
import pickle
import numpy as np

In [5]:
# get cfgnn bashapes dataset
path_cfgnn = "../cfgnnexplainer/data/gnn_explainer/syn1.pickle"
# path_cfgnn = "../cfgnnexplainer/data/gnn_explainer/syn4.pickle"
with open(path_cfgnn, "rb") as file:
    dataset_cfgnn = pickle.load(file) # format: dict(adj, feat, labels, train_idx, test_idx)

# get cfsqr bashapes dataset
path_cfsqr = "datasets/BA_Shapes/syn_data.pkl"
# path_cfsqr = "datasets/Tree_Cycles/syn_data.pkl"
with open(path_cfsqr, "rb") as file:
    dataset_cfsqr = pickle.load(file)
    # format tuple():
    '''
    [
        0. adjacency_matrix
        1. features
        2. y_train # [0., 1., 0., 0.,] means label 1.
        3. y_val
        4. y_test
        5. train_mask
        6. val_mask
        7. test_mask
        8. e_labels # how are these edge labels decided?
    ]
    '''

In [6]:
# same adjacency matrices?
differences = np.sum(dataset_cfgnn['adj'][0] - dataset_cfsqr[0]).astype(int)
print(f"Differences: {differences}")

if differences == 0:
    print("Nice, Same graphs!")
else:
    print("Oops!")

Differences: 0
Nice, Same graphs!


## Generate node-label mapping

### CFGNN

In [4]:
labels_cfgnn = {node:label for node, label in enumerate(dataset_cfgnn['labels'].flatten())}

### CFSQR

In [5]:
# OR all datasets
labels_matrix_cfsqr = np.bitwise_or(
    np.bitwise_or(
        dataset_cfsqr[2].astype(int),
        dataset_cfsqr[3].astype(int)
    ),
    dataset_cfsqr[4].astype(int)
)

In [6]:
labels_cfsqr = {node:label for node, label in enumerate(np.argmax(labels_matrix_cfsqr, axis=-1))}

## Compare

Since the adjacency matrices are the same, there is no node which is present in one dataset and absent in the other.

In [7]:
mismatches = 0
for node in labels_cfgnn.keys():
    label_cfgnn = labels_cfgnn[node]
    label_cfsqr = labels_cfsqr[node]
    if label_cfgnn != label_cfsqr:
        mismatches += 1
        print(f"CFGNN: {node}:{label_cfgnn}")
        print(f"CFSQR: {node}:{label_cfsqr}")
print(f"Mismatches: {mismatches}")

Mismatches: 0


# CFSQR Black-box accuracy

Prior to running the following cells, ensure that all the cells above are run for the same dataset as you intend to use now.

In [8]:
FOLDER = "outputs/bashapes/bashapes-alp_0.0-1653544735"
# FOLDER = "outputs/treecycles/treecycles-alp_0.0-1653386395"

with open(f"{FOLDER}/pred_label_dict.pkl", "rb") as file:
    pred_label_dict = pickle.load(file) # format: node_id: initial_blackbox_prediction

In [9]:
mismatches = 0
for node in pred_label_dict.keys():
    if labels_cfsqr[node] != int(pred_label_dict[node]):
        mismatches += 1

print(f"Test set accuracy: {100 * (1 - mismatches/len(pred_label_dict.keys())):.2f}%")

Test set accuracy: 97.37%
