# NCI1

In [1]:
from math import sqrt
import sys

import dgl
import torch

sys.path.append("../")
from gcn import GCNGraph_NCI1
from utils.preprocessing.nci1_preprocessing \
    import nci1_preprocessing

## Data

In [2]:
dataset_dir = "/home/shade/code/github/graph-classification/data/NCI1/raw"
dataset = nci1_preprocessing(dataset_dir)

processing


In [3]:
import pickle
with open("../../../data/NCI1/index.pkl", "rb") as file:
    index = pickle.load(file)
print(index.keys())

dict_keys(['idx_train', 'idx_val', 'idx_test'])


In [4]:
train_dataset = tuple(dataset[idx] for idx in index['idx_train'])
val_dataset = tuple(dataset[idx] for idx in index['idx_val'])
test_dataset = tuple(dataset[idx] for idx in index['idx_test'])

## Model

In [5]:
model = GCNGraph_NCI1(
    in_feats=dataset.graphs[0].ndata['feat'].size(1),
    h_feats=128
)
print(model)

GCNGraph_NCI1(
  (conv1): GraphConvLayer()
  (conv2): GraphConvLayer()
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=2, bias=True)
)


## Load weights

In [6]:
state_dict = torch.load("nci1_weights.pt")
for key, val in state_dict.items():
    print(f"{key:<15}: {val.size()}")

conv1.weight   : torch.Size([37, 128])
conv1.bias     : torch.Size([128])
conv2.weight   : torch.Size([128, 128])
conv2.bias     : torch.Size([128])
dense1.weight  : torch.Size([16, 128])
dense1.bias    : torch.Size([16])
dense2.weight  : torch.Size([8, 16])
dense2.bias    : torch.Size([8])
dense3.weight  : torch.Size([2, 8])
dense3.bias    : torch.Size([2])


In [7]:
model.load_state_dict(state_dict)
model.eval()

GCNGraph_NCI1(
  (conv1): GraphConvLayer()
  (conv2): GraphConvLayer()
  (dense1): Linear(in_features=128, out_features=16, bias=True)
  (dense2): Linear(in_features=16, out_features=8, bias=True)
  (dense3): Linear(in_features=8, out_features=2, bias=True)
)

## Eval

In [10]:
def test(dataset):
    model.eval()

    correct = 0
    for data in dataset:
        graph, label = data
        out = model(
            graph,
            graph.ndata['feat'].float(),
            graph.edata['weight'].float()
        )
        pred = out.argmax(dim=-1)
        correct += int((pred == label.long()).sum())
    return correct / len(dataset)

In [11]:
test_acc = test(test_dataset)
print(f"Test accuracy: {100 * test_acc:.2f} %")

Test accuracy: 68.20 %


## Rough