# Results analysis of the ComplEx on MIND-CtD model
* The goal of this notebook is to create a sample analysis of the model
* We will also highlight how to use some of the functions in score_utils
* Finally, we will calculate the MRR and the Hits@k, as well as extract the filtered top k results

In [1]:
import os
import pandas as pd
import score_utils as scu
os.chdir('..')

In [2]:
!bash run.sh train ComplEx MIND_CtD 0 megha 88 108 125 48.0 1.0 0.0027 1000000 16 -dr -de 

1.12.1+cu102
Start Training......
2023-04-06 17:23:07,430 INFO     Model: ComplEx
2023-04-06 17:23:07,430 INFO     Data Path: data/MIND_CtD
2023-04-06 17:23:07,430 INFO     #entity: 249605
2023-04-06 17:23:07,430 INFO     #relation: 83
2023-04-06 17:23:12,512 INFO     #train: 9651042
2023-04-06 17:23:12,513 INFO     #valid: 537
2023-04-06 17:23:12,514 INFO     #test: 537
2023-04-06 17:23:13,159 INFO     Model Parameter Configuration:
2023-04-06 17:23:13,160 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-06 17:23:13,160 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-06 17:23:13,160 INFO     Parameter entity_embedding: torch.Size([249605, 250]), require_grad = True
2023-04-06 17:23:13,160 INFO     Parameter relation_embedding: torch.Size([83, 250]), require_grad = True
2023-04-06 17:24:03,378 INFO     Ramdomly Initializing ComplEx Model...
2023-04-06 17:24:03,378 INFO     Start Training...
2023-04-06 17:24:03,378 INFO     init_s

#### Get predictions for the test file
* Run the predictions on `test.txt`.. Results should export with the `--do_predict` flag
* For `--do_test and --do_predict` flag, output file `test_scores.tsv`

In [3]:
!python -u codes/run.py --do_predict --do_test -init models/ComplEx_MIND_CtD_megha #--cuda

2023-04-07 09:08:35,782 INFO     Model: ComplEx
2023-04-07 09:08:35,782 INFO     Data Path: data/MIND_CtD
2023-04-07 09:08:35,782 INFO     #entity: 249605
2023-04-07 09:08:35,782 INFO     #relation: 83
2023-04-07 09:08:43,473 INFO     #train: 9651042
2023-04-07 09:08:43,475 INFO     #valid: 537
2023-04-07 09:08:43,476 INFO     #test: 537
2023-04-07 09:08:44,348 INFO     Model Parameter Configuration:
2023-04-07 09:08:44,348 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-07 09:08:44,349 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-07 09:08:44,349 INFO     Parameter entity_embedding: torch.Size([249605, 250]), require_grad = True
2023-04-07 09:08:44,349 INFO     Parameter relation_embedding: torch.Size([83, 250]), require_grad = True
2023-04-07 09:08:44,349 INFO     Loading checkpoint models/ComplEx_MIND_CtD_megha...
2023-04-07 09:08:55,102 INFO     Start Training...
2023-04-07 09:08:55,103 INFO     init_step = 999999
2023-04-

## Create the score input as tail-batching.
* I should of wrote the function to remove all "head-batch" entities if choosing "tail-batch" and remove all "tail-batch" when mode is "head-batch"

In [4]:
raw = scu.ProcessOutput(data_dir = './data/MIND_CtD/', scores_outfile = './models/ComplEx_MIND_CtD_megha/test_scores.tsv', mode = 'tail-batch')

In [5]:
raw.df.head()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[-3.143256425857544, -3.401657819747925, -2.03...",head-batch
1,184021,61,183664,"[-3.143256425857544, -3.401657819747925, -2.03...",head-batch
2,117007,61,27517,"[-5.423618316650391, -2.1479499340057373, -3.9...",head-batch
3,234163,61,26686,"[-6.002433776855469, -4.713563919067383, -6.56...",head-batch
4,163877,61,46731,"[-5.292157173156738, -1.3759181499481201, -5.4...",head-batch


## Generate the true answer for tail-batch.
* True answer is anything that shows up as a "t" for a combination of "h-r" in the graph
* Can also do the inverse for head-batch

In [6]:
raw.get_true_targets()

Unnamed: 0,h,r,t
37,CHEBI:10023,indication,"[DOID:0050289, MESH:D055744, MESH:C536777, DOI..."
48,CHEBI:100241,indication,"[DOID:12385, DOID:13258, DOID:13622, DOID:1465..."
60,CHEBI:100246,indication,"[DOID:1679, DOID:9700, HP:0002740, WD:Q1563040..."
72,CHEBI:10033,indication,"[DOID:0060903, HP:0001907]"
84,CHEBI:10034,indication,[HP:0001907]
...,...,...,...
469139,UNII:T4H8FMA7IM,indication,[DOID:0111157]
469160,UNII:TT6HN20MVF,indication,[DOID:6193]
469185,UNII:UMD07X179E,indication,"[DOID:0060189, DOID:0060190, DOID:0060192, DOI..."
469241,UNII:XRO4566Q4R,indication,"[DOID:2377, DOID:2378, DOID:0050783, DOID:0050..."


## Format the raw scores to embedded values
* Initial scores datframe has some value ranging from (-,+).
* uses torch function `argsort()` to sort from high to low. Highest value becomes 1, next highest 2 ... to n highest.

In [7]:
# res is in place
raw.format_raw_scores_to_df()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[168134, 247187, 249192, 82959, 140052, 248163...",head-batch
1,184021,61,183664,"[168134, 247187, 249192, 82959, 140052, 248163...",head-batch
2,117007,61,27517,"[14332, 145700, 136455, 6140, 174513, 97710, 7...",head-batch
3,234163,61,26686,"[56096, 157, 108783, 193718, 74279, 126159, 12...",head-batch
4,163877,61,46731,"[234694, 106298, 133799, 23811, 20668, 108145,...",head-batch
...,...,...,...,...,...
1069,196421,61,125397,"[145230, 160149, 30319, 187208, 46805, 117921,...",tail-batch
1070,42491,61,125397,"[11280, 206388, 107924, 66023, 39715, 120305, ...",tail-batch
1071,136067,61,219821,"[12096, 16633, 36539, 55259, 91361, 13869, 319...",tail-batch
1072,26499,61,125397,"[187236, 31431, 204355, 123043, 158966, 27279,...",tail-batch


## Now we have our embedded values, can we get the actual names?
* The conversion of embedding to values are "in-place"
* note the method has a variable `direction` where it can be "from" or "to". The default is "to", meaning (value TO embedding).

In [8]:
# in place
raw.translate_embeddings(direction = "from")

Unnamed: 0,h,r,t,preds,batch
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:95129, CHEBI:77655, MESH:C500528, CHEBI...",head-batch
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:95129, CHEBI:77655, MESH:C500528, CHEBI...",head-batch
2,CHEBI:135876,indication,DOID:11054,"[MESH:D043823, CHEBI:3962, CHEBI:59809, CHEBI:...",head-batch
3,CHEBI:135923,indication,DOID:14499,"[NCBIGene:7371, KEGG:hsa04914, CHEBI:48811, GO...",head-batch
4,CHEBI:135925,indication,DOID:1094,"[CHEBI:39462, CHEBI:4636, CHEBI:86990, CHEBI:8...",head-batch
...,...,...,...,...,...
1069,CHEBI:9667,indication,NCBIGene:367,"[DOID:184, DOID:12306, DOID:813, DOID:1526, DO...",tail-batch
1070,CHEBI:41423,indication,NCBIGene:367,"[DOID:205, DOID:14284, DOID:14276, DOID:12662,...",tail-batch
1071,CHEBI:9168,indication,NCBIGene:7490,"[DOID:5427, DOID:9261, DOID:6727, DOID:0060463...",tail-batch
1072,CHEBI:41879,indication,NCBIGene:367,"[DOID:8577, DOID:1612, DOID:3459, MESH:D015452...",tail-batch


In [9]:
tail_batch = raw.calculate_mrr()

In [10]:
tail_batch.head()

Unnamed: 0,h,r,target,preds,batch,true_t,mrr
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:95129, CHEBI:77655, MESH:C500528, CHEBI...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.000508
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:95129, CHEBI:77655, MESH:C500528, CHEBI...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.000508
2,CHEBI:135876,indication,DOID:11054,"[MESH:D043823, CHEBI:3962, CHEBI:59809, CHEBI:...",head-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...",7.8e-05
3,CHEBI:135923,indication,DOID:14499,"[NCBIGene:7371, KEGG:hsa04914, CHEBI:48811, GO...",head-batch,[DOID:14499],0.02439
4,CHEBI:135925,indication,DOID:1094,"[CHEBI:39462, CHEBI:4636, CHEBI:86990, CHEBI:8...",head-batch,"[MESH:D056912, DOID:1094]",0.000686


In [11]:
tail_batch = tail_batch.query('batch=="tail-batch"')

In [12]:
tail_batch.shape

(537, 7)

In [13]:
# Get hits at K
tb_hits = raw.calculate_individual_hits_k(hits = [1,3,10]).query('batch=="tail-batch"')
tb_hits.head()

Unnamed: 0,h,r,target,preds,batch,true_t,position,ind_rank,hits_1,hits_3,hits_10
537,CHEBI:135735,indication,DOID:10763,"[WD:Q25303605, DOID:10763, DOID:10824, DOID:64...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",2,False,True,True
538,CHEBI:135738,indication,DOID:10763,"[WD:Q25303605, UMLS:C0221155, DOID:10763, NCBI...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, ...",3,False,True,True
539,CHEBI:135876,indication,DOID:11054,"[DOID:11813, DOID:5432, DOID:11820, DOID:11809...",tail-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",17,False,False,False
540,CHEBI:135923,indication,DOID:14499,"[HP:0001945, DOID:9993, DOID:17, DOID:162, MES...",tail-batch,[DOID:14499],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",231,False,False,False
541,CHEBI:135925,indication,DOID:1094,"[DOID:1094, DOID:8670, MESH:D056912, DOID:8986...",tail-batch,"[MESH:D056912, DOID:1094]","[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,True,True,True


In [14]:
# calculate overall hits at K
print(f"""hits@1: {sum(tb_hits['hits_1'])/len(tb_hits['hits_1']):.4f}
hits@3: {sum(tb_hits['hits_3'])/len(tb_hits['hits_3']):.4f}
hits@10: {sum(tb_hits['hits_10'])/len(tb_hits['hits_10']):.4f}
""")

hits@1: 0.0428
hits@3: 0.1136
hits@10: 0.2570



In [15]:
tb_hits['preds']=tb_hits.preds.apply(lambda x: x[0:1000])

In [16]:
tb_hits.shape

(537, 11)

In [17]:
os.getcwd()

'/home/rogertu/projects/KnowledgeGraphEmbedding'

In [18]:
tb_hits.to_csv('top1000preds_complex.tsv',sep='\t',header = True, index = False)

## Generate the top n filtered results

In [49]:
raw.filter_predictions(top = 100).query('batch=="tail-batch"').head()

Unnamed: 0,h,r,preds,batch,true_t,filt_preds
527,CHEBI:34385,indication_CiD,"[DOID:2377, CHEBI:34385, DOID:8869, DOID:2378,...",tail-batch,"[DOID:2378, DOID:2377, DOID:0050784, DOID:0050...","[CHEBI:34385, DOID:8869, DOID:3393, DOID:1824,..."
528,CHEBI:34829,indication_CiD,"[DOID:2474, DOID:2457, DOID:11204, DOID:13452,...",tail-batch,"[DOID:11204, DOID:2457, DOID:2474]","[DOID:13452, CHEBI:34829, DOID:8881, DOID:9383..."
529,CHEBI:3441,indication_CiD,"[DOID:3393, DOID:10763, DOID:4248, DOID:6000, ...",tail-batch,"[DOID:10591, DOID:6432, DOID:11130, DOID:10824...","[DOID:3393, DOID:4248, DOID:6000, DOID:5844, D..."
530,CHEBI:34730,indication_CiD,"[DOID:11729, DOID:3482, DOID:0040083, DOID:004...",tail-batch,"[DOID:11104, DOID:13034, DOID:13035, DOID:1327...","[DOID:11729, DOID:3482, DOID:0040083, DOID:004..."
531,CHEBI:3437,indication_CiD,"[DOID:916, MESH:D056486, MONDO:0005043, MESH:D...",tail-batch,"[DOID:6432, DOID:13544, DOID:11130, DOID:10825...","[DOID:916, MESH:D056486, MONDO:0005043, MESH:D..."
