In [101]:
from tqdm.auto import tqdm
import itertools
import random
import csv
from random import shuffle
from collections import defaultdict
import pandas as pd
import re
from pathlib import Path

## load MRCONSO.RFF (and some basic preprocessing)

In [102]:
UMLS_DIR = Path("C:/Users/bxcai/Downloads/umls-2021AB-metathesaurus/2021AB/META")

In [103]:
with open(UMLS_DIR/"MRCONSO.RRF", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

16543671


### use only English names

In [104]:
cleaned = []
count = 0
for l in tqdm(lines):
    lst = l.rstrip("\n").split("|")
    cui, lang, synonym = lst[0], lst[1], lst[14]
    if lang != "ENG": continue # comment this out if you need all languages
    row = cui+"||"+synonym.lower()
    cleaned.append(row)
print(len(cleaned))

HBox(children=(FloatProgress(value=0.0, max=16543671.0), HTML(value='')))


11755677


### remove duplicates

In [105]:
print(len(cleaned))
cleaned = list(set(cleaned)) 
print(len(cleaned))

11755677
9580357


In [106]:
cleaned[:3]

['C1415688||hox1d',
 'C3131260||haenydra madronensis',
 'C3876881||meniscus screw, bioabsorbable (physical object)']

## add tradeneames (optional) 

Regard drug tradenames/brandnames from the relation file as synonym relations. This slightly boosts SapBERT's performance on some biomedical entity linking datasets (e.g. COMETA). MRREL.RRF can be extracted from the full UMLS release file: https://www.nlm.nih.gov/research/umls/licensedcontent/umlsarchives04.html#2020AA.

In [107]:
# load MRCONSO.RFF
with open(UMLS_DIR/"MRREL.RRF", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

54660298


In [108]:
umls_dict = {} # constrauct cui to list of name dict
for line in tqdm(cleaned):
    cui, name = line.split("||")
    if cui in umls_dict:
        umls_dict[cui].append(name)
    else:
        umls_dict[cui] = [name]

HBox(children=(FloatProgress(value=0.0, max=9580357.0), HTML(value='')))




In [109]:
tradename_mappings = {}
for l in tqdm(lines):
    if "has_tradename" in l or "tradename_of" in l:
        cells =l.split("|")
        head, tail = cells[0], cells[4]
        try: # if in CUI
            sfs = umls_dict[tail]
            tradename_mappings[head] = sfs
        except:
            continue
print(len(tradename_mappings))

HBox(children=(FloatProgress(value=0.0, max=54660298.0), HTML(value='')))


131088


In [110]:
# add tradenames
print(len(cleaned))
for cui,synonyms in tradename_mappings.items():
    for s in synonyms:
        row = cui+"||"+ s.lower()
        cleaned.append(row)
print(len(cleaned))

9580357
10542116


### remove duplications, again

In [111]:
print(len(cleaned))
cleaned_do_dup = list(set(cleaned))
print(len(list(set(cleaned_do_dup))))

10542116
10541661


## positive pairs generation

In [112]:
umls_dict = {} # constrauct cui to list of name dict, again
for line in tqdm(cleaned_do_dup):
    cui, name = line.split("||")
    if cui in umls_dict:
        umls_dict[cui].append(name)
    else:
        umls_dict[cui] = [name]

HBox(children=(FloatProgress(value=0.0, max=10541661.0), HTML(value='')))




## Load Trees

In [13]:
class TREE(object):
    def __init__(self, tree_path, map_path):
        self.tree_path = tree_path
        self.map_path = map_path
        self.load()
    
    def clean(self, term, lower=True, clean_NOS=True, clean_bracket=True, clean_dash=True):
        term = " " + term + " "
        if lower:
            term = term.lower()
        if clean_NOS:
            term = term.replace(" NOS ", " ").replace(" nos ", " ")
        if clean_bracket:
            term = re.sub(u"\\(.*?\\)", "", term)
        if clean_dash:
            term = term.replace("-", " ")
        term = " ".join([w for w in term.split() if w])
        return term

    def load(self):
        self.parent = {}
        self.children = defaultdict(set)
        self.grandchildren = defaultdict(set)
        self.text = defaultdict(set)
        with open(self.tree_path) as f:
            reader = csv.reader(f)
            reader.__next__()
            for row in reader:
                parent = row[0]
                current = row[1]
                if len(parent) > 0 and len(current) > 0:
                    self.parent[current] = parent
                    self.children[parent].add(current)

        with open(self.map_path) as f:
            reader = csv.reader(f)
            reader.__next__()
            for row in reader:
                current = row[0]
                text = row[1]
                self.text[current].add(self.clean(text))

        for current in list(self.children):
            for child in self.children[current]:
                self.grandchildren[current] = self.grandchildren[current].union(self.children[child])

        self.text = self.clean_set_dict(self.text)

    def clean_set_dict(self, d):
        out = {}
        for i in d:
            if len(d[i]) > 0:
                out[i] = tuple(d[i])
        return out


    def __len__(self):
        return len(self.text)

In [14]:
TREE_DIR = Path("D:/Projects/CODER/Hierarchical-CODER/data/cleaned/all")

TREE_SUBDIRS = [f for f in TREE_DIR.iterdir() if f.is_dir()]
trees = {}
for tree_subdir in TREE_SUBDIRS:
    print(tree_subdir.name)
    trees[tree_subdir.name] = TREE(tree_subdir/"hierarchy.csv", tree_subdir/"code2string.csv")

cpt
loinc
phecode
rxnorm


## Sample eval dataset

In [15]:
# with open('./phecode_eval.txt', 'w') as f:
#     for code in random.sample(list(set([i.split(".")[0] for i in trees["phecode"].text.keys()])), 100):
#         f.write("%s\n" % code)

In [85]:
# EVAL_DIR = Path("D:/Projects/CODER/deps")
# with open(EVAL_DIR/"COMPREHENSIVE_CUI_CUI_PAIRS_UMLS2021AB.csv", "r", encoding="utf-8") as f:
#     lines = f.readlines()
# print(len(lines))

# with open('./cui_cui_pairs_eval.txt', 'w') as f:
#     for line in random.sample(lines[1:], 200000):
#         f.write("%s" % line)

46748145


## load eval data

In [87]:
with open("./cui_cui_pairs_eval.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

200000


In [88]:
eval_cuis = []
eval_tree_codes = []
for l in tqdm(lines):
    lst = l.rstrip("\n").split(",")
    eval_cuis.append(lst[0])
    eval_cuis.append(lst[2])
print(len(eval_cuis))

HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))


400000


In [89]:
print(len(eval_cuis))
eval_cuis = list(set(eval_cuis)) 
print(len(eval_cuis))

400000
216245


In [90]:
eval_terms = []
for i in eval_cuis:
    if i in umls_dict:
        eval_terms += umls_dict[i]

In [91]:
print(len(eval_terms))
eval_terms = list(set(eval_terms)) 
print(len(eval_terms))

1272625
1160287


In [92]:
EVAL_DIR = Path("D:/Projects/CODER/Hierarchical-CODER/sapbert_hierarchical/training_data")
with open(EVAL_DIR/"phecode_eval.txt", "r", encoding="utf-8") as f:
    lines = f.readlines()
print(len(lines))

100


In [93]:
eval_phecodes = []
for l in tqdm(lines):
    lst = l.rstrip("\n")
    eval_phecodes.append(lst)
print(len(eval_phecodes))

HBox(children=(FloatProgress(value=0.0), HTML(value='')))


100


## Filter Data

In [94]:
print(len(umls_dict))
for tree in trees:
    print(tree, len(trees[tree].text))

4536277
cpt 14056
loinc 171191
phecode 1728
rxnorm 192683


In [95]:
for i in eval_cuis:
    if i in umls_dict:
        del umls_dict[i]
for phecode in eval_phecodes:
    if phecode in trees["phecode"].text:
        del trees["phecode"].text[phecode]

In [96]:
print(len(umls_dict))
for tree in trees:
    print(tree, len(trees[tree].text))

4320036
cpt 14056
loinc 171191
phecode 1728
rxnorm 192683


### generate!

In [113]:
def gen_pairs(input_list):
    return list(itertools.combinations(input_list, r=2))

In [114]:
pos_pairs = []
for k,v in tqdm(umls_dict.items()):
    pairs = gen_pairs(v)
    if len(pairs)>50: # if >50 pairs, then trim to 50 pairs
        pairs = random.sample(pairs, 50)
    for p in pairs:
        line = str(k) + "||" + p[0] + "||" + p[1]
        pos_pairs.append(line)

HBox(children=(FloatProgress(value=0.0, max=4536277.0), HTML(value='')))




In [115]:
print(len(pos_pairs))

13170830


In [116]:
pos_pairs[:3]

['C1415688||hox1d||hoxd3',
 'C1415688||hox1d||homeo box d3',
 'C1415688||hox1d||hox4']

### save the pairwise positive training file

In [117]:
with open('./training_files/umls.txt', 'w', encoding="utf-8") as f:
    for line in pos_pairs:
        f.write("%s\n" % line)

## Generate Tree Data

In [58]:
obj_list = []
obj_len = 0
for tree in trees:
    if tree == 'phecode':
        obj_list += [(i, tree) for i in trees[tree].text.keys()] * 70
        obj_len += len(trees[tree]) * 100
    elif tree == 'cpt':
        obj_list += [(i, tree) for i in trees[tree].text.keys()] * 10
        obj_len += len(trees[tree]) * 15
    else:
        obj_list += [(i, tree) for i in trees[tree].text.keys()]
        obj_len += len(trees[tree])

In [59]:
tree_pairs = {}
tree_terms = {}
for tree in trees:
    tree_pairs[tree] = []
    tree_terms[tree] = [i for j in list(trees[tree].text.values()) for i in j]


for anchor_id, tree in tqdm(obj_list):

    if anchor_id not in trees[tree].text:
        continue



    samples = {}
    samples_text = {}
    for i in range(3):
        samples[i] = []
        samples_text[i] = []
    
    

    if tree == 'phecode':
        if "." not in anchor_id:
            level = 3
        else:
            level = 3 - len(anchor_id.split(".")[1])
            
        samples[level - 1] += [i for i in trees[tree].children[anchor_id]]
        
        if anchor_id in trees[tree].parent:
            parent = trees[tree].parent[anchor_id]
            samples[level] += [parent]
            samples[level] += [i for i in trees[tree].children[parent] if i != anchor_id]
        
    else:
        samples[1] += [i for i in trees[tree].children[anchor_id]]
        if anchor_id in trees[tree].parent:
            parent = trees[tree].parent[anchor_id]
            samples[2] += [parent]
            samples[1] += [i for i in trees[tree].children[parent] if i != anchor_id]

    samples[0] += [anchor_id]
    assert len(samples[0]) == 1

    
    for i in range(3):
        for j in samples[i]:
            if j in trees[tree].text:
                samples_text[i] += trees[tree].text[j]

    
    for i in samples_text:
        if len(samples_text[i]) > 40:
            samples_text[i] = random.sample(samples_text[i], 40)
                
    
    pairs = list(itertools.combinations(samples_text[0], 2))
    if len(pairs) > 20:
        pairs = random.sample(pairs, 20)
    for p in pairs:
        line = str(0) + "||" + p[0] + "||" + p[1]
        tree_pairs[tree].append(line)    
    
    
    for i in range(1, 3):
        pairs = list(itertools.product(samples_text[0], samples_text[i]))
        if len(pairs) > 20:
            pairs = random.sample(pairs, 20)
        for p in pairs:
            line = str(i) + "||" + p[0] + "||" + p[1]
            tree_pairs[tree].append(line)    

            
    random_samples = random.sample(tree_terms[tree], 40)
    pairs = list(itertools.product(samples_text[0], random_samples))
    if len(pairs) > 20:
        pairs = random.sample(pairs, 20)
    for p in pairs:
        line = str(3) + "||" + p[0] + "||" + p[1]
        tree_pairs[tree].append(line)
        

HBox(children=(FloatProgress(value=0.0, max=625394.0), HTML(value='')))




In [66]:
for tree in tree_pairs:
    with open('./training_files/' + tree + '.txt', 'w', encoding="utf-8") as f:
        for line in tree_pairs[tree]:
            f.write("%s\n" % line)
        print(tree, len(tree_pairs[tree]))

cpt 4751900
loinc 4948254
phecode 5924240
rxnorm 5771353
