# Extracting patients' information from Cerner database - Normal and ESKD groups*

### * This code belongs to the paper "Early Prediction of End Stage Kidney Disease Based on Cumulative Estimated Glomerular Filtration Rate Velocity"

In [None]:
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.types import *
import time
import pandas as pd
import numpy as np
import pyspark.sql.functions as f
from pyspark.sql.window import Window
import socket    
hostname = socket.gethostname()    
IPAddr = socket.gethostbyname(hostname)  

#conf = SparkConf()
conf = SparkConf().setAll([("spark.executor.instances", '5'), ('spark.executor.memory', '8g'), ('spark.executor.cores', '5'), ('spark.driver.memory','4g'),('spark.sql.broadcastTimeout', '3000')])
conf.setMaster('yarn')
conf.setAppName('spark-yarn-2')
#conf.set("spark.driver.host", '10.42.7.162') #Change it accordingly based on your host ip 
#address. Open a terminal and use "cat /etc/hosts", the last line is the host ip and the host name.
conf.set("spark.driver.host", IPAddr)#Change it accordingly based on your host ip address

In [None]:
spark = SparkSession.builder.config(conf=conf).enableHiveSupport().getOrCreate()

In [None]:
# Pulling out the groups from Cerner database

## Adult patient full data (with sCr level)
data_pool = spark.sql("select P.patient_sk, L.lab_drawn_dt_tm as Date, P.race as Race, P.gender as Gender, E.age_in_years as Age, L.numeric_result as sCr_level\
                                      from cerner.orc_hf_d_patient P \
                                      join cerner.orc_hf_f_encounter E on E.patient_id = P.patient_id\
                                      join cerner.orc_hf_f_lab_procedure L on L.encounter_id = E.encounter_id\
                                      where L.detail_lab_procedure_id ='13.0' and L.numeric_result >= 0 and L.numeric_result <= 1000\
                                      and L.numeric_result is not null\
                                      and E.age_in_years >= '18'\
                                      and L.lab_drawn_dt_tm is not null")
                              
data_pool.persist()
data_pool.cache()
#labs_patients_detailed.take(10)

#---------------------------------------------------------------------------------------------------------------------

#assuming pat is the initial spark dataframe consisting all information

k_female = 0.7
k_male = 0.9
alpha_male = -0.411
alpha_female = -0.329
alpha_fixed = -1.209
age_factor = 0.993

# min(SCr/κ, 1)

data_pool=data_pool.withColumn("new_sCr", f.when(f.col('sCr_level') >= f.lit(60), f.col('sCr_level')*f.lit(0.01132))
.otherwise(f.col('sCr_level')))

data_pool=data_pool.withColumn("min(SCr/κ, 1)", f.when(f.col('Gender') =='Female', f.least(f.col('new_sCr')/k_female, f.lit(1)))
.otherwise(f.least(f.col('new_sCr')/k_male, f.lit(1))))

# min(SCr/κ, 1)^α
data_pool=data_pool.withColumn("pow(min(SCr/κ, 1),alpha)", f.when(f.col('Gender') =='Female', f.pow(f.col('min(SCr/κ, 1)'), f.lit(alpha_female)))
.otherwise(f.pow(f.col('min(SCr/κ, 1)'), f.lit(alpha_male))))

# max(SCr/κ, 1)
data_pool=data_pool.withColumn("max(SCr/κ, 1)", f.when(f.col('Gender') =='Female', f.greatest(f.col('new_sCr')/k_female, f.lit(1)))
.otherwise(f.greatest(f.col('new_sCr')/k_male, f.lit(1))))

# max(SCr /κ, 1)^-1.209
data_pool=data_pool.withColumn("pow(max(SCr/κ, 1),alpha)", f.pow(f.col('max(SCr/κ, 1)'), f.lit(alpha_fixed)))


# 0.993^Age
data_pool=data_pool.withColumn("pow_age", f.pow(f.lit(age_factor),f.col('Age')))


# Intermediate egfrs : 141 x min(SCr/κ, 1)^α x max(SCr /κ, 1)^-1.209 x 0.993^Age
data_pool=data_pool.withColumn("Intermediate_egfr_1", 141 * f.col('pow(min(SCr/κ, 1),alpha)'))
data_pool=data_pool.withColumn("Intermediate_egfr_2", f.col('Intermediate_egfr_1') * f.col('pow(max(SCr/κ, 1),alpha)'))
data_pool=data_pool.withColumn("Intermediate_egfr_3", f.col('Intermediate_egfr_2') * f.col("pow_age"))


# Intermediate egfrs : 141 x min(SCr/κ, 1)^α x max(SCr /κ, 1)^-1.209 x 0.993^Age * 1.018[female]
data_pool=data_pool.withColumn("Intermediate_egfr_4", f.when(f.col('Gender') =='Female', f.col('Intermediate_egfr_3')* 1.018)
                   .otherwise(f.col('Intermediate_egfr_3')))


#Final egfrs

data_pool=data_pool.withColumn("eGFR_EPI", f.when(f.col('Race') =='Black', f.col('Intermediate_egfr_4')* 1.159)
                   .otherwise(f.col('Intermediate_egfr_4')))

#---------------------------------------------------------------------------------------------------------------------

data_pool = data_pool.drop("sCr_level")
data_pool = data_pool.drop("max(SCr/κ, 1)")
data_pool = data_pool.drop("min(SCr/κ, 1)")
data_pool = data_pool.drop("pow(min(SCr/κ, 1),alpha)")
data_pool = data_pool.drop("pow(max(SCr/κ, 1),alpha)")
data_pool = data_pool.drop("pow_age")
data_pool = data_pool.drop("Intermediate_egfr_1")
data_pool = data_pool.drop("Intermediate_egfr_2")
data_pool = data_pool.drop("Intermediate_egfr_3")
data_pool = data_pool.drop("Intermediate_egfr_4")

#---------------------------------------------------------------------------------------------------------------------

data_pool = data_pool.withColumn("eGFR_bin", 
                                 f.when(f.col('eGFR_EPI') < 15, "[0, 15)-Stage5")
                                 .when((f.col('eGFR_EPI') >= 15) & (f.col('eGFR_EPI') < 30), "[15, 30)-Stage4")
                                 .when((f.col('eGFR_EPI') >= 30) & (f.col('eGFR_EPI') < 45), "[30,45)-Stage3B")
                                 .when((f.col('eGFR_EPI') >= 45) & (f.col('eGFR_EPI') < 60), "[45, 60)-Stage3A")
                                 .when((f.col('eGFR_EPI') >= 60) & (f.col('eGFR_EPI') < 90), "[60, 90)-Stage2")
                                 .when((f.col('eGFR_EPI') >= 90) & (f.col('eGFR_EPI') < 500), "[90, ..)-Normal")
                                 .otherwise(None))                  
#data_test = data_test.withColumn("eGFR_bin_5", f.when(f.col('eGFR_EPI') < 15, "[0, 15)-Stage5").otherwise(None)) "[45, 
#data_test = data_test.withColumn("eGFR_bin_4", f.when((f.col('eGFR_EPI') >= 15)&  (f.col('eGFR_EPI') < 30), "[15, 30)-Stage4").otherwise(None))

#---------------------------------------------------------------------------------------------------------------------

#9 or more datapoints
patients_9more = data_pool.groupBy(f.col('patient_sk')).agg(f.count(f.when((f.col('eGFR_bin') == "[15, 30)-Stage4")|(f.col('eGFR_bin') == "[15, 30)-Stage4")|(f.col('eGFR_bin') == "[30,45)-Stage3B")|(f.col('eGFR_bin') == "[45,60)-Stage3B")|(f.col('eGFR_bin') == "[60, 90)-Stage2"),True)))
patients_9more = patients_9more.filter(patients_9more['count(CASE WHEN (((((eGFR_bin = [15, 30)-Stage4) OR (eGFR_bin = [15, 30)-Stage4)) OR (eGFR_bin = [30,45)-Stage3B)) OR (eGFR_bin = [45,60)-Stage3B)) OR (eGFR_bin = [60, 90)-Stage2)) THEN true END)'] >= (9))
patients_9more = patients_9more.drop("count(CASE WHEN (((((eGFR_bin = [15, 30)-Stage4) OR (eGFR_bin = [15, 30)-Stage4)) OR (eGFR_bin = [30,45)-Stage3B)) OR (eGFR_bin = [45,60)-Stage3B)) OR (eGFR_bin = [60, 90)-Stage2)) THEN true END)")

#---------------------------------------------------------------------------------------------------------------------

data_pool_9more = data_pool.join(patients_9more, on = ['patient_sk'] , how = 'inner')
data_pool_9more.cache()

#---------------------------------------------------------------------------------------------------------------------

### Now, the second criterion (makeing sure this is not acute kidney disease)

data_test = data_pool_9more.filter((f.col('eGFR_bin') == "[15, 30)-Stage4")|(f.col('eGFR_bin') == "[15, 30)-Stage4")|(f.col('eGFR_bin') == "[30,45)-Stage3B")|(f.col('eGFR_bin') == "[45,60)-Stage3B")|(f.col('eGFR_bin') == "[60, 90)-Stage2"))

patients_max_date = data_test.groupBy(f.col('patient_sk')).agg(f.max(f.col('Date')))
patients_min_date = data_test.groupBy(f.col('patient_sk')).agg(f.min(f.col('Date')))

timeFmt = "YY-mm-dd HH:MM:SS"
patients_not_acute = patients_max_date.join(patients_min_date, on = ['patient_sk']).withColumn("Duration", f.unix_timestamp('max(Date)', format=timeFmt) - f.unix_timestamp('min(Date)',format=timeFmt))
patients_not_acute = patients_not_acute.filter(f.col('Duration') > 7776000)

#---------------------------------------------------------------------------------------------------------------------

#Duration
patients_not_acute = patients_not_acute.drop(f.col('max(Date)'))
patients_not_acute = patients_not_acute.drop(f.col('min(Date)'))

#patients with more than 3 months of data points (not accute)
data_pool_9more_chronic = data_pool_9more.join(patients_not_acute.select(patients_not_acute['patient_sk']), on = ['patient_sk'] , how = 'inner')
data_pool_9more_chronic.cache()

data_pool_9more_chronic_sorted = data_pool_9more_chronic.orderBy('patient_sk', 'Date')

#---------------------------------------------------------------------------------------------------------------------

patients_60to90 = data_pool_9more_chronic_sorted.groupBy(f.col('patient_sk')).agg(f.count(f.when(f.col('eGFR_bin') == "[60, 90)-Stage2", True)).alias('count'))
patients_60to90 = patients_60to90.filter(patients_60to90['count']>=3)
patients_60to90.drop(patients_60to90['count'])

data_pool_9more_chronic_sorted_60to90 = data_pool_9more_chronic_sorted.join(patients_60to90.select('patient_sk'), on=['patient_sk'], how='inner')


#---------------------------------------------------------------------------------------------------------------------

# ICD9 + ICD 10
#Patients With ESRD, Dialysis, Stage 5 CKD

Patients_diagnosed_ESRD = spark.sql("select distinct P.patient_sk\
                          from cerner.orc_hf_d_diagnosis Dd \
                          join cerner.orc_hf_f_diagnosis Df on Dd.diagnosis_id = Df.diagnosis_id\
                          join cerner.orc_hf_f_encounter E on Df.encounter_id = E.encounter_id\
                          join cerner.orc_hf_d_patient P on E.patient_id = P.patient_id\
                          where Dd.diagnosis_code in ('585.6', 'N18.6')\
                          and E.age_in_years >= '18'")

Patients_diagnosed_ESRD.cache()
Patients_diagnosed_ESRD = Patients_diagnosed_ESRD.dropDuplicates()

#---------------------------------------------------------------------------------------------------------------------

# ICD9 + ICD 10
#Patients With ANY diagnosis regrading CKD (different stages or transplent) or Kidney related death (Death caused by Kidney Disease (Nephritis, Nephrotic Syndrome, Nephrosis))

Patients_diagnosed_any = spark.sql("select distinct P.patient_sk\
                          from cerner.orc_hf_d_diagnosis Dd \
                          join cerner.orc_hf_f_diagnosis Df on Dd.diagnosis_id = Df.diagnosis_id\
                          join cerner.orc_hf_f_encounter E on Df.encounter_id = E.encounter_id\
                          join cerner.orc_hf_d_patient P on E.patient_id = P.patient_id\
                          where Dd.diagnosis_code in ('585.6', '585.5', 'V42.0', 'N18.6', 'Z94.0', 'N18.1', 'N18.2', 'N18.3', 'N18.4', 'N18.5', 'N00', 'N01', 'N02', 'N03', 'N04', 'N05','N06','N07', 'N17', 'N18', 'N19', 'N25', 'N26', 'N27')\
                          and E.age_in_years >= '18'")

Patients_diagnosed_any.cache()
Patients_diagnosed_any = Patients_diagnosed_any.dropDuplicates()

#---------------------------------------------------------------------------------------------------------------------

#Now, splitting up into two groups, again

#Normal_group_done = data_pool_9more_chronic_sorted_60to90_age50to80.join(Patients_diagnosed_any.select('patient_sk'), on = ['patient_sk'], how="leftanti")
Normal_group_done = data_pool_9more_chronic_sorted_60to90.join(Patients_diagnosed_any.select('patient_sk'), on = ['patient_sk'], how="leftanti")
Normal_group_done.cache()

#ESRD_group_done = data_pool_9more_chronic_sorted_60to90_age50to80.join(Patients_diagnosed_ESRD.select('patient_sk'), on = ['patient_sk'], how="inner")
ESRD_group_done = data_pool_9more_chronic_sorted_60to90.join(Patients_diagnosed_ESRD.select('patient_sk'), on = ['patient_sk'], how="inner")
ESRD_group_done.cache()

#---------------------------------------------------------------------------------------------------------------------

# The same aim, but now dropping based on eGFR values (first observation below 20)

patients_ESRD_eGFR_min_date = ESRD_group_done.groupBy('patient_sk').agg(f.min(f.column('Date')).alias('min_eGFR_date'))
ESRD_group_done = ESRD_group_done.join(patients_ESRD_eGFR_min_date, on=['patient_sk'], how='left')

ESRD_group_done_first_eGFR = ESRD_group_done.withColumn('first_eGFR', f.when((f.col('min_eGFR_date') == f.col('Date')) , f.col('eGFR_EPI')).otherwise(f.lit(10000)))
patients_ESRD_group_done_above_60 =  ESRD_group_done_first_eGFR.filter(f.col('first_eGFR')<1000)
patients_ESRD_group_done_above_60 = patients_ESRD_group_done_above_60.filter(f.col('first_eGFR')>=60)
ESRD_group_done_above_60 = ESRD_group_done.join(patients_ESRD_group_done_above_60.select('patient_sk'), on=['patient_sk'], how="inner")

# The same aim, but now Normal group above 50

patients_Normal_group_done_above_60 =  Normal_group_done.groupBy('patient_sk').agg(f.min(f.column('eGFR_EPI')).alias('min_eGFR'))
patients_Normal_group_done_above_60 = patients_Normal_group_done_above_60.filter(f.col('min_eGFR')>=60)
Normal_group_done_above_60 = Normal_group_done.join(patients_Normal_group_done_above_60.select('patient_sk'), on=['patient_sk'], how="inner")

#---------------------------------------------------------------------------------------------------------------------

# Now, getting rid of the condenced datapoint patients :)

Normal_group_done_sparse_above_60 = Normal_group_done_above_60.withColumn('lag', f.lag('Date').over(Window.partitionBy('patient_sk').orderBy('Date')))
patients_min_date = Normal_group_done_above_60.groupBy('patient_sk').agg(f.min('Date').alias('mindate'))
Normal_group_done_sparse_above_60 = Normal_group_done_sparse_above_60.join(patients_min_date, on = ['patient_sk'], how='left')

Normal_group_done_sparse_above_60 = Normal_group_done_sparse_above_60.withColumn('lag_set', f.when(Normal_group_done_sparse_above_60['Date'] > Normal_group_done_sparse_above_60['mindate'], f.col('lag')).otherwise(None))
Normal_group_done_sparse_above_60 = Normal_group_done_sparse_above_60.withColumn("Duration", f.abs(f.unix_timestamp('Date', format=timeFmt) - f.unix_timestamp('lag_set',format=timeFmt)))

patient_Normal_group_done_sparse = Normal_group_done_sparse_above_60.groupBy('patient_sk').agg(f.min(f.col('Duration')).alias('minn'))
patient_Normal_group_done_sparse = patient_Normal_group_done_sparse.filter(f.col('minn')>86400)
patient_Normal_group_done_sparse = patient_Normal_group_done_sparse.drop('minn')
Normal_group_done_sparse_above_60 = Normal_group_done_above_60.join(patient_Normal_group_done_sparse.select('patient_sk'), on = ['patient_sk'] , how = 'inner')

#---------------------------------------------------------------------------------------------------------------------
## In order to rerun the code, you may want to delet theis part. Here we double check the data extraction with what we expect.
# Normal having 9 or more obs AFTER dropping duplicates :)

dropped_normal_patients = list(pd.read_csv('dropped_normal_group.csv').patient_sk)
dropped_normal_patients = [str(i) for i in dropped_normal_patients]
patients_Normal_group_done = Normal_group_done_sparse_above_60.where(f.col("patient_sk").isin(dropped_normal_patients))
Normal_group_done_sparse_above_60_droped = Normal_group_done_sparse_above_60.join(patients_Normal_group_done.select('patient_sk'), on = ['patient_sk'] , how = 'leftanti')


kept_ESRD_group = list(pd.read_csv('kept_ESRD_group.csv').patient_sk)
kept_ESRD_group = [str(i) for i in kept_ESRD_group]
patients_ESRD_group_done = ESRD_group_done_above_60.where(f.col("patient_sk").isin(kept_ESRD_group))
ESRD_group_done_above_60_droped = ESRD_group_done_above_60.join(patients_ESRD_group_done.select('patient_sk'), on = ['patient_sk'] , how = 'left')

#---------------------------------------------------------------------------------------------------------------------

ESRD_group_done = ESRD_group_done_above_60_droped.dropDuplicates()
Normal_group_done = Normal_group_done_sparse_above_60_droped.dropDuplicates()

In [None]:
ESRD_group_done_pandas = ESRD_group_done.toPandas()

In [None]:
ESRD_group_done_pandas.to_csv('Final_ESRD_group_done_pandas.csv')

In [None]:
Normal_group_done_pandas = Normal_group_done.toPandas()

In [None]:
Normal_group_done_pandas.to_csv('Final_Normal_group_done_pandas.csv')

### For further information please contact rzz5164@psu.edu