# Create Checkpoint

In [1]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_NCI1

In [2]:
DATASET = "NCI1" # OPTIONS: Mutagenicity, NCI1, IsCyclic

#todo: MUTAG dataset is different from other baselines.

## Data

In [3]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    num_classes = 2

In [4]:
ckpt['cg']['label'].unique(return_counts=True)

(tensor([0, 1], dtype=torch.int32), tensor([2052, 2051]))

In [7]:
cg_dict['adj'].size()

torch.Size([4103, 100, 100])

In [8]:
max_label = max(max(cg_dict['test_idx']), max(cg_dict['val_idx']), max(cg_dict['train_idx']))
print(max_label)

4102


In [10]:
with open(f"../data/{DATASET}/index.pkl", "rb") as file:
    indices = pickle.load(file)

train_set_indices = list(set([int(i) for i in indices['idx_train'] if i <= max_label]))
val_set_indices = list(set([int(i) for i in indices['idx_val'] if i <= max_label]))
test_set_indices = list(set([int(i) for i in indices['idx_test'] if i <= max_label]))

In [11]:
len(test_set_indices)

2004

## Model

In [12]:
model = GNN_Custom_NCI1(
    in_features=input_dim,
    h_features=128,
)
print(model)

GNN_Custom_NCI1(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=2, bias=True)
)


## CFGNN model weights

In [13]:
state_dict_cfgnn = torch.load(
    f"../graph_classification_model_weights/{DATASET}_weights.pt"
)

## Preds

In [14]:
model.load_state_dict(state_dict_cfgnn)
model.eval()

GNN_Custom_NCI1(
  (conv1): GraphConvolution (37 -> 128)
  (conv2): GraphConvolution (128 -> 128)
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=2, bias=True)
)

In [15]:
preds = list()
labels = list()

for graph_id in range(cg_dict["adj"].size(0)):
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    label = cg_dict['label'][graph_id]
    pred = model(feat, adj)
    preds.append(pred)
    labels.append(label)
preds = torch.cat(preds).detach().reshape(-1, 2)
labels = torch.Tensor(labels)

In [18]:
acc = (preds.argmax(dim=-1) == labels).sum()/len(labels)
print(f"Accuracy: {100 * acc:.2f} %")

Accuracy: 53.33 %


In [19]:
ckpt['cg']['pred'] = preds.unsqueeze(0).numpy()

## Our eval set as part of the training set

In [20]:
train_set_1 = list()
val_set_1 = list()
test_set_1 = list()

for set_ in ['train', 'val', 'test']:
    for idx in eval(f"{set_}_set_indices"):
        label = ckpt['cg']['label'][idx]
        pred = ckpt['cg']['pred'][0][idx].argmax(axis=-1)
        if label == pred == 1:
            eval(f"{set_}_set_1.append(int(idx))")

In [21]:
ckpt["cg"]["train_idx"] = train_set_1 + test_set_1
ckpt["cg"]["val_idx"] = val_set_1
ckpt["cg"]["test_idx"] = test_set_1
ckpt["model_state"] = state_dict_cfgnn

In [22]:
len(ckpt['cg']['train_idx'])

245

In [23]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as the validation set

In [24]:
ckpt["cg"]["train_idx"] = train_set_1
ckpt["cg"]["val_idx"] = val_set_1
ckpt["cg"]["test_idx"] = test_set_1
ckpt["model_state"] = state_dict_cfgnn

In [25]:
len(ckpt['cg']['train_idx'])

106

In [20]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough

In [23]:
for set_ in ['train', 'val', 'test']:
    correct = 0
    for graph_idx in ckpt['cg'][f'{set_}_idx']:
        feat = ckpt['cg']['feat'][graph_idx, :].float().unsqueeze(0)
        adj = ckpt['cg']['adj'][graph_idx].float().unsqueeze(0)
        label = ckpt['cg']['label'][graph_idx].long().unsqueeze(0)
        pred = model(feat, adj).argmax(dim=-1)
        if label == pred:
            correct += 1
    print(f"{set_.capitalize()} accuracy: "
        f"{100 * correct/len(ckpt['cg'][f'{set_}_idx'])} %")

Train accuracy: 100.0 %
Val accuracy: 100.0 %
Test accuracy: 100.0 %
