***Using our drug repurposing model: GDRnet***

In [7]:
import torch
import torch.nn as nn
import numpy as np
import pickle
import openpyxl
import re

Defining the required functions ----------

In [8]:
def predict(net,dis_list): #list of diseases - in Disease::MESH:D###### format
  dis_batches,drug_dict = get_disease_batches(nodes_mapping,dis_list)
  dis_batches = torch.LongTensor(dis_batches)
  dictionaries_norm = []
  for i in range(len(dis_batches)):
    embed,logits = net(input_features.to(device),ax.to(device),a2x.to(device),dis_batches[i].to(device))
    probs = standardize(logits)
    dct_norm = dict ()
    for j in range(len(probs)):
      x = get_node_name(dis_batches[i,j,0].item())
      if (x in drug_dict.keys()) : 
        dct_norm[drug_dict[x][0]] = probs[j].item()
      else : 
        dct_norm[x] = probs[j].item()
    dictionaries_norm.append(dct_norm)
  return embed,dictionaries_norm

def load_variable(filename):
  return pickle.load(open(filename,'rb'))

def get_node_name(id):
  return list(nodes_mapping.keys())[list(nodes_mapping.values()).index(id)]

def get_node_id(name):
  return list(nodes_mapping.values())[list(nodes_mapping.keys()).index(name)]

def load_model_on_cpu(model,path):
  model.load_state_dict(torch.load(path,map_location=torch.device('cpu')))
  return model

def get_disease_batches(nodes_mapping,disease_list): #disease_id in the form like Disease::MESH..
  dct = get_drug_name_desc_dict()
  keys = list(nodes_mapping.keys())
  drugs = []
  batches = []
  for key in keys:
    if (re.search(r"Compound+",key,re.I)):
      if (key in dct.keys()):
        '''We can change the set of drugs here -- as in if wanna remove the withdrawn/experimental drugs''' 
        #a = dct[key][1].split(',')
        #if (not ((('experimental' in a) and (len(a)==1)) or 'withdrawn' in a)) :
        drugs.append(nodes_mapping[key])
  for disease in disease_list:
    disease_id = get_node_id(disease)
    batch = []
    for drug in drugs:
      batch.append((drug,disease_id))
    batches.append(batch)
  return batches,dct

def get_drug_name_desc_dict():
  #May need to change the path of Drug_details file accordingly
  sheet = openpyxl.load_workbook('Drug_details.xlsx').active
  dct = dict ()
  for i in range(1,sheet.max_row+1):
    dct[sheet.cell(row=i,column=1).value] = (sheet.cell(row=i,column=2).value,sheet.cell(row=i,column=3).value,sheet.cell(row=i,column=4).value)
  return dct

def standardize(t):
  mean = torch.mean(t)
  stdev = torch.std(t)
  standard_t = (t-mean)/stdev
  return standard_t

def get_rank(dct,key):
  lst = sorted(dct.items(),key=lambda t:t[1])[::-1]
  for i in range(len(lst)):
    if (key==lst[i][0]):
      break
  return i+1

Model definition / Blue print -------

In [9]:
L_Relu = nn.LeakyReLU(0.2)
sig = nn.Sigmoid()
Relu = nn.ReLU()
tanh = nn.Tanh()

class GDRnet(nn.Module):
  def __init__(self):
    super(GDRnet, self).__init__()
    decoder_dim = 250
    input_dim = 400
    r = 3
    self.theta0 = nn.Linear(input_dim,decoder_dim) 
    self.theta1 = nn.Linear(input_dim,decoder_dim)
    self.theta2 = nn.Linear(input_dim,decoder_dim)
    self.combine1 = nn.Linear(decoder_dim*r,decoder_dim) 
    self.layer8 = nn.Linear(decoder_dim,decoder_dim)
    self.layer9 = nn.Linear(decoder_dim,decoder_dim) #not used 

  def decoder(self,t,batch): 
    self.t_new = torch.empty(len(batch)).to(device)
    for i in range(len(batch)):
      self.c = torch.dot(t[batch[i,0].item()],self.layer8(t[batch[i,1].item()])).to(device) #+torch.dot(t[batch[i,1].item()],self.layer9(t[batch[i,0].item()]))).to(device)
      self.t_new[i] = self.c
    return self.t_new

  def forward(self,X,ax,a2x,batch):
    t1 = tanh(self.theta0(X))
    t2 = tanh(self.theta1(ax))
    t3 = tanh(self.theta2(a2x))
    c = torch.cat((t1,t2,t3),dim=1)
    c = L_Relu(self.combine1(c))
    t1 = self.decoder(c,batch)
    return c,t1
  

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading our pre-trained model--------


In [11]:
# give a path for all these files
input_features = load_variable("input_features.p") 
nodes_mapping = load_variable("nodes_mapping.p")
A_tilda = load_variable("A_tilda.p")
ax = A_tilda*np.array(input_features)
a2x = A_tilda*ax
a2x = torch.tensor(a2x,dtype=torch.float)
ax = torch.tensor(ax,dtype=torch.float)
empty_model = GDRnet()
net = load_model_on_cpu(empty_model,"DR_model").to(device)

We give out a list of all the 4k diseases and 8k drugs in our dataset, on which our is model is trained. We can predict the drugs for any of these diseases.

In [35]:
embeddings,drugs = predict(net,["Disease::MESH:D003920"]) #give a list of diseases in the same form as in the "Disease_list.xlsx" 
#embeddings - Our 250 dimensional node embeddings for all the entities in our graph
'''drugs - here "drugs" will be a list of dictionaries (each dict for a disease you give) with every dict
following keys = drug names and values = corresponding scores'''
print(drugs)

[{'Tavaborole': -2.1057610511779785, 'Calcium citrate': -0.4683394134044647, 'Human Varicella-Zoster Immune Globulin': 0.4643632471561432, 'Dasabuvir': -0.2541722357273102, 'Vinorelbine': 0.616840660572052, 'Dydrogesterone': -0.06728821247816086, 'Fenbufen': -0.36650538444519043, 'Glisoxepide': 1.13666832447052, 'Flubendazole': -0.010115300305187702, 'Thalidomide': 0.8006444573402405, 'Echothiophate': -2.1691439151763916, 'Valaciclovir': -0.11793768405914307, 'Fenoterol': -0.19378606975078583, 'Anakinra': -0.89090496301651, 'Tolazoline': -0.20104973018169403, 'Gadoxetic acid': -0.24765901267528534, 'Human thrombin': -1.0409824848175049, 'Olaparib': -1.4412636756896973, 'Mirabegron': 0.025839947164058685, 'Terbinafine': 1.2895681858062744, 'Pralatrexate': -1.0633047819137573, 'Amiodarone': 1.5354702472686768, 'Nicergoline': 0.5524407029151917, 'Brimonidine': -0.21688887476921082, 'Vasopressin': -0.14083033800125122, 'Bosutinib': 0.53678297996521, 'Sulfisoxazole': 0.390613317489624, 'Pro

In [33]:
#we can check the rank of any drug in our predicted list ----------
get_rank(drugs[0],"Acetylsalicylic acid")

2

In [36]:
#top 30 predicted drugs for a disease, this list is of MESH::D008288 -- Malaria
sorted(drugs[0].items(),key=lambda t:t[1])[::-1][:30]

[('Insulin lispro', 3.7640163898468018),
 ('Insulin aspart', 3.306720495223999),
 ('Insulin glargine', 3.1753931045532227),
 ('Glycyrrhizic acid', 3.064267158508301),
 ('Insulin glulisine', 3.041670560836792),
 ('Colesevelam', 3.024714231491089),
 ('Insulin human', 2.9247312545776367),
 ('Insulin detemir', 2.837725877761841),
 ('Insulin degludec', 2.747065544128418),
 ('Pramlintide', 2.6770401000976562),
 ('Lanreotide', 2.57383394241333),
 ('Cholestyramine', 2.5061159133911133),
 ('Simvastatin', 2.5053722858428955),
 ('Pravastatin', 2.4693946838378906),
 ('Insulin pork', 2.4605274200439453),
 ('Glyburide', 2.426546335220337),
 ('Miglitol', 2.395211696624756),
 ('Rosuvastatin', 2.3418991565704346),
 ('Lovastatin', 2.301260471343994),
 ('Colestipol', 2.286648750305176),
 ('Repaglinide', 2.249419927597046),
 ('Niacin', 2.2303683757781982),
 ('Liraglutide', 2.156116485595703),
 ('Pegvisomant', 2.1022632122039795),
 ('Liothyronine', 2.092202663421631),
 ('Fluvastatin', 2.0828566551208496),
