# Create Checkpoint

In [1]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_Graph

In [2]:
DATASET = "NCI1" # OPTIONS: Mutagenicity, NCI1, IsCyclic

#todo: MUTAG dataset is different from other baselines.

## Data

In [3]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]

In [4]:
with open(f"prog_args_{DATASET}.pkl", "rb") as file:
    prog_args = pickle.load(file)

## Model

In [5]:
model = GNN_Custom_Graph(
    in_features=input_dim,
    h_features=128,
)
print(model)

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)


## CFGNN model weights

In [6]:
state_dict_cfgnn = torch.load(
    f"../graph_classification_model_weights/{DATASET}_weights.pt"
)

## Preds

In [7]:
model.load_state_dict(state_dict_cfgnn)
model.eval()

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)

In [8]:
preds = list()
for graph_id in range(cg_dict['adj'].size(0)):
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    pred = model(feat, adj)
    preds.append(pred)
preds = torch.Tensor(preds)

In [9]:
cg_dict['pred'] = preds

## Eval set

In [10]:
with open(f"../data/{DATASET}/index_{DATASET}.pkl", "rb") as file:
    indices = pickle.load(file)

train_set_indices = [int(i) for i in indices['idx_train']]
val_set_indices = [int(i) for i in indices['idx_val']]
test_set_indices = [int(i) for i in indices['idx_test']]

## Our eval set as part of the training set

In [11]:
ckpt["cg"]["train_idx"] = train_set_indices + val_set_indices + test_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [12]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as the validation set

In [13]:
ckpt["cg"]["train_idx"] = train_set_indices + val_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [14]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough

In [15]:
temp = preds.round()

In [16]:
acc = (temp == cg_dict['label']).sum()/temp.size(0)
print(f"Accuracy: {100 * acc:.2f} %")

Accuracy: 68.68 %
