# Create Checkpoint

In [18]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_Graph

In [19]:
DATASET = "NCI1" # OPTIONS: Mutagenicity, NCI1, IsCyclic

#todo: MUTAG dataset is different from other baselines.

## Data

In [20]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    num_classes = 2

In [21]:
with open(f"prog_args_{DATASET}.pkl", "rb") as file:
    prog_args = pickle.load(file)

In [22]:
max_label = max(max(cg_dict['test_idx']), max(cg_dict['val_idx']), max(cg_dict['train_idx']))
print(max_label)

4102


In [23]:
with open(f"../data/{DATASET}/index_{DATASET}.pkl", "rb") as file:
    indices = pickle.load(file)

train_set_indices = list(set([int(i) for i in indices['idx_train'] if i <= max_label]))
val_set_indices = list(set([int(i) for i in indices['idx_val'] if i <= max_label]))
test_set_indices = list(set([int(i) for i in indices['idx_test'] if i <= max_label]))

## Model

In [24]:
model = GNN_Custom_Graph(
    in_features=input_dim,
    h_features=128,
)
print(model)

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)


## CFGNN model weights

In [25]:
state_dict_cfgnn = torch.load(
    f"../graph_classification_model_weights/{DATASET}_weights.pt"
)

## Preds

In [26]:
model.load_state_dict(state_dict_cfgnn)
model.eval()

GNN_Custom_Graph(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (conv3): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=1, bias=True)
)

In [27]:
preds = list()
labels = list()
indices = train_set_indices + val_set_indices + test_set_indices
for graph_id in indices:
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    pred = model(feat, adj)
    label = cg_dict['label'][graph_id]
    preds.append(pred)
    labels.append(label)
preds = torch.Tensor(preds)
labels = torch.Tensor(labels)

In [28]:
cg_dict['pred'] = preds.unsqueeze(0).numpy()
cg_dict['label'] = labels

## Our eval set as part of the training set

In [29]:
ckpt["cg"]["train_idx"] = train_set_indices + val_set_indices + test_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [30]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as the validation set

In [31]:
ckpt["cg"]["train_idx"] = train_set_indices + val_set_indices
ckpt["cg"]["test_idx"] = test_set_indices
ckpt["model_state"] = state_dict_cfgnn

In [32]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough

In [33]:
temp = preds.round()

In [34]:
acc = (temp == labels).sum()/temp.size(0)
print(f"Accuracy: {100 * acc:.2f} %")

Accuracy: 68.68 %
