# Create Datasets

In [16]:
import os
import pickle

import torch

In [17]:
DATASET = "syn4"

## Data

In [18]:
with torch.no_grad():
    ckpt = torch.load('../ckpt/%s_base_h20_o20.pth.tar'%(DATASET))
    cg_dict = ckpt["cg"] # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    adj = cg_dict["adj"][0]
    label = cg_dict["label"][0]
    features = torch.tensor(cg_dict["feat"][0], dtype=torch.float)
    num_class = max(label)+1

In [19]:
ckpt.keys()

dict_keys(['epoch', 'model_type', 'optimizer', 'model_state', 'optimizer_state', 'cg'])

In [20]:
cg_dict.keys()

dict_keys(['adj', 'feat', 'label', 'pred', 'train_idx', 'test_idx'])

## CFGNN model weights

In [21]:
state_dict_cfgnn = torch.load(f"../cfgnn_model_weights/gcn_3layer_{DATASET}.pt")
for key, val in state_dict_cfgnn.items():
    print(f"{key:<10} : {val.size()}")

gc1.weight : torch.Size([10, 20])
gc1.bias   : torch.Size([20])
gc2.weight : torch.Size([20, 20])
gc2.bias   : torch.Size([20])
gc3.weight : torch.Size([20, 20])
gc3.bias   : torch.Size([20])
lin.weight : torch.Size([2, 60])
lin.bias   : torch.Size([2])


## Eval set

In [22]:
with open("../../eval_set.pkl", "rb") as file:
    eval_set = pickle.load(file)

In [23]:
eval_set.keys()

dict_keys(['syn1/ba-shapes', 'syn4/tree-cycles', 'syn5/tree-grid'])

In [24]:
if DATASET == "syn1":
    KEY = "syn1/ba-shapes"
elif DATASET == "syn4":
    KEY = "syn4/tree-cycles"
elif DATASET == "syn5":
    KEY = "syn5/tree-grids"

## Our eval set as part of the training set

In [25]:
train_set_indices = [range(label.shape[0])]
test_set_indices = eval_set[KEY]

### Save

In [26]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [27]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as part of the validation set

In [28]:
train_set_indices = [i for i in range(label.shape[0]) if i not in eval_set[KEY]]
test_set_indices = eval_set[KEY]

### Save

In [29]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [30]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")