In [None]:
import os, json, argparse, random, math, re
from json import JSONEncoder
import numpy as np
import pickle
from itertools import chain
from preprocess import save_sparse, save_data
from preprocess.parse_csv import Mimic3Parser, Mimic4Parser, EICUParser, Mimic4NoteParser
from preprocess.encode import encode_code
from preprocess.build_dataset import split_patients, build_code_xy, build_heart_failure_y
from preprocess.auxiliary import generate_code_code_adjacent, generate_neighbors, normalize_adj, divide_middle, generate_code_levels
from utils import DateTimeEncoder
import pandas as pd
from tqdm import tqdm
import simple_icd_10_cm as cm
from statistics import mean, median
from sentence_transformers import SentenceTransformer, util

In [None]:
target_task = 'target_diagnoses'
# target_task = 'target_procedures'
# target_task = 'target_laborders'
# target_task = 'target_prescriptions'

save_path_parsed = 'data/mimic4/parsed'
debug_mode = False

In [None]:
with open(os.path.join(save_path_parsed, 'diagcode_longtitle.json')) as f:
    diagcode_longtitle = json.load(f)
with open(os.path.join(save_path_parsed, 'procedurecode_longtitle.json')) as f:
    procedurecode_longtitle = json.load(f)
with open(os.path.join(save_path_parsed, 'labitem_labels.json')) as f:
    labitem_labels = json.load(f)
with open(os.path.join(save_path_parsed, 'loinc_metadata.json'), 'r') as f:
    loinc_metadata = json.load(f)
with open(os.path.join('code_sys/NDC', 'ndc_metadata.json')) as f:
    ndc_metadata = json.load(f)
with open(os.path.join(save_path_parsed, 'ndc_names.json')) as f:
    ndc_names = json.load(f)

sen_model = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
if os.path.exists(os.path.join(save_path_parsed, f'score_cache_{target_task}_code_pool.json')):
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_code_pool.json')) as f:
        code_pool = json.load(f)
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_candidate_str_pool.json')) as f:
        candidate_str_pool = json.load(f)
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_candidate_char_pool.json')) as f:
        candidate_char_pool = json.load(f)
    print(f'Loaded code pool from cache')
else:
    code_pool = []
    candidate_str_pool = []
    candidate_char_pool = []
    if target_task == 'target_diagnoses':
        for code, title in tqdm(diagcode_longtitle.items()):
            code_simple = code.replace('ICD-10_', '')
            if not cm.is_valid_item(code_simple):
                continue
            code_simple = cm.add_dot(code_simple)
            code_pool.append(code_simple)
            candidate_str_pool.append(title)
            candidate_char_pool.append(title.lower().replace(' ', ''))
    elif target_task == 'target_procedures':
        for code, title in tqdm(procedurecode_longtitle.items()):
            code_simple = code.replace('ICD-10_', '')
            # code_simple is like: 10D00Z0
            code_pool.append(code_simple)
            candidate_str_pool.append(title)
            candidate_char_pool.append(title.lower().replace(' ', ''))
    elif target_task == 'target_laborders':
        for code, title in tqdm(labitem_labels.items()):
            code_pool.append(str(code))
            candidate_str_pool.append(title)
            candidate_char_pool.append(title.split('(')[0].lower().replace(' ', ''))
    elif target_task == 'target_prescriptions':
        for code, name_list in tqdm(ndc_names.items()):
            for name in name_list:
                code_pool.append(str(code))
                candidate_str_pool.append(name)
                candidate_char_pool.append(name.lower().replace(' ', ''))
    else:
        raise NotImplementedError
    # save code_pool
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_code_pool.json'), 'w') as f:
        json.dump(code_pool, f)
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_candidate_str_pool.json'), 'w') as f:
        json.dump(candidate_str_pool, f)
    with open(os.path.join(save_path_parsed, f'score_cache_{target_task}_candidate_char_pool.json'), 'w') as f:
        json.dump(candidate_char_pool, f)
    print(f'Saved code pool to cache')

sen_embs_cache_path = os.path.join(save_path_parsed, f'score_cache_{target_task}_embs.npy')
if os.path.exists(sen_embs_cache_path):
    candidate_embeddings = np.load(sen_embs_cache_path)
    print(f'Loaded sentence embeddings from {sen_embs_cache_path}')
else:
    candidate_embeddings = sen_model.encode(candidate_str_pool)
    np.save(sen_embs_cache_path, candidate_embeddings)
    print(f'Saved sentence embeddings to cache path {sen_embs_cache_path}')

In [None]:
# llm_name = 'gpt-3.5-turbo-0125'
# llm_name = 'gpt-4-0125-preview'

# LLaMA3 family
llm_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
# llm_name = 'meta-llama/Meta-Llama-3-8B'
# llm_name = 'meta-llama/Meta-Llama-3-70B-Instruct'

# LLaMA2 family
# llm_name = 'NousResearch/Llama-2-7b-chat-hf'
# llm_name = 'meta-llama/Llama-2-7b-hf'
# llm_name = 'NousResearch/Llama-2-13b-chat-hf' # (secondary)
# llm_name = 'NousResearch/Llama-2-70b-chat-hf' # (secondary)

# Mistral family
# llm_name = 'mistralai/Mistral-7B-Instruct-v0.2'
# llm_name = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
# llm_name = 'mistralai/Mixtral-8x22B-Instruct-v0.1'

# Biomed LLM
# llm_name = 'aaditya/Llama3-OpenBioLLM-8B'
# llm_name = 'aaditya/Llama3-OpenBioLLM-70B'
# llm_name = 'BioMistral/BioMistral-7B-DARE'
# llm_name = 'epfl-llm/meditron-7b'
# llm_name = 'epfl-llm/meditron-70b'

# Biomed LLM
# llm_name = 'medalpaca/medalpaca-13b' # (secondary)
# llm_name = 'axiong/PMC_LLaMA_13B' # (secondary)

lora_path = ''
# lora_path = '/home/ubuntu/derek-240318/clinical-event-pred/alignment-handbook/data/llama3-8b-instruct-sft-qlora-codes-diagnoses-full/checkpoint-5000'
lora_name = '-'.join(lora_path.split('/')[-2:]) if '/' in lora_path else lora_path
if lora_name != '':
    lora_name = '_' + lora_name

save_name = llm_name.split('/')[-1] if '/' in llm_name else llm_name
# save_name += '_wo-labs' # for ablation study only

In [None]:
# Get some intuition about how to set the nl_match_null_threshold
# sampled_min_sims = []
# sampled_max_sims = []
# for i in range(100):
#     samples = random.sample(candidate_str_pool, 1)
#     nl_emb = sen_model.encode(samples[0])
#     cos_sims = util.cos_sim(nl_emb, candidate_embeddings)[0].tolist()
#     sampled_min_sims.append(min(cos_sims))
#     sampled_max_sims.append(max(cos_sims))
# print(f'similarities for smallest randomly sampled pairs: min {min(sampled_min_sims)}, mean {mean(sampled_min_sims)}, max {max(sampled_min_sims)}')
# print(f'similarities for largest randomly sampled pairs: min {min(sampled_max_sims)}, mean {mean(sampled_max_sims)}, max {max(sampled_max_sims)}')
# # nl_match_null_threshold = min(sampled_min_sims)
# nl_match_null_threshold = 0

In [None]:
if target_task == 'target_diagnoses':
    # ['E08.3293', 'E08.329', 'E08.32', 'E08.3', 'E08', 'E08-E13', '4']
    granularity_index = [-1, -2, -3, -4, 0]
    granularity_name = ['l1_chapter', 'l2_category-groups', 'l3_category', 'l4_sub-category', 'l5_leaf']
elif target_task == 'target_procedures':
    granularity_index = [-1, -2, -3, 0]
    granularity_name = ['l1', 'l2', 'l3', 'l4_leaf']
elif target_task == 'target_laborders':
    granularity_index = [-1, -2, -3, 0]
    granularity_name = ['l1', 'l2', 'l3', 'l4_leaf']
elif target_task == 'target_prescriptions':
    # full code should have 7 chars, but mapped ATC in our data only have 5 chars, so only use 4 levels
    # granularity_index = [-1, -2, -3, -4, 0]
    # granularity_name = ['l1', 'l2', 'l3', 'l4', 'l5_leaf']
    granularity_index = [-1, -2, -3, 0]
    granularity_name = ['l1', 'l2', 'l3', 'l4_leaf']
else:
    raise NotImplementedError

In [None]:
file_pred = os.path.join('data', 'mimic4', f'{target_task}_output', f'{save_name}{lora_name}.json')
file_gold = os.path.join('data', 'mimic4', f'{target_task}', f'test.pkl')
result_save_dir = os.path.join('data', 'mimic4', f'{target_task}_result')
if not os.path.exists(result_save_dir):
    os.mkdir(result_save_dir)
result_save_path = os.path.join(result_save_dir, f'{save_name}{lora_name}.json')
output_parsed_save_path = os.path.join('data', 'mimic4', f'{target_task}_output', f'{save_name}_parsed.json')
with open(file_pred, 'r') as f:
    data_pred = json.load(f)
with open(file_gold, 'rb') as f:
    data_gold = pickle.load(f)
# with open(file_gold, 'r') as f:
#     data_gold = json.load(f)

pred_dict = {}
for pred in data_pred:
    if len(pred['output']) > 0:
        pred_dict[pred['hadm_id']] = pred['output'][0]
print(f'Loaded {len(pred_dict)} predictions from {file_pred}')

data_gold = [dp for dp in data_gold if dp['hadm_id'] in pred_dict]
print(f'Cut data_gold to {len(data_gold)}')

In [None]:
# Load ontology for this task
def universal_get_ancestors(code, level_idx):
    if target_task == 'target_diagnoses':
        ancs = [code] + cm.get_ancestors(code)
    elif target_task == 'target_procedures':
        ancestor_1 = code[:1]
        ancestor_2 = code[:2]
        ancestor_3 = code[:3]
        ancs = [code] + [ancestor_3, ancestor_2, ancestor_1]
    elif target_task == 'target_laborders':
        # v1 implementation
        # lab_def = labitem_labels[str(code)]
        # category = re.findall(r'\((.*?)\)', lab_def)[0]
        # ancs = [code] + [category]
        if 'ancestors' in loinc_metadata[str(code)]:
            ancs = loinc_metadata[str(code)]['ancestors']
        else:
            ancs = [code]
    elif target_task == 'target_prescriptions':
        ancs = []
        if code in ndc_metadata:
            if 'atc' in ndc_metadata[code]:
                for atc_this in ndc_metadata[code]['atc'][0]:
                    ancestors = []
                    atc_id = atc_this['id']
                    # A, A10, A10B, A10BA, A10BA02
                    ancestors.append(atc_id[0])
                    if len(atc_id) >= 3:
                        ancestors.append(atc_id[:2])
                    if len(atc_id) >= 4:
                        ancestors.append(atc_id[:3])
                    if len(atc_id) >= 5:
                        ancestors.append(atc_id[:4])
                    if len(atc_id) >= 7:
                        ancestors.append(atc_id[:6])
                    if len(atc_id) not in [1, 3, 4, 5, 7]:
                        print('Cannot find ancestors for ATC code with current implementation', atc_id)
                    ancestors = ancestors[::-1]
                    ancs.append(ancestors)
    else:
        raise NotImplementedError
    
    # return the code at the specified level
    if target_task == 'target_prescriptions':
        return list(set([ancestors[level_idx] for ancestors in ancs]))
    else:
        try:
            return [ancs[level_idx]]
        except:
            return [code]
    

In [None]:
def nl_to_code(nl):
    if nl.lower().replace(' ', '') in candidate_char_pool:
        code_idx = candidate_char_pool.index(nl.lower().replace(' ', ''))
    else:
        nl_emb = sen_model.encode(nl)
        cos_sim = util.cos_sim(nl_emb, candidate_embeddings)[0].tolist()
        if max(cos_sim) < 0:
            # if the similarity is too low, do not map to any code
            return []
        code_idx = cos_sim.index(max(cos_sim))
    selected_code = code_pool[code_idx]
    return [selected_code]

# def nl_to_code(nl):
#     return ['I70.219']

def segment_seq(seq, hadm_id):
    seq_ori = seq
    extracted_nl = []
    extracted_co = []
    extracted = []
    codes_pred_all = []
    points = []

    try:
        matches = re.findall(r'\n\n', seq)
        if '\n\n' in seq:
            if len(matches) <= 5:
                if seq.strip()[0].isdigit() or seq.strip()[0] == '•' or seq.strip()[0] == '-' or seq.strip()[0] == '*' or seq.strip()[0] == 'o':
                    # if the string starts with a number, then it's directly the prediction list
                    seq = seq.split('\n\n')[0]
                else:
                    # there is a disclaimer at the end
                    seq = seq.split('\n\n')[1]
                points = seq.split('\n')
            else:
                points = seq.split('\n\n')
        elif '\n' in seq:
            points = seq.split('\n')
        # elif '<sep>' in seq:
        #     points = seq.split('<sep>')
        elif '<br>' in seq:
            points = seq.split('<br>')
        else:
            print(f'Did not find points in {hadm_id}: {seq}')
    except Exception as e:
        print(f'Error in segmenting {hadm_id}: {e}')
        points = []

    if len(points) == 0:
        points = re.findall(r'\d+\.\s(.+?)(?=\d+\.|\Z)', seq_ori, flags=re.DOTALL)

    points = [p for p in points if len(p) > 0]

    for point in points:
        # get natural language part
        if len(point) > 0 and point[0].isdigit():
            match = re.search(r'\d+\.\s*(.+)', point)
            extracted_nl_this = match.group(1) if match else ''
        else:
            extracted_nl_this = point
        # get ICD-10 code
        extracted_co_this = re.findall(r'[A-Z]\d+\.[A-Za-z0-9]+', point)
        # clean NL part
        for c in extracted_co_this:
            if f"({c})" in extracted_nl_this:
                extracted_nl_this = extracted_nl_this.replace(f"({c})", '').strip()
            if f"[{c}]" in extracted_nl_this:
                extracted_nl_this = extracted_nl_this.replace(f"[{c}]", '').strip()
            if c in extracted_nl_this:
                extracted_nl_this = extracted_nl_this.replace(c, '').strip()
            extracted_nl_this = extracted_nl_this.strip()
        extracted_nl_this = extracted_nl_this.replace('<sep>', '') \
                                                .replace('(ICD-10-CM Code: )', '') \
                                                .replace('(ICD-10-CM code: )', '') \
                                                .replace('(ICD-10-CM: )', '') \
                                                .replace('ICD-10-CM Code:', '') \
                                                .replace('ICD-10-CM code:', '') \
                                                .replace('ICD-10-CM:', '') \
                                                .strip() 
        
        extracted_nl.append(extracted_nl_this)
        extracted_co.extend(extracted_co_this)
        extracted_co_this_valid = [c for c in extracted_co_this if validate_code(c)]
        
        # Convert nl predictions to code predictions
        codes_from_nl = nl_to_code(extracted_nl_this)

        # if there is valid code mentioned in this point, use the code
        # if there is no valid code mentioned, fine nl corresponding codes
        if len(extracted_co_this_valid) > 0:
            codes_pred = extracted_co_this_valid
            source_from_code = 1
        else:
            codes_pred = codes_from_nl
            source_from_code = 0
        codes_pred_all.extend(codes_pred)

        result_dict_this = {'extracted_nl': extracted_nl_this,
                            'extracted_co': extracted_co_this,
                            'codes_from_nl': codes_from_nl,
                            'source_from_code': source_from_code,
                            'codes_pred': codes_pred
                            }
        # print(result_dict_this)
        extracted.append(result_dict_this)

    if len(extracted_co) == 0 or len(extracted_nl) < 5:
        # seems extracting by points doesn't work, just extract all ICD-10 codes
        extracted_co = re.findall(r'[A-Z]\d+\.[A-Za-z0-9]+', seq_ori)

    extracted_nl = list(set(extracted_nl))
    extracted_co = list(set(extracted_co))
    codes_pred_all = list(set(codes_pred_all))

    # codes_pred_all, a flat list of codes
    # extracted, a list of dict, each dict is detail for each bullet point
    # extracted_nl, a flat list of all nl phrases, each map to a prediction
    # extracted_co, a flat list of extracted co explicitly mentioned in text, each map to a prediction
    return codes_pred_all, extracted, extracted_nl, extracted_co

def validate_code(code):
    valid_flag = False
    if target_task == 'target_diagnoses':
        if cm.is_valid_item(code):
            valid_flag = True
    elif target_task == 'target_procedures':
        if code in code_pool:
            valid_flag = True
    elif target_task == 'target_laborders':
        if str(code) in code_pool:
            valid_flag = True
    elif target_task == 'target_prescriptions':
        if str(code) in code_pool:
            valid_flag = True
    else:
        NotImplementedError
    return valid_flag

In [None]:
aggregated_result_all_levels = {}
f1_list_all_levels = {}
latex_text = ''
for level_name, level_idx in zip(granularity_name, granularity_index):
    aggregated_result_all_levels[level_name] = {
        'count_true': 0,
        'count_gold': 0,
        'count_pred': 0,
    }
    f1_list_all_levels[level_name] = []

source_from_code_flags = []
code_count_invalid = 0
code_count_all = 0

for dp_i, dp in enumerate(tqdm(data_gold)):
    if dp['hadm_id'] not in pred_dict:
        continue
    pred_seq = pred_dict[dp['hadm_id']]

    # Get ground-truth
    if target_task == 'target_diagnoses':
        # [['S271XXA', 10, 1], ... ]
        codes_gold = list(set([item[0] for item in dp['target_diagnoses']]))
        codes_gold = [cm.add_dot(c) for c in codes_gold]
    elif target_task == 'target_procedures':
        codes_gold = list(set([item[0] for item in dp['target_procedures']]))
    elif target_task == 'target_laborders':
        # [[51463,
        # Timestamp('2139-07-18 15:30:00'),
        # Timestamp('2139-07-18 15:47:00')],
        # ...
        # ]
        # for v1 data format
        # codes_gold = [str(item[2]) for item in dp['labevents'] if item[2] != '' and not math.isnan(item[4])]
        # for v2 data format, which includes target_laborders
        codes_gold = list(set([item[0] for item in dp['target_laborders']]))
    elif target_task == 'target_prescriptions':
        # [[28325232,
        # Timestamp('2139-07-18 17:00:00'),
        # Timestamp('2139-07-21 20:00:00'),
        # 'MAIN',
        # 'Acetaminophen',
        # '004489',
        # '00904198261',
        # '325mg Tablet'],
        # ...]
        codes_gold = list(set([item[6] for item in dp['target_prescriptions'] if len(item[6]) > 5]))
    else:
        raise NotImplementedError

    if debug_mode:
        print('-----', dp['hadm_id'])
    # Process prediction to bullet points
    codes_pred, extracted, extracted_nl, extracted_co = segment_seq(pred_seq, dp['hadm_id'])
    data_gold[dp_i]['extracted_details'] = extracted

    # Pinpoint this data point is nothing is extracted sucessfully
    # Do not count score for this instance, as there might be some issue for the text parsing
    if len(codes_pred) == 0:
        print(f"Did not extract any codes for admission {dp['hadm_id']}")
        continue

    # Remove predicted code that are not valid
    if len(extracted_co) > 0:
        extracted_co_valid = [c for c in extracted_co if validate_code(c)]
        code_count_invalid += len(extracted_co) - len(extracted_co_valid)
        code_count_all += len(extracted_co)

    source_from_code_flag = [item['source_from_code'] for item in extracted]
    source_from_code_flags.extend(source_from_code_flag)

    if debug_mode:
        print(extracted_nl)
        print(codes_gold)
        print(codes_pred)

    for level_name, level_idx in zip(granularity_name, granularity_index):
        # print('Level:', level_name, 'Index:', level_idx)
        # Convert code to the correct granularity
        # universal_get_ancestors may return a single code on the hierarchy, can also return a list of codes at the level for prescriptions
        #   so use the chain to flatten the list
        codes_gold_this_level = list(set(chain.from_iterable([universal_get_ancestors(c, level_idx) for c in codes_gold])))
        codes_pred_this_level = list(set(chain.from_iterable([universal_get_ancestors(c, level_idx) for c in codes_pred])))

        # Add count and F1 score for this one
        count_true_this = len(set(codes_gold_this_level).intersection(set(codes_pred_this_level)))
        count_gold_this = len(codes_gold_this_level)
        count_pred_this = len(codes_pred_this_level)
        f1_this = 2 * count_true_this / (count_gold_this + count_pred_this) if count_gold_this + count_pred_this > 0 else 0

        if debug_mode:
            print(count_true_this, count_gold_this, count_pred_this)
        
        aggregated_result_all_levels[level_name]["count_true"] += count_true_this
        aggregated_result_all_levels[level_name]["count_gold"] += count_gold_this
        aggregated_result_all_levels[level_name]["count_pred"] += count_pred_this
        f1_list_all_levels[level_name].append(f1_this)

with open(output_parsed_save_path, 'w') as f:
    json.dump(data_gold, f, indent=4, cls=DateTimeEncoder)

# Calculate aggregated scores
prec_of_levels = []
reca_of_levels = []
f1_of_levels = []
latex_texts = []

for level_name, level_idx in zip(granularity_name, granularity_index):
    aggregated_result = aggregated_result_all_levels[level_name]
    prec = aggregated_result['count_true']/aggregated_result['count_pred'] if aggregated_result['count_pred'] > 0 else 0
    reca = aggregated_result['count_true']/aggregated_result['count_gold'] if aggregated_result['count_gold'] > 0 else 0
    f1 = 2 * prec * reca / (prec + reca) if prec + reca > 0 else 0
    aggregated_result_all_levels[level_name]['precision'] = prec
    aggregated_result_all_levels[level_name]['recall'] = reca
    aggregated_result_all_levels[level_name]['f1'] = f1
    aggregated_result_all_levels[level_name]['f1_macro'] = mean(f1_list_all_levels[level_name])
    prec_of_levels.append(prec)
    reca_of_levels.append(reca)
    f1_of_levels.append(f1)

    aggregated_result_all_levels[level_name] = aggregated_result
    latex_texts.append(f" & {aggregated_result['precision'] * 100:.2f} & {aggregated_result['recall'] * 100:.2f} & {aggregated_result['f1'] * 100:.2f}")

aggregated_result_all_levels['avg_precision'] = mean(prec_of_levels)
aggregated_result_all_levels['avg_recall'] = mean(reca_of_levels)
aggregated_result_all_levels['avg_f1'] = mean(f1_of_levels)
latex_texts.append(f" & {aggregated_result_all_levels['avg_precision'] * 100:.2f} & {aggregated_result_all_levels['avg_recall'] * 100:.2f} & {aggregated_result_all_levels['avg_f1'] * 100:.2f}")
aggregated_result_all_levels['code_count_invalid'] = code_count_invalid
aggregated_result_all_levels['code_count_all'] = code_count_all
aggregated_result_all_levels['source_from_code_count'] = sum(source_from_code_flags)
aggregated_result_all_levels['latex'] = latex_texts
with open(result_save_path, 'w') as f:
    json.dump(aggregated_result_all_levels, f, indent=4)
print(json.dumps(aggregated_result_all_levels, indent=4))
print(f'Scores are saved to {result_save_path}')