In [6]:
import time
import os
import pandas as pd
import numpy as np
import torch
import gc
import dgl
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from torch.autograd import Variable
from dgl.data import DGLDataset
from sklearn.utils import shuffle
from gcn import GCN
from my_dataset import MyDataset

In [11]:
my_batch_size = 30
my_dataset = MyDataset('./test_dataset.csv', my_batch_size)

from dgl.dataloading.pytorch import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(my_dataset)
print("dataset length:", num_examples)

test_sampler = SubsetRandomSampler(torch.arange(num_examples))
test_dataloader = GraphDataLoader(my_dataset, sampler=test_sampler, batch_size=my_batch_size, drop_last=False)

dataset length: 4649


In [36]:
modelPath = '../models/gcn1638302501.pkl'
model = torch.load(modelPath)
model

GCN(
  (conv1): GraphConv(in=44, out=80, normalization=both, activation=None)
  (conv2): GraphConv(in=80, out=160, normalization=both, activation=None)
  (conv3): GraphConv(in=160, out=112, normalization=both, activation=None)
  (conv4): GraphConv(in=112, out=160, normalization=both, activation=None)
  (conv5): GraphConv(in=160, out=176, normalization=both, activation=None)
  (conv6): GraphConv(in=176, out=96, normalization=both, activation=None)
  (conv7): GraphConv(in=96, out=144, normalization=both, activation=None)
  (conv8): GraphConv(in=144, out=96, normalization=both, activation=None)
  (conv9): GraphConv(in=96, out=128, normalization=both, activation=None)
  (conv10): GraphConv(in=128, out=96, normalization=both, activation=None)
  (conv11): GraphConv(in=96, out=160, normalization=both, activation=None)
  (dnn1): Linear(in_features=160, out_features=140, bias=True)
  (dnn2): Linear(in_features=140, out_features=1, bias=True)
)

In [37]:
num_correct = 0
num_tests = 0
FP = 0
FN = 0
device = torch.device("cuda:0")
for batched_graph, labels in test_dataloader:
    batched_graph, labels = batched_graph.to(device), labels.to(device)
    pred = model(batched_graph, batched_graph.ndata['h'].float()).squeeze(1).squeeze(1)
    # print(pred, labels)
    for i, p in enumerate(pred.round()):
        if p != labels[i]:
            FP += 1 if p == torch.tensor(1.0) else 0
            FN += 1 if p == torch.tensor(0.0) else 0

    num_correct += (pred.round() == labels).sum().item() # TP+TN
    num_tests += len(labels) # TP+TN+FP+FN
num_correct/num_tests

0.6205635620563562

# Find out the inside of acc

In [40]:
dfTest = pd.read_csv('./test_dataset.csv')
testFiles = dfTest['file_name']
testFiles[-10:-5]

2778    ../negative_graph_save8A_test/6ays_ligand_6.p
2779    ../negative_graph_save8A_test/6r8l_ligand_8.p
2780    ../negative_graph_save8A_test/6ajy_ligand_6.p
2781    ../negative_graph_save8A_test/6i65_ligand_4.p
2782    ../negative_graph_save8A_test/6r1b_ligand_4.p
Name: file_name, dtype: object

In [41]:
testFiles[:5]

0    ../positive_graph_save8A_test/6as8
1    ../positive_graph_save8A_test/6eeb
2    ../positive_graph_save8A_test/6dys
3    ../positive_graph_save8A_test/5zg1
4    ../positive_graph_save8A_test/6aro
Name: file_name, dtype: object

In [50]:
testPreds = []
for i, file in enumerate(testFiles):
    g = dgl.load_graphs(file)[0][0].to(device) # load_graphs returns tuple(graphs, labels)
    pred = model(g, g.ndata['h'].float()).squeeze(1).squeeze(1)
    testPreds.append(pred.round().item())
    print('\r' + str(i), end='')

2787

In [53]:
len(testFiles) == len(testPreds)

True

In [57]:
dfPred = pd.DataFrame(columns=['file_name', 'label'])
dfPred['file_name'] = testFiles
dfPred['label'] = testPreds
dfPred.head()

Unnamed: 0,file_name,label
0,../positive_graph_save8A_test/6as8,1.0
1,../positive_graph_save8A_test/6eeb,0.0
2,../positive_graph_save8A_test/6dys,1.0
3,../positive_graph_save8A_test/5zg1,0.0
4,../positive_graph_save8A_test/6aro,0.0


In [58]:
dfPred.to_csv('./test_preds.csv')