# Results analysis of the DistMult on MIND-CtD model
* The goal of this notebook is to create a sample analysis of the model
* We will also highlight how to use some of the functions in score_utils
* Finally, we will calculate the MRR and the Hits@k, as well as extract the filtered top k results

In [1]:
import os
import pandas as pd
import score_utils as scu
os.chdir('..')

In [2]:
!bash run.sh train DistMult MIND_CtD 0 megha  200 100 225 48.0 1.0 0.009 1000000 16 

1.12.1+cu102
Start Training......
2023-04-06 09:53:03,893 INFO     Model: DistMult
2023-04-06 09:53:03,893 INFO     Data Path: data/MIND_CtD
2023-04-06 09:53:03,893 INFO     #entity: 249605
2023-04-06 09:53:03,894 INFO     #relation: 83
2023-04-06 09:53:11,355 INFO     #train: 9651042
2023-04-06 09:53:11,356 INFO     #valid: 537
2023-04-06 09:53:11,357 INFO     #test: 537
2023-04-06 09:53:12,128 INFO     Model Parameter Configuration:
2023-04-06 09:53:12,128 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-06 09:53:12,128 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-06 09:53:12,128 INFO     Parameter entity_embedding: torch.Size([249605, 225]), require_grad = True
2023-04-06 09:53:12,128 INFO     Parameter relation_embedding: torch.Size([83, 225]), require_grad = True
2023-04-06 09:54:21,520 INFO     Ramdomly Initializing DistMult Model...
2023-04-06 09:54:21,520 INFO     Start Training...
2023-04-06 09:54:21,521 INFO     init

#### Get predictions for the test file
* Run the predictions on `test.txt`.. Results should export with the `--do_predict` flag
* For `--do_test and --do_predict` flag, output file `test_scores.tsv`

In [3]:
!python -u codes/run.py --do_predict --do_test -init models/DistMult_MIND_CtD_megha #--cuda

2023-04-06 17:17:27,597 INFO     Model: DistMult
2023-04-06 17:17:27,597 INFO     Data Path: data/MIND_CtD
2023-04-06 17:17:27,597 INFO     #entity: 249605
2023-04-06 17:17:27,597 INFO     #relation: 83
2023-04-06 17:17:32,773 INFO     #train: 9651042
2023-04-06 17:17:32,775 INFO     #valid: 537
2023-04-06 17:17:32,776 INFO     #test: 537
2023-04-06 17:17:33,344 INFO     Model Parameter Configuration:
2023-04-06 17:17:33,344 INFO     Parameter gamma: torch.Size([1]), require_grad = False
2023-04-06 17:17:33,344 INFO     Parameter embedding_range: torch.Size([1]), require_grad = False
2023-04-06 17:17:33,344 INFO     Parameter entity_embedding: torch.Size([249605, 225]), require_grad = True
2023-04-06 17:17:33,344 INFO     Parameter relation_embedding: torch.Size([83, 225]), require_grad = True
2023-04-06 17:17:33,344 INFO     Loading checkpoint models/DistMult_MIND_CtD_megha...
2023-04-06 17:17:41,785 INFO     Start Training...
2023-04-06 17:17:41,785 INFO     init_step = 999999
2023-0

## Create the score input as tail-batching.
* I should of wrote the function to remove all "head-batch" entities if choosing "tail-batch" and remove all "tail-batch" when mode is "head-batch"

In [4]:
raw = scu.ProcessOutput(data_dir = './data/MIND_CtD/', scores_outfile = './models/DistMult_MIND_CtD_megha/test_scores.tsv', mode = 'tail-batch')

In [5]:
raw.df.head()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[-16.166667938232422, -15.291213035583496, -18...",head-batch
1,184021,61,183664,"[-16.166667938232422, -15.291213035583496, -18...",head-batch
2,117007,61,27517,"[-20.326316833496094, -6.151558876037598, -26....",head-batch
3,234163,61,26686,"[-63.32097244262695, -60.7429084777832, -68.17...",head-batch
4,163877,61,46731,"[-17.171836853027344, -25.85724639892578, -30....",head-batch


## Generate the true answer for tail-batch.
* True answer is anything that shows up as a "t" for a combination of "h-r" in the graph
* Can also do the inverse for head-batch

In [6]:
raw.get_true_targets()

Unnamed: 0,h,r,t
0,CHEBI:100,activates_CaG,[NCBIGene:2100]
1,CHEBI:100,affects_CafG,[NCBIGene:2100]
2,CHEBI:100,part_of_CpoBP,"[GO:0009701, GO:0046289, GO:0046290]"
3,CHEBI:100,treats_CtD,"[DOID:883, DOID:11476, DOID:8692]"
4,CHEBI:10001,palliates_CplD,[DOID:11968]
...,...,...,...
612847,WP:WP75,associated_with_PWawD,"[DOID:0070315, DOID:0080169, DOID:11950, DOID:..."
612848,WP:WP75,associated_with_PWawP,[HP:0010774]
612849,WP:WP78,associated_with_PWawD,"[DOID:0050575, DOID:0050771, DOID:0050773, DOI..."
612850,WP:WP80,associated_with_PWawD,[DOID:0110705]


## Format the raw scores to embedded values
* Initial scores datframe has some value ranging from (-,+).
* uses torch function `argsort()` to sort from high to low. Highest value becomes 1, next highest 2 ... to n highest.

In [7]:
# res is in place
raw.format_raw_scores_to_df()

Unnamed: 0,h,r,t,preds,batch
0,71951,61,183664,"[247187, 35633, 227056, 77526, 191050, 199450,...",head-batch
1,184021,61,183664,"[247187, 35633, 227056, 77526, 191050, 199450,...",head-batch
2,117007,61,27517,"[43365, 92437, 185115, 1792, 10640, 54887, 200...",head-batch
3,234163,61,26686,"[195629, 146253, 42118, 11108, 18304, 97634, 5...",head-batch
4,163877,61,46731,"[93283, 184692, 31387, 166433, 27280, 78770, 8...",head-batch
...,...,...,...,...,...
1069,196421,61,125397,"[72162, 183213, 3894, 217034, 65036, 206627, 1...",tail-batch
1070,42491,61,125397,"[103888, 144258, 118855, 132433, 83316, 237300...",tail-batch
1071,136067,61,219821,"[173455, 135066, 240535, 13869, 235715, 4847, ...",tail-batch
1072,26499,61,125397,"[1208, 172657, 36769, 227085, 31898, 2981, 118...",tail-batch


## Now we have our embedded values, can we get the actual names?
* The conversion of embedding to values are "in-place"
* note the method has a variable `direction` where it can be "from" or "to". The default is "to", meaning (value TO embedding).

In [8]:
# in place
raw.translate_embeddings(direction = "from")

Unnamed: 0,h,r,t,preds,batch
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch
2,CHEBI:135876,indication,DOID:11054,"[CHEBI:3237, CHEBI:379896, CHEBI:135899, PCID:...",head-batch
3,CHEBI:135923,indication,DOID:14499,"[MESH:D007291, CHEBI:28783, CHEBI:28285, CHEBI...",head-batch
4,CHEBI:135925,indication,DOID:1094,"[CHEBI:2611, CHEBI:94777, IKEY:YRVIKLBSVVNSHF-...",head-batch
...,...,...,...,...,...
1069,CHEBI:9667,indication,NCBIGene:367,"[DOID:9115, DOID:0060501, MONDO:0024419, DOID:...",tail-batch
1070,CHEBI:41423,indication,NCBIGene:367,"[WD:Q26721126, DOID:0070081, DOID:849, DOID:67...",tail-batch
1071,CHEBI:9168,indication,NCBIGene:7490,"[WD:Q54086296, DOID:4007, DOID:4472, DOID:4840...",tail-batch
1072,CHEBI:41879,indication,NCBIGene:367,"[WD:Q3966966, DOID:0050615, DOID:0060032, MESH...",tail-batch


In [9]:
# note the mode you imported is important. Need to reimport the dataframe as mode = "tail-batch" if you want accurate tail batch predictions
raw.calculate_mrr()

Unnamed: 0,h,r,target,preds,batch,true_t,mrr
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.000009
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",0.000009
2,CHEBI:135876,indication,DOID:11054,"[CHEBI:3237, CHEBI:379896, CHEBI:135899, PCID:...",head-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...",0.000009
3,CHEBI:135923,indication,DOID:14499,"[MESH:D007291, CHEBI:28783, CHEBI:28285, CHEBI...",head-batch,[DOID:14499],0.000009
4,CHEBI:135925,indication,DOID:1094,"[CHEBI:2611, CHEBI:94777, IKEY:YRVIKLBSVVNSHF-...",head-batch,"[MESH:D056912, DOID:1094]",0.000069
...,...,...,...,...,...,...,...
1069,CHEBI:9667,indication,NCBIGene:367,"[DOID:9115, DOID:0060501, MONDO:0024419, DOID:...",tail-batch,"[DOID:417, DOID:9074, HP:0010562, DOID:4481, D...",0.007221
1070,CHEBI:41423,indication,NCBIGene:367,"[WD:Q26721126, DOID:0070081, DOID:849, DOID:67...",tail-batch,"[DOID:7147, DOID:7148, DOID:8398, WD:Q3281303,...",0.004273
1071,CHEBI:9168,indication,NCBIGene:7490,"[WD:Q54086296, DOID:4007, DOID:4472, DOID:4840...",tail-batch,"[DOID:3963, DOID:4450, DOID:4451, DOID:4454, D...",0.007109
1072,CHEBI:41879,indication,NCBIGene:367,"[WD:Q3966966, DOID:0050615, DOID:0060032, MESH...",tail-batch,"[DOID:0050745, DOID:0060058, DOID:0060060, DOI...",0.003861


In [14]:
tail_batch = raw.calculate_mrr()

In [15]:
tail_batch.head()

Unnamed: 0,h,r,target,preds,batch,true_t,mrr
0,CHEBI:135735,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",9e-06
1,CHEBI:135738,indication,DOID:10763,"[CHEBI:77655, CHEBI:31556, IKEY:FEJVSJIALLTFRP...",head-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...",9e-06
2,CHEBI:135876,indication,DOID:11054,"[CHEBI:3237, CHEBI:379896, CHEBI:135899, PCID:...",head-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...",9e-06
3,CHEBI:135923,indication,DOID:14499,"[MESH:D007291, CHEBI:28783, CHEBI:28285, CHEBI...",head-batch,[DOID:14499],9e-06
4,CHEBI:135925,indication,DOID:1094,"[CHEBI:2611, CHEBI:94777, IKEY:YRVIKLBSVVNSHF-...",head-batch,"[MESH:D056912, DOID:1094]",6.9e-05


In [16]:
tail_batch = tail_batch.query('batch=="tail-batch"')

In [17]:
tail_batch.shape

(537, 7)

In [18]:
# Calculate overall MRR vs the 'wrong' head-batch setting.
sum(tail_batch['mrr'])/len(tail_batch['mrr'])

0.040851145953473285

In [19]:
# Get hits at K
tb_hits = raw.calculate_individual_hits_k(hits = [1,3,10]).query('batch=="tail-batch"')
tb_hits.head()

Unnamed: 0,h,r,target,preds,batch,true_t,position,ind_rank,hits_1,hits_3,hits_10
537,CHEBI:135735,indication,DOID:10763,"[WD:Q25303605, UMLS:C0221155, DOID:9654, DOID:...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, ...",8,False,False,True
538,CHEBI:135738,indication,DOID:10763,"[WD:Q25303605, UMLS:C0221155, DOID:9654, DOID:...",tail-batch,"[DOID:10591, DOID:10824, DOID:10825, DOID:1113...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",24,False,False,False
539,CHEBI:135876,indication,DOID:11054,"[DOID:11817, DOID:11820, DOID:6225, DOID:11809...",tail-batch,"[DOID:11593, DOID:11811, DOID:11812, DOID:1181...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",74,False,False,False
540,CHEBI:135923,indication,DOID:14499,"[DOID:2785, DOID:676, DOID:1607, DOID:9352, HP...",tail-batch,[DOID:14499],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",4217,False,False,False
541,CHEBI:135925,indication,DOID:1094,"[DOID:14332, MESH:D002375, DOID:0060893, DOID:...",tail-batch,"[MESH:D056912, DOID:1094]","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...",9,False,False,True


In [20]:
# calculate overall hits at K
print(f"""hits@1: {sum(tb_hits['hits_1'])/len(tb_hits['hits_1']):.4f}
hits@3: {sum(tb_hits['hits_3'])/len(tb_hits['hits_3']):.4f}
hits@10: {sum(tb_hits['hits_10'])/len(tb_hits['hits_10']):.4f}
""")

hits@1: 0.0112
hits@3: 0.0317
hits@10: 0.0950



In [19]:
tb_hits['preds']=tb_hits.preds.apply(lambda x: x[0:1000])

In [21]:
tb_hits.shape

(511, 8)

In [20]:
tb_hits.to_csv('top1000preds_rotate.tsv',sep='\t',header = True, index = False)

## Generate the top n filtered results

In [49]:
raw.filter_predictions(top = 100).query('batch=="tail-batch"').head()

Unnamed: 0,h,r,preds,batch,true_t,filt_preds
527,CHEBI:34385,indication_CiD,"[DOID:2377, CHEBI:34385, DOID:8869, DOID:2378,...",tail-batch,"[DOID:2378, DOID:2377, DOID:0050784, DOID:0050...","[CHEBI:34385, DOID:8869, DOID:3393, DOID:1824,..."
528,CHEBI:34829,indication_CiD,"[DOID:2474, DOID:2457, DOID:11204, DOID:13452,...",tail-batch,"[DOID:11204, DOID:2457, DOID:2474]","[DOID:13452, CHEBI:34829, DOID:8881, DOID:9383..."
529,CHEBI:3441,indication_CiD,"[DOID:3393, DOID:10763, DOID:4248, DOID:6000, ...",tail-batch,"[DOID:10591, DOID:6432, DOID:11130, DOID:10824...","[DOID:3393, DOID:4248, DOID:6000, DOID:5844, D..."
530,CHEBI:34730,indication_CiD,"[DOID:11729, DOID:3482, DOID:0040083, DOID:004...",tail-batch,"[DOID:11104, DOID:13034, DOID:13035, DOID:1327...","[DOID:11729, DOID:3482, DOID:0040083, DOID:004..."
531,CHEBI:3437,indication_CiD,"[DOID:916, MESH:D056486, MONDO:0005043, MESH:D...",tail-batch,"[DOID:6432, DOID:13544, DOID:11130, DOID:10825...","[DOID:916, MESH:D056486, MONDO:0005043, MESH:D..."
