In [119]:
# This notebooks is an attempt to port the code form https://github.com/MostHumble/Clinical-GAN/blob/master/process_data.py 
# to a more recent mimic dataset version
# while adding suitable updates that weren't taken into account: for now mailnly scheduled to work on stratification

# Done porting: 

In [1]:
import sys
import pickle
from datetime import datetime
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import argparse
import gzip
from itertools import chain
from collections import defaultdict
from collections import Counter

#parser = argparse.ArgumentParser()

In [2]:
!ls mimic-iv-2.2/hosp

admissions.csv.gz	 emar_detail.csv.gz	    poe_detail.csv.gz
d_hcpcs.csv.gz		 hcpcsevents.csv.gz	    prescriptions.csv.gz
diagnoses_icd.csv.gz	 labevents.csv.gz	    procedures_icd.csv.gz
d_icd_diagnoses.csv.gz	 microbiologyevents.csv.gz  provider.csv.gz
d_icd_procedures.csv.gz  omr.csv.gz		    services.csv.gz
d_labitems.csv.gz	 patients.csv.gz	    transfers.csv.gz
drgcodes.csv.gz		 pharmacy.csv.gz
emar.csv.gz		 poe.csv.gz


In [3]:
mimic_iv_path = 'mimic-iv-2.2/hosp'
CCSRDX_file = 'DXCCSR_v2021-2/DXCCSR_v2021-2.csv'
CCSRPCS_file = 'PRCCSR_v2021-1/PRCCSR_v2021-1.csv'
#os.path.join(mimic_iv_path, 'ADMISSIONS.csv')
admissionFile = os.path.join(mimic_iv_path, 'admissions.csv.gz')
diagnosisFile = os.path.join(mimic_iv_path, 'diagnoses_icd.csv.gz')
procedureFile = os.path.join(mimic_iv_path, 'procedures_icd.csv.gz')
#patientsAge = os.path.join(mimic_iv_path, 'patientsAge.csv')
prescriptionFile = os.path.join(mimic_iv_path, 'prescriptions.csv.gz')
#diagnosisFrequencyFile = os.path.join(mimic_iv_path, 'WITHOUT_IF_CODE_COUNT.csv')
#outFile = 'data'

In [4]:
adm = pd.read_csv(admissionFile)

In [5]:
adm.head()

Unnamed: 0,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admit_provider_id,admission_location,discharge_location,insurance,language,marital_status,race,edregtime,edouttime,hospital_expire_flag
0,10000032,22595853,2180-05-06 22:23:00,2180-05-07 17:15:00,,URGENT,P874LG,TRANSFER FROM HOSPITAL,HOME,Other,ENGLISH,WIDOWED,WHITE,2180-05-06 19:17:00,2180-05-06 23:30:00,0
1,10000032,22841357,2180-06-26 18:27:00,2180-06-27 18:49:00,,EW EMER.,P09Q6Y,EMERGENCY ROOM,HOME,Medicaid,ENGLISH,WIDOWED,WHITE,2180-06-26 15:54:00,2180-06-26 21:31:00,0
2,10000032,25742920,2180-08-05 23:44:00,2180-08-07 17:50:00,,EW EMER.,P60CC5,EMERGENCY ROOM,HOSPICE,Medicaid,ENGLISH,WIDOWED,WHITE,2180-08-05 20:58:00,2180-08-06 01:44:00,0
3,10000032,29079034,2180-07-23 12:35:00,2180-07-25 17:55:00,,EW EMER.,P30KEH,EMERGENCY ROOM,HOME,Medicaid,ENGLISH,WIDOWED,WHITE,2180-07-23 05:54:00,2180-07-23 14:00:00,0
4,10000068,25022803,2160-03-03 23:16:00,2160-03-04 06:26:00,,EU OBSERVATION,P51VDL,EMERGENCY ROOM,,Other,ENGLISH,SINGLE,WHITE,2160-03-03 21:55:00,2160-03-04 06:26:00,0


In [6]:
def reformat_icd(code: str, version: int, is_diag: bool) -> str:
    """format icd code depending on version"""
    if version == 9:
        return reformat_icd9(code, is_diag)
    elif version == 10:
        return reformat_icd10(code, is_diag)
    else:
        raise ValueError("version must be 9 or 10")

def reformat_icd10(code: str, is_diag: bool) -> str:
    """
    Put a period in the right place because the MIMIC-3 data files exclude them.
    Generally, procedure codes have dots after the first two digits,
    while diagnosis codes have dots after the first three digits.
    """
    code = "".join(code.split("."))
    if not is_diag:
        return code
    return code[:3] + "." + code[3:]


def reformat_icd9(code: str, is_diag: bool) -> str:
    """
    Put a period in the right place because the MIMIC-3 data files exclude them.
    Generally, procedure codes have dots after the first two digits,
    while diagnosis codes have dots after the first three digits.
    """
    code = "".join(code.split("."))
    if is_diag:
        if code.startswith("E"):
            if len(code) > 4:
                return code[:4] + "." + code[4:]
        else:
            if len(code) > 3:
                return code[:3] + "." + code[3:]
    else:
        if len(code) > 2:
            return code[:2] + "." + code[2:]
    return code


In [7]:
def get_ICDs_from_mimic_file(fileName ,isdiagnosis=True):
    
    mapping = {}
    mimicFile = gzip.open(fileName, 'r')
        
    codes = []
    
    number_of_null_ICD9_codes = 0
    number_of_null_ICD10_codes = 0
    mimicFile.readline()
    for line in mimicFile:  #   0  ,     1    ,    2   ,   3  ,    4
        tokens = line.decode('utf-8').strip().split(',')
        #print(tokens)
        hadm_id = int(tokens[1])
        if ( isdiagnosis and len(tokens[3]) == 0) or ( not(isdiagnosis) and len(tokens[4]) == 0 ):
            if isdiagnosis:
                if (tokens[4] =='9'):
                    # ignore diagnoses where ICD9_code is null
                    number_of_null_ICD9_codes += 1
                else:
                    number_of_null_ICD10_codes += 1

                continue;
            else:
                if (tokens[5] =='9'):
                    # ignore diagnoses where ICD9_code is null
                    number_of_null_ICD9_codes += 1
                else:
                    number_of_null_ICD10_codes += 1

                continue;
                
        if isdiagnosis:
            ICD_code = tokens[3]
        else:
            ICD_code = tokens[4] 
            
            
        if ICD_code.find("\"") != -1:
            #print("ICD_Code before",ICD_code )
            ICD_code = ICD_code[1:-1].strip()  # toss off quotes and proceed
            #print("ICD_Code after",ICD_code )
        # since diagnosis and procedure ICD9 codes have intersections, a prefix is necessary for disambiguation
       
        if isdiagnosis:
            ICD_code = 'D' + tokens[4]+ '_' +ICD_code
        else:
            ICD_code = 'P' + tokens[5] + '_' + ICD_code

        # To understand the line below, check https://mimic.physionet.org/mimictables/diagnoses_icd/
        # "The code field for the ICD-9-CM Principal and Other Diagnosis Codes is six characters in length (not really!),
        # with the decimal point implied between the third and fourth digit for all diagnosis codes other than the V codes.
        # The decimal is implied for V codes between the second and third digit."
        # Actually, if you look at the codes (https://raw.githubusercontent.com/drobbins/ICD9/master/ICD9.txt), simply take the three first characters
        #if not map_ICD9_to_CCS:
          #  ICD_code = ICD_code[:4]  # No CCS mapping, get the first alphanumeric four letters only


        if hadm_id in mapping:
            mapping[hadm_id].append(ICD_code.strip())
        else:
            mapping[hadm_id]= [ICD_code.strip()]  
    mimicFile.close()
    print ('-Number of null ICD9 codes in file ' + fileName + ': ' + str(number_of_null_ICD9_codes))
    print ('-Number of null ICD10 codes in file ' + fileName + ': ' + str(number_of_null_ICD10_codes))
    #print ('-Number of diagnosis codes in file ' + fileName + ': ' + str(len(codes)))
    return mapping

In [8]:
def get_drugs_from_mimic_file(fileName, choice ='ndc'):
    """
    This creates a hospital to list of drugs mapping and a drug to description ( name ) map
    
    inputs:
    fileNAME : path to the procedure file
    choice : drug codification to choose from
    
    outputs:
    drugDescription : dict that maps drug codes to their name
    mapping : dict that maps hospital admissions to a list of ndc/gsn drug codes
    
    
    """
    mapping = {}
    drugDescription = {}
    mimicFile = gzip.open(fileName, 'r')  # subject_id,hadm_id,gsn,ndc,drug
    mimicFile.readline()
    number_of_null_NDC_codes = 0
    try:
        for line in mimicFile:
            #print(line)#   0  ,     1    ,    2   ,   3  ,    4
            #break
            tokens = line.decode('utf-8').strip().split(',')
            #print(tokens)
            hadm_id = int(tokens[1])
            if choice =='ndc':                        #code : Total Number of NDC code 5912
                drug_code = tokens[12]   
            else:    
                drug_code = tokens[11]                    #code : Total Number of gsn code 3081

            drug_code = drug_code.strip()  

            drug_code = 'DR'+'_'+drug_code
            if hadm_id in mapping:
                mapping[hadm_id].append(drug_code.strip())
            else:
                #mapping[hadm_id]=set()           #use set to avoid repetitions
                #mapping[hadm_id].add(drug_code.strip())
                mapping[hadm_id]=[drug_code.strip()]
                
            if drug_code not in drugDescription:
                drugDescription[drug_code] = tokens[9]
                
    except Exception as e:
        print(line)
        print(e)
    #for hadm_id in mapping.keys():
        #mapping[hadm_id] = list(mapping[hadm_id])   #convert to list, as the rest of the codes expects
    mimicFile.close()
    return drugDescription, mapping

In [9]:
def load_mimic_data(choice ='ndc'):
    """
    inputs:
    choice : the type of drug description to choose 
    
    output:
    subject_idAdmMap : dict that maps subject ids to hostpital admissions
    admDxMap : dict that maps hospital admission ids to a list of ICD-9/ICD-10 diagnosis codes
    admPxMap : dict that maps hospital admission ids to a list of ICD-9/ICD-10 procedure codes
    admDrugMap : dict that maps hospital admissions to a list of ndc/gsn drug codes
    drugDescription : dict that maps drug codes to their name
    """
    print ('Building subject_id-admission mapping, admission-date mapping')
    previous_subject = 0
    previous_admission = 0
    subject_idAdmMap = {}
    admDateMap = {}
    subject_idStatic = {}   # adm type, Insurance , ethnicity , marital status
    infd = gzip.open(admissionFile, 'r')
    infd.readline()
    for line in infd:
        tokens = line.decode('utf-8').strip().split(',')
        subject_id = int(tokens[0])
        hadm_id = int(tokens[1])
        #admTime = datetime.strptime(tokens[2], '%Y-%m-%d %H:%M:%S')
        #admTime = tokens[3]
        #admDateMap[hadm_id] = admTime
        #subject_idStatic[subject_id] = [convert_binary_to_Int(tokens)]
        if subject_id in subject_idAdmMap: 
            subject_idAdmMap[subject_id].add(hadm_id)
        else: 
            subject_idAdmMap[subject_id] = set()
            subject_idAdmMap[subject_id].add(hadm_id)
    for subject_id in subject_idAdmMap.keys():
        subject_idAdmMap[subject_id] = list(subject_idAdmMap[subject_id])  
    infd.close()

    print ('Building admission-diagnosis mapping')
    admDxMap = get_ICDs_from_mimic_file(diagnosisFile)

    print ('Building admission-procedure mapping')
    admPxMap = get_ICDs_from_mimic_file(procedureFile, isdiagnosis=False)

    print ('Building admission-drug mapping')
    drugDescription, admDrugMap = get_drugs_from_mimic_file(prescriptionFile, choice)
    return subject_idAdmMap,admDxMap,admPxMap,admDrugMap,drugDescription


In [10]:
def updateAdmCodeList(subject_idAdmMap,admDxMap,admPxMap,admDrugMap):
    """
    This function discards filtered subjects from the admission ids to procedure, diagnosis, and drug ICD code maps.
    
    inputs:
    subject_idAdmMap : dict that maps subject ids to hostpital admissions
    admDxMap : dict that maps hospital admission ids to a list of ICD-9/ICD-10 diagnosis codes
    admPxMap : dict that maps hospital admission ids to a list of ICD-9/ICD-10 procedure codes
    admDrugMap : dict that maps hospital admissions to a list of ndc/gsn drug codes
    
    outputs:
    filtred dicts 
    """
    adDx = {}
    adPx = {}
    adDrug={}
    for subject_id, admIdList in subject_idAdmMap.items():
        for admId in admIdList:
            adDx[admId] = admDxMap[admId]
            adPx[admId] =admPxMap[admId]
            #adAge[admId] = admAgeMap[admId]
            adDrug[admId] =admDrugMap[admId]
            
    return adDx,adPx,adDrug

In [11]:
def ListAvgVisit(dic):
    a =[len(intList) for k,intList in dic.items()]
    return sum(a)/len(a)

In [12]:
# New
def countCodes(*dicts):
    all_values = [value for dic in dicts for value in dic.values()]
    code_counts = Counter(code for sublist in all_values for code in sublist)
    return len(code_counts)

In [13]:
def display(pidAdmMap,admDxMap,admPxMap,admDrugMap):
    print(f" Total Number of patients {len(pidAdmMap)}")
    print(f" Total Number of admissions {len(admDxMap)}")
    print(f" Average number of admissions per patient {ListAvgVisit(pidAdmMap)}")
    print(f" Total Number of diagnosis code {countCodes(admDxMap)}")
    print(f" Total Number of procedure code {countCodes(admPxMap)}")
    print(f" Total Number of drug code {countCodes(admDrugMap)}")
    print(f" Total Number of codes {countCodes(admPxMap) +countCodes(admDxMap)+countCodes(admDrugMap) }")
    print(f" average Number of procedure code per visit {ListAvgVisit(admPxMap)}")
    print(f" average Number of diagnosis code per visit {ListAvgVisit(admDxMap)}")
    print(f" average Number of Drug code per visit {ListAvgVisit(admDrugMap)}")

In [14]:
def clean_data(subject_idAdmMap,admDxMap,admPxMap,admDrugMap, min_admissions_threshold = 2):
    # removing the subject_id which are not present in diagnostic code but present in procedure and vice versa
    print("Cleaning data...")
    subDelList = []

    print("Removing patient records who does not have all three medical codes for an admission")
    for subject_id,hadm_ids in  subject_idAdmMap.items():
        for hadm_id in hadm_ids:
            if (hadm_id not in admDxMap.keys()):
                subDelList.append(subject_id)
            if (hadm_id not in admPxMap.keys()):
                subDelList.append(subject_id)
            if (hadm_id not in admDrugMap.keys()):
                subDelList.append(subject_id)

    subDelList = list(set(subDelList))       
    #print(f"Number of subject_ids to be deleted :{len(subDelList)} ")

    for subject_id_to_rm in subDelList:
        del subject_idAdmMap[subject_id_to_rm]

    #print(f"Number of subject_ids aftr cleaning :{len(subject_idAdmMap)} ")  
            
    adDx,adPx,adDrug=updateAdmCodeList(subject_idAdmMap,admDxMap,admPxMap,admDrugMap)

    #display(subject_idAdmMap,adDx,adPx,adDrug)
    # removing patient who made less than 2 admissions

    print(f"Removing patient who made less than {min_admissions_threshold} admissions")
    pidMap = {}
    adm = []
    subDelList=[]
    subject_idAdmMap1 = subject_idAdmMap
    for pid, admIdList in subject_idAdmMap.items():
        if len(admIdList) < min_admissions_threshold:
            subDelList.append(pid)
            continue

    for i in subDelList:
        del subject_idAdmMap[i]  

    adDx,adPx,adDrug=updateAdmCodeList(subject_idAdmMap,adDx,adPx,adDrug)   
    display(subject_idAdmMap,adDx,adPx,adDrug)  
    return subject_idAdmMap,adDx,adPx,adDrug

In [15]:
def create_CCS_CCSR_mapping(CCSRDX_file,CCSRPCS_file,CCSDX_file,CCSPX_file, dump = True):

    # This part seem to create an ICD-10 Diagnosis, Procedures map to CCS token list? 
    df = pd.read_csv(CCSRDX_file)
    a = df[["\'ICD-10-CM CODE\'", "\'CCSR CATEGORY 1\'", "\'CCSR CATEGORY 2\'", "\'CCSR CATEGORY 3\'", "\'CCSR CATEGORY 4\'", "\'CCSR CATEGORY 5\'", "\'CCSR CATEGORY 6\'"]]

    a = a.map(lambda x: str(x)[1:-1])

    a = a.set_index("\'ICD-10-CM CODE\'").T.to_dict('list')
    # remove null values
    for key, value in a.items():
        newValue = []
        value = list(filter(lambda x: x.strip(),value))
        for value in value: # never seen this: value gets overwritten
            newValue.append('D10_'+value)
        a[key] =  newValue
        
    b={}
    for key in a.keys():
        new_key = 'D10_'+key 
        b[new_key] = a[key]

    df = pd.read_csv(CCSRPCS_file, on_bad_lines = 'skip' )
    df = df[["\'ICD-10-PCS\'", "\'PRCCSR\'"]]
    df = df.map(lambda x: str(x)[1:-1])
    df = df.set_index("\'ICD-10-PCS\'").T.to_dict('list')

    for key, value in df.items():
        newValue = []
        value = list(filter(lambda x: x.strip(), value))
        for value in value:
            newValue.append('P10_'+value)
        df[key] =  newValue
        
    for key in df.keys():
        new_key = 'P10_'+key 
        b[new_key] = df[key]
        
    # ICD -9 diagnosis code and prescription to CCS
    ccsTOdescription_Map = {}
    #'ICD-9-CM CODE','CCS CATEGORY','CCS CATEGORY DESCRIPTION','ICD-9-CM CODE DESCRIPTION','OPTIONAL CCS CATEGORY','OPTIONAL CCS CATEGORY DESCRIPTION'
    #dxref_ccs_file = open('Single_Level_CCS_2015/$dxref 2015.csv', 'r')
    dxref_ccs_file = open(CCSDX_file, 'r')
    dxref_ccs_file.readline() #note
    dxref_ccs_file.readline() #header
    dxref_ccs_file.readline() #null
    for line in dxref_ccs_file:
        tokens = line.strip().split(',')
        # since diagnosis and procedure ICD9 codes have intersections, a prefix is necessary for disambiguation
        b['D9_'+str(tokens[0][1:-1]).strip()] = 'D9_'+str(tokens[1][1:-1]).strip() #[1:-1] retira aspas
        ccsTOdescription_Map['D9_'+str(tokens[1][1:-1]).strip()] = str(tokens[2][1:-1]).strip() #[1:-1] retira aspas
    dxref_ccs_file.close()

    dxprref_ccs_file = open(CCSPX_file, 'r')
    dxprref_ccs_file.readline() #note
    dxprref_ccs_file.readline() #header
    dxprref_ccs_file.readline() #null
    for line in dxprref_ccs_file:
        tokens = line.strip().split(',')
        #since diagnosis and procedure ICD9 codes have intersections, a prefix is necessary for disambiguation
        b['P9_'+str(tokens[0][1:-1]).strip()] = 'P9_'+str(tokens[1][1:-1]).strip() #[1:-1] retira aspas
        ccsTOdescription_Map['P9_'+str(tokens[1][1:-1]).strip()] = str(tokens[2][1:-1]).strip() #[1:-1] retira aspas
    dxprref_ccs_file.close()

    if dump:
        pickle.dump(b, open('ICD_9_10_to_CSS', 'wb'), -1)
        pickle.dump(ccsTOdescription_Map, open('ccs_to_description_dictionary', 'wb'), -1)
    print ('Total ICD to ccs entries: ' + str(len(b)))
    print( 'Total ccs codes/descriptions: ' + str(len(ccsTOdescription_Map)))

    v1= []
    for v in b.values():
        for val in v:
            
            v1.append(val)
    v1 = list(set(v1))
    print("total number of unqiue codes(DIag + proc):", len(v1))

    return ccsTOdescription_Map

In [16]:
def map_ccsr_description(filename, cat = 'Diag'):
    if cat == 'Diag':
        padStr = 'D10_'
    else:
        padStr = 'P10_'
    df = pd.read_excel(filename, sheet_name="CCSR_Categories", skiprows=1)
    if type!='Diag':
        df = df[:-1]
    codeDescription = df[["CCSR Category", "CCSR Category Description"]]
    codeDescription = codeDescription.map(lambda x: padStr+str(x))
    codeDescription = codeDescription.set_index("CCSR Category").T.to_dict('list')
    for key,value in codeDescription.items():
        newValue = value[0][4:]
        codeDescription[key] = newValue

    return codeDescription

In [17]:
def convValuestoList(codeDic):
    for key, value in codeDic.items():
        codeDic[key] =  [value]
    return codeDic

In [18]:
def map_ICD_to_CCSR(mapping):
    icdTOCCS_Map = pickle.load(open('ICD_9_10_to_CSS','rb'))
    CodesToInternalMap = {}
    missingCodes = []
    set_of_used_codes = set()
    number_of_codes_missing = 0
    countICD9=0
    countICD10 =0
    for (hadm_id, ICDs_List) in mapping.items():
        for ICD in ICDs_List:
            #print(ICD,type(ICD),len(ICD))
            #while (len(ICD9) < 6): ICD9 += ' '  #pad right white spaces because the CCS mapping uses this pattern
            if ICD.startswith('D10_'):
                padStr = 'D10_'
            elif ICD.startswith('D9_'):
                padStr = 'D9_'
            elif ICD.startswith('P10_'):
                padStr = 'P10_'    
            elif ICD.startswith('P9_'):
                padStr = 'P9_'  
            else:
                print("Wrong coding format")

            try:

                CCS_code = icdTOCCS_Map[ICD]

                if hadm_id in CodesToInternalMap:
                    if(isinstance(CCS_code, str)): 
                        CodesToInternalMap[hadm_id].append(CCS_code)
                    else:
                        for code in CCS_code:
                            CodesToInternalMap[hadm_id].append(code)
                        
                else:
                    if(isinstance(CCS_code, str)): 
                        CodesToInternalMap[hadm_id] = [CCS_code]
                    else:
                        for i in range(len(CCS_code)):
                            if i==0:
                                CodesToInternalMap[hadm_id] = [CCS_code[i]]
                            else:
                                CodesToInternalMap[hadm_id].append(CCS_code[i])
                                
                            
                set_of_used_codes.add(ICD)

            except KeyError:
                #print(f"the mapping of {ICD} {hadm_id}")
                missingCodes.append(ICD)
                #print(f"the mapping of  is : {icdTOCCS_Map[ICD]}")
                number_of_codes_missing +=1
                #print (str(sys.exc_info()[0]) + '  ' + str(ICD) + ". ICD9 code not found, please verify your ICD9 to CCS mapping before proceeding.")


            
    print(f"total number of ICD9 codes used {countICD9} and ICD10 codes: {countICD10}")  
    print ('-Total number (complete set) of ICD9+ICD10 codes (diag + proc): ' + str(len(set(icdTOCCS_Map.keys()))))
    #print ('-Total number (complete set) of CCS codes (diag + proc): ' + str(len(set(icd9TOCCS_Map.values()))))
    print ('-Total number of ICD codes actually used: ' + str(len(set_of_used_codes)))
    print ('-Total number of ICD codes missing in the admissions list: ' , number_of_codes_missing)
    #print(icd9TOCCS_Map)
    
    return CodesToInternalMap,missingCodes,set_of_used_codes

In [19]:
def displayCodeStats(adDx,adPx,adDrug):
    print(f" Total Number of diagnosis code {countCodes(adDx)}")
    print(f" Total Number of procedure code {countCodes(adPx)}")
    print(f" Total Number of drug code {countCodes(adDrug)}")
    print(f" Total Number of unique  D,P codes {countCodes(adDx,adPx) }")
    print(f" Total Number of all codes {countCodes(adDx,adPx,adDrug) }")


    print(f" average Number of procedure code per visit {ListAvgVisit(adPx)}")
    print(f" average Number of diagnosis code per visit {ListAvgVisit(adDx)}")
    print(f" average Number of drug code per visit {ListAvgVisit(adDrug)}")

    print(f" Min. and max. Number of diagnosis code per admission {minMaxCodes(adDx)}")
    print(f" Min. and max. Number of procedure code  per admission{minMaxCodes(adPx)}")
    print(f" Min. and max. Number of drug code  per admission {minMaxCodes(adDrug)}")

In [20]:
def minMaxCodes(dic):
    countCode = []
    for codes in dic.values():
        countCode.append(len(codes))    
                
    return min(countCode),max(countCode)

In [21]:
def icd_mapping(CCSRDX_file,CCSRPCS_file,CCSDX_file,CCSPX_file,D_CCSR_Ref_file,P_CCSR_Ref_file,adDx,adPx,adDrug,drugDescription):
    # creating mappint between all ICD codes to CCS and CCSR mapping
    ccsTOdescription_Map = create_CCS_CCSR_mapping(CCSRDX_file,CCSRPCS_file,CCSDX_file,CCSPX_file)
    # getting the description of all codes
    DxcodeDescription = map_ccsr_description(D_CCSR_Ref_file)
    PxcodeDescription = map_ccsr_description(P_CCSR_Ref_file, cat = 'Proc')
    codeDescription ={**DxcodeDescription ,**PxcodeDescription }
    codeDescription ={**codeDescription , **convValuestoList(ccsTOdescription_Map), **drugDescription}
    # mapping diagnois codes
    adDx,missingDxCodes,set_of_used_codes1 = map_ICD_to_CCSR(adDx)
    # mapping procedure codes
    print('here it should be working')
    print(adPx[23384508])
    print('---------------')
    adPx,missingPxCodes,set_of_used_codes2 = map_ICD_to_CCSR(adPx)
    print( 'P10_0QS604Z' in missingPxCodes)
    codeDescription['SOH'] = 'Start of history'
    codeDescription['EOH'] = 'End of history'
    codeDescription['BOV'] = 'Beginning of visit'
    codeDescription['EOV'] = 'End of visit'
    codeDescription['BOS'] = 'Beginning of sequence'
    codeDescription['PAD'] = 'Padding'
    displayCodeStats(adDx,adPx,adDrug)
    return adDx,adPx,codeDescription

In [22]:
def trim(adDx, adPx, adDrug, min_dxm, min_px, min_drg):
    print("Trimming the diagnosis, procedure, and medication codes for each visit")
    
    for admission, DiagCodes in adDx.items():
        adDx[admission] = DiagCodes[:min_dx]
        
    for admission, ProcCodes in adPx.items():
        adPx[admission] = ProcCodes[:min_px]
        
    for admission, DrugCodes in adDrug.items():
        adDrug[admission] = DrugCodes[:min_drg]
        
    displayCodeStats(adDx, adPx, adDrug)
    return adDx, adPx, adDrug


In [23]:
def buildData(subject_idAdmMap,adDx,adPx,adDrug, minVisits = 2):
    
    adPx, adDx, adDrug = map(lambda d: defaultdict(list, d), (adPx, adDx, adDrug)) # add default [] for missing values

    print (f'Building admission-Visits mapping & filtering patients with less than {minVisits} ')
    pidSeqMap = {}
    
    skipped = 0 
    for subject_id, admIdList in subject_idAdmMap.items():
        if len(admIdList) < minVisits: 
            skipped += 1
            continue # skip patients with less than minVisits ( default 1 )
        sortedList = [( adDx[admId], adPx[admId],adDrug[admId]) for admId in admIdList]
        
        pidSeqMap[subject_id] = sortedList
        
    adPx, adDx, adDrug = map(dict, (adPx, adDx, adDrug))  # remove default [] behavior to not break something

    
    print(f'{skipped} subjects were removed')
    print ('Building subject-id, diagnosis,procedure,drugs mapping')
    subject_ids = []
    dates = []
    seqs =[]
    ages = []
    for subject_id,visits in pidSeqMap.items():
        subject_ids.append(subject_id)
        diagnose = []
        procedure = []
        drugs = []
        date = []
        seq=[]
        #age = []
        for visit in visits:
            #date.append(visit[0])
            #age.append(visit[4])
            #joined = [visit[4]] + visit[1] +visit[2]+visit[3]
            joined = list(dict.fromkeys(chain.from_iterable(visit))) # dict.from keys used as an ordered set func
            seq.append(joined)
        #dates.append(date)
        seqs.append(seq)
        #ages.append(age)
    
    print ('Converting Strings Codes into unique integer, and making types')
    types={}
    newSeqs = []
    for patient in seqs:
        newPatient = []
        #print("patient",patient)
        for visit in patient:
            #print("vsit",visit)
            newVisit = []
            for code in visit:
                #print("code",code)
                if code in types:
                    newVisit.append(types[code])
                else:
                    types[code] = len(types)
                    newVisit.append(types[code])
                    #print("newVisit",newVisit)
            newPatient.append(newVisit)
        newSeqs.append(newPatient)
    return newSeqs,types

In [24]:
def ListAvgVisitForRemoveCode(dic):
    a =[len(intList) for intList in dic]
    return sum(a)/len(a)

In [25]:
def removeCode(currentSeqs, types, threshold=5):

    print(ListAvgVisitForRemoveCode(currentSeqs))
    
    countCode = Counter()
    
    for visits in currentSeqs:
        for visit in visits:
            countCode.update(visit)
            
    codes = [key for key, value in countCode.items() if value <= threshold]
    
    print(f" Total number of codes removed: {len(codes)}  ")
    print(f" Total number of  unique codes : {len(countCode)}  ")

    reverseTypes = {v:k for k,v in types.items()}

    # List of codes like : D9_660...
    types = defaultdict(lambda: len(types), {"PAD": 0,"BOH":1 ,"BOS": 2, "BOV": 3, "EOV": 4, "EOH": 5})

    # Recreates a new mapping while taking into consideration the removed tokens
    updatedSeqs = [[[types[reverseTypes[code]] for code in visit if code not in codes] for visit in patient] for patient in currentSeqs]
    
    reverseTypes = {v:k for k,v in types.items()}

    return updatedSeqs, dict(types), reverseTypes

In [26]:
def saveFiles(updatedSeqs,types,codeDescription,outpath = 'outputData/originalData'):

    if not os.path.exists(outpath):
        os.makedirs(outpath)
    
    pickle.dump(updatedSeqs, open(outFile+'.seqs', 'wb'), -1)
    pickle.dump(types, open(outFile+'.types', 'wb'), -1)
    pickle.dump(codeDescription, open(outFile+'.description', 'wb'), -1)

In [27]:
def generateCodeTypes(outFile,reverseTypes):
    ICD_9_10_to_CSS = pickle.load(open('ICD_9_10_to_CSS','rb'))
    codeType = {}
    countD = 0
    countP=0
    countDr =0
    countT =0
    for keys,values in reverseTypes.items():
        found =0
        if keys not in codeType:
            if values.startswith('DR_'):
                found =1        
                codeType[keys] ='DR'
                countDr= countDr+1
                
            elif values=='PAD' or values == 'BOH' or values == "BOS" or values == 'BOV' or  values=='EOV' or values=='EOH':
                found = 1
                codeType[keys] = 'T'
                countT= countT+1
            else:
                for k,v in ICD_9_10_to_CSS.items():
                    if values in v:
                        found = 1
                        if keys not in codeType:
                            if k.startswith('D'):
                                codeType[keys] = 'D'
                                countD = countD+1
                            elif k.startswith('P'):
                                codeType[keys] = 'P'
                                countP = countP+1
            if found == 0:
                print(keys,values)
    print(countD,countP,countDr,countT)        
    pickle.dump(codeType, open(outFile+'.codeType', 'wb'), -1)
    
    return codeType

In [28]:
def load_data(outFile):
    # load the data again
    seqs = pickle.load(open(outFile +'.seqs','rb'))
    types = pickle.load(open(outFile + '.types','rb'))
    codeType = pickle.load(open(outFile + '.codeType','rb'))
    reverseTypes = {v:k for k,v in types.items()}
    return seqs,types,codeType,reverseTypes

In [29]:
def PrepareForTF(sequence):
    X, y, pairs = list(), list(),list()
    for i in range(len(sequence)):
    # find the end of this pattern
        if i+1 >= len(sequence):
            break
        seq_x, seq_y = sequence[:i+1], sequence[i+1:]
        X.append(seq_x)
        y.append(seq_y)
    pairs=pairing1(X,y)
    return pairs

In [30]:
def PrepareForSDP(sequence):
    X, y,pairs = list(), list(),list()
    for i in range(len(sequence)):
    # find the end of this pattern
        #print(f"i:{i}, seq: {len(sequence)} \n {sequence}")
        # check if we are beyond the sequence
        if i+1 >= len(sequence):
            break
        seq_x, seq_y = sequence[:i+1], [sequence[i+1]]
        #print(f"X: {seq_x} ----\n Y: {seq_y}")
       # print("in")
        #print(sequence[:i+1],sequence[i+1])
        X.append(seq_x)
        y.append(seq_y)
        #print(f"X: {X}, Y: {y}")
    pairs=pairing1(X,y)
    return pairs

In [31]:
def PrepareForSDPclean(sequence):
    X, y,pairs = list(), list(),list()
    for i in range(len(sequence)-1):

        #seq_x, seq_y = sequence[:i+1], [sequence[i+1]]
        X.append(sequence[:i+1])
        y.append( [sequence[i+1]])
        #print(f"X: {X}, Y: {y}")
    pairs=pairing1(X,y)
    return pairs

In [32]:
def PrepareForDAI(sequence, n_steps):
    X,pairs = list(),list()
    for i in range(n_steps):
        # find the end of this pattern
        end_ix = i + n_steps
        # check if we are beyond the sequence
        if end_ix > len(sequence)-1:
            break
        # gather input  parts of the pattern
        seq_x = sequence[i:]
        X.append(seq_x)
    pairs=pairing2(X)
    return pairs

In [34]:
def removePairs(newPairs, mn = 600):
    print(f"\n  Total no of pairs before removing :{len(newPairs)}")
    b = len(newPairs)
    x,y,curPair = [],[],[]
    count,county,counts =0,0,0
    for pair in newPairs:
        if len(pair[0]) > mn and len(pair[1]) > mn:
            counts =counts +1
            #newPairs.remove(pair)   
        elif len(pair[0]) > mn or len(pair[1]) > mn:
            count =count +1
            #newPairs.remove(pair)
        else:
            curPair.append(pair)
            
    print(f"\n  Total no of pairs after removing :{len(curPair)}")
    print(f"\n  Total no of pairs removed :{b-len(curPair)}")
    return curPair

In [35]:
def formatData(originalSeqs, dataFormat = 'TF', mn = 400):
    
    pairs = []
    
    for i in range(len(originalSeqs)):
        # Trajectory forecasting (TF): predict until the end of EOH
        if dataFormat == 'TF':
            pairs.extend(PrepareForTF(originalSeqs[i]))
        # Sequential disease prediction (SDP): predict until the next visit
        elif dataFormat == 'SDP':
            pairs.extend(PrepareForSDP(originalSeqs[i]))
        elif dataFormat == 'DAI':
            pairs.extend(PrepareForDAI(originalSeqs[i],1))
        else:
            
            print("Wrong Format")
            
    newPairs,p = [], []

    for pair in pairs:
        #print("paiot",pair)
        input,output,p =[],[],[]
        for i in pair[0]:
            #print("i",i)
            i = i +[2]
            input.extend(i)
        p.append([1]+ input + [3])
        for o in pair[1]:
            o = o +[2]
            #print("o",o)
            output.extend(o)
        p.append([1]+ output+ [3])

        newPairs.append(tuple(p))
    ## sample
    n =2
    #print(f" Orginal: {pairs[:10]}  \n\n\n After formating : {newPairs[:10]} \n ----------------------------------------\n\n\n")
    if(stats(newPairs, mn = mn)):
        print(f"\n\n\nRemoving pairs greater than  {mn} seq length")
        newPairs = removePairs(newPairs,mn=mn)
        stats(newPairs)
    return newPairs

In [36]:
def pairing1(x,y):
    pairs =[]
    for i,a in enumerate(zip(x,y)):
        pairs.append(a)
    return pairs

In [37]:
def pairing2(x):
    inp,trg,pairs =[],[],[]
    for x in x:
        inp.append(x[:-1])
        trg.append(x[1:])
    for i,a in enumerate(zip(inp,trg)):
        pairs.append(a)
    return pairs

In [38]:
def pairing3(x):
    inp,trg,pairs =[],[],[]
    for x in x:
        inp.append(x[:-1])
        trg.append([x[-1]])
    for i,a in enumerate(zip(inp,trg)):
        pairs.append(a)
    return pairs

In [39]:
def stats(newPairs,mn =600):
    x,y = [],[]
    count,county,counts =0,0,0
    for pair in newPairs:
        if len(pair[0]) > mn:
            count =count +1

        if len(pair[1]) > mn:
            county =county +1 

        if len(pair[0])>mn and len(pair[1])  >mn:
            counts = counts +1
        x.append(len(pair[0]))
        y.append(len(pair[1]))
    #print(f"\n Statistics of the input and output data")
    #print(f"\n Avg seq len y: {sum(y)/len(y)} ,  Avg seq len x: {sum(x)/len(x)}")  
    #print(f"\n Total no of pairs > seq of len({mn}): \n X: {count},\n Y : {county},\n X,Y : {counts},\n total pairs :{len(newPairs)} \n max value X :{max(torch.tensor(x)) }\n max value Y :{max(torch.tensor(y))}")
    if count > 0 or county > 0 or counts > 0:
        run = True
    else:
        run = False
    return run

In [47]:
def resetIntegerOutput(updSeqs,isall =1):
    # updating the output codes to reduce hypothesis space as some of the medical codes have been removed.

    # outTypes = {prev-codes : new-codes} ,  token codes remain same 
    updPair = []
    outTypes = {}
    outTypes.update({0:0 , 1:1,  2:2, 3:3})
    for i,pair in enumerate(updSeqs):
        newVisit = []
        for code in pair[1]:
            if code in outTypes:
                newVisit.append(outTypes[code])
            else:
                outTypes[code] = len(outTypes)
                newVisit.append(outTypes[code])
        updPair.append((pair[0],newVisit))
    return updPair,outTypes

In [45]:
def updateOutput(newPairs, codeType, diagnosis=0, procedure=0, drugs =0, all = 0):
    updSeqs = []
    if procedure == 1 and drugs == 1:
        print("\n Removing drug and procedure codes from output for forecasting diagnosis code only")
        for i,pair in enumerate(newPairs):
            newOutput = []
            for code in pair[1]:
                if (codeType[code] =='D' or codeType[code] =='T'):
                    newOutput.append(code)
                        
            if len(newOutput) >= 4:
            #print(f"{newOutput} \n")
                updSeqs.append((pair[0],newOutput))
    if drugs == 1 and procedure == 0:
        print("\n Removing only drug codes from output for forecasting diagnosis and procedure code only")
        for i,pair in enumerate(newPairs):
            newOutput = []
            for code in pair[1]:
                if not (codeType[code] == 'DR'):
                    newOutput.append(code)
            if len(newOutput)>=4:
                updSeqs.append((pair[0],newOutput))
    if all:
        print("\n keeping all codes")
        updSeqs = newPairs.copy()
        
    return updSeqs

In [52]:
def storeFiles(pair,outTypes,codeType,types,reverseTypes,outFile):
    if not os.path.exists(outFile):
        os.makedirs(outFile)
    pickle.dump(pair, open(outFile+'.seqs', 'wb'), -1)
    pickle.dump(outTypes, open(outFile+'.outTypes', 'wb'), -1)
    pickle.dump(codeType, open(outFile+'.codeType', 'wb'), -1)
    pickle.dump(types, open(outFile+'.types', 'wb'), -1)
    pickle.dump(reverseTypes, open(outFile+'.reverseTypes', 'wb'), -1)
    reverseOutTypes = {v:k for k,v in outTypes.items()}
    pickle.dump(reverseOutTypes, open(outFile+'.reverseTypes', 'wb'), -1)

In [40]:
CCS_DIR = './CSS/'

CCSRDX_file = os.path.join(CCS_DIR, 'DXCCSR_v2021-2.csv')
CCSRPCS_file = os.path.join(CCS_DIR, 'PRCCSR_v2021-1.CSV')

CCSDX_file = os.path.join(CCS_DIR, '$dxref 2015.csv')
CCSPX_file = os.path.join(CCS_DIR, '$prref 2015.csv')

D_CCSR_Ref_file = os.path.join(CCS_DIR, 'DXCCSR-Reference-File-v2021-2.xlsx')
P_CCSR_Ref_file = os.path.join(CCS_DIR, 'PRCCSR-Reference-File-v2021-1.xlsx')

In [54]:
print("Loading the data...")
subject_idAdmMap,admDxMap,admPxMap,admDrugMap,drugDescription = load_mimic_data()
print("\n Completed...")
#stage 2 and 3
print("\n Cleaning data...")
subject_idAdmMap,adDx,adPx,adDrug = clean_data(subject_idAdmMap,admDxMap,admPxMap,admDrugMap)
print("\n Completed...")
#stage 4
print("\nMapping ICD data to CCS and CCSR...")
adDx,adPx,codeDescription = icd_mapping(CCSRDX_file,CCSRPCS_file,CCSDX_file,CCSPX_file,D_CCSR_Ref_file,P_CCSR_Ref_file,adDx,adPx,adDrug,drugDescription)
print("\n Completed...")
#stage 5
print("\n Trimming the codes assigned per visit based on a threshold...")
min_dx, min_px, min_drg = 80, 80, 80 
adDx, adPx, adDrug= trim(adDx, adPx, adDrug, min_dx, min_px, min_drg)
print("\n Completed...")
print("\n Building the data..")
newSeqs,types=buildData(subject_idAdmMap,adDx,adPx,adDrug)
#stage 6
threshold = 5
print(f"\n removing the code whose occurence is less than a certain threshold: {threshold}")
updatedSeqs ,types ,reverseTypes  = removeCode(newSeqs,types,threshold=threshold)
# outFile - is a folder path in the working directory where the data is going to get stored
outFile = os.path.join('outputData','originalData')
print("\n Save the data before formmating based on the task")
saveFiles(updatedSeqs,dict(types),codeDescription)
codeType = generateCodeTypes(outFile,reverseTypes)
seqs,types,codeType,reverseTypes = load_data(outFile)
print("\n Completed...")
print("\n Preparing data for Trajectory Forecasting....")
# sequence length threshold  -mn
seqLength = 500
newPairs = formatData(seqs,dataFormat = 'TF', mn = seqLength)
diagnosisOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d')
diagnosisProcedureOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d_p')
AllOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d_p_dr')
# sequence length threshold  -mn
seqLength = 500
newPairs = formatData(seqs,dataFormat = 'TF', mn = seqLength)
diagnosisOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d')
diagnosisProcedureOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d_p')
AllOutputFile = os.path.join('outputData','TF','Inp_d_p_dr_out_d_p_dr')
AllUpdPair,AllOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =0,all =1))
print(f"\n Remove certain codes from output for different data formats")
AllUpdPair,AllOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =0,all =1))
diagnosisUpdPair,diagnosisOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=1,drugs =1,all =0))
diagnosisProcedureUpdPair,diagnosisProcedureOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =1,all =0))
AllUpdPair,AllOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =0,all =1))
print(f"\n Remove certain codes from output for different data formats")
AllUpdPair,AllOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =0,all =1))
diagnosisUpdPair,diagnosisOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=1,drugs =1,all =0))
diagnosisProcedureUpdPair,diagnosisProcedureOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=0,drugs =1,all =0))

print(f"\n total # S1 records : {len(diagnosisUpdPair)}\n total # S2 records :{len(diagnosisProcedureUpdPair)}\n total # S3 records :{len(AllUpdPair)}")
print(f"\n total Dx codes:{len(diagnosisOutTypes)} \n  total Dx,Px codes:{len(diagnosisProcedureOutTypes)} \n total Dx,Px,Rx codes:{len(AllOutTypes)}")
print("\n Storing all the information related to Trajectory Forecasting...")


storeFiles(diagnosisUpdPair,diagnosisOutTypes,codeType,types,reverseTypes,diagnosisOutputFile)
storeFiles(diagnosisProcedureUpdPair,diagnosisProcedureOutTypes,codeType,types,reverseTypes,diagnosisProcedureOutputFile)
storeFiles(AllUpdPair,AllOutTypes,codeType,types,reverseTypes,AllOutputFile)
print("\n Completed...")

print("\nPreparing data for Sequential disease prediction....")
newPairs = formatData(seqs,dataFormat = 'SDP',mn =500)
diagnosisOutputFile = os.path.join('outputData','SDP','Inp_d_p_dr_out_d')

print(f"\n\n Remove certain codes from output for different data formats")
diagnosisUpdPair,diagnosisOutTypes= resetIntegerOutput(updateOutput(newPairs.copy(),codeType,diagnosis=0,procedure=1,drugs =1,all =0))

print(f"\n total # records: {len(diagnosisUpdPair)} \n total # of codes: {len(diagnosisOutTypes)}")

print("\n Storing all the information related to TSequential disease prediction...")
storeFiles(diagnosisUpdPair,diagnosisOutTypes,codeType,types,reverseTypes,diagnosisOutputFile)
print("\n Completed...")
print("\n All the preprocessing step has been completed, Now use the data in the outputData folder to build the model...")

Loading the data...
Building subject_id-admission mapping, admission-date mapping
Building admission-diagnosis mapping
-Number of null ICD9 codes in file mimic-iv-2.2/hosp/diagnoses_icd.csv.gz: 0
-Number of null ICD10 codes in file mimic-iv-2.2/hosp/diagnoses_icd.csv.gz: 0
Building admission-procedure mapping
-Number of null ICD9 codes in file mimic-iv-2.2/hosp/procedures_icd.csv.gz: 0
-Number of null ICD10 codes in file mimic-iv-2.2/hosp/procedures_icd.csv.gz: 0
Building admission-drug mapping

 Completed...

 Cleaning data...
Cleaning data...
Removing patient records who does not have all three medical codes for an admission
Removing patient who made less than 2 admissions
 Total Number of patients 19834
 Total Number of admissions 52957
 Average number of admissions per patient 2.6700110920641325
 Total Number of diagnosis code 14119
 Total Number of procedure code 7511
 Total Number of drug code 4891
 Total Number of codes 26521
 average Number of procedure code per visit 3.0819910