# Data Preprocessing for MIMIC III

In [2]:
import torch
import numpy as np
from helpers.loader import DataBaseLoader

loader = DataBaseLoader(user="mt361", password="tian01050417", dbname="mimic", schema="mimiciii")
SQL_SCRIPT = """
with first_table as (
SELECT
    ce.subject_id
    , ce.icustay_id
    , extract(epoch from (ce.charttime - FIRST_VALUE(ce.charttime) OVER(PARTITION BY ce.subject_id, ce.icustay_id ORDER BY ce.charttime)))/3600 AS time
	, ce.charttime
    , AVG(CASE WHEN itemid IN (220045)
            AND valuenum > 0
            AND valuenum < 300
            THEN valuenum END
    ) AS heart_rate
    , AVG(CASE WHEN itemid IN (220179, 220050, 225309)
            AND valuenum > 0
            AND valuenum < 400
            THEN valuenum END
    ) AS sbp
    , AVG(CASE WHEN itemid IN (220180, 220051, 225310)
                AND valuenum > 0
                AND valuenum < 300
                THEN valuenum END
    ) AS dbp
    , AVG(CASE WHEN itemid IN (220052, 220181, 225312)
                AND valuenum > 0
                AND valuenum < 300
                THEN valuenum END
    ) AS mbp
    , AVG(CASE WHEN itemid = 220179
                AND valuenum > 0
                AND valuenum < 400
                THEN valuenum END
    ) AS sbp_ni
    , AVG(CASE WHEN itemid = 220180
                AND valuenum > 0
                AND valuenum < 300
                THEN valuenum END
    ) AS dbp_ni
    , AVG(CASE WHEN itemid = 220181
                AND valuenum > 0
                AND valuenum < 300
                THEN valuenum END
    ) AS mbp_ni
    , AVG(CASE WHEN itemid IN (220210, 224690)
                AND valuenum > 0
                AND valuenum < 70
                THEN valuenum END
    ) AS resp_rate
    , ROUND(CAST(
            AVG(CASE
                -- converted to degC in valuenum call
                WHEN itemid IN (223761)
                    AND valuenum > 70
                    AND valuenum < 120
                    THEN (valuenum - 32) / 1.8
                -- already in degC, no conversion necessary
                WHEN itemid IN (223762)
                    AND valuenum > 10
                    AND valuenum < 50
                    THEN valuenum END)
            AS NUMERIC), 2) AS temperature
    , MAX(CASE WHEN itemid = 224642 THEN value END
    ) AS temperature_site
    , AVG(CASE WHEN itemid IN (220277)
                AND valuenum > 0
                AND valuenum <= 100
                THEN valuenum END
    ) AS spo2
    , AVG(CASE WHEN itemid IN (225664, 220621, 226537)
                AND valuenum > 0
                THEN valuenum END
    ) AS glucose
FROM chartevents ce
WHERE ce.icustay_id IS NOT NULL
    AND ce.itemid IN
    (
        220045 -- Heart Rate
        , 225309 -- ART BP Systolic
        , 225310 -- ART BP Diastolic
        , 225312 -- ART BP Mean
        , 220050 -- Arterial Blood Pressure systolic
        , 220051 -- Arterial Blood Pressure diastolic
        , 220052 -- Arterial Blood Pressure mean
        , 220179 -- Non Invasive Blood Pressure systolic
        , 220180 -- Non Invasive Blood Pressure diastolic
        , 220181 -- Non Invasive Blood Pressure mean
        , 220210 -- Respiratory Rate
        , 224690 -- Respiratory Rate (Total)
        , 220277 -- SPO2, peripheral
        -- GLUCOSE, both lab and fingerstick
        , 225664 -- Glucose finger stick
        , 220621 -- Glucose (serum)
        , 226537 -- Glucose (whole blood)
        -- TEMPERATURE
        -- 226329 -- Blood Temperature CCO (C)
        , 223762 -- "Temperature Celsius"
        , 223761  -- "Temperature Fahrenheit"
        , 224642 -- Temperature Site
    )
GROUP BY ce.subject_id, ce.icustay_id, ce.charttime
),
time_diffs AS (
    SELECT 
        *,
        ROW_NUMBER() OVER(PARTITION BY subject_id, icustay_id, ROUND(time) ORDER BY ABS(time - ROUND(time))) AS rn		-- find the measurement to the nearest hour
    FROM first_table
),
hours AS (
    SELECT generate_series(0,24) AS hr
),
all_hours AS (
    SELECT distinct icustay_id, hr
    FROM time_diffs, hours
)
SELECT 
    a.icustay_id,
    a.hr as time,
    t.heart_rate,
    t.sbp,
    t.dbp,
    t.mbp,
    t.resp_rate as resprate,
    t.temperature as temp,
    t.spo2
FROM all_hours a
LEFT JOIN time_diffs t ON a.icustay_id = t.icustay_id AND a.hr = ROUND(t.time)
WHERE t.rn = 1 OR t.rn IS NULL
"""

In [3]:
vitals_table = loader.query(SQL_SCRIPT)
vitals_table.head()

Unnamed: 0,icustay_id,time,heart_rate,sbp,dbp,mbp,resprate,temp,spo2
0,200001,0,115.0,,,,,,
1,200001,1,113.0,110.0,65.0,76.0,20.0,,97.0
2,200001,2,108.0,113.0,68.0,79.0,18.0,,98.0
3,200001,3,110.0,116.0,68.0,79.0,27.0,,98.0
4,200001,4,102.0,102.0,61.0,71.0,21.0,37.67,96.0


In [4]:
adm = loader["admissions"]
adm.head()

Unnamed: 0,row_id,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,insurance,language,religion,marital_status,ethnicity,edregtime,edouttime,diagnosis,hospital_expire_flag,has_chartevents_data
0,21,22,165315,2196-04-09 12:26:00,2196-04-10 15:54:00,NaT,EMERGENCY,EMERGENCY ROOM ADMIT,DISC-TRAN CANCER/CHLDRN H,Private,,UNOBTAINABLE,MARRIED,WHITE,2196-04-09 10:06:00,2196-04-09 13:24:00,BENZODIAZEPINE OVERDOSE,0,1
1,22,23,152223,2153-09-03 07:15:00,2153-09-08 19:10:00,NaT,ELECTIVE,PHYS REFERRAL/NORMAL DELI,HOME HEALTH CARE,Medicare,,CATHOLIC,MARRIED,WHITE,NaT,NaT,CORONARY ARTERY DISEASE\CORONARY ARTERY BYPASS...,0,1
2,23,23,124321,2157-10-18 19:34:00,2157-10-25 14:00:00,NaT,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME HEALTH CARE,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,NaT,NaT,BRAIN MASS,0,1
3,24,24,161859,2139-06-06 16:14:00,2139-06-09 12:48:00,NaT,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME,Private,,PROTESTANT QUAKER,SINGLE,WHITE,NaT,NaT,INTERIOR MYOCARDIAL INFARCTION,0,1
4,25,25,129635,2160-11-02 02:06:00,2160-11-05 14:55:00,NaT,EMERGENCY,EMERGENCY ROOM ADMIT,HOME,Private,,UNOBTAINABLE,MARRIED,WHITE,2160-11-02 01:01:00,2160-11-02 04:27:00,ACUTE CORONARY SYNDROME,0,1


In [5]:
unique_stay_id = vitals_table["icustay_id"].unique()
len(unique_stay_id)

26150

In [6]:
for id in unique_stay_id:
    if len(vitals_table[vitals_table["icustay_id"] == id]) != 25:
        print(f"actual length is {len(vitals_table[vitals_table['icustay_id'] == id])}")

In [8]:
from tqdm import tqdm
def turn_to_ndarray(values_df, ids, adm_df, icu_df):
    
    def add_mortality_flag(tmp_df, adm_df, id):
        tmp_df = tmp_df.copy()
        subject_id = icu_df[icu_df["icustay_id"] == id]["subject_id"].values[0]         # given icustay_id, find subject_id
        flag = adm_df[adm_df["subject_id"] == subject_id]["hospital_expire_flag"].values[0]
        flag_vec = np.ones((tmp_df.shape[0], 1)) * flag
        tmp_df["mortality"] = flag_vec
        return tmp_df
    
    def add_cumulative_nan_flag(tmp_array):
        result = []
        for idx in range(tmp_array.shape[1]):
            if idx <= tmp_array.shape[1] - 2:
                result.append(tmp_array[:, idx])
                result.append(np.cumsum(~np.isnan(tmp_array[:, idx]), axis=0))
            else:
                result.append(tmp_array[:, idx])
        return np.asarray(result)
        
    result = []
    for id in tqdm(ids, desc="Extracting..."):
        tmp_df = values_df[values_df["icustay_id"] == id]
        tmp_df = add_mortality_flag(tmp_df, adm_df, id)
        tmp_array = add_cumulative_nan_flag(tmp_df.drop(["icustay_id", "time"], axis=1).values)
        result.append(tmp_array)
    
    return np.asarray(result)

icu_df = loader["icustays"]
vitals_ndarray = turn_to_ndarray(vitals_table, unique_stay_id, adm, icu_df)
vitals_ndarray.shape

Extracting...: 100%|██████████| 26150/26150 [00:48<00:00, 536.99it/s]


(26150, 15, 25)

train test split

In [9]:
def train_test_split(array, train_percent=0.8):
    indices = np.arange(len(array))
    np.random.shuffle(indices)
    train, test = array[indices[:int(len(array) * train_percent)]], array[indices[int(len(array) * train_percent):]]
    return train, test

def check_class_imbalance(array):
    return np.sum(array[:,-1,:]) / (array.shape[0] * array.shape[1])

train, test = train_test_split(vitals_ndarray)
print(f"train shape: {train.shape}, test shape: {test.shape}")

train shape: (20920, 15, 25), test shape: (5230, 15, 25)


In [10]:
check_class_imbalance(train), check_class_imbalance(test)

(0.13304652644996814, 0.12938177182919056)

save data

In [11]:
import torch
np.save("data_pack/mimic-iii/train.npy", train)
train = torch.tensor(train)
torch.save(train, "data_pack/mimic-iii/train.pt")
np.save("data_pack/mimic-iii/test.npy", test)
test = torch.tensor(test)
torch.save(test, "data_pack/mimic-iii/test.pt")