In [1]:
import pandas as pd


DISPLAY_ALL_TEXT = False

pd.set_option("display.max_colwidth", 0 if DISPLAY_ALL_TEXT else 50)

In [2]:
### Dataset contains text and disease and drug mentions which were extracted from the DailyMed drug lables

### Dataset labelled by experts to be used as test data

In [3]:
experts = pd.read_csv('../data/input/expert_resolved_all.csv')
experts.sample()

Unnamed: 0,#,context,disease_name,drug_name,workers_answers,medical_expert1,medical_expert2,medical_expert3,do_id,drug_brand_name,drug_id,label_id,sheet,expert_consensus,inter_agree_experts,medical_expert4,Final \n(closest option)
27,30.0,Nadolol Tablets USP are indicated for the long...,HYPERTENSION,NADOLOL,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,IDK,DOID_10763,,DB01203,af5152c5-3641-4622-9295-02e9ac51b0d5.xml,Indication_treatment,Indication: Treatment,2,Indication: Treatment,


In [4]:
experts = experts.rename(columns={'drug_name':'Drug_name',
                        'disease_name':'disease', 'context':'Context',
                        'do_id':'DO_ID','drug_id':'DB_ID',
                        'label_id':'Label_ID', 'expert_consensus':'relation'})
#experts[['DB_ID', 'DO_ID','disease', 'Label_ID', 'Set_ID','relation', 'Drug_name', 'Context','Section']]


In [5]:
experts.DB_ID='Compound::'+experts.DB_ID
experts.DO_ID= 'Disease::DOID:'+experts.DO_ID.str[5:]

In [6]:
experts.head()

Unnamed: 0,#,Context,disease,Drug_name,workers_answers,medical_expert1,medical_expert2,medical_expert3,DO_ID,drug_brand_name,DB_ID,Label_ID,sheet,relation,inter_agree_experts,medical_expert4,Final \n(closest option)
0,1.0,Nisoldipine extended-release tablets are indic...,HYPERTENSION,NISOLDIPINE,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,Indication: Treatment,Disease::DOID:10763,,Compound::DB00401,ce733b39-7857-4538-92d6-1c68a4e4eb75.xml,Indication_treatment,Indication: Treatment,3,,
1,2.0,"[None, 'Bisoprolol fumarate tablets, USP are i...",HYPERTENSION,BISOPROLOL,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,Indication: Treatment,Disease::DOID:10763,,Compound::DB00612,c818c38d-ee37-4ceb-b078-611483b4f743.xml,Indication_treatment,Indication: Treatment,3,,
2,3.0,Pindolol tablets are indicated in the manageme...,HYPERTENSION,PINDOLOL,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,Indication: Treatment,Disease::DOID:10763,,Compound::DB00960,d4078b63-30ad-435b-a081-4d90562962b5.xml,Indication_treatment,Indication: Treatment,3,,
3,4.0,Eprosartan mesylate tablets are indicated for ...,HYPERTENSION,EPROSARTAN,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,Indication: Treatment,Disease::DOID:10763,,Compound::DB00876,ab531d28-345b-4113-9fde-00cd63ad5b1a.xml,Indication_treatment,Indication: Treatment,3,,
4,5.0,"[None, 'Labetalol hydrochloride tablets, USP a...",HYPERTENSION,LABETALOL,indication_treatment\nindication_treatment\nin...,Indication: Treatment,Indication: Treatment,Indication: Treatment,Disease::DOID:10763,,Compound::DB00598,78765cc3-27dc-0373-e053-2991aa0a7fe1.xml,Indication_treatment,Indication: Treatment,3,,


In [7]:
experts.relation = experts.relation =='Indication: Treatment'

In [8]:
experts.relation.sum()

109

In [9]:
df_het = pd.read_csv('../data/input/hetionet_ground_truth.csv')

In [10]:
#df_het= df_het.append(experts[['DB_ID','DO_ID']])
df_het.DB_ID='Compound::'+df_het.DB_ID
df_het.DO_ID= 'Disease::'+df_het.DO_ID
df_het.head()

Unnamed: 0,DB_ID,DO_ID
0,Compound::DB08906,Disease::DOID:2841
1,Compound::DB00282,Disease::DOID:11476
2,Compound::DB00958,Disease::DOID:184
3,Compound::DB00970,Disease::DOID:363
4,Compound::DB00530,Disease::DOID:1324


In [11]:
from sklearn.model_selection import train_test_split

In [12]:
df_het_train, df_het_test =train_test_split(df_het, shuffle=False, test_size=0.2)

In [13]:
df_het_train.shape

(604, 2)

In [14]:
indications_het =set()
experts_dict = {}
for i,row in df_het_train.iterrows():
    #print (row['a'], row['b'])
    indications_het.add((str(row['DB_ID']),str(row['DO_ID'])))
    #break
len(indications_het)

604

In [15]:
('Compound::DB08906', 'Disease::DOID:2841') in indications_het

True

In [16]:
whole_dr_di = []
for dr in df_het.DB_ID.unique():
    for di in df_het.DO_ID.unique():
        if (dr, di) in indications_het:
            #print ((dr, di))
            whole_dr_di.append([dr, di, 1])
        else:
            whole_dr_di.append([dr, di, 0])

In [17]:
df_train = pd.DataFrame(whole_dr_di, columns=['DB_ID', 'DO_ID','label'])

### Define KG Rules as Labeling Functions in Snorkel
Snorkel is data labeling framework where multiple heuristic and programatic rules can be combined to assign a label to training data 

In [18]:
# Labels to be assigned 
POSITIVE = 1 # positive that the drug indication is recommendable
NEGATIVE = 0 # negative that the drug indication is recommendable
ABSTAIN = -1 # not sure whether the drug indication is recommendable or not


### Embedding

### TransE Embeddings (Pre-Trained)

In [19]:
import numpy as np
import csv
from sklearn.metrics.pairwise import cosine_similarity

import torch.nn.functional as fn
import torch as th

node_emb = np.load('../embed/DRKG_TransE_l2_entity.npy')
relation_emb = np.load('../embed/DRKG_TransE_l2_relation.npy')

print(node_emb.shape)
print(relation_emb.shape)

(97238, 400)
(107, 400)


In [20]:
entity2id = {}
id2entity = {}
with open("../embed/entities.tsv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['entity','id'])
    for row_val in reader:
        id = row_val['id']
        entity = row_val['entity']

        entity2id[entity] = int(id)
        id2entity[int(id)] = entity

print("Number of entities: {}".format(len(entity2id)))

Number of entities: 97238


In [21]:
relation2id = {}
id2relation = {}
with open("../embed/relations.tsv", newline='', encoding='utf-8') as csvfile:
    reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['relation','id'])
    for row_val in reader:
        id = row_val['id']
        relation = row_val['relation']
        #print (relation, id)
        relation2id[relation] = int(id)
        id2relation[int(id)] = relation

print("Number of relations: {}".format(len(relation2id)))

Number of relations: 107


In [22]:
node_emb[entity2id['Compound::DB00970']][:10]

array([-0.40771222, -0.5394161 , -0.05413219, -0.4154125 , -0.2909808 ,
       -0.37861708,  0.5271472 , -0.4764343 , -0.68072844,  0.80214214],
      dtype=float32)

In [23]:
i =0
for e in  entity2id:
    if 'Compound::DB' in e or 'Disease::DOID:' in e:
        i+=1
print (i)

with open('../embed/DRGK_entity_embedding.txt','w') as fw:
    fw.write(str(i)+' '+str(node_emb.shape[1])+'\n')
    for e in  entity2id:
        if 'Compound::DB' in e or 'Disease::DOID:' in e:
            fw.write(e+' '+" ".join( [str(v) for v in  node_emb[entity2id[e]]] )+'\n')

10678


### RDF2VEC Emebdding

In [24]:
from gensim.models import KeyedVectors

In [25]:
word_vectors = KeyedVectors.load_word2vec_format('../embed/DRGK_entity_embedding.txt', binary=False)

In [26]:
word_vectors.most_similar('Disease::DOID:2841')

[('Disease::DOID:4481', 0.7617197632789612),
 ('Disease::DOID:3083', 0.688909113407135),
 ('Disease::DOID:3310', 0.63420170545578),
 ('Disease::DOID:7148', 0.5728758573532104),
 ('Disease::DOID:8893', 0.5591062307357788),
 ('Disease::DOID:9074', 0.5585140585899353),
 ('Disease::DOID:8577', 0.552219033241272),
 ('Disease::DOID:8778', 0.5235368609428406),
 ('Disease::DOID:418', 0.5231932401657104),
 ('Disease::DOID:9008', 0.5188432931900024)]

In [27]:
word_vectors.similar_by_vector(word_vectors['Compound::DB00970'])

[('Compound::DB00970', 1.0),
 ('Compound::DB00694', 0.8634293079376221),
 ('Compound::DB01204', 0.8496915102005005),
 ('Compound::DB00997', 0.8147120475769043),
 ('Compound::DB01042', 0.7971892356872559),
 ('Compound::DB00305', 0.7745863199234009),
 ('Compound::DB00987', 0.7593553066253662),
 ('Compound::DB01005', 0.7480059862136841),
 ('Compound::DB00445', 0.7352684140205383),
 ('Compound::DB01030', 0.7341449856758118)]

In [28]:
# Use Embedding model in the labeling function  

### A probablistic rule is a rule that contains uncertainty about reasoning
Based on the guilt-by-association principle (e.g. Barabasi et al., 2011;
Chiang and Butte, 2009), new disease–drug relationships can be inferred
through existing relationships between similar diseases and similar drugs.
#### Example: If drug ${dr}$ treats disease ${ds}$ and disease ${ds}$ is similar to disease ${ds'}$ then drug ${dr}$ treats disease ${ds'}$
*  ${?dr}$ CtD ${?ds}$   ^  ${?ds}$  DrD ${?ds'}$  => ${?dr}$ CtD ${?ds'}$

In [29]:
from snorkel.labeling import labeling_function
def rule1_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        r2 = relation_emb[relation2id['Hetionet::DrD::Disease:Disease']]
        a = node_emb[dr_idx]
        b = node_emb[ds_idx]
        similarEntity = word_vectors.similar_by_vector(a + r1 + r2)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 



### Translation based probablistic rule as labeling function
$if (h, r, t) => h + r = t  $, vector norm of difference between head entity plus relation and tail entity in embedding space should be zero
* we can simply assess the correctness of a triple (h, r, t) by
checking the plausibility score of the embedding vectors of h, r and t.

In [30]:
@labeling_function()
def transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r), topn=10)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

In [31]:
relation2id

{'bioarx::HumGenHumGen:Gene:Gene': 0,
 'bioarx::VirGenHumGen:Gene:Gene': 1,
 'bioarx::DrugVirGen:Compound:Gene': 2,
 'bioarx::DrugHumGen:Compound:Gene': 3,
 'bioarx::Covid2_acc_host_gene::Disease:Gene': 4,
 'bioarx::Coronavirus_ass_host_gene::Disease:Gene': 5,
 'DGIDB::INHIBITOR::Gene:Compound': 6,
 'DGIDB::ANTAGONIST::Gene:Compound': 7,
 'DGIDB::OTHER::Gene:Compound': 8,
 'DGIDB::AGONIST::Gene:Compound': 9,
 'DGIDB::BINDER::Gene:Compound': 10,
 'DGIDB::MODULATOR::Gene:Compound': 11,
 'DGIDB::BLOCKER::Gene:Compound': 12,
 'DGIDB::CHANNEL BLOCKER::Gene:Compound': 13,
 'DGIDB::ANTIBODY::Gene:Compound': 14,
 'DGIDB::POSITIVE ALLOSTERIC MODULATOR::Gene:Compound': 15,
 'DGIDB::ALLOSTERIC MODULATOR::Gene:Compound': 16,
 'DGIDB::ACTIVATOR::Gene:Compound': 17,
 'DGIDB::PARTIAL AGONIST::Gene:Compound': 18,
 'DRUGBANK::x-atc::Compound:Atc': 19,
 'DRUGBANK::ddi-interactor-in::Compound:Compound': 20,
 'DRUGBANK::target::Compound:Gene': 21,
 'DRUGBANK::enzyme::Compound:Gene': 22,
 'DRUGBANK::carrie

### Probablistic Rule using TransE Embedding:
#### Rule1: 
If drug $?dr$ treats disease $?f$ and disease $?f$ resembles disease $?ds$ then drug $?dr$  treats  $?ds$  

$$dr + r1 = f,   f + r2 = ds  => dr + r1 + r2 = ds, $$
$$r1= treats , r2= resembles$$
$$ sim = cosine (dr, ds) $$



In [32]:
@labeling_function()
def rule1_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        r2 = relation_emb[relation2id['Hetionet::DrD::Disease:Disease']]
        a = node_emb[dr_idx]
        b = node_emb[ds_idx]
        similarEntity = word_vectors.similar_by_vector(a + r1 + r2)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 2:
If drug $?dr$ binds gene $?g$ and disease $?ds$ associates gene $?g$ then drug $?dr$  treats  $?ds$  
 Compound–binds–Gene–associates–Disease => Compound–treats–Disease

In [33]:
@labeling_function()
def rule2_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CbG::Compound:Gene']]
        r2 = relation_emb[relation2id['Hetionet::DaG::Disease:Gene']]
        a = node_emb[dr_idx]
        similarEntity = word_vectors.similar_by_vector(a + r1 -r2)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 3: 
 Compound–resemble–Compound–Treat–Disease => Compound–treats–Disease

In [34]:
@labeling_function()
def rule3_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        r2 = relation_emb[relation2id['Hetionet::CrC::Compound:Compound']]
        a = node_emb[dr_idx]
        b = node_emb[ds_idx]
        similarEntity = word_vectors.similar_by_vector(a + r2 + r1)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 4
Compound-include(-1)-Pharmacologic Class-include-Compound--treat-disease

In [35]:

@labeling_function()
def rule4_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::PCiC::Pharmacologic Class:Compound']]
        r2 = relation_emb[relation2id['Hetionet::PCiC::Pharmacologic Class:Compound']]
        r3 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] -r1 + r2 +r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 


#### Rule 5
Compound-resembles-Compound-resembles-Compound-treats-Disease

In [36]:
@labeling_function()
def rule5_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CrC::Compound:Compound']]
        r2 = relation_emb[relation2id['Hetionet::CrC::Compound:Compound']]
        r3 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 + r2 +r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 6
Compound-palliates-Disease-palliates(-1)-Compound-treats-Disease


In [37]:
@labeling_function()
def rule6_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CpD::Compound:Disease']]
        r2 = relation_emb[relation2id['Hetionet::CpD::Compound:Disease']]
        r3 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 - r2 +r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 7
Compound-binds-Gene-binds(-1)-Compound-treats-Disease

In [38]:
@labeling_function()
def rule7_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CbG::Compound:Gene']]
        r2 = relation_emb[relation2id['Hetionet::CbG::Compound:Gene']]
        r3 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 - r2 +r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 8
Compound-causes-Side Effect-causes(1)-Compound-treats-Disease


In [39]:
@labeling_function()
def rule8_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CcSE::Compound:Side Effect']] 
        r2 = relation_emb[relation2id['Hetionet::CcSE::Compound:Side Effect']] 
        r3 = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 - r2 +r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 9 
Compound-resembles-Compound-binds-Gene-associates(-1)-Disease



In [40]:

@labeling_function()
def rule9_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CrC::Compound:Compound']] 
        r2 = relation_emb[relation2id['Hetionet::CbG::Compound:Gene']] 
        r3 = relation_emb[relation2id['Hetionet::DaG::Disease:Gene']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 + r2 - r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

#### Rule 10
Compound-binds-Gene-expresses(1)-Anatomy-localizes(1)-Disease

cosine_similarity(dr+r1-r2-r3, ds)

In [41]:
@labeling_function()
def rule10_transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r1 = relation_emb[relation2id['Hetionet::CbG::Compound:Gene']] 
        r2 = relation_emb[relation2id['Hetionet::AeG::Anatomy:Gene']] 
        r3 = relation_emb[relation2id['Hetionet::DlA::Disease:Anatomy']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r1 - r2 - r3))
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

In [42]:
@labeling_function()
def transe_sim(x):
    drug= x.DB_ID
    disease = x.DO_ID
    #print (drug, disease)
    # search all indications treating the same disease
    # if the drug that is similar to the one in the known indication,
    # return 'INDICATION' 
    score = 0.0
    if drug in entity2id and disease in entity2id:
        dr_idx = entity2id[drug]
        ds_idx = entity2id[disease]
        r = relation_emb[relation2id['Hetionet::CtD::Compound:Disease']]
        similarEntity = word_vectors.similar_by_vector((node_emb[dr_idx] + r), topn=10)
        for en,sim in similarEntity:
            if en == disease:
                return POSITIVE
    return ABSTAIN 

In [None]:
from snorkel.labeling import PandasLFApplier, LFAnalysis

lfs = [transe_sim, rule1_transe_sim, rule2_transe_sim,rule3_transe_sim,rule4_transe_sim, rule5_transe_sim, rule6_transe_sim,rule7_transe_sim,rule8_transe_sim,rule9_transe_sim, rule10_transe_sim]


applier = PandasLFApplier(lfs=lfs)
L_train = applier.apply(df=df_train)

  from pandas import Panel
 74%|███████▍  | 22022/29799 [03:37<01:00, 127.57it/s]

In [None]:
from snorkel.labeling.model import LabelModel

label_model = LabelModel(cardinality=2, verbose=True)
label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)

In [None]:
preds_train = label_model.predict(L=L_train)

In [None]:
probs_train =label_model.predict_proba(L=L_train)

In [None]:
probs_train[:, 1]

In [None]:
sum(probs_train[:, 1] > 0.5)

In [None]:
df_train['prob'] = probs_train[:, 1]

In [None]:
df_het_test['label'] =1

In [None]:
from snorkel.labeling import PandasLFApplier, LFAnalysis
applier = PandasLFApplier(lfs)
L_dev = applier.apply(df_het_test)
LFAnalysis(L_dev, lfs).lf_summary(df_het_test.label.values)

In [None]:
df_rank = {}
df_train.index = df_train.DB_ID
for i,drug in enumerate(df_train.DB_ID.unique()):
    x =  df_train[df_train.DB_ID==drug]
    x['Rank'] = x['prob'].rank(method='first',ascending=False)
    x.index = x.DO_ID
    df_rank[drug]= x['Rank'].to_dict()
    

In [None]:
pos_pair_dict = {}
for i,row in df_het_test.iterrows():
    if row['DB_ID'] not in pos_pair_dict:
        pos_pair_dict[row['DB_ID']] =[]
    pos_pair_dict[row['DB_ID']].append( row['DO_ID'])

In [None]:
list(pos_pair_dict.items())[:10]

In [None]:
def mrr_map_new(pos_pair_dict, df_rank):
    rank_scores = []
    ave_p = []
    in_top_5 = 0
    in_top_10 = 0
    print('number of positives: ',len(pos_pair_dict.keys()))

    for fromt in pos_pair_dict.keys():
        relevant_ranks = []
        min_rank = 9999999999999999
        for asdf in pos_pair_dict[fromt]:
            rank = df_rank[fromt][asdf]
            relevant_ranks.append(rank)
            if rank < min_rank:
                min_rank = rank
        rank = min_rank

        rank_scores.append(1.0/rank)
        if min_rank <= 5:
            in_top_5 += 1
        if min_rank <= 10:
            in_top_10 += 1

        precisions = []
        for rank in relevant_ranks:
            good_docs = len([r for r in relevant_ranks if r <= rank])
            precisions.append(good_docs/rank)
        if len(precisions) == 0:
            precisions = [0]
        ave_p.append(np.mean(precisions))

    print('mean average precision: ', np.mean(ave_p))
    print('mean reciprocal rank: ', np.mean(rank_scores))
    print('recallrate at 5: ', in_top_5/len(pos_pair_dict.keys()))
    print('recallrate at 10: ', in_top_10/len(pos_pair_dict.keys()))
    print('formatted')
    print(round(np.mean(ave_p), 3),' & ',round(np.mean(rank_scores), 3),' & ',round(in_top_5/len(pos_pair_dict.keys()), 3),' & ',round(in_top_10/len(pos_pair_dict.keys()), 3))
    return round(np.mean(ave_p), 3)

In [None]:
mrr_map_new(pos_pair_dict, df_rank)