In [1]:
import time
import os
import pandas as pd
import numpy as np
import torch
import gc
import dgl
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
from torch.autograd import Variable
from dgl.data import DGLDataset
from sklearn.utils import shuffle
from gcn import GCN
from my_dataset import MyDataset

Using backend: pytorch


In [2]:
my_batch_size = 30
my_dataset = MyDataset('./test_dataset.csv', my_batch_size)

from dgl.dataloading.pytorch import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(my_dataset)
print("dataset length:", num_examples)

test_sampler = SubsetRandomSampler(torch.arange(num_examples))
test_dataloader = GraphDataLoader(my_dataset, sampler=test_sampler, batch_size=my_batch_size, drop_last=False)

dataset length: 4850


In [3]:
modelPath = '../models/gcn1639381295GCNWithNewTestDataset.pkl'
model = torch.load(modelPath)
model

GCN(
  (gat): GATConv(
    (fc): Linear(in_features=44, out_features=44, bias=False)
    (feat_drop): Dropout(p=0.0, inplace=False)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (leaky_relu): LeakyReLU(negative_slope=0.2)
  )
  (conv1): GraphConv(in=44, out=80, normalization=both, activation=None)
  (conv2): GraphConv(in=80, out=160, normalization=both, activation=None)
  (conv3): GraphConv(in=160, out=112, normalization=both, activation=None)
  (conv4): GraphConv(in=112, out=160, normalization=both, activation=None)
  (conv5): GraphConv(in=160, out=176, normalization=both, activation=None)
  (conv6): GraphConv(in=176, out=96, normalization=both, activation=None)
  (conv7): GraphConv(in=96, out=144, normalization=both, activation=None)
  (conv8): GraphConv(in=144, out=96, normalization=both, activation=None)
  (conv9): GraphConv(in=96, out=128, normalization=both, activation=None)
  (conv10): GraphConv(in=128, out=96, normalization=both, activation=None)
  (conv11): GraphConv(in=9

In [4]:
num_correct = 0
num_tests = 0
FP = 0
FN = 0
device = torch.device("cuda:0")
for batched_graph, labels in test_dataloader:
    batched_graph, labels = batched_graph.to(device), labels.to(device)
    pred = model(batched_graph, batched_graph.ndata['h'].float()).squeeze(1).squeeze(1)
    # print(pred, labels)
    for i, p in enumerate(pred.round()):
        if p != labels[i]:
            FP += 1 if p == torch.tensor(1.0) else 0
            FN += 1 if p == torch.tensor(0.0) else 0

    num_correct += (pred.round() == labels).sum().item() # TP+TN
    num_tests += len(labels) # TP+TN+FP+FN
num_correct/num_tests

0.7521649484536083

# Find out the inside of acc

In [6]:
dfTest = pd.read_csv('./test_dataset.csv')
testFiles = dfTest['file_name']
testFiles[-10:-5]

2936    ../negative_graph_save/4tz2_ligand_4
2937    ../negative_graph_save/2x95_ligand_9
2938    ../negative_graph_save/5vja_ligand_6
2939             ../positive_graph_save/2za0
2940             ../positive_graph_save/2vmf
Name: file_name, dtype: object

In [7]:
testFiles[:5]

0    ../negative_graph_save/3iof_ligand_1
1    ../negative_graph_save/4bao_ligand_6
2             ../positive_graph_save/6f20
3    ../negative_graph_save/2vot_ligand_3
4    ../negative_graph_save/2brb_ligand_4
Name: file_name, dtype: object

In [8]:
testPreds = []
for i, file in enumerate(testFiles):
    g = dgl.load_graphs(file)[0][0].to(device) # load_graphs returns tuple(graphs, labels)
    pred = model(g, g.ndata['h'].float()).squeeze(1).squeeze(1)
    testPreds.append(pred.round().item())
    print('\r' + str(i), end='')

2945

In [9]:
len(testFiles) == len(testPreds)

True

In [10]:
dfPred = pd.DataFrame(columns=['file_name', 'label'])
dfPred['file_name'] = testFiles
dfPred['label'] = testPreds
dfPred.head()

Unnamed: 0,file_name,label
0,../negative_graph_save/3iof_ligand_1,0.0
1,../negative_graph_save/4bao_ligand_6,0.0
2,../positive_graph_save/6f20,1.0
3,../negative_graph_save/2vot_ligand_3,0.0
4,../negative_graph_save/2brb_ligand_4,0.0


In [12]:
def getNegativeName(file_name):
    if 'negative' not in file_name:
        return ''
    start = file_name.find('ligand')-5
    end = file_name.find('ligand')+8
    return file_name[start:end]

In [14]:
getNegativeName('../negative_graph_save8A_test/6ny0_ligand_7.p')

'6ny0_ligand_7'

In [15]:
'./positive_graph_save/6f20'[-4:]

'6f20'

In [11]:
dfPred.to_csv('./test_preds.csv')

In [22]:
dfFalse = pd.DataFrame(columns=['file_name', 'label']) # comment: FP or FT
for i, row in dfPred.iterrows():
    # print(row)
    fname = row['file_name']
    #print(fname)
    pred = row['label']
    if ('negative' in fname and pred == 1.0) or ('positive' in fname and pred == 0.0):
        fname = fname[-4:] if 'positive' in fname else getNegativeName(fname)
        dfFalse = dfFalse.append([{'file_name': fname, 'label': pred}], ignore_index=True)
        
len(dfFalse)

633

In [23]:
dfFalse.to_csv('pred_false.csv')