# Visualization dataset

In [46]:
import os
import pickle
import sys

import torch
import numpy as np
import pandas as pd

sys.path.append("../gnnexp")
from models import GCNSynthetic

In [47]:
DATASET = "syn5"
EVAL = "train"

In [48]:
if DATASET == "syn1":
    if EVAL == "eval":
        OUTPUTS = "output/syn1/1660599579"
    else:
        OUTPUTS = "output/syn1/1660600121"
elif DATASET == "syn4":
    if EVAL == "eval":
        OUTPUTS = "output/syn4/1660599591"
    else:
        OUTPUTS = "output/syn4/1660600106"
elif DATASET == "syn5":
    if EVAL == "eval":
        OUTPUTS = "output/syn5/1660599659"
    else:
        OUTPUTS = "output/syn5/1660600177"

## Data

In [49]:
with open(f"../{OUTPUTS}/original_sub_data.pkl", "rb") as file:
    sub_data = pickle.load(file)
sub_labels = dict()
for node in sub_data:
    new_idx = sub_data[node]['node_idx_new']
    sub_labels[node] = int(sub_data[node]['sub_label'][new_idx])

explanations = dict()
PATH = f"../explanation/{DATASET}_top6"
for filename in os.listdir(PATH):
    if 'pred' not in filename:
        continue
    node_idx = ''.join(filter(lambda i: i.isdigit(), filename))
    explanations[int(node_idx)] = pd.read_csv(
        f"{PATH}/{filename}", header=None).to_numpy()

## Model

In [50]:
ckpt = torch.load(f"../data/{DATASET}/eval_as_{EVAL}.pt")
cg_dict = ckpt["cg"]
input_dim = cg_dict["feat"].shape[2]
num_classes = cg_dict["pred"].shape[2]
feat = torch.from_numpy(cg_dict["feat"]).float()
adj = torch.from_numpy(cg_dict["adj"]).float()
label = torch.from_numpy(cg_dict["label"]).long()
with open(f"../tests/prog_args_{DATASET}.pkl", "rb") as file:
    prog_args = pickle.load(file)

model = GCNSynthetic(
    nfeat=input_dim,
    nhid=prog_args.hidden_dim,
    nout=prog_args.output_dim,
    nclass=num_classes,
    dropout=0.0,
)
model.load_state_dict(ckpt["model_state"])
model.eval()

GCNSynthetic(
  (gc1): GraphConvolution (10 -> 20)
  (gc2): GraphConvolution (20 -> 20)
  (gc3): GraphConvolution (20 -> 20)
  (lin): Linear(in_features=60, out_features=2, bias=True)
)

## Top k cfs

In [51]:
top_indices = dict() # They are from the Upper Triangular part only.
top_k = 6
for graph_id, graph in explanations.items():
    # print(graph_id, graph.shape)
    # triu: Upper Triangular
    # abs: These are indices of the flattended version, not of the 2D version.
    # * Following what gem's authors have done:
    # * explanation = adjecency * explanation_mask
    explanation = sub_data[graph_id]['org_adj'].squeeze(0) * graph
    triu_abs_top_indices = (-np.triu(explanation).flatten()).argsort()[:top_k]
    index_rows = triu_abs_top_indices // explanation.shape[0]
    index_cols = triu_abs_top_indices % explanation.shape[0]
    triu_top_k_indices = [(r,c) for r,c in zip(index_rows, index_cols)]
    top_indices[graph_id] = triu_top_k_indices

## New Dataset

In [52]:
new_dataset = dict()
for graph_id, data in explanations.items():
    feat = sub_data[graph_id]["sub_feat"].float()
    adj = sub_data[graph_id]["org_adj"].float()
    label = sub_labels[graph_id]
    target = sub_data[graph_id]['node_idx_new']
    original_prediction = int(model(feat, adj).squeeze(0).argmax(dim=-1)[target])
    if label == 0:
        continue
    go = False
    if label == original_prediction:
        go = True
    if not go:
        continue
    
    adj = adj.squeeze(0).numpy()
    cfs = top_indices[graph_id]
    cf_list = list()
    for cf in cfs:
        src, dest = cf
        cf_list.append([src, dest, 'del'])
    new_dataset[graph_id] = {"target": target, "adj": adj, "cfs": cf_list}

In [53]:
with open(f"{DATASET}_gem_{EVAL}", "wb") as file:
    pickle.dump(new_dataset, file)

## Rough

In [54]:
sub_data

{514: {'node_idx_new': 2,
  'sub_feat': tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        