In [1]:
import json
import sys
sys.path.append("relation.py")
from relation import get_relations, get_all_relations

In [2]:
nephqa_root = f'/Users/arvin/dev/GreaseLM/data_kg-dd-db-10text-sciel-1zero_q-sciel-noex/nephqa'

In [3]:
# load entity linking results
basefile = f'{nephqa_root}/statement/test.statement.umls_linked.jsonl'

with open(basefile) as f:
    lines = f.readlines()

json_lines = [json.loads(line) for line in lines]

In [4]:
# get all the CUIs per question text
question_cuis = []
for json_line in json_lines:
        ent_results = [v for ent in json_line['question']['stem_ents'] 
                       for k,v in ent.items() if k == 'linking_results']
        cuis = [ent['Concept ID'] for ent_matches in ent_results 
                for ent in ent_matches]
        question_cuis.append(cuis)

In [5]:
# get all the CUIs per choice per question 
answer_cuis = []
for json_line in json_lines:
        choice_cuis = []
        for choice in json_line['question']['choices']:
            cuis = [ent_match['Concept ID'] 
                    for ent_results in choice['text_ents'] 
                    for ent_match in ent_results['linking_results']]
            choice_cuis.append(cuis)
        answer_cuis.append(choice_cuis)

In [6]:
def get_all_related_cuis(CUI):
    json_response = get_relations(CUI)
    all_related_cuis = []
    if "result" in json_response:
        for rel in json_response["result"]:
            url_split = rel["relatedId"].split('/')
            related_cuis = [url_split[-1]]
            if len(related_cuis) > 0:
                all_related_cuis.extend([(rel_cui, rel["relationLabel"], 
                                          rel["additionalRelationLabel"]) 
                                         for rel_cui in related_cuis])
    return all_related_cuis


def get_two_hop_paths(source_cuis, dest_cuis):
    two_hop_paths = []
    for i, source_cui in enumerate(source_cuis):
        print(f"  {i}")
        int_cuis_rels = get_all_related_cuis(source_cui)
        for j, (int_cui, int_rel1, int_rel2) in enumerate(int_cuis_rels):
            print(f"  {i}.{j}")
            target_cuis_rels = get_all_related_cuis(int_cui)
            two_hop_paths.extend([[(source_cui, int_cui, int_rel1, int_rel2), 
                                   (int_cui, target_cui, target_rel1, target_rel2)] 
                                  for target_cui, target_rel1, target_rel2 in target_cuis_rels 
                                  if target_cui in dest_cuis])
    return two_hop_paths

In [None]:
# get 2-hop subgraphs from question_cuis and answer_cuis
subgraphs = []

for i, (q_cuis, a_choices_cuis) in enumerate(zip(question_cuis, answer_cuis)):
    for j, a_choice_cuis in enumerate(a_choices_cuis):
        print(f"{i}.{j}")
        valid_paths = get_two_hop_paths(q_cuis, a_choice_cuis)
        subgraphs.append(valid_paths)

0.0
  0
  0.0
  0.1
  0.2
  0.3
  0.4
  0.5
  0.6
  1
  1.0
  1.1
  1.2
  1.3
  1.4
  2
  2.0
  2.1
  2.2
  2.3
  2.4
  2.5
  2.6
  2.7
  2.8
  2.9
  2.10
  2.11
  2.12
  2.13
  2.14
  2.15
  2.16
  2.17
  2.18
  2.19
  2.20
  2.21
  2.22
  2.23
  2.24
  3
  3.0
  3.1
  3.2
  4
  4.0
  5
  5.0
  5.1
  6
  6.0
  6.1
  6.2
  6.3
  6.4
  6.5
  6.6
  6.7
  6.8
  6.9
  6.10
  6.11
  6.12
  6.13
  6.14
  6.15
  6.16
  6.17
  6.18
  6.19
  6.20
  6.21
  6.22
  6.23
  6.24
  7
  7.0
  7.1
  7.2
  7.3
  7.4
  7.5
  7.6
  7.7
  7.8
  8
  8.0
  9
  9.0
  9.1
  9.2
  10
  10.0
  10.1
  11
  11.0
  11.1
  11.2
  11.3
  11.4
  11.5
  11.6
  11.7
  12
  13
  13.0
  13.1
  13.2
  13.3
  13.4
  13.5
  13.6
  13.7
  13.8
  13.9
  13.10
  13.11
  13.12
  13.13
  13.14
  13.15
  13.16
  13.17
  13.18
  13.19
  13.20
  13.21
  13.22
  13.23
  13.24
  14
  14.0
  14.1
  14.2
  15
  16
  17
  17.0
  17.1
  17.2
  18
  18.0
  18.1
  18.2
  18.3
  18.4
  18.5
  18.6
  18.7
  18.8
  18.9
  19
  19.0
  19.1
  19

  13.7
  13.8
  13.9
  13.10
  13.11
  13.12
  13.13
  13.14
  13.15
  13.16
  13.17
  13.18
  13.19
  13.20
  13.21
  13.22
  13.23
  13.24
  14
  14.0
  14.1
  14.2
  15
  16
  17
  17.0
  17.1
  17.2
  18
  18.0
  18.1
  18.2
  18.3
  18.4
  18.5
  18.6
  18.7
  18.8
  18.9
  19
  19.0
  19.1
  19.2
  20
  20.0
  20.1
  20.2
  20.3
  20.4
  20.5
  20.6
  20.7
  20.8
  20.9
  20.10
  20.11
  20.12
  20.13
  20.14
  20.15
  20.16
  20.17
  20.18
  20.19
  20.20
  20.21
  20.22
  20.23
  20.24
  21
  21.0
  21.1
  21.2
  22
  22.0
  23
  23.0
  23.1
  23.2
  23.3
  23.4
  23.5
  24
  25
  25.0
  25.1
  25.2
  25.3
  25.4
  25.5
  25.6
  25.7
  25.8
  25.9
  25.10
  25.11
  25.12
  25.13
  25.14
  25.15
  25.16
  25.17
  25.18
  25.19
  25.20
  25.21
  25.22
  25.23
  25.24
  26
  26.0
  26.1
  26.2
  27
  27.0
  28
  28.0
  28.1
  29
  29.0
  29.1
  29.2
  29.3
  29.4
  29.5
  29.6
  29.7
  29.8
  29.9
  29.10
  29.11
  29.12
  29.13
  29.14
  29.15
  29.16
  29.17
  29.18
  29.19
  29

  2
  2.0
  2.1
  2.2
  2.3
  2.4
  2.5
  2.6
  2.7
  2.8
  2.9
  2.10
  2.11
  2.12
  2.13
  2.14
  2.15
  2.16
  2.17
  2.18
  2.19
  2.20
  2.21
  2.22
  2.23
  2.24
  3
  3.0
  3.1
  3.2
  3.3
  4
  4.0
  4.1
  4.2
  4.3
  5
  5.0
  5.1
  5.2
  5.3
  5.4
  6
  6.0
  6.1
  6.2
  6.3
  6.4
  6.5
  6.6
  7
  7.0
  8
  8.0
  8.1
  8.2
  8.3
  8.4
  9
  9.0
  9.1
  9.2
  9.3
  9.4
  9.5
  9.6
  9.7
  9.8
  9.9
  9.10
  9.11
  9.12
  9.13
  9.14
  9.15
  9.16
  9.17
  9.18
  9.19
  9.20
  9.21
  9.22
  9.23
  10
  10.0
  11
  11.0
  11.1
  12
  12.0
  12.1
  12.2
  12.3
  12.4
  12.5
  13
  13.0
  13.1
  13.2
  13.3
  13.4
  13.5
  13.6
  13.7
  13.8
  13.9
  13.10
  13.11
  13.12
  13.13
  13.14
  13.15
  13.16
  13.17
  13.18
  13.19
  13.20
  14
  14.0
  14.1
  14.2
  14.3
  14.4
  14.5
  14.6
  14.7
  14.8
  14.9
  14.10
  14.11
  14.12
  14.13
  14.14
  14.15
  14.16
  14.17
  14.18
  14.19
  14.20
  15
  15.0
  15.1
  15.2
  15.3
  15.4
  16
  17
  17.0
  17.1
  17.2
  18
  19
  19

  10
  10.0
  10.1
  10.2
  10.3
  10.4
  10.5
  10.6
  10.7
  10.8
  10.9
  10.10
  10.11
  10.12
  10.13
  10.14
  10.15
  10.16
  10.17
  11
  11.0
  11.1
  11.2
  11.3
  11.4
  11.5
  11.6
  11.7
  11.8
  11.9
  11.10
  11.11
  11.12
  12
  12.0
  12.1
  12.2
  12.3
  12.4
  12.5
  13
  14
  15
  15.0
  15.1
  15.2
  15.3
  15.4
  15.5
  15.6
  15.7
  15.8
  15.9
  15.10
  15.11
  15.12
  15.13
  15.14
  15.15
  15.16
  15.17
  15.18
  15.19
  15.20
  15.21
  15.22
  15.23
  15.24
  16
  16.0
  16.1
  16.2
  16.3
  16.4
  16.5
  17
  17.0
  17.1
  17.2
  17.3
  17.4
  17.5
  17.6
  17.7
  17.8
  17.9
  17.10
  17.11
  17.12
  17.13
  17.14
  17.15
  17.16
  17.17
  17.18
  17.19
  17.20
  17.21
  17.22
  17.23
  17.24
  18
  18.0
  19
  20
  20.0
  20.1
  21
  22
  22.0
  22.1
  22.2
  22.3
  22.4
  22.5
  22.6
  22.7
  22.8
  22.9
  22.10
  22.11
  22.12
  23
  24
  24.0
  24.1
  24.2
  24.3
  24.4
  24.5
  24.6
  25
  25.0
  25.1
  26
  26.0
  27
  27.0
  28
  28.0
  28.1
  28.2


  56.23
  56.24
  57
  57.0
  58
  58.0
  59
  59.0
  59.1
  59.2
  59.3
  59.4
  59.5
  59.6
  59.7
  59.8
  59.9
  59.10
  59.11
  59.12
  59.13
  59.14
  59.15
  59.16
  59.17
  59.18
  59.19
  59.20
  59.21
  59.22
  59.23
  60
  60.0
  60.1
  60.2
  60.3
  61
  61.0
  61.1
  61.2
  61.3
  61.4
  61.5
  61.6
  61.7
  61.8
  61.9
  61.10
  61.11
  61.12
  61.13
  61.14
  61.15
  61.16
  61.17
  61.18
  61.19
  61.20
  61.21
  61.22
  61.23
  61.24
  62
  62.0
  63
  63.0
  64
  65
  65.0
  66
  67
  68
  69
  69.0
  69.1
  69.2
  69.3
  69.4
  69.5
  69.6
  69.7
  70
  70.0
  70.1
  70.2
  70.3
  70.4
  70.5
  70.6
  70.7
  70.8
  70.9
  70.10
  70.11
  70.12
  70.13
  70.14
  70.15
  70.16
  70.17
  70.18
  70.19
  70.20
  70.21
  70.22
  70.23
  70.24
  71
  71.0
  71.1
  71.2
  71.3
  72
  72.0
  72.1
  72.2
  72.3
  72.4
  72.5
  72.6
  72.7
  72.8
  73
  73.0
  73.1
  73.2
  74
  74.0
  74.1
  74.2
  74.3
  74.4
  74.5
  74.6
  74.7
  74.8
  74.9
  74.10
  74.11
  74.12
  74.13