In [None]:
import numpy as np
import pandas as pd
import sys, os
from random import shuffle
from tqdm import tqdm
import torch
import torch.nn as nn
from models.gat import GATNet
from models.gat_gcn import GAT_GCN
from models.gcn import GCNNet
from models.ginconv import GINConvNet
from lifelines.utils import concordance_index
from utils import *

In [2]:
def predicting(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()
    print('Make prediction for {} samples...'.format(len(loader.dataset)))
    with torch.no_grad():
        for data in tqdm(loader):
            data = data.to(device)
            output = model(data)
            total_preds = torch.cat((total_preds.to(device), output), 0)
            total_labels = torch.cat((total_labels.to(device), data.y.view(-1, 1).to(device)), 0)
    return total_labels.cpu().numpy().flatten(),total_preds.cpu().numpy().flatten()

In [12]:
# strange model is not strange - it is GOOD!!!
datasets = [['davis','kiba','Ki','Kd','IC50', 'davis2', 'bdtdc_ic50','bdtdc_kd','bdtdc_ki','bindingdb_ic50','bindingdb_ki','bindingdb_kd'][-3]]
modeling = [GINConvNet, GATNet, GAT_GCN, GCNNet][3]
# model_st = modeling.__name__
model_file_name = 'GraphDTA_Results/BindingDB/GCNNet/model_GCNNet_bindingdb_ic50.model' 

In [13]:
TEST_BATCH_SIZE = 512
for dataset in datasets:
    print('Testing ' + dataset )
    if dataset == 'Ki' or dataset == 'Kd' or dataset == 'IC50':
        processed_data_file_test = 'data/bindingdb/processed/bindingDB_' + dataset + '_test.pt'
    else:
        processed_data_file_test = 'data/processed/' + dataset + '_test.pt'
    if (not os.path.isfile(processed_data_file_test)):
        print('please run create_data.py to prepare data in pytorch format!')
    else:
        if dataset == 'Ki' or dataset == 'Kd' or dataset == 'IC50':
            test_data = TestbedDataset(root='data/bindingdb', dataset='bindingDB_'+dataset+'_test')
        else:
            test_data = TestbedDataset(root='data', dataset=dataset+'_test')
        
        # make data PyTorch mini-batch processing ready
        test_loader = DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False)

        # training the model
        
        device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
        model = modeling()
        model.to(device)
        model.load_state_dict(torch.load(model_file_name,map_location='cuda:0'))

        G,P = predicting(model, device, test_loader)
        #ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)]
        ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),concordance_index(G,P)]

        

Testing bindingdb_ic50
Pre-processed data found: data/processed/bindingdb_ic50_test.pt, loading ...
Make prediction for 219906 samples...


100%|██████████| 430/430 [00:57<00:00,  7.41it/s]


In [14]:
print(ret)

[0.7003693730054906, 0.49051726, 0.8867872533814504, 0.8846194353694959, 0.858644259970728]
