# Create Checkpoint

In [None]:
import os
import pickle
import sys

import torch

sys.path.append("../gnnexp")
from models import GNN_Custom_Graph

In [None]:
DATASET = "Mutagenicity" # OPTIONS: Mutagenicity, NCI1, IsCyclic

#todo: MUTAG dataset is different from other baselines.

## Data

In [None]:
with torch.no_grad():
    ckpt = torch.load(f"../ckpt/{DATASET}_base_h20_o20.pth.tar")
    cg_dict = ckpt["cg"] # Get the graph data.
    input_dim = cg_dict["feat"].shape[2]
    num_classes = 2

In [None]:
max_label = max(max(cg_dict['test_idx']), max(cg_dict['val_idx']), max(cg_dict['train_idx']))
print(max_label)

In [None]:
with open(f"../data/{DATASET}/index_{DATASET}.pkl", "rb") as file:
    indices = pickle.load(file)

train_set_indices = list(set([int(i) for i in indices['idx_train'] if i <= max_label]))
val_set_indices = list(set([int(i) for i in indices['idx_val'] if i <= max_label]))
test_set_indices = list(set([int(i) for i in indices['idx_test'] if i <= max_label]))

In [None]:
len(test_set_indices)

## Model

In [None]:
model = GNN_Custom_Graph(
    in_features=input_dim,
    h_features=128,
)
print(model)

## CFGNN model weights

In [None]:
state_dict_cfgnn = torch.load(
    f"../graph_classification_model_weights/{DATASET}_weights.pt"
)

## Preds

In [None]:
model.load_state_dict(state_dict_cfgnn)
model.eval()

In [None]:
preds = list()
labels = list()

for graph_id in range(cg_dict["adj"].size(0)):
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    label = cg_dict['label'][graph_id]
    pred = model(feat, adj)
    preds.append(pred)
    labels.append(label)
preds = torch.Tensor(preds)
labels = torch.Tensor(labels)

In [None]:
(preds.round() == labels).sum()/len(labels)

In [None]:
ckpt['cg']['pred'] = preds.unsqueeze(0).numpy()

## Our eval set as part of the training set

In [None]:
train_set_1 = list()
val_set_1 = list()
test_set_1 = list()

for set_ in ['train', 'val', 'test']:
    for idx in eval(f"{set_}_set_indices"):
        label = ckpt['cg']['label'][idx]
        pred = ckpt['cg']['pred'][0][idx].round()
        if label == pred == 1:
            eval(f"{set_}_set_1.append(idx)")

In [None]:
ckpt["cg"]["train_idx"] = train_set_1 + val_set_1 + test_set_1
ckpt["cg"]["test_idx"] = test_set_1
ckpt["model_state"] = state_dict_cfgnn

In [None]:
len(ckpt['cg']['train_idx'])

In [None]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_train.pt")

## Our eval set as the validation set

In [None]:
ckpt["cg"]["train_idx"] = train_set_1 + val_set_1
ckpt["cg"]["test_idx"] = test_set_1
ckpt["model_state"] = state_dict_cfgnn

In [None]:
len(ckpt['cg']['train_idx'])

In [None]:
os.makedirs(f"../data/{DATASET}", exist_ok=True)
torch.save(ckpt, f"../data/{DATASET}/eval_as_eval.pt")

## Rough

In [None]:
test_preds = list()
test_labels = list()
indices = test_set_1
for graph_id in indices:
    feat = cg_dict["feat"][graph_id, :].float().unsqueeze(0)
    adj = cg_dict["adj"][graph_id].float().unsqueeze(0)
    label = cg_dict['label'][graph_id]
    pred = model(feat, adj)
    test_preds.append(pred)
    test_labels.append(label)
test_preds = torch.Tensor(test_preds)
test_labels = torch.Tensor(test_labels)

In [None]:
acc = (test_labels == test_preds.round()).sum()/test_labels.size(0)
print(f"Test accuracy (label-1): {100 * acc:.2f} %")