# Results analysis of the RotatE on MIND-CtD model
* The goal of this notebook is to create a sample analysis of the model
* We will also highlight how to use some of the functions in score_utils
* Finally, we will calculate the MRR and the Hits@k, as well as extract the filtered top k results

In [1]:
import os
import pandas as pd
import score_utils as scu
os.chdir('..')

In [2]:
!bash run.sh train RotatE MIND_CtD 0 megha 244 108 250 48.0 1.0 0.0001 1000000 16 -de

1.12.1+cu102
Start Training......
2023-04-03 15:24:05,327 INFO     Model: RotatE
2023-04-03 15:24:05,327 INFO     Data Path: data/MIND_CtD
2023-04-03 15:24:05,327 INFO     #entity: 249605
2023-04-03 15:24:05,327 INFO     #relation: 83
2023-04-03 15:24:10,443 INFO     #train: 9651042
2023-04-03 15:24:10,445 INFO     #valid: 537
2023-04-03 15:24:10,445 INFO     #test: 537
2023-04-03 15:24:11,355 INFO     Model Parameter Configuration:
2023-04-03 15:24:11,355 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-03 15:24:11,356 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-03 15:24:11,356 INFO     Parameter entity_embedding: torch.Size([249605, 500]), require_grad = True
2023-04-03 15:24:11,356 INFO     Parameter relation_embedding: torch.Size([83, 250]), require_grad = True
2023-04-03 15:25:01,884 INFO     Ramdomly Initializing RotatE Model...
2023-04-03 15:25:01,885 INFO     Start Training...
2023-04-03 15:25:01,885 INFO     init_ste

#### Get predictions for the test file
* Run the predictions on `test.txt`.. Results should export with the `--do_predict` flag
* For `--do_test and --do_predict` flag, output file `test_scores.tsv`

In [3]:
!python -u codes/run.py --do_predict --do_test -init models/RotatE_MIND_CtD_megha #--cuda

2023-04-04 09:31:36,975 INFO     Model: RotatE
2023-04-04 09:31:36,975 INFO     Data Path: data/MIND_CtD
2023-04-04 09:31:36,975 INFO     #entity: 249605
2023-04-04 09:31:36,975 INFO     #relation: 83
2023-04-04 09:31:42,069 INFO     #train: 9651042
2023-04-04 09:31:42,071 INFO     #valid: 537
2023-04-04 09:31:42,071 INFO     #test: 537
2023-04-04 09:31:43,002 INFO     Model Parameter Configuration:
2023-04-04 09:31:43,003 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-04 09:31:43,003 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-04 09:31:43,003 INFO     Parameter entity_embedding: torch.Size([249605, 500]), require_grad = True
2023-04-04 09:31:43,003 INFO     Parameter relation_embedding: torch.Size([83, 250]), require_grad = True
2023-04-04 09:31:43,003 INFO     Loading checkpoint models/RotatE_MIND_CtD_megha...
2023-04-04 09:31:59,026 INFO     Start Training...
2023-04-04 09:31:59,026 INFO     init_step = 999999
2023-04-04

## Create the score input as tail-batching.
* I should of wrote the function to remove all "head-batch" entities if choosing "tail-batch" and remove all "tail-batch" when mode is "head-batch"

In [4]:
raw = scu.ProcessOutput(data_dir = './data/MIND_CtD/', scores_outfile = './models/RotatE_MIND_CtD_megha/test_scores.tsv', mode = 'tail-batch')

In [5]:
raw.df.head()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[-8.605365753173828, -6.895137786865234, -8.77...",head-batch
1,184021,61,183664,"[-8.605365753173828, -6.895137786865234, -8.77...",head-batch
2,117007,61,27517,"[-8.586658477783203, -5.259132385253906, -9.13...",head-batch
3,234163,61,26686,"[-9.565330505371094, -3.085784912109375, -11.0...",head-batch
4,163877,61,46731,"[-11.458580017089844, -5.571342468261719, -11....",head-batch


## Generate the true answer for tail-batch.
* True answer is anything that shows up as a "t" for a combination of "h-r" in the graph
* Can also do the inverse for head-batch

In [6]:
raw.get_true_targets()

Unnamed: 0,h,r,t
0,CHEBI:100,activates_CaG,[NCBIGene:2100]
1,CHEBI:100,affects_CafG,[NCBIGene:2100]
2,CHEBI:100,part_of_CpoBP,"[GO:0009701, GO:0046289, GO:0046290]"
3,CHEBI:100,treats_CtD,"[DOID:883, DOID:11476, DOID:8692]"
4,CHEBI:10001,palliates_CplD,[DOID:11968]
...,...,...,...
612847,WP:WP75,associated_with_PWawD,"[DOID:0070315, DOID:0080169, DOID:11950, DOID:..."
612848,WP:WP75,associated_with_PWawP,[HP:0010774]
612849,WP:WP78,associated_with_PWawD,"[DOID:0050575, DOID:0050771, DOID:0050773, DOI..."
612850,WP:WP80,associated_with_PWawD,[DOID:0110705]


## Format the raw scores to embedded values
* Initial scores datframe has some value ranging from (-,+).
* uses torch function `argsort()` to sort from high to low. Highest value becomes 1, next highest 2 ... to n highest.

In [7]:
# res is in place
raw.format_raw_scores_to_df()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[35633, 59815, 227221, 141546, 224330, 128485,...",head-batch
1,184021,61,183664,"[35633, 59815, 227221, 141546, 224330, 128485,...",head-batch
2,117007,61,27517,"[42404, 90983, 86089, 99426, 136492, 136972, 7...",head-batch
3,234163,61,26686,"[99705, 185976, 179500, 173413, 226408, 182300...",head-batch
4,163877,61,46731,"[156059, 27283, 128485, 58030, 124226, 98284, ...",head-batch
...,...,...,...,...,...
1069,196421,61,125397,"[178458, 127022, 173467, 218398, 90254, 215027...",tail-batch
1070,42491,61,125397,"[71476, 96052, 201633, 45059, 43223, 123436, 2...",tail-batch
1071,136067,61,219821,"[229633, 63762, 218968, 113398, 233910, 106190...",tail-batch
1072,26499,61,125397,"[51624, 58633, 2930, 15192, 200745, 199198, 84...",tail-batch


## Now we have our embedded values, can we get the actual names?
* The conversion of embedding to values are "in-place"
* note the method has a variable `direction` where it can be "from" or "to". The default is "to", meaning (value TO embedding).

In [8]:
# in place
raw.translate_embeddings(direction = "from")

Unnamed: 0,h,r,t,preds,batch
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:31556, MESH:D010936, CHEBI:8229, MESH:D...",head-batch
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:31556, MESH:D010936, CHEBI:8229, MESH:D...",head-batch
2,CHEBI:135876,indication,DOID:11054,"[CHEBI:31522, CHEBI:30616, CHEBI:50172, NCBITa...",head-batch
3,CHEBI:135923,indication,DOID:14499,"[GO:0003676, GO:0005525, GO:0003723, GO:000028...",head-batch
4,CHEBI:135925,indication,DOID:1094,"[MESH:C544151, CHEBI:33216, CHEBI:29081, CHEBI...",head-batch
...,...,...,...,...,...
1069,CHEBI:9667,indication,NCBIGene:367,"[DOID:1205, DOID:0060056, DOID:50185, DOID:605...",tail-batch
1070,CHEBI:41423,indication,NCBIGene:367,"[MESH:D056486, DOID:3042, DOID:37, KEGG:hsa052...",tail-batch
1071,CHEBI:9168,indication,NCBIGene:7490,"[DOID:60164, DOID:4556, MESH:D002471, DOID:368...",tail-batch
1072,CHEBI:41879,indication,NCBIGene:367,"[DOID:706, DOID:2841, DOID:1555, DOID:1686, DO...",tail-batch


In [9]:
tail_batch = raw.calculate_mrr()

In [10]:
tail_batch.head()

Unnamed: 0,h,r,target,preds,batch,true_t,mrr
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:31556, MESH:D010936, CHEBI:8229, MESH:D...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.01721
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:31556, MESH:D010936, CHEBI:8229, MESH:D...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.017158
2,CHEBI:135876,indication,DOID:11054,"[CHEBI:31522, CHEBI:30616, CHEBI:50172, NCBITa...",head-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...",0.006566
3,CHEBI:135923,indication,DOID:14499,"[GO:0003676, GO:0005525, GO:0003723, GO:000028...",head-batch,[DOID:14499],0.018868
4,CHEBI:135925,indication,DOID:1094,"[MESH:C544151, CHEBI:33216, CHEBI:29081, CHEBI...",head-batch,"[MESH:D056912, DOID:1094]",0.010904


In [11]:
tail_batch = tail_batch.query('batch=="tail-batch"')

In [12]:
tail_batch.shape

(537, 7)

In [13]:
# Get hits at K
tb_hits = raw.calculate_individual_hits_k(hits = [1,3,10]).query('batch=="tail-batch"')
tb_hits.head()

Unnamed: 0,h,r,target,preds,batch,true_t,position,ind_rank,hits_1,hits_3,hits_10
537,CHEBI:135735,indication,DOID:10763,"[WD:Q25303605, UMLS:C0221155, DOID:10763, DOID...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...",3,False,True,True
538,CHEBI:135738,indication,DOID:10763,"[DOID:10763, DOID:60164, DOID:11130, DOID:6432...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...",1,True,True,True
539,CHEBI:135876,indication,DOID:11054,"[DOID:11817, DOID:5432, DOID:11820, DOID:5429,...",tail-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",16,False,False,False
540,CHEBI:135923,indication,DOID:14499,"[DOID:11476, DOID:10763, DOID:0060343, MESH:D0...",tail-batch,[DOID:14499],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1842,False,False,False
541,CHEBI:135925,indication,DOID:1094,"[DOID:12858, DOID:10935, DOID:10825, HP:000218...",tail-batch,"[MESH:D056912, DOID:1094]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",778,False,False,False


In [14]:
# calculate overall hits at K
print(f"""hits@1: {sum(tb_hits['hits_1'])/len(tb_hits['hits_1']):.4f}
hits@3: {sum(tb_hits['hits_3'])/len(tb_hits['hits_3']):.4f}
hits@10: {sum(tb_hits['hits_10'])/len(tb_hits['hits_10']):.4f}
""")

hits@1: 0.1322
hits@3: 0.2142
hits@10: 0.3426



In [15]:
tb_hits['preds']=tb_hits.preds.apply(lambda x: x[0:1000])

In [16]:
tb_hits.shape

(537, 11)

In [17]:
tb_hits.to_csv('top1000preds_rotate.tsv',sep='\t',header = True, index = False)

## Generate the top n filtered results

In [49]:
raw.filter_predictions(top = 100).query('batch=="tail-batch"').head()

Unnamed: 0,h,r,preds,batch,true_t,filt_preds
527,CHEBI:34385,indication_CiD,"[DOID:2377, CHEBI:34385, DOID:8869, DOID:2378,...",tail-batch,"[DOID:2378, DOID:2377, DOID:0050784, DOID:0050...","[CHEBI:34385, DOID:8869, DOID:3393, DOID:1824,..."
528,CHEBI:34829,indication_CiD,"[DOID:2474, DOID:2457, DOID:11204, DOID:13452,...",tail-batch,"[DOID:11204, DOID:2457, DOID:2474]","[DOID:13452, CHEBI:34829, DOID:8881, DOID:9383..."
529,CHEBI:3441,indication_CiD,"[DOID:3393, DOID:10763, DOID:4248, DOID:6000, ...",tail-batch,"[DOID:10591, DOID:6432, DOID:11130, DOID:10824...","[DOID:3393, DOID:4248, DOID:6000, DOID:5844, D..."
530,CHEBI:34730,indication_CiD,"[DOID:11729, DOID:3482, DOID:0040083, DOID:004...",tail-batch,"[DOID:11104, DOID:13034, DOID:13035, DOID:1327...","[DOID:11729, DOID:3482, DOID:0040083, DOID:004..."
531,CHEBI:3437,indication_CiD,"[DOID:916, MESH:D056486, MONDO:0005043, MESH:D...",tail-batch,"[DOID:6432, DOID:13544, DOID:11130, DOID:10825...","[DOID:916, MESH:D056486, MONDO:0005043, MESH:D..."
