In [1]:
import dgl
import torch
import numpy as np
import pandas as pd
import networkx as nx
from tqdm.notebook import tqdm

Using backend: pytorch


### Making the comb files

In [None]:
# Eval combinations
eval_df = pd.read_csv('../data/TWOSIDE-evaluation-PSE-964.csv', sep=',')
eval_df

In [None]:
drugs_list = eval_df['Drug1'].unique().tolist()
for drug in eval_df['Drug2'].unique().tolist():
    if drug not in drugs_list:
        drugs_list.append(drug)

len(drugs_list)

In [None]:
%%time
from urllib.request import urlopen

f = open('../data/Eval_drugs_964.tsv', 'a')
f.write('Drug_name\tPubChemID\tSMILES\n')

for drug in tqdm(drugs_list):
    drug_name = drug
    
    if ' ' in drug:
        drug = drug.replace(' ', '%20')
        
    try:
        url1 = 'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/'+drug+'/property/CanonicalSMILES/TXT'
        url2 = 'https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/'+drug+'/cids/TXT'
        res1 = urlopen(url1)
        smiles = str(res1.read())[2:-3]
        res2 = urlopen(url2)
        drugid = str(res2.read())[2:-3].split('\\n')[0]
        
        row = drug_name + '\t' + drugid +'\t'+ smiles + '\n'
        f.write(row)
    except:
        #row = drug_name + '\t' + '-\n'
        #f.write(row)
        pass

f.close()

In [None]:
df = pd.read_csv('../data/Eval_drugs_964.tsv', sep='\t')
df

### Making the DTI and DrugID file

#### DrugID

In [None]:
#DrugID
eval_drugs = pd.read_csv('../data/Eval_drugs_964.tsv', sep='\t')
eval_drugs = eval_drugs[['Drug_name','PubChemID']]
eval_drugs

In [None]:
eval_drugs

In [None]:
drugs = eval_drugs.sort_values(by='PubChemID')['PubChemID'].unique().tolist()
drugs

In [None]:
eval_drugs['GraphID'] = '-'
drugs = eval_drugs.sort_values(by='PubChemID')['PubChemID'].unique().tolist()
dic = {drug:drugs.index(drug)+1 for drug in drugs} # conversion dic, starts at 0
eval_drugs['GraphID'] = eval_drugs['PubChemID'].map(dic) #DrugIDs
eval_drugs

In [None]:
eval_drugs = eval_drugs.sort_values(by='GraphID')
eval_drugs

In [None]:
eval_drugs = eval_drugs.rename({'Drug_name':'Name','PubChemID':'DrugID'}, axis=1)
eval_drugs = eval_drugs[['GraphID','DrugID','Name']]
eval_drugs

In [None]:
eval_drugs.to_csv('../data/Eval_DrugID.csv', index=False, sep = ',')

#### DTI

In [None]:
eval_dti = pd.read_csv('../data/Eval_affinity_cut_83.37.csv', sep=',')
eval_dti

In [None]:
gene_id = pd.read_csv('../data/GeneID.csv', sep=',')
gene_id

In [None]:
eval_dti['ProteinID'] = '-'
genes = gene_id['Name'].tolist()
gene_dic = {gene:genes.index(gene)+1 for gene in genes}
eval_dti['ProteinID'] = eval_dti['GeneID'].map(gene_dic) #DrugIDs
eval_dti

In [None]:
eval_dti[['DrugID','ProteinID']].to_csv('../data/Eval_DTI_full.csv', index=False, sep = ',')

## Evaluation 

In [1]:
import os
import dgl
import time
import torch
import pandas as pd
import torch.nn as nn
from tqdm.notebook import tqdm
from sklearn import metrics
import torch.nn.functional as F
from dgl.data import DGLDataset
torch.cuda.set_device(0)  

Using backend: pytorch


In [2]:
class PSE_eval(DGLDataset):
    def __init__(self):
        super().__init__(name='PSE_eval')

    def process(self):
        features = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv',index_col = 'ProteinID', sep=',')
        drug_comb = pd.read_csv('../data/Eval_TWOSIDE-evaluation-PSE-964.csv', sep=',') 
        nodes = pd.read_csv('../data/GNN-GSE_full_pkd_norm.csv', sep=',')
        edges = pd.read_csv('../data/GNN-PPI-net.csv', sep=',')
        dti = pd.read_csv('../data/Eval_DTI_full.csv', sep=',')
        DrugID = pd.read_csv('../data/Eval_DrugID.csv', sep = ',')
        print('data loaded!')
        
        # generate drug specific ppi subgraph for GNN edges
        def drug2ppi(drug):
            genes = dti['ProteinID'].loc[dti['DrugID'] == drug].tolist()
            df = edges[['protein1','protein2']].loc[edges['protein1'].isin(genes)]
            df = df.loc[df['protein2'].isin(genes)]
            num_nodes = len(df['protein1'].unique())
            df['graph_id'] = DrugID.loc[DrugID['DrugID'] == drug]['GraphID'].tolist()[0]  #DrugID
            df = df.rename(columns={'protein1': 'src_prot', 'protein2': 'dst_prot'}) # prot: actual protein id
            final_genes =df['src_prot'].unique().tolist() # final genes that have ppi data
            dic = {gene:final_genes.index(gene) for gene in final_genes} # conversion dic, starts at 0
            df['src'] = df['src_prot'].map(dic) #local ids
            df['dst'] = df['dst_prot'].map(dic) #local ids
            return(df[['graph_id', 'src', 'dst', 'src_prot', 'dst_prot']],num_nodes)
        
        self.graphs = []
        self.labels = []
        self.comb_graphs = []
        self.comb_labels = []

        #Node features or PSEs dictionary
        feature_dic = {i+1:torch.tensor(features.loc[i+1,]) for i in range(len(features))}
    
        # For each graph ID...
        for drug in tqdm(DrugID['DrugID'].tolist()[:100]):
            # Find the edges as well as the number of nodes and its label.
            edges_of_id,num_nodes = drug2ppi(drug)
            src = edges_of_id['src'].to_numpy()
            dst = edges_of_id['dst'].to_numpy()
            label = DrugID.loc[DrugID['DrugID'] == drug]['Name'].tolist()[0]
            
            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            
            # Need to convert proteinsIDs for feature assigning
            prot_ids = edges_of_id['src_prot'].unique().tolist()
            for prot in edges_of_id['dst_prot'].unique().tolist():
                if prot not in prot_ids:
                    prot_ids.append(prot)
            convert_prot = {prot_ids.index(prot):prot for prot in prot_ids}
            
            #Adding features of each node
            g.ndata['PSE'] = torch.zeros(g.num_nodes(), 964)
            for node in g.nodes().tolist():
                g.ndata['PSE'][node] = feature_dic[convert_prot[node]]
                
            self.graphs.append(g)
            self.labels.append(label)
            
        # conver drugid to their respective graph id
        #drug2graph = {properties['label'][i]:i for i in range(len(properties))} 
        #drug2graph = {self.labels[i]:i for i in range(len(self.labels))} 
        
        for i in range(len(drug_comb)):
            row = drug_comb.loc[i]
            try:
                g1 = self.graphs[self.labels.index(row[0])] # Drug1 graph
                g2 = self.graphs[self.labels.index(row[1])] # Drug2 graph  
                self.comb_graphs.append([g1,g2])
                self.comb_labels.append(torch.tensor(row[2:])) # PSE values
            except:
                pass

    def __getitem__(self, i):
        return self.comb_graphs[i], self.comb_labels[i]
        #return self.graphs[i], self.labels[i]

    def __len__(self):
        return len(self.comb_graphs)
    
print('\nCreating the Evaluation Dataset ...\n')
dataset = PSE_eval()

print('\nEvaluation Dataset created!\n')

print('\ndataset is compiled! \n')


Creating the Evaluation Dataset ...

data loaded!


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))



Evaluation Dataset created!


dataset is compiled! 



In [3]:
print('\nCreating eval batches ...\n')
# Making the batches
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
eval_sampler = SubsetRandomSampler(torch.arange(num_examples))
eval_dataloader = GraphDataLoader(dataset, sampler=eval_sampler, batch_size=1, drop_last=False)

print('\nEval batches are created!\n')


Creating eval batches ...


Eval batches are created!



In [4]:
# GNN Model: Siamese GCN 
from dgl.nn import GraphConv

class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats,  num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        out = F.relu(dgl.mean_nodes(g, 'h'))
        #out = F.relu(dgl.max_nodes(g, 'h'))
        return out

In [35]:
# Evaluation

# Specify a path
PATH = "../results/entire_model_V2.pt"
conf = '../results/state_dict_model_V2.pt'

# Load
model = GCN(964,200,964)
model.load_state_dict(torch.load(conf))
model.eval()
print('\nEvaluating \n')

#f = open("Eval.txt", "a")

def predict(g1, g2):  # graph1, graph2
    pred1 = model(g1, g1.ndata['PSE'].float())
    pred2 = model(g2, g2.ndata['PSE'].float())
    pred = F.normalize(pred1+pred2)/2
    return(pred)

all_acc = []
all_prec = []
all_spec = []
all_mcc = []
for batched_graph, labels in eval_dataloader:
    g1 = batched_graph[0]
    g2 = batched_graph[1]
    pred = predict(g1, g2)
    #print(pred)
    #print(labels)

    # Threshold 
    tr = 0.5*pred.mean().tolist()
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    for i in range(len(labels)):
        for j in range(len(labels[i])):
            if labels[i][j] == 1 and pred[i][j] >= tr:
                TP += 1
            elif labels[i][j] == 1 and pred[i][j] < tr:
                FN += 1
            elif labels[i][j] == 0 and pred[i][j] >= tr:
                FP += 1
            elif labels[i][j] == 0 and pred[i][j] < tr:
                TN += 1
            else:
                pass
            
    # Validation metrics        
    acc = ((TP+TN)*100)/(TP+FP+FN+TN)
    prec = (TP*100)/(TP+FP)
    recall = (TP*100)/(TP+FN)
    F1 = 2*(recall*prec)/(recall+prec)
    spec = (TN*100)/(FP+TN)
    mcc = (TP*TN-FP*FN)/((TP+FN)*(TP+FP)*(TN+FP)*(TN+FN))**0.5
    mcc2 = (mcc+1)/2
    #sim = ((F.cosine_similarity(pred.float(),labels.float())).mean().tolist())*100
    
    x = [TP,TN,FP,FN]
    print('TP:%s,TN:%s,FP:%s,FN:%s' % (x[0],x[1],x[2],x[3]))
    msg2 = 'Accuracy: %s | Precision: %s | Recall: %s | F1: %s | Specificity: %s | MCC: %s\n' %(
            round(acc,4),round(prec,4),round(recall,4),round(F1,4),round(spec,4),round(mcc2,4))
    all_acc.append(acc)
    all_prec.append(prec)
    all_spec.append(spec)
    all_mcc.append(mcc2)
    print(msg2)



Evaluating 

TP:27,TN:853,FP:7,FN:77
Accuracy: 91.2863 | Precision: 79.4118 | Recall: 25.9615 | F1: 39.1304 | Specificity: 99.186 | MCC: 0.7115

TP:2,TN:931,FP:28,FN:3
Accuracy: 96.7842 | Precision: 6.6667 | Recall: 40.0 | F1: 11.4286 | Specificity: 97.0803 | MCC: 0.5767

TP:10,TN:912,FP:24,FN:18
Accuracy: 95.6432 | Precision: 29.4118 | Recall: 35.7143 | F1: 32.2581 | Specificity: 97.4359 | MCC: 0.6509

TP:13,TN:904,FP:17,FN:30
Accuracy: 95.1245 | Precision: 43.3333 | Recall: 30.2326 | F1: 35.6164 | Specificity: 98.1542 | MCC: 0.6687

TP:21,TN:882,FP:13,FN:48
Accuracy: 93.6722 | Precision: 61.7647 | Recall: 30.4348 | F1: 40.7767 | Specificity: 98.5475 | MCC: 0.7025

TP:24,TN:859,FP:6,FN:75
Accuracy: 91.5975 | Precision: 80.0 | Recall: 24.2424 | F1: 37.2093 | Specificity: 99.3064 | MCC: 0.7058

TP:19,TN:875,FP:15,FN:55
Accuracy: 92.7386 | Precision: 55.8824 | Recall: 25.6757 | F1: 35.1852 | Specificity: 98.3146 | MCC: 0.6731

TP:11,TN:919,FP:23,FN:11
Accuracy: 96.473 | Precision: 32.35

TP:4,TN:919,FP:30,FN:11
Accuracy: 95.7469 | Precision: 11.7647 | Recall: 26.6667 | F1: 16.3265 | Specificity: 96.8388 | MCC: 0.5789

TP:22,TN:818,FP:12,FN:112
Accuracy: 87.1369 | Precision: 64.7059 | Recall: 16.4179 | F1: 26.1905 | Specificity: 98.5542 | MCC: 0.6404

TP:8,TN:908,FP:26,FN:22
Accuracy: 95.0207 | Precision: 23.5294 | Recall: 26.6667 | F1: 25.0 | Specificity: 97.2163 | MCC: 0.6124

TP:4,TN:921,FP:30,FN:9
Accuracy: 95.9544 | Precision: 11.7647 | Recall: 30.7692 | F1: 17.0213 | Specificity: 96.8454 | MCC: 0.5863

TP:10,TN:894,FP:24,FN:36
Accuracy: 93.7759 | Precision: 29.4118 | Recall: 21.7391 | F1: 25.0 | Specificity: 97.3856 | MCC: 0.6105

TP:24,TN:865,FP:10,FN:65
Accuracy: 92.2199 | Precision: 70.5882 | Recall: 26.9663 | F1: 39.0244 | Specificity: 98.8571 | MCC: 0.7026

TP:8,TN:910,FP:26,FN:20
Accuracy: 95.2282 | Precision: 23.5294 | Recall: 28.5714 | F1: 25.8065 | Specificity: 97.2222 | MCC: 0.6174

TP:8,TN:911,FP:26,FN:19
Accuracy: 95.332 | Precision: 23.5294 | Recall: 

TP:5,TN:927,FP:29,FN:3
Accuracy: 96.6805 | Precision: 14.7059 | Recall: 62.5 | F1: 23.8095 | Specificity: 96.9665 | MCC: 0.6462

TP:14,TN:905,FP:20,FN:25
Accuracy: 95.332 | Precision: 41.1765 | Recall: 35.8974 | F1: 38.3562 | Specificity: 97.8378 | MCC: 0.6802

TP:17,TN:903,FP:17,FN:27
Accuracy: 95.4357 | Precision: 50.0 | Recall: 38.6364 | F1: 43.5897 | Specificity: 98.1522 | MCC: 0.7081

TP:6,TN:926,FP:28,FN:4
Accuracy: 96.6805 | Precision: 17.6471 | Recall: 60.0 | F1: 27.2727 | Specificity: 97.065 | MCC: 0.6567

TP:19,TN:870,FP:15,FN:60
Accuracy: 92.2199 | Precision: 55.8824 | Recall: 24.0506 | F1: 33.6283 | Specificity: 98.3051 | MCC: 0.6662

TP:4,TN:922,FP:30,FN:8
Accuracy: 96.0581 | Precision: 11.7647 | Recall: 33.3333 | F1: 17.3913 | Specificity: 96.8487 | MCC: 0.5907

TP:8,TN:914,FP:26,FN:16
Accuracy: 95.6432 | Precision: 23.5294 | Recall: 33.3333 | F1: 27.5862 | Specificity: 97.234 | MCC: 0.6291

TP:1,TN:923,FP:29,FN:11
Accuracy: 95.8506 | Precision: 3.3333 | Recall: 8.3333 | 

TP:8,TN:899,FP:26,FN:31
Accuracy: 94.0871 | Precision: 23.5294 | Recall: 20.5128 | F1: 21.9178 | Specificity: 97.1892 | MCC: 0.5945

TP:5,TN:926,FP:29,FN:4
Accuracy: 96.5768 | Precision: 14.7059 | Recall: 55.5556 | F1: 23.2558 | Specificity: 96.9634 | MCC: 0.6369

TP:7,TN:911,FP:27,FN:19
Accuracy: 95.2282 | Precision: 20.5882 | Recall: 26.9231 | F1: 23.3333 | Specificity: 97.1215 | MCC: 0.6056

TP:27,TN:823,FP:7,FN:107
Accuracy: 88.1743 | Precision: 79.4118 | Recall: 20.1493 | F1: 32.1429 | Specificity: 99.1566 | MCC: 0.681

TP:3,TN:924,FP:31,FN:6
Accuracy: 96.1618 | Precision: 8.8235 | Recall: 33.3333 | F1: 13.9535 | Specificity: 96.7539 | MCC: 0.5784

TP:5,TN:910,FP:29,FN:20
Accuracy: 94.917 | Precision: 14.7059 | Recall: 20.0 | F1: 16.9492 | Specificity: 96.9116 | MCC: 0.5729

TP:16,TN:893,FP:14,FN:41
Accuracy: 94.2946 | Precision: 53.3333 | Recall: 28.0702 | F1: 36.7816 | Specificity: 98.4564 | MCC: 0.6802

TP:24,TN:899,FP:6,FN:35
Accuracy: 95.7469 | Precision: 80.0 | Recall: 40.67

In [33]:
sum(all_acc)/len(all_acc)

91.48335036176738

In [34]:
sum(all_prec)/len(all_prec)

39.9514333209374

In [31]:
sum(all_spec)/len(all_spec)

99.37736665920714

In [19]:
sum(all_mcc)/len(all_mcc)

0.6244535919931011

### Visualization

In [None]:
import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

In [None]:
plt.figure()
lw = 2
plt.plot(fpr[2], tpr[2], color='darkorange',
         lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()