In [1]:
# AD_drug_repurposing
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 4, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 4, 2023
#
# 该脚本展示了如何使用 DRKG 的预训练模型 (TransE) 进行药物重定位 (Alzheimer's disease).
#
# 需要的包:
#          csv
#          numpy
#          torch
#
# 需要的文件:
#          ./prerequisites/infer_drug.tsv -> Drugbank 中的 FDA 批准的药物
#          ./prerequisites/ad_drugs.txt
#          ../../data/drkg/entities.tsv
#          ../../data/drkg/relations.tsv
#          ../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy
#          ../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_relation.npy

# Alzheimer's disease Drug Repurposing via disease-compounds relations
这个例子展示了如何使用 **DRKG** 的预训练模型进行药物重定位.

## Collecting Alzheimer's disease

一开始我们需要收集 DRKG 中的 AD 列表. 我们能够使用 DRKG 中的疾病 ID 编码疾病. 下面我们将全部的 AD 疾病作为目标. 

In [2]:
AD_disease_list = [
'Disease::DOID:10652',
'Disease::MESH:C536599',
'Disease::MESH:D000544'
]

In [3]:
len(AD_disease_list)

3

In [4]:
AD_disease_list

['Disease::DOID:10652', 'Disease::MESH:C536599', 'Disease::MESH:D000544']

## Candidate drugs

现在我们使用 Drugbank 中的 FDA 批准的药物作为候选药物.（我们排除分子量 < 250 的药物）药物清单在 infer\_drug.tsv 中.

In [5]:
import csv

# Load entity file
drug_list = []
with open("./prerequisites/infer_drug.tsv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['drug','ids'])
    for row_val in reader:
        drug_list.append(row_val['drug'])

In [6]:
len(drug_list)

8104

In [7]:
drug_list[:3]

['Compound::DB00605', 'Compound::DB00983', 'Compound::DB01240']

## Treatment relation

In [8]:
treatment_list = [
'DRUGBANK::treats::Compound:Disease',
'GNBR::T::Compound:Disease',
'Hetionet::CtD::Compound:Disease'
]

In [9]:
treatment_list

['DRUGBANK::treats::Compound:Disease',
 'GNBR::T::Compound:Disease',
 'Hetionet::CtD::Compound:Disease']

## DRKG 中 AD 的治疗药物

In [10]:
ad_drugs = []

with open("./prerequisites/ad_drugs.txt", encoding='utf-8') as f:
    for drug in f:
        ad_drugs.append(drug[:-1])

In [11]:
len(ad_drugs)

126

In [12]:
ad_drugs

['Compound::DB13588',
 'Compound::DB09130',
 'Compound::DB00244',
 'Compound::DB04815',
 'Compound::DB03994',
 'Compound::DB00712',
 'Compound::DB00158',
 'Compound::DB00472',
 'Compound::DB00331',
 'Compound::DB01593',
 'Compound::DB11100',
 'Compound::DB03128',
 'Compound::DB06712',
 'Compound::DB00993',
 'Compound::DB00877',
 'Compound::DB11094',
 'Compound::DB04115',
 'Compound::DB11805',
 'Compound::DB00843',
 'Compound::DB00928',
 'Compound::DB11068',
 'Compound::DB00694',
 'Compound::DB11748',
 'Compound::DB14500',
 'Compound::DB00382',
 'Compound::DB00393',
 'Compound::DB00490',
 'Compound::DB00215',
 'Compound::DB03843',
 'Compound::DB06756',
 'Compound::DB05289',
 'Compound::DB00564',
 'Compound::DB05381',
 'Compound::DB01235',
 'Compound::DB11780',
 'Compound::DB00945',
 'Compound::DB00571',
 'Compound::DB00674',
 'Compound::DB01234',
 'Compound::DB00134',
 'Compound::DB08842',
 'Compound::DB00136',
 'Compound::DB00368',
 'Compound::DB12310',
 'Compound::DB00682',
 'Compound

## Get pretrained model

我们能直接使用预训练模型做药物重定位.

In [13]:
entity_id_file = '../../data/drkg/entities.tsv'
relation_id_file = '../../data/drkg/relations.tsv'

## Get embeddings for diseases and drugs

In [14]:
# Get drugname/disease name to entity ID mappings
entity_map = {}
entity_id_map = {}
relation_map = {}
with open(entity_id_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
    for row_val in reader:
        entity_map[row_val['name']] = int(row_val['id'])
        entity_id_map[int(row_val['id'])] = row_val['name']
        
with open(relation_id_file, newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
    for row_val in reader:
        relation_map[row_val['name']] = int(row_val['id'])
        
# handle the ID mapping
drug_ids = []
disease_ids = []
for drug in drug_list:
    drug_ids.append(entity_map[drug])
    
for disease in AD_disease_list:
    disease_ids.append(entity_map[disease])

treatment_ids = [relation_map[treat]  for treat in treatment_list]

In [15]:
len(disease_ids),len(drug_ids),len(treatment_ids)

(3, 8104, 3)

In [16]:
disease_ids, drug_ids[:3], treatment_ids

([83488, 45467, 25842], [9475, 11010, 7486], [24, 35, 68])

In [17]:
# Load embeddings
import torch
import numpy as np
entity_emb = np.load('../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy')
rel_emb = np.load('../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_relation.npy')

drug_ids = torch.tensor(drug_ids).long()
disease_ids = torch.tensor(disease_ids).long()
treatment_ids = torch.tensor(treatment_ids)

drug_emb = torch.tensor(entity_emb[drug_ids])
treatment_embs = [torch.tensor(rel_emb[r_id]) for r_id in treatment_ids]

In [18]:
disease_ids, drug_ids[:3], treatment_ids

(tensor([83488, 45467, 25842]),
 tensor([ 9475, 11010,  7486]),
 tensor([24, 35, 68]))

In [19]:
drug_emb.shape

torch.Size([8104, 400])

## Drug Repurposing Based on Edge Score

我们使用下面算法来计算 the edge score. 注意, 这里我们用 log(Sigmoid) 函数使全部 scores < 0. 分数越大, $(h, r, t)$ 越可能成立.

$\mathbf{d} = \gamma - ||\mathbf{h}\circ \mathbf{r}-\mathbf{t}||_{2}$

$\mathbf{score} = \log\left(\frac{1}{1+\exp(\mathbf{-d})}\right)$

在进行药物重定位时，我们只使用与治疗相关的关系.

DGL-KE 官网实现 RotatE 评分函数的代码在:

- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/ke_model.py 940 - 955 行

- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/pytorch/score_fun.py 460 - 472 行

OpenKE 实现 RotaE 评分函数的代码在:

- https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/RotatE.py

其中两者的实现代码基本一样.

epsilon 两者都假定值为 2.

- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/ke_model.py 53 行

- https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/examples/train_rotate_WN18RR_adv.py 29 行

In [20]:
import torch.nn.functional as fn

gamma = 18.0
hidden_dim =  200
epsilon = 2.0
emb_init = (gamma + epsilon) / hidden_dim

def RotatE_score(head, rel, tail):
    re_head, im_head = torch.chunk(head, 2, dim=-1)
    re_tail, im_tail = torch.chunk(tail, 2, dim=-1)
    phase_rel = rel / (emb_init / np.pi)
    re_rel, im_rel = torch.cos(phase_rel), torch.sin(phase_rel)
    re_score = re_head * re_rel - im_head * im_rel
    im_score = re_head * im_rel + im_head * re_rel
    re_score = re_score - re_tail
    im_score = im_score - im_tail
    score = torch.stack([re_score, im_score], dim=0)
    score = score.norm(dim=0)
    return gamma - score.sum(-1)

In [21]:
scores_per_disease = []
drugs = []
for r_id in range(len(treatment_embs)):
    treatment_emb = treatment_embs[r_id]
    for disease_id in disease_ids:
        disease_emb = torch.tensor(entity_emb[disease_id])
        score = fn.logsigmoid(RotatE_score(drug_emb, treatment_emb, disease_emb))
        scores_per_disease.append(score)
        drugs.append(drug_ids)
scores = torch.cat(scores_per_disease)
drugs = torch.cat(drugs)

In [22]:
scores.shape, drugs.shape, 3*3*8104

(torch.Size([72936]), torch.Size([72936]), 72936)

In [23]:
# sort scores in decending order
# torch.flip: Reverse the order of a n-D tensor along given axis in dims.
idx = torch.flip(torch.argsort(scores), dims=[0])
scores = scores[idx].numpy()
drugs = drugs[idx].numpy()

scores.shape, drugs.shape, 3*3*8104

((72936,), (72936,), 72936)

### Now we output proposed treatments

In [24]:
_, unique_indices = np.unique(drugs, return_index=True)
topk = 100
topk_indices = np.sort(unique_indices)[:topk]
proposed_drugs = drugs[topk_indices]
proposed_scores = scores[topk_indices]

In [25]:
for i in range(topk):
    drug = int(proposed_drugs[i])
    if entity_id_map[drug] not in ad_drugs:
        score = proposed_scores[i]
        print("[{}]\t{}\t{}".format(i+1, entity_id_map[drug], score))

[9]	Compound::DB00143	-0.1855190545320511
[11]	Compound::DB00502	-0.19982466101646423
[13]	Compound::DB06774	-0.2012828290462494
[16]	Compound::DB04216	-0.20350724458694458
[17]	Compound::DB00783	-0.2035258710384369
[18]	Compound::DB09341	-0.20749205350875854
[20]	Compound::DB00822	-0.21294096112251282
[21]	Compound::DB00640	-0.22600767016410828
[23]	Compound::DB01105	-0.23136001825332642
[29]	Compound::DB00715	-0.26189476251602173
[31]	Compound::DB00907	-0.2680314779281616
[39]	Compound::DB01229	-0.28694769740104675
[41]	Compound::DB04540	-0.2887122929096222
[43]	Compound::DB01016	-0.29204732179641724
[44]	Compound::DB02010	-0.2982581555843353
[46]	Compound::DB14681	-0.29944315552711487
[48]	Compound::DB00321	-0.30107274651527405
[55]	Compound::DB00714	-0.3295489549636841
[56]	Compound::DB00297	-0.33073094487190247
[57]	Compound::DB00806	-0.33394524455070496
[61]	Compound::DB00704	-0.34633323550224304
[62]	Compound::DB00159	-0.34660911560058594
[63]	Compound::DB00541	-0.34707534313201