# SARS-CoV-2 Knowledge Graph 

#### Please Change the Flag Below to Change the Embedding Algorithm used.

#### DRKG = True will only use DRKG embedding
#### DRKG = False will use the DRKG + SARS-CoV-2 knowledge graph embedding.

In [None]:
DRGK = True

In [None]:
Load packages

In [227]:
import numpy as np
import pandas as pd
import time
import re
import math
import random
import pickle

from sklearn.model_selection import train_test_split
from sklearn import metrics 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn.modules import Module
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import train_test_split_edges
from torch_geometric.utils import add_remaining_self_loops, add_self_loops
from torch_geometric.utils import to_undirected
from torch_geometric.nn import GCNConv, SAGEConv,GAE, VGAE

In [228]:
data_path='data/'
exp_id='vDRKG'
device_id='cpu' #'cpu' if CPU, device number if GPU
embedding_size=64

Load preprocessed files

In [229]:
le=pickle.load(open(data_path+'LabelEncoder_'+exp_id+'.pkl', 'rb'))
edge_index=pickle.load(open(data_path+'edge_index_'+exp_id+'.pkl','rb'))
node_feature_np=pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))

In [230]:
node_feature=torch.tensor(node_feature_np, dtype=torch.float)

In [231]:
edge=torch.tensor(edge_index[['node1', 'node2']].values, dtype=torch.long)

In [232]:
edge_attr_dict={'gene-drug':0,'gene-gene':1,'bait-gene':2, 'gene-phenotype':3, 'drug-phenotype':4}
edge_index['type']=edge_index['type'].apply(lambda x: edge_attr_dict[x])

In [233]:
edge_index['type'].value_counts()

0    29005
1     6174
3     2053
4     1365
2      247
Name: type, dtype: int64

In [234]:
edge_attr=torch.tensor(edge_index['type'].values,dtype=torch.long)

In [235]:
data = Data(x=node_feature,
            edge_index=edge.t().contiguous(),
            edge_attr=edge_attr
           )

In [236]:
data.num_features, data.num_nodes,data.num_edges

(400, 15700, 38844)

In [237]:
edge_attr.size()

torch.Size([38844])

In [238]:
data.has_isolated_nodes(), data.is_directed()

(False, True)

## Batch

In [239]:
def train_test_split_edges(data, val_ratio=0.05, test_ratio=0.1):
    r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
    into positive and negative train/val/test edges, and adds attributes of
    `train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
    `val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
    to :attr:`data`.

    Args:
        data (Data): The data object.
        val_ratio (float, optional): The ratio of positive validation
            edges. (default: :obj:`0.05`)
        test_ratio (float, optional): The ratio of positive test
            edges. (default: :obj:`0.1`)

    :rtype: :class:`torch_geometric.data.Data`
    """

    assert 'batch' not in data  # No batch-mode.

    num_nodes = data.num_nodes
    row, col = data.edge_index
    #data.edge_index = None
    attr = data.edge_attr

    # Return upper triangular portion.
    #mask = row < col
    #row, col = row[mask], col[mask]

    n_v = int(math.floor(val_ratio * row.size(0)))
    n_t = int(math.floor(test_ratio * row.size(0)))

    # Positive edges.
    perm = torch.randperm(row.size(0))
    row, col = row[perm], col[perm]
    attr=attr[perm]

    r, c = row[:n_v], col[:n_v]
    data.val_pos_edge_index = torch.stack([r, c], dim=0)
    data.val_pos_edge_attr = attr[:n_v]
    
    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
    data.test_pos_edge_index = torch.stack([r, c], dim=0)
    data.test_post_edge_attr = attr[n_v:n_v + n_t]

    r, c = row[n_v + n_t:], col[n_v + n_t:]
    data.train_pos_edge_index = torch.stack([r, c], dim=0)
    data.train_pos_edge_attr = attr[n_v+n_t:]

    # Negative edges.
    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
    neg_adj_mask[row, col] = 0

    neg_row, neg_col = neg_adj_mask.nonzero().t()
    perm = random.sample(range(neg_row.size(0)),
                         min(n_v + n_t, neg_row.size(0)))
    perm = torch.tensor(perm)
    perm = perm.to(torch.long)
    neg_row, neg_col = neg_row[perm], neg_col[perm]

    neg_adj_mask[neg_row, neg_col] = 0
    data.train_neg_adj_mask = neg_adj_mask

    row, col = neg_row[:n_v], neg_col[:n_v]
    data.val_neg_edge_index = torch.stack([row, col], dim=0)

    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]
    data.test_neg_edge_index = torch.stack([row, col], dim=0)

    return data

In [240]:
device=torch.device(device_id)

In [241]:
data_split=train_test_split_edges(data, test_ratio=0.1, val_ratio=0)
x,train_pos_edge_index,train_pos_edge_attr = data_split.x.to(device), data_split.train_pos_edge_index.to(device), data_split.train_pos_edge_attr.to(device)


In [242]:
train_pos_edge_index, train_pos_edge_attr=add_remaining_self_loops(train_pos_edge_index,train_pos_edge_attr)

In [243]:
pd.Series(train_pos_edge_attr.cpu().numpy()).value_counts()

0    26098
1    21222
3     1852
4     1219
2      220
dtype: int64

In [244]:
x,train_pos_edge_index,train_pos_edge_attr = Variable(x),Variable(train_pos_edge_index),Variable(train_pos_edge_attr)

## Learning models

Define VGAE model

In [245]:
class Encoder_VGAE(nn.Module):
    def __init__(self, in_channels, out_channels, isClassificationTask=False):
        super(Encoder_VGAE, self).__init__()
        self.isClassificationTask=isClassificationTask
        self.conv_gene_drug=  SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_bait_gene = SAGEConv(in_channels, 2*out_channels, )
        self.conv_gene_phenotype = SAGEConv(in_channels, 2*out_channels, )
        self.conv_drug_phenotype = SAGEConv(in_channels, 2*out_channels)

        self.bn = nn.BatchNorm1d(5*2*out_channels)
        #variational encoder
        self.conv_mu = SAGEConv(5*2*out_channels, out_channels, )
        self.conv_logvar = SAGEConv(5*2*out_channels, out_channels,)

    def forward(self,x,edge_index,edge_attr):
        
        x = F.dropout(x, training=self.training)
        
        index_gene_drug=(edge_attr==0).nonzero().reshape(1,-1)[0]
        edge_index_gene_drug=edge_index[:, index_gene_drug]
        
        index_gene_gene=(edge_attr==1).nonzero().reshape(1,-1)[0]
        edge_index_gene_gene=edge_index[:, index_gene_gene]
        
        index_bait_gene=(edge_attr==2).nonzero().reshape(1,-1)[0]
        edge_index_bait_gene=edge_index[:, index_bait_gene]
        
        index_gene_phenotype=(edge_attr==3).nonzero().reshape(1,-1)[0]
        edge_index_gene_phenotype=edge_index[:, index_gene_phenotype]
        
        index_drug_phenotype=(edge_attr==4).nonzero().reshape(1,-1)[0]
        edge_index_drug_phenotype=edge_index[:, index_drug_phenotype]
        
        
        x_gene_drug = F.dropout(F.relu(self.conv_gene_drug(x,edge_index_gene_drug)), p=0.5, training=self.training, )
        x_gene_gene = F.dropout(F.relu(self.conv_gene_gene(x,edge_index_gene_gene)), p=0.5, training=self.training)
        x_bait_gene = F.dropout(F.relu(self.conv_bait_gene(x,edge_index_bait_gene)), p=0.1, training=self.training)
        x_gene_phenotype = F.dropout(F.relu(self.conv_gene_phenotype(x,edge_index_gene_phenotype)), training=self.training)
        x_drug_phenotype = F.dropout(F.relu(self.conv_drug_phenotype(x,edge_index_drug_phenotype)), training=self.training)

        x=self.bn(torch.cat([x_gene_drug,x_gene_gene,x_bait_gene,x_gene_phenotype,x_drug_phenotype],dim=1))        
        
        return self.conv_mu(x,edge_index), self.conv_logvar(x,edge_index)

In [246]:
model=VGAE(Encoder_VGAE(node_feature.shape[1], embedding_size)).to(device)
optimizer=torch.optim.Adam(model.parameters())

In [247]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index, train_pos_edge_attr)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    print(loss.item())
    
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z=model.encode(x, train_pos_edge_index,train_pos_edge_attr)
    return model.test(z, pos_edge_index, neg_edge_index)

In [248]:
#DRKG's accuracy for comparison
model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index )

(0.3808762559116542, 0.48294000002768667)

In [268]:
%%time
for epoch in range(1, 30):
    train()
    auc, ap = test(data_split.test_pos_edge_index, data_split.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))


8.304447174072266
Epoch: 001, AUC: 0.7941, AP: 0.8012
8.137896537780762
Epoch: 002, AUC: 0.7901, AP: 0.7968
7.99877405166626
Epoch: 003, AUC: 0.7859, AP: 0.7923
8.243403434753418
Epoch: 004, AUC: 0.7817, AP: 0.7878
8.236063003540039
Epoch: 005, AUC: 0.7777, AP: 0.7836
8.122429847717285
Epoch: 006, AUC: 0.7738, AP: 0.7797
8.168563842773438
Epoch: 007, AUC: 0.7701, AP: 0.7762
8.181551933288574
Epoch: 008, AUC: 0.7666, AP: 0.7728
8.180973052978516
Epoch: 009, AUC: 0.7632, AP: 0.7696
8.245285034179688
Epoch: 010, AUC: 0.7596, AP: 0.7662
8.06427001953125
Epoch: 011, AUC: 0.7563, AP: 0.7632
8.210638999938965
Epoch: 012, AUC: 0.7534, AP: 0.7605
8.157368659973145
Epoch: 013, AUC: 0.7505, AP: 0.7580
8.216042518615723
Epoch: 014, AUC: 0.7480, AP: 0.7558
8.018542289733887
Epoch: 015, AUC: 0.7449, AP: 0.7531
7.957126140594482
Epoch: 016, AUC: 0.7426, AP: 0.7512
8.137089729309082
Epoch: 017, AUC: 0.7403, AP: 0.7491
8.189958572387695
Epoch: 018, AUC: 0.7381, AP: 0.7473
8.212522506713867
Epoch: 019, 

Node embedding

In [269]:
model.eval()
z=model.encode(x, data.edge_index.to(device), data.edge_attr.to(device))
z_np = z.squeeze().detach().cpu().numpy()
z_np_DRKG = pickle.load(open(data_path+'node_feature_'+exp_id+'.pkl','rb'))

Save the new embedding 

In [270]:
pickle.dump(z_np, open(data_path+'node_embedding_'+exp_id+'.pkl', 'wb'))

Save the torch model

In [271]:
torch.save(model.state_dict(), data_path+'VAE_encoders_'+exp_id+'.pkl')

### Result 1
#### Link prediction result. Format: (AUROC, AUPRC)

In [272]:
model.load_state_dict(torch.load(data_path+'VAE_encoders_'+exp_id+'.pkl'))
model.eval()
#print(type(z))

print("link prediction for DRKG's accuracy for comparison: (AUROC, AUPRC)")
print(model.test(x,data_split.test_pos_edge_index, data_split.test_neg_edge_index ))
print("link prediction result for SARS-CoV-2 knowledge graph embedding + general embedding: (AUROC, AUPRC)")
print(model.test(z,data_split.test_pos_edge_index, data_split.test_neg_edge_index ))


<class 'torch.Tensor'>
(0.796145771132142, 0.8167289046418051)
(0.3808762559116542, 0.48294000002768667)


# Ranking model

In [273]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn import metrics
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score

In [274]:
topk=300
types=np.array([item.split('_')[0] for item in le.classes_ ])

Load drugs under clinical trial

In [275]:
#label
trials=pd.read_excel(data_path+'literature-mining/All_trails_5_24.xlsx',header=1,index_col=0)
trials_drug=set([drug.strip().upper() for lst in trials.loc[trials['study_category'].apply(lambda x: 'drug' in x.lower()),'intervention'].apply(lambda x: re.split(r'[+|/|,]',x.replace(' vs. ', '/').replace(' vs ', '/').replace(' or ', '/').replace(' with and without ', '/').replace(' /wo ', '/').replace(' /w ', '/').replace(' and ', '/').replace(' - ', '/').replace(' (', '/').replace(') ', '/'))).values for drug in lst])
drug_labels=[1 if drug.split('_')[1] in trials_drug else 0 for drug in le.classes_[types=='drug'] ]

## BPR loss NN

In [276]:
seed=70
indices = np.arange(len(drug_labels))
if DRKG:
    X_train, X_test, y_train, y_test,indices_train,indices_test=train_test_split(z_np_DRKG[types=='drug'],drug_labels,indices, test_size=0.5,random_state=seed,)
else:
    X_train, X_test, y_train, y_test,indices_train,indices_test=train_test_split(z_np[types=='drug'],drug_labels,indices, test_size=0.5,random_state=seed,)

In [277]:
#Variable wrapping for torch.tensor
_X_train, _y_train=Variable(torch.tensor(X_train,dtype=torch.float).to(device)), Variable(torch.tensor(y_train,dtype=torch.float).to(device))
_X_test, _y_test=Variable(torch.tensor(X_test,dtype=torch.float).to(device)), Variable(torch.tensor(y_test,dtype=torch.float).to(device))

In [278]:
class Classifier(nn.Module):
    def __init__(self,embedding_dim=embedding_size):
        super(Classifier, self).__init__() 
        self.fc1=nn.Linear(embedding_dim,embedding_dim)
        self.fc2=nn.Linear(embedding_dim,1)
        self.bn=nn.BatchNorm1d(embedding_dim)
    def forward(self, x):
        residual1 = x
        x = F.dropout(x, training=self.training)
        x= self.bn(F.dropout(F.relu(self.fc1(x)),training=self.training))
        x += residual1  
        return self.fc2(x)        

In [279]:
from torch.utils.data import BatchSampler, WeightedRandomSampler
class BPRLoss(nn.Module):
    def __init__(self, num_neg_samples):
        super(BPRLoss, self).__init__()
        self.num_neg_samples=num_neg_samples
    
    def forward(self, output, label):
        positive_output=output[label==1]
        negative_output=output[label!=1]
        
        #negative sample proportional to the high values
        negative_sampler=WeightedRandomSampler(negative_output-min(negative_output), num_samples=self.num_neg_samples*len(positive_output),replacement=True)
        negative_sample_output=negative_output[torch.tensor(list(BatchSampler(negative_sampler, batch_size=len(positive_output),drop_last=True)),dtype=torch.long).t()]
        return -(positive_output.view(-1,1)-negative_sample_output).sigmoid().log().mean()


In [280]:
clf=Classifier(64).to(device)
optimizer=torch.optim.Adam(clf.parameters())
criterion=BPRLoss(num_neg_samples=15)

In [281]:
best_auprc=0
for epoch in range(30):
    clf.train()
    optimizer.zero_grad()
    out = clf(_X_train)
    loss=criterion(out.squeeze(), _y_train)
    loss.backward()
    optimizer.step()   
    print('training loss',loss.item())

    clf.eval()
    print('test loss', criterion(clf(_X_test).squeeze(), _y_test).item())
    prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()
    auprc=metrics.average_precision_score(y_test,prob)
    if auprc>best_auprc:
        best_auproc=auprc
        torch.save(clf, data_path+'nn_clf-temp.pt')


training loss 0.8287684917449951
test loss 0.6613861322402954
training loss 0.7583634853363037
test loss 0.6544188261032104
training loss 0.7924999594688416
test loss 0.6489272713661194
training loss 0.657692551612854
test loss 0.6385542154312134
training loss 0.6243553161621094
test loss 0.6282044649124146
training loss 0.6163519620895386
test loss 0.6291022300720215
training loss 0.5931350588798523
test loss 0.6193060278892517
training loss 0.6124421954154968
test loss 0.6188808083534241
training loss 0.608545184135437
test loss 0.6102878451347351
training loss 0.6020726561546326
test loss 0.6002886295318604
training loss 0.5439878702163696
test loss 0.5953081846237183
training loss 0.5547744035720825
test loss 0.5803726315498352
training loss 0.45607510209083557
test loss 0.5776352286338806
training loss 0.4593876302242279
test loss 0.5610921382904053
training loss 0.529018759727478
test loss 0.5616405010223389
training loss 0.5259349942207336
test loss 0.5544975399971008
training l

In [282]:
clf.load_state_dict(torch.load(data_path+'nn_clf-temp.pt').state_dict())

<All keys matched successfully>

### Result 2&3 - AUROC and AUPRC for proposed ranking model

In [283]:
#Compute AUC
clf.eval()

prob=torch.sigmoid(clf(_X_test)).cpu().detach().numpy().squeeze()
print("AUROC", metrics.roc_auc_score(y_test,prob))
print("AUPRC", metrics.average_precision_score(y_test,prob))

AUROC 0.8593405083581407
AUPRC 0.18809515636348195


In [284]:
top_items_idx=np.argsort(-clf(torch.tensor(z_np[types=='drug'],dtype=torch.float).to(device)).squeeze().detach().cpu().numpy())

## Baseline models

### Result 2&3 - AUROC and AUPRC for baseline models

In [285]:
clf=LogisticRegression().fit(X_train,y_train)
print("Logit AUROC", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))
print("Logit AUPRC", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))

Logit AUROC 0.872319670254179
Logit AUPRC 0.21037455021954116


In [286]:
clf=GradientBoostingClassifier().fit(X_train,y_train)
print("XGBoost AUROC", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))
print("XGBoost AUPRC", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))

XGBoost AUROC 0.8278223036409436
XGBoost AUPRC 0.18680309215602448


In [287]:
clf=RandomForestClassifier().fit(X_train,y_train)
print("rf AUROC", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))
print("rf AUPRC", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))

rf AUROC 0.839335928555072
rf AUPRC 0.15397971302113533


In [288]:
clf=make_pipeline(StandardScaler(), SVC(gamma='auto',probability=True)).fit(X_train,y_train)
print("svm AUROC", roc_auc_score(y_test,clf.predict_proba(X_test)[:,1]))
print("svm AUPRC", average_precision_score(y_test,clf.predict_proba(X_test)[:,1]))

svm AUROC 0.6729608426837645
svm AUPRC 0.1743949510729016


In [222]:
#top_items_idx=np.argsort(-clf.predict_proba(z_np[types=='drug'])[:,1])

Save the high-ranked drugs into csv file

In [198]:
topk_drugs=pd.DataFrame([(rank, drug.split('_')[1]) for rank,drug in enumerate(le.inverse_transform((types=='drug').nonzero()[0][top_items_idx])[:topk+1])], columns=['rank', 'drug'])
topk_drugs['under_trials']=topk_drugs['drug'].isin(trials_drug).astype(int)
topk_drugs['is_used_in_training']=topk_drugs['drug'].isin(np.array([drug.split('_')[1] for drug in le.classes_[types=='drug']])[indices_train]).astype(int)
topk_drugs.to_csv('top300_drugs.csv')