In [1]:
from typing import List

import numpy as np
import pandas as pd
from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset


class TimeFeature:
    def __init__(self):
        pass

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        pass

    def __repr__(self):
        return self.__class__.__name__ + "()"


class SecondOfMinute(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.second / 59.0 - 0.5


class MinuteOfHour(TimeFeature):
    """Minute of hour encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.minute / 59.0 - 0.5


class HourOfDay(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.hour / 23.0 - 0.5


class DayOfWeek(TimeFeature):
    """Hour of day encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return index.dayofweek / 6.0 - 0.5


class DayOfMonth(TimeFeature):
    """Day of month encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.day - 1) / 30.0 - 0.5


class DayOfYear(TimeFeature):
    """Day of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.dayofyear - 1) / 365.0 - 0.5


class MonthOfYear(TimeFeature):
    """Month of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.month - 1) / 11.0 - 0.5


class WeekOfYear(TimeFeature):
    """Week of year encoded as value between [-0.5, 0.5]"""

    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
        return (index.isocalendar().week - 1) / 52.0 - 0.5


def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
    """
    Returns a list of time features that will be appropriate for the given frequency string.
    Parameters
    ----------
    freq_str
        Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc.
    """

    features_by_offsets = {
        offsets.YearEnd: [],
        offsets.QuarterEnd: [MonthOfYear],
        offsets.MonthEnd: [MonthOfYear],
        offsets.Week: [DayOfMonth, WeekOfYear],
        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
        offsets.Minute: [
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
        offsets.Second: [
            SecondOfMinute,
            MinuteOfHour,
            HourOfDay,
            DayOfWeek,
            DayOfMonth,
            DayOfYear,
        ],
    }

    offset = to_offset(freq_str)

    for offset_type, feature_classes in features_by_offsets.items():
        if isinstance(offset, offset_type):
            return [cls() for cls in feature_classes]

    supported_freq_msg = f"""
    Unsupported frequency {freq_str}
    The following frequencies are supported:
        Y   - yearly
            alias: A
        M   - monthly
        W   - weekly
        D   - daily
        B   - business days
        H   - hourly
        T   - minutely
            alias: min
        S   - secondly
    """
    raise RuntimeError(supported_freq_msg)


def time_features(dates, freq='h'):
    return np.vstack([feat(dates) for feat in time_features_from_frequency_str(freq)])

In [2]:
import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler


data_path = './processed_data/'
# data_path = os.path.join(data_path, 'format_mimic4_ehr.csv')


def process_subset(data_path, origin_filename='format_mimic4_ehr.csv', y_feature='Outcome', flag="train", save=True, shuffle_patients=False):
    df_raw = pd.read_csv(os.path.join(data_path, origin_filename))
    # x_features = ["Sex","Age"]+["Diastolic blood pressure", "Fraction inspired oxygen", "Glucose",
    # "Heart Rate", "Height", "Mean blood pressure", "Oxygen saturation",
    # "Respiratory rate", "Systolic blood pressure", "Temperature", "Weight", "pH"] + ["Capillary refill rate", "Glascow coma scale eye opening",
    # "Glascow coma scale motor response", "Glascow coma scale total", "Glascow coma scale verbal response"]
    # x_features = df_raw.columns[df_raw.columns.str.contains('|'.join(x_features))].tolist()

    basic_records = ['PatientID', 'RecordTime',
                     'AdmissionTime', 'DischargeTime']
    target_features = ['Outcome', 'LOS', 'Readmission']
    # Sex and ICUType are binary features, others are continuous features
    demographic_features = ['Sex', 'Age']
    labtest_features = ['Capillary refill rate->0.0', 'Capillary refill rate->1.0',
                        'Glascow coma scale eye opening->To Pain',
                        'Glascow coma scale eye opening->3 To speech',
                        'Glascow coma scale eye opening->1 No Response',
                        'Glascow coma scale eye opening->4 Spontaneously',
                        'Glascow coma scale eye opening->None',
                        'Glascow coma scale eye opening->To Speech',
                        'Glascow coma scale eye opening->Spontaneously',
                        'Glascow coma scale eye opening->2 To pain',
                        'Glascow coma scale motor response->1 No Response',
                        'Glascow coma scale motor response->3 Abnorm flexion',
                        'Glascow coma scale motor response->Abnormal extension',
                        'Glascow coma scale motor response->No response',
                        'Glascow coma scale motor response->4 Flex-withdraws',
                        'Glascow coma scale motor response->Localizes Pain',
                        'Glascow coma scale motor response->Flex-withdraws',
                        'Glascow coma scale motor response->Obeys Commands',
                        'Glascow coma scale motor response->Abnormal Flexion',
                        'Glascow coma scale motor response->6 Obeys Commands',
                        'Glascow coma scale motor response->5 Localizes Pain',
                        'Glascow coma scale motor response->2 Abnorm extensn',
                        'Glascow coma scale total->11', 'Glascow coma scale total->10',
                        'Glascow coma scale total->13', 'Glascow coma scale total->12',
                        'Glascow coma scale total->15', 'Glascow coma scale total->14',
                        'Glascow coma scale total->3', 'Glascow coma scale total->5',
                        'Glascow coma scale total->4', 'Glascow coma scale total->7',
                        'Glascow coma scale total->6', 'Glascow coma scale total->9',
                        'Glascow coma scale total->8',
                        'Glascow coma scale verbal response->1 No Response',
                        'Glascow coma scale verbal response->No Response',
                        'Glascow coma scale verbal response->Confused',
                        'Glascow coma scale verbal response->Inappropriate Words',
                        'Glascow coma scale verbal response->Oriented',
                        'Glascow coma scale verbal response->No Response-ETT',
                        'Glascow coma scale verbal response->5 Oriented',
                        'Glascow coma scale verbal response->Incomprehensible sounds',
                        'Glascow coma scale verbal response->1.0 ET/Trach',
                        'Glascow coma scale verbal response->4 Confused',
                        'Glascow coma scale verbal response->2 Incomp sounds',
                        'Glascow coma scale verbal response->3 Inapprop words',
                        'Diastolic blood pressure', 'Fraction inspired oxygen', 'Glucose',
                        'Heart Rate', 'Height', 'Mean blood pressure', 'Oxygen saturation',
                        'Respiratory rate', 'Systolic blood pressure', 'Temperature', 'Weight',
                        'pH']

    # set the features
    x_features = demographic_features + labtest_features
    normalize_features = ['Age'] + ['Diastolic blood pressure', 'Fraction inspired oxygen', 'Glucose',
                                    'Heart Rate', 'Height', 'Mean blood pressure', 'Oxygen saturation',
                                    'Respiratory rate', 'Systolic blood pressure', 'Temperature', 'Weight',
                                    'pH'] + ['LOS']
    if y_feature != 'LOS':
        normalize_features.remove('LOS')
    nonormalize_features = list(set(x_features) - set(normalize_features))
    y_feature = [y_feature]
    time_feature_origin = ['DischargeTime']
    id_feature = ['PatientID']
    time_feature = ['Time']
    df_raw = df_raw[x_features + y_feature + time_feature_origin + id_feature]

    # group by patient
    df_raw['Patient'] = df_raw['PatientID'].apply(lambda x: x.split('_')[0])
    patients = df_raw['Patient'].unique()
    patients_filename = os.path.join(data_path, 'patients.txt')
    has_patients_file = os.path.exists(patients_filename)
    if shuffle_patients or (not has_patients_file):
        import random
        random.shuffle(patients)
        with open(patients_filename, 'w') as f:
            for item in patients:
                f.write("%s\n" % item)
        print(f'shuffle patients and save to file')
    else:
        with open(patients_filename, 'r') as f:
            patients = f.read().splitlines()
        print(f'read patients from file, len(patients): {len(patients)}')

    # seperate train, val, test 7, 1, 2
    border1s = [0, int(len(patients) * 0.7), int(len(patients) * 0.8)]
    border2s = [int(len(patients) * 0.7),
                int(len(patients) * 0.8), len(patients)]
    flag_id = {'train': 0, 'valid': 1, 'test': 2}
    border1 = border1s[flag_id[flag]]
    border2 = border2s[flag_id[flag]]
    patients = patients[border1:border2]

    train_data = df_raw[df_raw['Patient'].isin(patients)]
    train_data = train_data[x_features]
    scaler = StandardScaler()
    train_data_need_normalize = train_data[normalize_features].values
    scaler.fit(train_data_need_normalize)

    # group by PatientID
    df_data = df_raw[df_raw['Patient'].isin(patients)]
    datas = df_data.groupby('PatientID')
    # add time feature
    xs = []
    times = []
    ys = []
    lens = []
    masks = []
    for i, data in datas:
        time_0 = data[time_feature_origin[0]].values[-1]
        # times len=data_len and freq=1h
        time_add = pd.Timestamp(
            time_0) + pd.to_timedelta(np.arange(48) * 1, 'h')
        time_add = time_features(
            time_add, freq='h').transpose(1, 0)  # seq_len, 4
        # scale and fillna
        x_need_normalize = data[normalize_features]
        x_no_normalize = data[nonormalize_features]
        x_need_normalize = x_need_normalize.fillna(
            method='ffill').fillna(method='bfill').fillna(0).values
        x_no_normalize = x_no_normalize.fillna(
            method='ffill').fillna(method='bfill').fillna(0).values
        x_need_normalize = scaler.transform(x_need_normalize)
        x = np.concatenate([x_need_normalize, x_no_normalize], axis=1)
        # get the first 48 hours
        len_x_origin = len(x)
        x = x[:48]
        if len(x) < 48:
            x = np.concatenate([x, np.zeros((48-len(x), len(x[0])))], axis=0)
            mask = np.concatenate([np.ones(len(x)), np.zeros(48-len(x))], axis=0)
        else:
            mask = np.ones(48)
        y = data[y_feature].values[0]
        xs.append(x)
        ys.append(np.array([y]))
        lens.append(min(len_x_origin, 48))
        masks.append(mask)
        times.append(time_add)

    if save:
        import pickle
        pickle.dump(xs, open(os.path.join(data_path, f'{flag}_xs.pkl'), 'wb'))
        pickle.dump(ys, open(os.path.join(data_path, f'{flag}_ys.pkl'), 'wb'))
        pickle.dump(times, open(os.path.join(
            data_path, f'{flag}_times.pkl'), 'wb'))
        pickle.dump(lens, open(os.path.join(
            data_path, f'{flag}_lens.pkl'), 'wb'))
        pickle.dump(scaler, open(os.path.join(
            data_path, f'{flag}_scaler.pkl'), 'wb'))
        pickle.dump(masks, open(os.path.join(
            data_path, f'{flag}_masks.pkl'), 'wb'))
        
        print(f'save {flag} data')
    return xs, ys, times, lens, masks


res_train = process_subset(data_path, flag='train', save=True,
                           y_feature='Outcome', origin_filename='format_mimic4_ehr.csv')
res_val = process_subset(data_path, flag='valid', save=True,
                         y_feature='Outcome', origin_filename='format_mimic4_ehr.csv')
res_test = process_subset(data_path, flag='test', save=True,
                          y_feature='Outcome', origin_filename='format_mimic4_ehr.csv')

read patients from file, len(patients): 17612
save train data
read patients from file, len(patients): 17612
save valid data
read patients from file, len(patients): 17612
save test data


In [3]:
xs_test, ys_test, times_test, lens_test, masks_test = res_test

In [4]:
len(masks_test[0])
ys_test

[array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[1.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[1.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[1.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([[0.]]),
 array([