In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, datediff, to_date, regexp_replace, when
from pyspark.sql.types import IntegerType

spark = SparkSession.builder \
    .appName("mimic-iii") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/27 19:31:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
hdfs_patients_path = "hdfs://namenode:9000/mimic/PATIENTS.csv"
hdfs_admissions_path = "hdfs://namenode:9000/mimic/ADMISSIONS.csv"
hdfs_diagnoses_path = "hdfs://namenode:9000/mimic/DIAGNOSES_ICD.csv"

df_patients = spark.read.csv(hdfs_patients_path, header=True, inferSchema=True,multiLine=True)
df_admissions = spark.read.csv(hdfs_admissions_path, header=True, inferSchema=True,multiLine=True)
df_diagnoses_icd = spark.read.csv(hdfs_diagnoses_path, header=True, inferSchema=True,multiLine=True)


df_admissions =  df_admissions.select("ROW_ID",
                                      "SUBJECT_ID",
                                      "HADM_ID",
                                      "ADMITTIME",
                                      "DISCHTIME",
                                      "DEATHTIME",
                                      "ADMISSION_TYPE",
                                      "ADMISSION_LOCATION")
# df_pa

                                                                                

In [3]:
df_patients.show()

+------+----------+------+---------------+---------------+---------------+--------------+-----------+
|ROW_ID|SUBJECT_ID|GENDER|            DOB|            DOD|       DOD_HOSP|       DOD_SSN|EXPIRE_FLAG|
+------+----------+------+---------------+---------------+---------------+--------------+-----------+
|   234|       249|     F| 3/13/2075 0:00|           null|           null|          null|          0|
|   235|       250|     F|12/27/2164 0:00|11/22/2188 0:00|11/22/2188 0:00|          null|          1|
|   236|       251|     M| 3/15/2090 0:00|           null|           null|          null|          0|
|   237|       252|     M|  3/6/2078 0:00|           null|           null|          null|          0|
|   238|       253|     F|11/26/2089 0:00|           null|           null|          null|          0|
|   239|       255|     M|  8/5/2109 0:00|           null|           null|          null|          0|
|   240|       256|     M| 7/31/2086 0:00|           null|           null|        

In [4]:
df_admissions.show()

+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|    21|        22| 165315|  4/9/2196 12:26| 4/10/2196 15:54|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    22|        23| 152223|   9/3/2153 7:15|  9/8/2153 19:10|           null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    23|        23| 124321|10/18/2157 19:34|10/25/2157 14:00|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    24|        24| 161859|  6/6/2139 16:14|  6/9/2139 12:48|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    25|        25| 129635|  11/2/2160 2:06| 11/5/2160 14:55|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    26|        26| 197661|  5/6/2126 15:16| 5/13/2126 15:00|           null|     EMERGE

In [5]:
df_diagnoses_icd.show()

+------+----------+-------+-------+---------+
|ROW_ID|SUBJECT_ID|HADM_ID|SEQ_NUM|ICD9_CODE|
+------+----------+-------+-------+---------+
|  1297|       109| 172335|      1|    40301|
|  1298|       109| 172335|      2|      486|
|  1299|       109| 172335|      3|    58281|
|  1300|       109| 172335|      4|     5855|
|  1301|       109| 172335|      5|     4254|
|  1302|       109| 172335|      6|     2762|
|  1303|       109| 172335|      7|     7100|
|  1304|       109| 172335|      8|     2767|
|  1305|       109| 172335|      9|     7243|
|  1306|       109| 172335|     10|    45829|
|  1307|       109| 172335|     11|     2875|
|  1308|       109| 172335|     12|    28521|
|  1309|       109| 172335|     13|    28529|
|  1310|       109| 172335|     14|    27541|
|  1311|       109| 173633|      1|    40301|
|  1312|       109| 173633|      2|     5856|
|  1313|       109| 173633|      3|    58381|
|  1314|       109| 173633|      4|     7100|
|  1315|       109| 173633|      5

In [6]:
unique_admission_types = df_admissions.select("admission_type").distinct()
unique_admission_types.show()

[Stage 9:>                                                          (0 + 1) / 1]

+--------------+
|admission_type|
+--------------+
|       NEWBORN|
|      ELECTIVE|
|     EMERGENCY|
|        URGENT|
+--------------+



                                                                                

In [7]:
df_emergency_admissions = df_admissions.filter(col("admission_type") == "EMERGENCY")
df_emergency_admissions.show()

+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|    21|        22| 165315|  4/9/2196 12:26| 4/10/2196 15:54|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    23|        23| 124321|10/18/2157 19:34|10/25/2157 14:00|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    24|        24| 161859|  6/6/2139 16:14|  6/9/2139 12:48|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    25|        25| 129635|  11/2/2160 2:06| 11/5/2160 14:55|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    26|        26| 197661|  5/6/2126 15:16| 5/13/2126 15:00|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    30|        31| 128652| 8/22/2108 23:27| 8/30/2108 15:00|8/30/2108 15:00|     EMERGE

In [8]:
df_emergency_admissions = df_admissions.filter(col("admission_type") == "NEWBORN")
df_emergency_admissions.show()

+------+----------+-------+----------------+----------------+---------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------+--------------+--------------------+
|    27|        27| 134931|11/30/2191 22:16| 12/3/2191 14:45|     null|       NEWBORN|PHYS REFERRAL/NOR...|
|    41|        39| 106266|11/29/2114 21:04| 12/9/2114 15:10|     null|       NEWBORN|PHYS REFERRAL/NOR...|
|   461|       358| 110872|10/24/2168 23:48| 10/29/2168 3:23|     null|       NEWBORN|PHYS REFERRAL/NOR...|
|   468|       363| 196503|  3/1/2176 15:26|  3/3/2176 14:04|     null|       NEWBORN|CLINIC REFERRAL/P...|
|    49|        50| 132761| 6/23/2112 19:40| 6/26/2112 10:15|     null|       NEWBORN|PHYS REFERRAL/NOR...|
|    50|        51| 196010|11/30/2128 10:28| 12/2/2128 12:35|     null|       NEWBORN|PHYS REFERRAL/NOR...|
|    53|        54| 138795| 

In [9]:
df_emergency_admissions = df_admissions.filter(col("admission_type") == "URGENT")
df_emergency_admissions.show()

+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|    29|        30| 104557|10/14/2172 14:17|10/19/2172 14:37|           null|        URGENT|TRANSFER FROM HOS...|
|    67|        67| 186474| 2/25/2155 12:45|  3/6/2155 15:00|           null|        URGENT|PHYS REFERRAL/NOR...|
|    84|        83| 158569|  4/1/2142 12:34|  4/8/2142 14:46|           null|        URGENT|TRANSFER FROM HOS...|
|   196|       146| 190707|12/19/2119 12:15| 1/10/2120 13:08|           null|        URGENT|TRANSFER FROM HOS...|
|   205|       154| 162891|  4/5/2118 18:11| 4/11/2118 14:21|           null|        URGENT|PHYS REFERRAL/NOR...|
|   447|       353| 159730| 6/15/2148 11:04|  7/4/2148 17:48|           null|        URG

In [10]:
df_emergency_admissions = df_admissions.filter(col("admission_type") == "ELECTIVE")
df_emergency_admissions.show()

+------+----------+-------+----------------+----------------+---------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------+--------------+--------------------+
|    22|        23| 152223|   9/3/2153 7:15|  9/8/2153 19:10|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    28|        28| 162569|   9/1/2177 7:15|  9/6/2177 16:00|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    31|        32| 175413|   4/4/2170 8:00| 4/23/2170 12:45|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    35|        35| 166707| 2/10/2122 11:15| 2/20/2122 15:30|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    38|        36| 165660| 5/10/2134 11:30| 5/20/2134 13:16|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    42|        41| 101757|12/31/2132 10:30| 1/27/2133 15:45|     null|      ELECTIVE|PHYS REFERRAL/NOR...|
|   475|       369| 145787| 

In [11]:
df_admissions.show()

+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|ROW_ID|SUBJECT_ID|HADM_ID|       ADMITTIME|       DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|
+------+----------+-------+----------------+----------------+---------------+--------------+--------------------+
|    21|        22| 165315|  4/9/2196 12:26| 4/10/2196 15:54|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    22|        23| 152223|   9/3/2153 7:15|  9/8/2153 19:10|           null|      ELECTIVE|PHYS REFERRAL/NOR...|
|    23|        23| 124321|10/18/2157 19:34|10/25/2157 14:00|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    24|        24| 161859|  6/6/2139 16:14|  6/9/2139 12:48|           null|     EMERGENCY|TRANSFER FROM HOS...|
|    25|        25| 129635|  11/2/2160 2:06| 11/5/2160 14:55|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|
|    26|        26| 197661|  5/6/2126 15:16| 5/13/2126 15:00|           null|     EMERGE

In [12]:
from pyspark.sql.functions import to_timestamp, datediff, col, expr

df_admissions_new = (
    df_admissions
    .withColumnRenamed("DISCHTIME", "admission_DISCHTIME")
    .withColumnRenamed("ADMITTIME", "admission_ADMITTIME")
    .withColumnRenamed("SUBJECT_ID", "admission_SUBJECT_ID")
    .withColumn(
        "admission_DISCHTIME",
        to_timestamp(col("admission_DISCHTIME"), "M/d/yyyy H:mm")
    )
    .withColumn(
        "admission_ADMITTIME",
        to_timestamp(col("admission_ADMITTIME"), "M/d/yyyy H:mm")
    )
)

df_joined_for_los = df_admissions_new.join(df_diagnoses_icd, on="HADM_ID", how="inner")

df_joined_for_los = (
    df_joined_for_los
    .withColumn(
        "length_of_stay",
        datediff(col("admission_DISCHTIME"), col("admission_ADMITTIME"))
    )
    .filter(col("length_of_stay").isNotNull() & (col("length_of_stay") >= 0))
)

df_joined_for_los.show()


                                                                                

+-------+------+--------------------+-------------------+-------------------+---------+--------------+--------------------+------+----------+-------+---------+--------------+
|HADM_ID|ROW_ID|admission_SUBJECT_ID|admission_ADMITTIME|admission_DISCHTIME|DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|ROW_ID|SUBJECT_ID|SEQ_NUM|ICD9_CODE|length_of_stay|
+-------+------+--------------------+-------------------+-------------------+---------+--------------+--------------------+------+----------+-------+---------+--------------+
| 172335|   128|                 109|2141-09-18 10:32:00|2141-09-24 13:53:00|     null|     EMERGENCY|EMERGENCY ROOM ADMIT|  1297|       109|      1|    40301|             6|
| 172335|   128|                 109|2141-09-18 10:32:00|2141-09-24 13:53:00|     null|     EMERGENCY|EMERGENCY ROOM ADMIT|  1298|       109|      2|      486|             6|
| 172335|   128|                 109|2141-09-18 10:32:00|2141-09-24 13:53:00|     null|     EMERGENCY|EMERGENCY ROOM ADMIT|  

In [13]:
from pyspark.sql.functions import expr

df_length_of_stay = df_joined_for_los.select(expr("percentile_approx(length_of_stay, 0.5) AS median_los"))
df_length_of_stay.show()


[Stage 20:>                                                         (0 + 1) / 1]

+----------+
|median_los|
+----------+
|         8|
+----------+



                                                                                

In [14]:
df_admissions_new = (
    df_admissions
    .withColumnRenamed("DISCHTIME", "admission_DISCHTIME")
    .withColumnRenamed("ADMITTIME", "admission_ADMITTIME")
    .withColumn(
        "admission_DISCHTIME",
        to_timestamp(col("admission_DISCHTIME"), "M/d/yyyy H:mm")
    )
    .withColumn(
        "admission_ADMITTIME",
        to_timestamp(col("admission_ADMITTIME"), "M/d/yyyy H:mm")
    )
)

df_joined_patients = (
    df_admissions_new.alias("adm")
    .join(
        df_patients.alias("pat"),
        on="SUBJECT_ID",
        how="inner"
    )
)
df_joined_patients.show()

+----------+------+-------+-------------------+-------------------+---------------+--------------+--------------------+------+------+-------------------+--------------+--------------+--------------+-----------+
|SUBJECT_ID|ROW_ID|HADM_ID|admission_ADMITTIME|admission_DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|ROW_ID|GENDER|                DOB|           DOD|      DOD_HOSP|       DOD_SSN|EXPIRE_FLAG|
+----------+------+-------+-------------------+-------------------+---------------+--------------+--------------------+------+------+-------------------+--------------+--------------+--------------+-----------+
|        22|    21| 165315|2196-04-09 12:26:00|2196-04-10 15:54:00|           null|     EMERGENCY|EMERGENCY ROOM ADMIT|    19|     F|      5/7/2131 0:00|          null|          null|          null|          0|
|        23|    22| 152223|2153-09-03 07:15:00|2153-09-08 19:10:00|           null|      ELECTIVE|PHYS REFERRAL/NOR...|    20|     M|     7/17/2082 0:00|   

In [15]:
from pyspark.sql.functions import datediff, floor

df_joined_patients_los = df_joined_patients.withColumn(
    "length_of_stay",
    datediff(col("admission_DISCHTIME"), col("admission_ADMITTIME"))
)

df_joined_patients_dob = (
    df_joined_patients_los
    .withColumn(
        "DOB_ts",
        to_timestamp(col("DOB"), "M/d/yyyy H:mm")
    )
)
# df_joined_patients_dob.select("DOB_ts").show()
df_joined_approx_age = df_joined_patients_dob.withColumn(
    "approx_age",
    (col("admission_ADMITTIME").cast("long") - col("DOB_ts").cast("long")) / (60*60*24*365.25)
)

df_joined_approx_age.select("admission_ADMITTIME", "pat.DOB", "approx_age").show(10, False)
print(df_joined_approx_age.columns)

+-------------------+-------------------+--------------------+
|admission_ADMITTIME|DOB                |approx_age          |
+-------------------+-------------------+--------------------+
|2196-04-09 12:26:00|5/7/2131 0:00      |64.92681192486121   |
|2153-09-03 07:15:00|7/17/2082 0:00     |71.13019050878394   |
|2157-10-18 19:34:00|7/17/2082 0:00     |75.25479884401855   |
|2139-06-06 16:14:00|5/31/2100 0:00     |39.016225568484295  |
|2160-11-02 02:06:00|11/21/2101 0:00    |58.94890485968515   |
|2126-05-06 15:16:00|5/4/2054 0:00      |72.00447942809339   |
|2191-11-30 22:16:00|11/30/2191 0:00    |0.002540117119172561|
|2177-09-01 07:15:00|4/15/2103 0:00     |74.38275724389688   |
|2172-10-14 14:17:00|1872-10-14 00:00:00|null                |
|2108-08-22 23:27:00|5/17/2036 0:00     |72.26550878393795   |
+-------------------+-------------------+--------------------+
only showing top 10 rows

['SUBJECT_ID', 'ROW_ID', 'HADM_ID', 'admission_ADMITTIME', 'admission_DISCHTIME', 'DEATHTIME

# Building ML Model for readmission prediction

In [16]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lead, unix_timestamp

w = Window.partitionBy("SUBJECT_ID").orderBy("admission_ADMITTIME")

df_joined = df_joined_approx_age.withColumn(
    "next_admission_time",
    lead("admission_ADMITTIME").over(w)
)

df_joined = df_joined.withColumn("admission_DISCHTIME_unix", unix_timestamp("admission_DISCHTIME"))
df_joined = df_joined.withColumn("next_admission_time_unix", unix_timestamp("next_admission_time"))

# Time difference in days between discharge and next admission
df_joined = df_joined.withColumn(
    "days_to_next_admit",
    (col("next_admission_time_unix") - col("admission_DISCHTIME_unix")) / (60*60*24)
)

df_joined = df_joined.withColumn(
    "readmission_label",
    when((col("days_to_next_admit") <= 30) & (col("days_to_next_admit") >= 0), 1).otherwise(0)
)


In [17]:
df_joined.show()

[Stage 28:>                                                         (0 + 1) / 1]

+----------+------+-------+-------------------+-------------------+---------------+--------------+--------------------+------+------+-------------------+---------------+---------------+--------------+-----------+--------------+-------------------+--------------------+-------------------+------------------------+------------------------+------------------+-----------------+
|SUBJECT_ID|ROW_ID|HADM_ID|admission_ADMITTIME|admission_DISCHTIME|      DEATHTIME|ADMISSION_TYPE|  ADMISSION_LOCATION|ROW_ID|GENDER|                DOB|            DOD|       DOD_HOSP|       DOD_SSN|EXPIRE_FLAG|length_of_stay|             DOB_ts|          approx_age|next_admission_time|admission_DISCHTIME_unix|next_admission_time_unix|days_to_next_admit|readmission_label|
+----------+------+-------+-------------------+-------------------+---------------+--------------+--------------------+------+------+-------------------+---------------+---------------+--------------+-----------+--------------+-------------------+-

                                                                                

In [18]:
from pyspark.sql.functions import when
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline

gender_indexer = StringIndexer(inputCol="GENDER", outputCol="gender_index")

admission_type_indexer = StringIndexer(inputCol="ADMISSION_TYPE", outputCol="admission_type_index")

assembler = VectorAssembler(
    inputCols=["approx_age", "gender_index", "admission_type_index"],
    outputCol="features"
)


In [19]:
from pyspark.ml.classification import LogisticRegression

print(df_joined.columns)
df_readmission = df_joined.filter(col("approx_age").isNotNull())
df_readmission = df_readmission.select("readmission_label", "GENDER", "ADMISSION_TYPE", "approx_age")
train_df, test_df = df_readmission.randomSplit([0.8, 0.2], seed=42)
train_df.show()

['SUBJECT_ID', 'ROW_ID', 'HADM_ID', 'admission_ADMITTIME', 'admission_DISCHTIME', 'DEATHTIME', 'ADMISSION_TYPE', 'ADMISSION_LOCATION', 'ROW_ID', 'GENDER', 'DOB', 'DOD', 'DOD_HOSP', 'DOD_SSN', 'EXPIRE_FLAG', 'length_of_stay', 'DOB_ts', 'approx_age', 'next_admission_time', 'admission_DISCHTIME_unix', 'next_admission_time_unix', 'days_to_next_admit', 'readmission_label']


[Stage 34:>                                                         (0 + 1) / 1]

+-----------------+------+--------------+--------------------+
|readmission_label|GENDER|ADMISSION_TYPE|          approx_age|
+-----------------+------+--------------+--------------------+
|                0|     F|      ELECTIVE|3.441326336603544E-4|
|                0|     F|      ELECTIVE|5.456688721575785E-4|
|                0|     F|      ELECTIVE|  17.421971252566735|
|                0|     F|      ELECTIVE|   18.93583162217659|
|                0|     F|      ELECTIVE|  19.302760666210357|
|                0|     F|      ELECTIVE|  20.206536618754278|
|                0|     F|      ELECTIVE|   20.70993041295916|
|                0|     F|      ELECTIVE|   20.78857517681953|
|                0|     F|      ELECTIVE|  21.151684538748192|
|                0|     F|      ELECTIVE|   21.31234314396532|
|                0|     F|      ELECTIVE|   21.88227241615332|
|                0|     F|      ELECTIVE|   22.15280629705681|
|                0|     F|      ELECTIVE|  23.015200775

                                                                                

In [20]:
lr = LogisticRegression(
    featuresCol="features",
    labelCol="readmission_label",
    maxIter=10
)

pipeline_readmission = Pipeline(stages=[gender_indexer, admission_type_indexer, assembler, lr])

readmission_model = pipeline_readmission.fit(train_df)

predictions = readmission_model.transform(test_df)
predictions.select("readmission_label", "prediction", "probability").show(10, truncate=False)

25/01/27 19:33:37 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
[Stage 76:>                                                         (0 + 1) / 1]

+-----------------+----------+-----------------------------------------+
|readmission_label|prediction|probability                              |
+-----------------+----------+-----------------------------------------+
|0                |0.0       |[0.972831467083865,0.02716853291613497]  |
|0                |0.0       |[0.9695058173166078,0.030494182683392212]|
|0                |0.0       |[0.9693011292222456,0.03069887077775435] |
|0                |0.0       |[0.9691216713027423,0.030878328697257706]|
|0                |0.0       |[0.968764757627207,0.03123524237279296]  |
|0                |0.0       |[0.9686901526002006,0.03130984739979936] |
|0                |0.0       |[0.9684883168853065,0.031511683114693545]|
|0                |0.0       |[0.9682038211244751,0.03179617887552488] |
|0                |0.0       |[0.9679662601467691,0.03203373985323088] |
|0                |0.0       |[0.9679205159188007,0.032079484081199316]|
+-----------------+----------+---------------------

                                                                                

In [21]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(
    labelCol="readmission_label",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

auc = evaluator.evaluate(predictions)
print("Area under ROC:", auc)

                                                                                

Area under ROC: 0.5444723964090548
