In [1]:
import os
import sys

In [2]:
os.environ["PYSPARK_PYTHON"]="C:\Spark\spark-3.3.2-bin-hadoop2\python"
os.environ["JAVA_HOME"] = "C:\Program Files\Java\jdk-18.0.2.1"
os.environ["SPARK_HOME"] = "C:\Spark\spark-3.3.2-bin-hadoop2"
os.environ["PYLIB"] = os.environ["SPARK_HOME"] + "/python/lib"
sys.path.insert(0, os.environ["PYLIB"] + "/py4j-0.10.9.5-src.zip")
sys.path.insert(0, os.environ["PYLIB"] + "/pyspark.zip") 

In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("r_eda").getOrCreate()
spark

In [4]:
data = spark.read.option("header","true").csv("../data/r_hosp_dataset.csv")
data.printSchema()

root
 |-- subject_id: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- dod: string (nullable = true)
 |-- hadm_id: string (nullable = true)
 |-- age: string (nullable = true)
 |-- so2: string (nullable = true)
 |-- po2: string (nullable = true)
 |-- pco2: string (nullable = true)
 |-- fio2: string (nullable = true)
 |-- aado2: string (nullable = true)
 |-- ph: string (nullable = true)
 |-- baseexcess: string (nullable = true)
 |-- bg2_bicarbonate: string (nullable = true)
 |-- totalco2: string (nullable = true)
 |-- bg2_hematocrit: string (nullable = true)
 |-- bg2_haemoglobin: string (nullable = true)
 |-- carboxyhemoglobin: string (nullable = true)
 |-- methemoglobin: string (nullable = true)
 |-- bg2_chloride: string (nullable = true)
 |-- bg2_calcium: string (nullable = true)
 |-- temperature: string (nullable = true)
 |-- bg2_potassium: string (nullable = true)
 |-- bg2_sodium: string (nullable = true)
 |-- bg2_lactate: string (nullable = true)
 |-- bg2_glucose:

In [5]:
data.count()

431231

In [6]:
data.describe().toPandas()

Unnamed: 0,summary,subject_id,gender,dod,hadm_id,age,so2,po2,pco2,fio2,...,amylase,bilirubin_total,bilirubin_direct,bilirubin_indirect,ck_cpk,ck_mb,ggt,ld_ldh,crp,charlson_comorbidity_index
0,count,431231.0,431231,106218,431231.0,431231.0,27945.0,62060.0,62050.0,16541.0,...,21768.0,148620.0,13409.0,13897.0,74281.0,65427.0,3717.0,92776.0,20994.0,431231.0
1,mean,15007664.969607936,,,25003138.264725868,59.26944843426296,92.0697298264448,197.9319529487593,48.15066881547139,64.99111299195937,...,109.67851892686512,1.4621228636791646,3.0113729584607327,1.732935165863137,727.7982525814139,11.967495070842311,250.6031746031746,389.837005259981,66.98933552443553,3.587021805018656
2,stddev,2877497.806271308,,,2888180.402945388,19.207987665245223,12.191322751514052,136.2990225302963,14.732684122524285,24.739972885601127,...,216.2017369997382,3.557465027266691,5.138052210796493,2.325996291676031,6729.51889196154,38.24575537849562,386.639868567945,1562.3742701386582,74.66665896488824,3.031824233136227
3,min,10000032.0,F,2110-01-25,20000019.0,100.0098843901973,100.0,0.0,10.0,100.0,...,10.0,0.1,0.1,0.1,10.0,1.0,10.0,100.0,0.1,0.0
4,max,19999987.0,M,2212-01-22,29999928.0,99.96622041763482,99.9,99.0,99.0,99.0,...,996.0,9.9,9.9,9.9,999.0,99.0,999.0,9990.0,99.9,9.0


- All the rows has subject, gender, age, admission id (hadm_id) and charlson_comorbidity_index.


To Do
- Check for valid gender values and remove invalid genders (those other than 'M' and 'F').
- Check for valid age values.
- impute missing values by calculating the average of the value between the last & next reading for the same subject.
- Remove columns with more than 50% missing values after imputation.
- remove rows which has data in less than 25% of the columns.

#### Check for valid gender values and remove invalid genders (those other than 'M' and 'F').

In [7]:
#checking for gender values
data.groupby("gender").count().show()

+------+------+
|gender| count|
+------+------+
|     F|224990|
|     M|206241|
+------+------+



Gender column values looks good.

#### type conversion for the columns data

Except for Gender, all other columns contain decimal values. Hence converting every column type to decimal with 2 decimal values.

In [8]:
# converting all number columns to double of precision 2 except for gender
import pyspark.sql.functions as F

cols = data.columns
cols.remove("gender")
data = data.select(*(F.round(F.col(c).cast("double"), 2).alias(c) for c in cols), "gender")
print(data.columns)
print(data.printSchema())
data.describe().toPandas()

['subject_id', 'dod', 'hadm_id', 'age', 'so2', 'po2', 'pco2', 'fio2', 'aado2', 'ph', 'baseexcess', 'bg2_bicarbonate', 'totalco2', 'bg2_hematocrit', 'bg2_haemoglobin', 'carboxyhemoglobin', 'methemoglobin', 'bg2_chloride', 'bg2_calcium', 'temperature', 'bg2_potassium', 'bg2_sodium', 'bg2_lactate', 'bg2_glucose', 'db_wbc', 'basophils_abs', 'eosinophils_abs', 'lymphocytes_abs', 'monocytes_abs', 'neutrophils_abs', 'basophils', 'eosinophils', 'lymphocytes', 'monocytes', 'neutrophils', 'atypical_lymphocytes', 'bands', 'immature_granulocytes', 'metamyelocytes', 'nrbc', 'troponin_t', 'cm_ck_mb', 'ntprobnp', 'albumin', 'globulin', 'total_protein', 'aniongap', 'bicarbonate', 'bun', 'calcium', 'chloride', 'creatinine', 'glucose', 'sodium', 'potassium', 'd_dimer', 'fibrinogen', 'thrombin', 'inr', 'pt', 'ptt', 'hematocrit', 'hemoglobin', 'mch', 'mchc', 'mcv', 'platelet', 'rbc', 'rdw', 'rdwsd', 'wbc', 'scr_min', 'ckd', 'mdrd_est', 'scr_baseline', 'alt', 'alp', 'ast', 'amylase', 'bilirubin_total', 'bi

Unnamed: 0,summary,subject_id,dod,hadm_id,age,so2,po2,pco2,fio2,aado2,...,bilirubin_total,bilirubin_direct,bilirubin_indirect,ck_cpk,ck_mb,ggt,ld_ldh,crp,charlson_comorbidity_index,gender
0,count,431231.0,0.0,431231.0,431231.0,27945.0,62060.0,62050.0,16541.0,5437.0,...,148620.0,13409.0,13897.0,74281.0,65427.0,3717.0,92776.0,20994.0,431231.0,431231
1,mean,15007664.969607936,,25003138.264725868,59.2694490655819,92.0697298264448,197.9319529487593,48.15066881547139,64.99111299195937,492.7265256575317,...,1.4621228636791646,3.0113729584607327,1.732935165863137,727.7982525814139,11.967495070842311,250.6031746031746,389.837005259981,66.98933552443553,3.587021805018656,
2,stddev,2877497.806271308,,2888180.402945388,19.20798714444107,12.191322751514052,136.2990225302963,14.732684122524285,24.739972885601127,122.97508532137482,...,3.557465027266691,5.138052210796493,2.325996291676031,6729.51889196154,38.24575537849562,386.639868567945,1562.3742701386582,74.66665896488824,3.031824233136227,
3,min,10000032.0,,20000019.0,18.0,7.0,0.0,10.0,21.0,0.12,...,0.1,0.1,0.1,4.0,1.0,3.0,31.0,0.1,0.0,F
4,max,19999987.0,,29999928.0,103.17,100.0,4242.0,246.0,100.0,743.0,...,87.2,68.0,29.6,591950.0,673.0,7380.0,377000.0,608.1,20.0,M


- The min and max age values looks good.

#### Impute missing values - Calculating the average of the values between the prev & next admission reading for the same subject.

In [9]:
# sample for verification
data.filter((data.subject_id == 10040025)).select("subject_id","age","basophils_abs", "db_wbc", "platelet", "ast").orderBy("age").show()

+-----------+-----+-------------+------+--------+----+
| subject_id|  age|basophils_abs|db_wbc|platelet| ast|
+-----------+-----+-------------+------+--------+----+
|1.0040025E7|64.21|         null|  null|   202.0|null|
|1.0040025E7|66.51|         null|  null|    null|null|
|1.0040025E7|66.56|         null|   8.9|   275.0|null|
|1.0040025E7| 66.8|         null|   6.2|   258.0|null|
|1.0040025E7|68.46|         null|  12.6|   281.0|null|
|1.0040025E7|68.59|         0.01|  11.3|   307.0|null|
|1.0040025E7|68.86|         null|   9.8|   275.0|null|
|1.0040025E7|68.93|         null|   8.6|   323.0|null|
|1.0040025E7|68.99|         null|  12.6|   366.0|null|
|1.0040025E7|69.06|         0.11|  16.7|   349.0|16.0|
+-----------+-----+-------------+------+--------+----+



In [10]:
# Before impute
data.toPandas().to_csv("../data/EDA/before_imputation_all.csv")

In [11]:
# Use window function to impute missing values.
# https://sqlrelease.com/get-the-first-non-null-value-per-group-spark-dataframe
from pyspark.sql.window import Window
subject_win_prev = Window.partitionBy("subject_id").orderBy(F.desc("age")).rowsBetween(Window.currentRow+1,Window.unboundedFollowing)
subject_win_next = Window.partitionBy("subject_id").orderBy("age").rowsBetween(Window.currentRow+1,Window.unboundedFollowing)
#wi_next = Window.partitionBy("subject_id").orderBy(F.desc("age"))
for c in data.columns:
    if (c not in ("subject_id", "age", "gender", "hadm_id", "charlson_comorbidity_index")):
            data = data.withColumn('temp_' + c + 'prev', F.first(c, ignorenulls = True).over(subject_win_prev)) \
            .withColumn('temp_' + c + 'next', F.first(c, ignorenulls = True).over(subject_win_next)) \
            .withColumn(c , F.when(F.col(c).isNotNull() ,F.col(c)) \
                        .when(F.col(c).isNull() & F.col('temp_' + c + 'prev').isNull(), F.col('temp_' + c + 'next')) \
                        .when(F.col(c).isNull() & F.col('temp_' + c + 'next').isNull(), F.col('temp_' + c + 'prev')) \
                        .otherwise(((F.col('temp_' + c + 'prev') + F.col('temp_' + c + 'next'))/2))) \
            .drop('temp_' + c + 'prev', 'temp_' + c + 'next')

In [13]:
# same sample for verification
data.filter((data.subject_id == 10040025)).select("subject_id","age","basophils_abs", "db_wbc", "platelet", "ast").orderBy("age").show()

+-----------+-----+-------------+------+--------+----+
| subject_id|  age|basophils_abs|db_wbc|platelet| ast|
+-----------+-----+-------------+------+--------+----+
|1.0040025E7|64.21|         0.01|   8.9|   202.0|16.0|
|1.0040025E7|66.51|         0.01|   8.9|   238.5|16.0|
|1.0040025E7|66.56|         0.01|   8.9|   275.0|16.0|
|1.0040025E7| 66.8|         0.01|   6.2|   258.0|16.0|
|1.0040025E7|68.46|         0.01|  12.6|   281.0|16.0|
|1.0040025E7|68.59|         0.01|  11.3|   307.0|16.0|
|1.0040025E7|68.86|         0.06|   9.8|   275.0|16.0|
|1.0040025E7|68.93|         0.06|   8.6|   323.0|16.0|
|1.0040025E7|68.99|         0.06|  12.6|   366.0|16.0|
|1.0040025E7|69.06|         0.11|  16.7|   349.0|16.0|
+-----------+-----+-------------+------+--------+----+



In [13]:
data.toPandas().to_csv("../data/EDA/after_imputation_all.csv")

Py4JJavaError: An error occurred while calling o3388.collectToPython.
: org.apache.spark.SparkException: Job 12 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1(DAGScheduler.scala:1188)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1$adapted(DAGScheduler.scala:1186)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:79)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:1186)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:2887)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2784)
	at org.apache.spark.SparkContext.$anonfun$stop$11(SparkContext.scala:2105)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1484)
	at org.apache.spark.SparkContext.stop(SparkContext.scala:2105)
	at org.apache.spark.SparkContext.$anonfun$new$35(SparkContext.scala:670)
	at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:214)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$2(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2066)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$1(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:178)
	at org.apache.hadoop.util.ShutdownHookManager$1.run(ShutdownHookManager.java:54)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2238)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2259)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2278)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2303)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1021)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1020)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:424)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:348)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:376)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:348)
	at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:3688)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:3858)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:510)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3856)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3856)
	at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:3685)
	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
	at java.base/java.lang.reflect.Method.invoke(Method.java:577)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)


#### Remove columns with more than 50% missing values

In [None]:
dt = data.describe().toPandas()
dt

In [None]:
# identify and get columns with < 50% missing values
dt_t = dt.T
dt_t.columns = dt_t.iloc[0]
dt_t.drop(dt_t.index[0], inplace=True)
#dt_t.drop(dt_t.index[89], inplace=True)
dt_t["count"] = dt_t['count'].astype(int)
dt_t['missing_percentage'] = 100 - (dt_t["count"] / data.count())*100
print(dt_t)
dt_t = dt_t[dt_t["missing_percentage"] < 50]
print(dt_t)
print(dt_t.shape)

After removing all the columns with missing value % >= 50, we get remaining 45 feature fields which are as follows.

In [None]:
print(dt_t.shape)
dt_t.index

In [None]:
# Filter out only the identified column data from the data
data = data.select(dt_t.index.values.tolist())

In [None]:
# round off age by 2 decimal point
print(data.printSchema())
data.describe().toPandas()

No columns removed.

#### retain rows which has missing values less than 25% of the columns.

In [None]:
# calculate missing percentage for every row
from operator import add
from functools import reduce
for c in data2.columns:
    if ('missing_' not in c) and (c not in ("subject_id", "age", "gender", "hadm_id", "charlson_comorbidity_index")):
            data = data.withColumn('missing_' + c, F.when(F.col(c).isNull(), 1).otherwise(0))

data = data.withColumn('missing_percentage', (reduce(add, [F.col(x) for x in data.columns if "missing_" in x])/(len(data.columns)-5))*100)
data.toPandas()

In [None]:
data.groupBy("missing_percentage").count().toPandas()

In [None]:
data3 = data.filter("missing_percentage < 25")

In [None]:
print(data.count())
print(data3.count())


None of the rows has missing % >= 25

In [None]:
print(len(data3.columns))
condition = lambda x: ("missing_" in x)
data3 = data3.drop(*filter(condition, data3.columns))
len(data3.columns)

In [None]:
data3.toPandas().to_csv("../data/EDA/after_eda_all.csv")

In [None]:
data4 = spark.read.option("header","true").csv("../data/EDA/after_eda_all.csv")
data4.printSchema()

In [None]:
# get the first admission readings and last admission co-morbidity index value
from pyspark.sql.window import Window
import pyspark.sql.functions as F
subject_win = Window.partitionBy("subject_id").orderBy(("age"))
base_data = data4.withColumn("row",F.row_number().over(subject_win)) \
  .filter(F.col("row") == 1).drop("row", "charlson_comorbidity_index")

In [None]:
base_data.count()

In [None]:
subject_win_predict = Window.partitionBy("subject_id").orderBy(F.desc("age"))
base_data_predict = data4.withColumn("row",F.row_number().over(subject_win_predict)) \
  .filter(F.col("row") == 1).select("subject_id", "charlson_comorbidity_index")
base_data_predict.count()

In [None]:
print(base_data.columns)
print(base_data_predict.columns)

In [None]:
base_data.toPandas()

In [None]:
base_data = base_data.drop("_c0")

In [None]:
base_data.toPandas()

In [None]:
base_data_predict.toPandas()

In [None]:
base_data.toPandas().to_csv("../data/EDA/clustering_data_all.csv")
base_data_predict.toPandas().to_csv("../data/EDA/prediction_value_all.csv")