# Create Checkpoint

In [None]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GCNSynthetic

In [None]:
DATASET = "syn1"

## Data

In [None]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    adj = cg_dict["adj"][0]
    label = cg_dict["label"][0]
    features = torch.tensor(cg_dict["feat"][0], dtype=torch.float)
    num_class = max(label)+1

In [None]:
with open(F"prog_args_{DATASET}.pkl", "rb") as file:
    prog_args = pickle.load(file)

In [None]:
print(f'{"KEY":<10}: {"OBJECT":<25}: {"TYPE":<15}: SHAPE/LEN\n')
for key, val in ckpt['cg'].items():
    try:
        print(f"{key:<10}: {str(type(val)):<25}: {str(val.dtype):<15}: {val.shape}")
    except: # object doesn't have method named "shape"
        print(f"{key:<10}: {str(type(val)):<25}: {str(type(val[0])):<15}: {len(val)}")

## Model

In [None]:
model = GCNSynthetic(
    nfeat=input_dim,
    nhid=prog_args.hidden_dim,
    nout=prog_args.output_dim,
    nclass=num_class,
    dropout=0.0,
)

In [None]:
# CFGNN model weights
state_dict_cfgnn = torch.load(f"../cfgnn_model_weights/gcn_3layer_{DATASET}.pt")
# for key, val in state_dict_cfgnn.items():
#     print(f"{key:<10} : {val.size()}")

## Preds

In [None]:
model.load_state_dict(state_dict_cfgnn)
model.eval()
preds = model(
    torch.from_numpy(cg_dict['feat']).float(),
    torch.from_numpy(cg_dict['adj']).float()
)

In [None]:
cg_dict['pred'] = preds.detach().numpy()

## Eval set

In [None]:
with open("../../eval_set.pkl", "rb") as file:
    eval_set = pickle.load(file)
print(eval_set.keys())

In [None]:
if DATASET == "syn1":
    KEY = "syn1/ba-shapes"
elif DATASET == "syn4":
    KEY = "syn4/tree-cycles"

### Our eval set as part of the training set

In [None]:
train_set_indices = [i for i in range(label.shape[0])]
test_set_indices = list(set(eval_set[KEY]))

In [None]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [None]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

### Our eval set as the validation set

In [None]:
train_set_indices = [i for i in range(label.shape[0]) if i not in eval_set[KEY]]
test_set_indices = list(set(eval_set[KEY]))

In [None]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [None]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough