# Create Checkpoint

In [1]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GCNSynthetic

In [2]:
DATASET = "syn1"

## Data

In [3]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    adj = cg_dict["adj"][0]
    label = cg_dict["label"][0]
    features = torch.tensor(cg_dict["feat"][0], dtype=torch.float)
    num_class = max(label)+1

In [6]:
with open(F"prog_args_{DATASET}.pkl", "rb") as file:
    prog_args = pickle.load(file)

## Model

In [7]:
model = GCNSynthetic(
    nfeat=input_dim,
    nhid=prog_args.hidden_dim,
    nout=prog_args.output_dim,
    nclass=num_class,
    dropout=0.0,
)

## CFGNN model weights

In [8]:
state_dict_cfgnn = torch.load(f"../cfgnn_model_weights/gcn_3layer_{DATASET}.pt")
# for key, val in state_dict_cfgnn.items():
#     print(f"{key:<10} : {val.size()}")

## Preds

In [10]:
model.load_state_dict(state_dict_cfgnn)
model.eval()

GCNSynthetic(
  (gc1): GraphConvolution (10 -> 20)
  (gc2): GraphConvolution (20 -> 20)
  (gc3): GraphConvolution (20 -> 20)
  (lin): Linear(in_features=60, out_features=4, bias=True)
)

In [12]:
preds = model(
    torch.from_numpy(cg_dict['feat']).float(),
    torch.from_numpy(cg_dict['adj']).float()
)

In [15]:
cg_dict['pred'] = preds

## Eval set

In [17]:
with open("../../eval_set.pkl", "rb") as file:
    eval_set = pickle.load(file)

In [18]:
eval_set.keys()

dict_keys(['syn1/ba-shapes', 'syn4/tree-cycles', 'syn5/tree-grid'])

In [19]:
if DATASET == "syn1":
    KEY = "syn1/ba-shapes"
elif DATASET == "syn4":
    KEY = "syn4/tree-cycles"
elif DATASET == "syn5":
    KEY = "syn5/tree-grids"

## Our eval set as part of the training set

In [20]:
train_set_indices = [range(label.shape[0])]
test_set_indices = eval_set[KEY]

### Save

In [21]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [22]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as the validation set

In [23]:
train_set_indices = [i for i in range(label.shape[0]) if i not in eval_set[KEY]]
test_set_indices = eval_set[KEY]

### Save

In [24]:
ckpt["cg"]["train_idx"] = train_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [25]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough