In [1]:
from tqdm.auto import tqdm
import itertools
import random
import csv
from random import shuffle
from collections import defaultdict
import pandas as pd
import re
from pathlib import Path

## load MRCONSO.RFF (and some basic preprocessing)

In [2]:
UMLS_DIR = Path("C:/Users/bxcai/Downloads/umls-2021AB-metathesaurus/2021AB/META")

In [3]:
with open(UMLS_DIR/"MRCONSO.RRF", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

16543671


### use only English names

In [4]:
cleaned = []
count = 0
for l in tqdm(lines):
    lst = l.rstrip("\n").split("|")
    cui, lang, synonym = lst[0], lst[1], lst[14]
    if lang != "ENG": continue # comment this out if you need all languages
    row = cui+"||"+synonym.lower()
    cleaned.append(row)
print(len(cleaned))

HBox(children=(FloatProgress(value=0.0, max=16543671.0), HTML(value='')))


11755677


### remove duplicates

In [5]:
print(len(cleaned))
cleaned = list(set(cleaned)) 
print(len(cleaned))

11755677
9580357


In [6]:
cleaned[:3]

['C5049734||macrosiphum sp. bioug09174-g05',
 'C5005512||aprostocetus sp. bioug32162-g08',
 'C4742770||benzyl pivalate']

## add tradeneames (optional) 

Regard drug tradenames/brandnames from the relation file as synonym relations. This slightly boosts SapBERT's performance on some biomedical entity linking datasets (e.g. COMETA). MRREL.RRF can be extracted from the full UMLS release file: https://www.nlm.nih.gov/research/umls/licensedcontent/umlsarchives04.html#2020AA.

In [7]:
# load MRCONSO.RFF
with open(UMLS_DIR/"MRREL.RRF", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

54660298


In [8]:
umls_dict = {} # constrauct cui to list of name dict
for line in tqdm(cleaned):
    cui, name = line.split("||")
    if cui in umls_dict:
        umls_dict[cui].append(name)
    else:
        umls_dict[cui] = [name]

HBox(children=(FloatProgress(value=0.0, max=9580357.0), HTML(value='')))




In [9]:
tradename_mappings = {}
for l in tqdm(lines):
    if "has_tradename" in l or "tradename_of" in l:
        cells =l.split("|")
        head, tail = cells[0], cells[4]
        try: # if in CUI
            sfs = umls_dict[tail]
            tradename_mappings[head] = sfs
        except:
            continue
print(len(tradename_mappings))

HBox(children=(FloatProgress(value=0.0, max=54660298.0), HTML(value='')))


131088


In [10]:
# add tradenames
print(len(cleaned))
for cui,synonyms in tradename_mappings.items():
    for s in synonyms:
        row = cui+"||"+ s.lower()
        cleaned.append(row)
print(len(cleaned))

9580357
10542116


### remove duplications, again

In [11]:
print(len(cleaned))
cleaned_do_dup = list(set(cleaned))
print(len(list(set(cleaned_do_dup))))

10542116
10541661


## positive pairs generation

In [12]:
umls_dict = {} # constrauct cui to list of name dict, again
for line in tqdm(cleaned_do_dup):
    cui, name = line.split("||")
    if cui in umls_dict:
        umls_dict[cui].append(name)
    else:
        umls_dict[cui] = [name]

HBox(children=(FloatProgress(value=0.0, max=10541661.0), HTML(value='')))




## Load Trees

In [13]:
class TREE(object):
    def __init__(self, tree_path, map_path):
        self.tree_path = tree_path
        self.map_path = map_path
        self.load()
    
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
        term = " " + term + " "
        if lower:
            term = term.lower()
        if clean_NOS:
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
        if clean_bracket:
            term = re.sub(u"\\(.*?\\)", "", term)
        if clean_dash:
            term = term.replace("-", " ")
        term = " ".join([w for w in term.split() if w])
        return term

    def load(self):
        self.parent = {}
        self.children = defaultdict(set)
        self.grandchildren = defaultdict(set)
        self.text = defaultdict(set)
        with open(self.tree_path) as f:
            reader = csv.reader(f)
            reader.__next__()
            for row in reader:
                parent = row[0]
                current = row[1]
                if len(parent) > 0 and len(current) > 0:
                    self.parent[current] = parent
                    self.children[parent].add(current)

        with open(self.map_path) as f:
            reader = csv.reader(f)
            reader.__next__()
            for row in reader:
                current = row[0]
                text = row[1]
                self.text[current].add(self.clean(text))

        for current in list(self.children):
            for child in self.children[current]:
                self.grandchildren[current] = self.grandchildren[current].union(self.children[child])

        self.text = self.clean_set_dict(self.text)

    def clean_set_dict(self, d):
        out = {}
        for i in d:
            if len(d[i]) > 0:
                out[i] = tuple(d[i])
        return out


    def __len__(self):
        return len(self.text)

In [14]:
TREE_DIR = Path("D:/Projects/CODER/Hierarchical-CODER/data/cleaned/all")

TREE_SUBDIRS = [f for f in TREE_DIR.iterdir() if f.is_dir()]
trees = {}
for tree_subdir in TREE_SUBDIRS:
    print(tree_subdir.name)
    trees[tree_subdir.name] = TREE(tree_subdir/"hierarchy.csv", tree_subdir/"code2string.csv")

cpt
loinc
phecode
rxnorm


## Sample eval dataset

In [15]:
# with open('./phecode_eval.txt', 'w') as f:
#     for code in random.sample(list(set([i.split(".")[0] for i in trees["phecode"].text.keys()])), 100):
#         f.write("%s\n" % code)

In [16]:
# EVAL_DIR = Path("D:/Projects/CODER/deps")
# with open(EVAL_DIR/"COMPREHENSIVE_CUI_CUI_PAIRS_UMLS2021AB.csv", "r", encoding="utf-8") as f:
#     lines = f.readlines()
# print(len(lines))

# with open('./cui_cui_pairs_eval.txt', 'w') as f:
#     for line in random.sample(lines[1:], 200000):
#         f.write("%s" % line)

## load eval data

In [17]:
with open("./cui_cui_pairs_eval.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

200000


In [18]:
eval_cuis = []
eval_tree_codes = []
for l in tqdm(lines):
    lst = l.rstrip("\n").split(",")
    eval_cuis.append(lst[0])
    eval_cuis.append(lst[2])
print(len(eval_cuis))

HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))


400000


In [19]:
print(len(eval_cuis))
eval_cuis = list(set(eval_cuis)) 
print(len(eval_cuis))

400000
216245


In [20]:
eval_terms = []
for i in eval_cuis:
    if i in umls_dict:
        eval_terms += umls_dict[i]

In [21]:
print(len(eval_terms))
eval_terms = list(set(eval_terms)) 
print(len(eval_terms))

1272625
1160287


In [22]:
EVAL_DIR = Path("D:/Projects/CODER/Hierarchical-CODER/sapbert_hierarchical/training_data")
with open(EVAL_DIR/"phecode_eval.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

100


In [23]:
eval_phecodes = []
for l in tqdm(lines):
    lst = l.rstrip("\n")
    eval_phecodes.append(lst)
print(len(eval_phecodes))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


100


## Filter Data

In [24]:
print(len(umls_dict))
for tree in trees:
    print(tree, len(trees[tree].text))

4536277
cpt 14056
loinc 171191
phecode 1828
rxnorm 192683


In [25]:
for i in eval_cuis:
    if i in umls_dict:
        del umls_dict[i]

all_eval_phecodes = []
for phecode in trees["phecode"].text:
    if phecode.split(".")[0] in eval_phecodes:
        all_eval_phecodes.append(phecode)
        
for phecode in all_eval_phecodes:
    if phecode in trees["phecode"].text:
        del trees["phecode"].text[phecode]

In [26]:
print(len(umls_dict))
for tree in trees:
    print(tree, len(trees[tree].text))

4320036
cpt 14056
loinc 171191
phecode 1506
rxnorm 192683


### generate!

In [27]:
def gen_pairs(input_list):
    return list(itertools.combinations(input_list, r=2))

In [28]:
pos_pairs = []
for k,v in tqdm(umls_dict.items()):
    pairs = gen_pairs(v)
    if len(pairs)>50: # if >50 pairs, then trim to 50 pairs
        pairs = random.sample(pairs, 50)
    for p in pairs:
        line = str(k) + "||" + p[0] + "||" + p[1]
        pos_pairs.append(line)

HBox(children=(FloatProgress(value=0.0, max=4320036.0), HTML(value='')))




In [29]:
print(len(pos_pairs))

10555030


In [30]:
pos_pairs[:3]

['C0469458||hollister karaya-5 16"length 1"/3272 transparent ileostomy bags x30||hollister 3272 ileobags x30',
 'C0469458||hollister karaya-5 16"length 1"/3272 transparent ileostomy bags x30||hollister karaya-5 16"lngth 1"/3272 trans ileo bags x30',
 'C0469458||hollister 3272 ileobags x30||hollister karaya-5 16"lngth 1"/3272 trans ileo bags x30']

### save the pairwise positive training file

In [117]:
# with open('./training_files/umls.txt', 'w', encoding="utf-8") as f:
#     for line in pos_pairs:
#         f.write("%s\n" % line)

## Generate Tree Data

In [31]:
obj_list = []
obj_len = 0
for tree in trees:
    if tree == 'phecode':
        obj_list += [(i, tree) for i in trees[tree].text.keys()] * 70
        obj_len += len(trees[tree]) * 100
    elif tree == 'cpt':
        obj_list += [(i, tree) for i in trees[tree].text.keys()] * 10
        obj_len += len(trees[tree]) * 15
    else:
        obj_list += [(i, tree) for i in trees[tree].text.keys()]
        obj_len += len(trees[tree])

In [34]:
MAX_SAMPLES = 5

tree_pairs = {}
tree_terms = {}
for tree in trees:
    tree_pairs[tree] = []
    tree_terms[tree] = [i for j in list(trees[tree].text.values()) for i in j]


for anchor_id, tree in tqdm(obj_list):

    if anchor_id not in trees[tree].text:
        continue



    samples = {}
    samples_text = {}
    for i in range(3):
        samples[i] = []
        samples_text[i] = []
    
    

    if tree == 'phecode':
        if "." not in anchor_id:
            level = 3
        else:
            level = 3 - len(anchor_id.split(".")[1])
            
        samples[level - 1] += [i for i in trees[tree].children[anchor_id]]
        
        if anchor_id in trees[tree].parent:
            parent = trees[tree].parent[anchor_id]
            samples[level] += [parent]
            samples[level] += [i for i in trees[tree].children[parent] if i != anchor_id]
        
    else:
        samples[1] += [i for i in trees[tree].children[anchor_id]]
        if anchor_id in trees[tree].parent:
            parent = trees[tree].parent[anchor_id]
            samples[2] += [parent]
            samples[1] += [i for i in trees[tree].children[parent] if i != anchor_id]

    samples[0] += [anchor_id]
    assert len(samples[0]) == 1

    
    for i in range(3):
        for j in samples[i]:
            if j in trees[tree].text:
                samples_text[i] += trees[tree].text[j]

    
    for i in samples_text:
        if len(samples_text[i]) > 2*MAX_SAMPLES:
            samples_text[i] = random.sample(samples_text[i], 2*MAX_SAMPLES)
                
    
    pairs = list(itertools.combinations(samples_text[0], 2))
    if len(pairs) > MAX_SAMPLES:
        pairs = random.sample(pairs, MAX_SAMPLES)
    for p in pairs:
        line = str(0) + "||" + p[0] + "||" + p[1]
        tree_pairs[tree].append(line)    
    
    
    for i in range(1, 3):
        pairs = list(itertools.product(samples_text[0], samples_text[i]))
        if len(pairs) > MAX_SAMPLES:
            pairs = random.sample(pairs, MAX_SAMPLES)
        for p in pairs:
            line = str(i) + "||" + p[0] + "||" + p[1]
            tree_pairs[tree].append(line)    

            
    random_samples = random.sample(tree_terms[tree], 2*MAX_SAMPLES)
    pairs = list(itertools.product(samples_text[0], random_samples))
    if len(pairs) > MAX_SAMPLES:
        pairs = random.sample(pairs, MAX_SAMPLES)
    for p in pairs:
        line = str(3) + "||" + p[0] + "||" + p[1]
        tree_pairs[tree].append(line)
        

HBox(children=(FloatProgress(value=0.0, max=609854.0), HTML(value='')))




In [32]:
for tree in tree_pairs:
        print(tree, len(tree_pairs[tree]))

cpt 1219150
loinc 1540454
phecode 1607130
rxnorm 1860910


In [30]:
for tree in tree_pairs:
        print(tree, len(tree_pairs[tree]))

cpt 4751900
loinc 4948254
phecode 5924240
rxnorm 5771353


In [36]:
for tree in tree_pairs:
    with open('./training_files/' + tree + '.txt', 'w', encoding="utf-8") as f:
        for line in tree_pairs[tree]:
            f.write("%s\n" % line)
        print(tree, len(tree_pairs[tree]))

cpt 1219150
loinc 1540454
phecode 1394890
rxnorm 1860910


## Generate Phecode eval terms 

In [37]:
TREE_DIR = Path("D:/Projects/CODER/Hierarchical-CODER/data/cleaned/all")

TREE_SUBDIRS = [f for f in TREE_DIR.iterdir() if f.is_dir()]
trees = {}
for tree_subdir in TREE_SUBDIRS:
    print(tree_subdir.name)
    trees[tree_subdir.name] = TREE(tree_subdir/"hierarchy.csv", tree_subdir/"code2string.csv")

cpt
loinc
phecode
rxnorm


In [38]:
with open('./phecode_eval.txt', 'r') as f:
    lines = f.readlines()
print(len(lines))

100


In [39]:
eval_tree_codes = []
for l in tqdm(lines):
    lst = l.rstrip("\n").split(",")
    eval_tree_codes.append(lst[0])

all_eval_phecodes = []
all_eval_phecodes_text = {}
for phecode in trees["phecode"].text:
    if phecode.split(".")[0] in eval_phecodes:
        all_eval_phecodes.append(phecode)
        all_eval_phecodes_text[phecode] = trees["phecode"].text[phecode]

print(len(all_eval_phecodes))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))


322


In [40]:
MAX_SAMPLES = 5

phecode_pairs = []
phecode_terms = [i for j in list(all_eval_phecodes_text.values()) for i in j]


for anchor_id in tqdm(all_eval_phecodes):

    assert anchor_id in all_eval_phecodes_text


    samples = {}
    samples_text = {}
    for i in range(3):
        samples[i] = []
        samples_text[i] = []
    
    

    if "." not in anchor_id:
        level = 3
    else:
        level = 3 - len(anchor_id.split(".")[1])

    samples[level - 1] += [i for i in trees["phecode"].children[anchor_id]]

    if anchor_id in trees["phecode"].parent:
        parent = trees["phecode"].parent[anchor_id]
        samples[level] += [parent]
        samples[level] += [i for i in trees["phecode"].children[parent] if i != anchor_id]
        
        
    samples[0] += [anchor_id]
    assert len(samples[0]) == 1

    
    for i in range(3):
        for j in samples[i]:
            if j in all_eval_phecodes_text:
                samples_text[i] += all_eval_phecodes_text[j]

    
    for i in samples_text:
        if len(samples_text[i]) > 2*MAX_SAMPLES:
            samples_text[i] = random.sample(samples_text[i], 2*MAX_SAMPLES)
                
    
    pairs = list(itertools.combinations(samples_text[0], 2))
    if len(pairs) > MAX_SAMPLES:
        pairs = random.sample(pairs, MAX_SAMPLES)
    for p in pairs:
        line = str(0) + "||" + p[0] + "||" + p[1]
        phecode_pairs.append(line)    
    
    
    for i in range(1, 3):
        pairs = list(itertools.product(samples_text[0], samples_text[i]))
        if len(pairs) > MAX_SAMPLES:
            pairs = random.sample(pairs, MAX_SAMPLES)
        for p in pairs:
            line = str(i) + "||" + p[0] + "||" + p[1]
            phecode_pairs.append(line)    

            
    random_samples = random.sample(phecode_terms, 2*MAX_SAMPLES)
    pairs = list(itertools.product(samples_text[0], random_samples))
    if len(pairs) > MAX_SAMPLES:
        pairs = random.sample(pairs, MAX_SAMPLES)
    for p in pairs:
        line = str(3) + "||" + p[0] + "||" + p[1]
        phecode_pairs.append(line)
        

HBox(children=(FloatProgress(value=0.0, max=322.0), HTML(value='')))




In [42]:
with open('../../sapbert_hierarchical_eval/data/phecode_eval.txt', 'w', encoding="utf-8") as f:
    for line in phecode_pairs:
        f.write("%s\n" % line)
    print(len(phecode_pairs))

4325


In [48]:
from collections import Counter
Counter([i.split('||')[0] for i in phecode_pairs])

Counter({'1': 354, '2': 1180, '3': 1610, '0': 1181})

In [4]:
import pandas as pd

In [1]:
with open('../../sapbert_hierarchical_eval/data/phecode_eval.txt', 'r') as f:
    lines = f.readlines()

eval_data = []
for l in lines:
    lst = l.rstrip("\n").split("||")
    eval_data.append(lst)

In [10]:
pd.DataFrame(eval_data, columns=["dist", "term1", "term2"])

Unnamed: 0,dist,term1,term2
0,1,salmonella osteomyelitis,acute osteomyelitis involving ankle and foot
1,1,salmonella osteomyelitis,acute osteomyelitis
2,1,salmonella osteomyelitis,unspecified osteomyelitis involving hand
3,1,salmonella osteomyelitis,acute osteomyelitis involving multiple sites
4,1,salmonella osteomyelitis,acute osteomyelitis involving hand
...,...,...,...
4320,3,presence of cardiac and vascular implant and g...,other obstetrical trauma
4321,3,fitting and adjustment of other cardiac device,other obstetrical trauma
4322,3,unspecified cardiac device in situ,"arthrosis, unspecified"
4323,3,presence of cardiac and vascular implant and g...,heterophyiasis
