In [1]:
# 02_TransE_l2_AD_drug_repurposing
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on January 8, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 11, 2023
#
# 该脚本展示了如何使用我们的预训练模型 (TransE_l2) 进行药物重定位 (Alzheimer's disease).
#
# 需要的包:
#          csv
#          numpy
#          torch
#
# 需要的文件:
#          ./prerequisites/infer_drug.tsv -> Drugbank 中的 FDA 批准的药物
#          ./prerequisites/ad_drugs.txt
#          ../../data/drkg/entities.tsv
#          ../../data/drkg/relations.tsv
#          ../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_entity.npy
#          ../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_relation.npy

# Alzheimer's disease Drug Repurposing via disease-compounds relations

这个例子展示了如何使用我们的预训练模型 (TransE_l2)进行药物重定位.

## Collecting Alzheimer's disease

一开始我们需要收集 DRKG 中的 AD 列表. 我们能够使用 DRKG 中的疾病 ID 编码疾病. 下面我们将全部的 AD 疾病作为目标. 

In [2]:
AD_disease_list = [
'Disease::DOID:10652',
'Disease::MESH:C536599',
'Disease::MESH:D000544'
]

In [3]:
len(AD_disease_list)

3

In [4]:
AD_disease_list

['Disease::DOID:10652', 'Disease::MESH:C536599', 'Disease::MESH:D000544']

## Candidate drugs

现在我们使用 Drugbank 中的 FDA 批准的药物作为候选药物.（我们排除分子量 < 250 的药物）药物清单在 infer\_drug.tsv 中.

In [5]:
import csv

# Load entity file
drug_list = []
with open("./prerequisites/infer_drug.tsv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['drug','ids'])
    for row_val in reader:
        drug_list.append(row_val['drug'])

In [6]:
len(drug_list)

8104

In [7]:
drug_list[:3]

['Compound::DB00605', 'Compound::DB00983', 'Compound::DB01240']

## Treatment relation

In [8]:
treatment_list = [
'DRUGBANK::treats::Compound:Disease',
'GNBR::T::Compound:Disease',
'Hetionet::CtD::Compound:Disease'
]

In [9]:
treatment_list

['DRUGBANK::treats::Compound:Disease',
 'GNBR::T::Compound:Disease',
 'Hetionet::CtD::Compound:Disease']

## DRKG 中 AD 的治疗药物

In [10]:
ad_drugs = []

with open("./prerequisites/ad_drugs.txt", encoding='utf-8') as f:
    for drug in f:
        ad_drugs.append(drug[:-1])

In [11]:
len(ad_drugs)

126

In [12]:
ad_drugs

['Compound::DB13588',
 'Compound::DB09130',
 'Compound::DB00244',
 'Compound::DB04815',
 'Compound::DB03994',
 'Compound::DB00712',
 'Compound::DB00158',
 'Compound::DB00472',
 'Compound::DB00331',
 'Compound::DB01593',
 'Compound::DB11100',
 'Compound::DB03128',
 'Compound::DB06712',
 'Compound::DB00993',
 'Compound::DB00877',
 'Compound::DB11094',
 'Compound::DB04115',
 'Compound::DB11805',
 'Compound::DB00843',
 'Compound::DB00928',
 'Compound::DB11068',
 'Compound::DB00694',
 'Compound::DB11748',
 'Compound::DB14500',
 'Compound::DB00382',
 'Compound::DB00393',
 'Compound::DB00490',
 'Compound::DB00215',
 'Compound::DB03843',
 'Compound::DB06756',
 'Compound::DB05289',
 'Compound::DB00564',
 'Compound::DB05381',
 'Compound::DB01235',
 'Compound::DB11780',
 'Compound::DB00945',
 'Compound::DB00571',
 'Compound::DB00674',
 'Compound::DB01234',
 'Compound::DB00134',
 'Compound::DB08842',
 'Compound::DB00136',
 'Compound::DB00368',
 'Compound::DB12310',
 'Compound::DB00682',
 'Compound

## Get pretrained model

我们能直接使用我们的预训练模型 (TransE_l2)做药物重定位.

In [13]:
entity_id_file = '../../data/drkg/entities.tsv'
relation_id_file = '../../data/drkg/relations.tsv'

## Get embeddings for diseases and drugs

In [14]:
# Get drugname/disease name to entity ID mappings
entity_map = {}
entity_id_map = {}
relation_map = {}
with open(entity_id_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
    for row_val in reader:
        entity_map[row_val['name']] = int(row_val['id'])
        entity_id_map[int(row_val['id'])] = row_val['name']
        
with open(relation_id_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
    for row_val in reader:
        relation_map[row_val['name']] = int(row_val['id'])
        
# handle the ID mapping
drug_ids = []
disease_ids = []
for drug in drug_list:
    drug_ids.append(entity_map[drug])
    
for disease in AD_disease_list:
    disease_ids.append(entity_map[disease])

treatment_ids = [relation_map[treat]  for treat in treatment_list]

In [15]:
len(disease_ids),len(drug_ids),len(treatment_ids)

(3, 8104, 3)

In [16]:
disease_ids, drug_ids[:3], treatment_ids

([83488, 45467, 25842], [9475, 11010, 7486], [24, 35, 68])

In [17]:
# Load embeddings
import torch
import numpy as np
entity_emb = np.load('../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_entity.npy')
rel_emb = np.load('../01-model/ckpts/TransE_l2_All_DRKG_0/All_DRKG_TransE_l2_relation.npy')

drug_ids = torch.tensor(drug_ids).long()
disease_ids = torch.tensor(disease_ids).long()
treatment_ids = torch.tensor(treatment_ids)

drug_emb = torch.tensor(entity_emb[drug_ids])
treatment_embs = [torch.tensor(rel_emb[r_id]) for r_id in treatment_ids]

In [18]:
disease_ids, drug_ids[:3], treatment_ids

(tensor([83488, 45467, 25842]),
 tensor([ 9475, 11010,  7486]),
 tensor([24, 35, 68]))

In [19]:
drug_emb.shape

torch.Size([8104, 400])

## Drug Repurposing Based on Edge Score

我们使用下面算法来计算 the edge score. 注意, 这里我们用 logsigmoid 函数使全部 scores < 0. 分数越大, $(h, r, t)$ 越可能成立.

$\mathbf{d} = \gamma - ||\mathbf{h}+\mathbf{r}-\mathbf{t}||_{2}$

$\mathbf{score} = \log\left(\frac{1}{1+\exp(\mathbf{-d})}\right)$

在进行药物重定位时，我们只使用与治疗相关的关系.

In [20]:
import torch.nn.functional as fn

gamma = 12.0
def transE_l2(head, rel, tail):
    score = head + rel - tail
    return gamma - torch.norm(score, p=2, dim=-1)

scores_per_disease = []
drugs = []
for r_id in range(len(treatment_embs)):
    treatment_emb = treatment_embs[r_id]
    for disease_id in disease_ids:
        disease_emb = entity_emb[disease_id]
        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))
        scores_per_disease.append(score)
        drugs.append(drug_ids)
scores = torch.cat(scores_per_disease)
drugs = torch.cat(drugs)

In [21]:
scores.shape, drugs.shape, 3*3*8104

(torch.Size([72936]), torch.Size([72936]), 72936)

In [22]:
# sort scores in decending order
# torch.flip: Reverse the order of a n-D tensor along given axis in dims.
idx = torch.flip(torch.argsort(scores), dims=[0])
scores = scores[idx].numpy()
drugs = drugs[idx].numpy()

scores.shape, drugs.shape, 3*3*8104

((72936,), (72936,), 72936)

### Now we output proposed treatments

In [23]:
_, unique_indices = np.unique(drugs, return_index=True)
topk = 50
topk_indices = np.sort(unique_indices)[:topk]
proposed_drugs = drugs[topk_indices]
proposed_scores = scores[topk_indices]

In [24]:
top50_list = []

for i in range(topk):
    drug = int(proposed_drugs[i])
    if entity_id_map[drug] not in ad_drugs:
        score = proposed_scores[i]
        row = "[{}],{},{}\n".format(i+1, entity_id_map[drug], score)
        top50_list.append(row)
        print("[{}]\t{}\t{}".format(i+1, entity_id_map[drug], score))

[6]	Compound::DB04540	-0.122935451567173
[7]	Compound::DB09341	-0.12512117624282837
[8]	Compound::DB00143	-0.12561222910881042
[10]	Compound::DB00515	-0.1341310441493988
[14]	Compound::DB00997	-0.14285585284233093
[15]	Compound::DB00171	-0.1469261348247528
[16]	Compound::DB01229	-0.14909186959266663
[17]	Compound::DB00477	-0.15069124102592468
[19]	Compound::DB00755	-0.15511256456375122
[24]	Compound::DB00502	-0.16119839251041412
[26]	Compound::DB00783	-0.1685544103384018
[27]	Compound::DB00295	-0.16982416808605194
[28]	Compound::DB00661	-0.17014111578464508
[34]	Compound::DB00675	-0.17270652949810028
[35]	Compound::DB00624	-0.17352452874183655
[36]	Compound::DB00363	-0.17481908202171326
[37]	Compound::DB12153	-0.17557311058044434
[39]	Compound::DB01708	-0.17726033926010132
[40]	Compound::DB00541	-0.17793041467666626
[41]	Compound::DB00959	-0.179716095328331
[44]	Compound::DB00396	-0.18199630081653595
[45]	Compound::DB00907	-0.18305420875549316
[46]	Compound::DB04216	-0.1852144002914428

In [25]:
# 输出结果到文件

f = open('./results/02_transE_l2_top50.csv', 'w')
for row in top50_list:
    f.write(row)
f.close()