In [41]:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col,upper,udf,element_at,explode,regexp_replace,size
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, DateType

from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import LinearRegression

import os

spark = SparkSession \
    .builder \
    .appName("FHIR Analytics with Python") \
    .getOrCreate()
keyspace = "myCatalog.hfs_data"


In [2]:
spark.conf.set("spark.sql.catalog.myCatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
spark.conf.set("spark.cassandra.input.split.sizeInMB", "67108864")

In [3]:
def getMaritalStatus(ms):
    return ms.text_
gms = udf(getMaritalStatus, StringType())
def getBirthDate(bd):
    return bd[0]
gbd = udf(getBirthDate, DateType())

def getAvgForLoinc(loinc_code, agg_column, df):
    return df.filter(col("LoincCode").like(loinc_code)) \
                         .withColumnRenamed("ValueQuantity", agg_column) \
                         .select(col("Subject"), col(agg_column)) \
                         .groupBy("Subject").agg(F.round(F.avg(col(agg_column)), 3).alias(agg_column))

In [4]:
rawPatient = spark.read.table(keyspace + ".patient")
rawReference = spark.read.table(keyspace + ".reference")
#rawReference = spark.read.format("org.apache.spark.sql.cassandra").options(**{"table": "reference","keyspace": "hfs_data"}).load()
#rawObservation = spark.read.format("org.apache.spark.sql.cassandra").options(**{"table": "observation","keyspace": "hfs_data"}).load()
rawObservation = spark.read.table(keyspace + ".observation")
rawEncounter = spark.read.table(keyspace + ".encounter")

In [5]:
patientDataFrame = rawPatient.select(col("id").alias("PatientId"), gms(col("maritalstatus")).alias("Marital Status"), col("birthdate")["0"].alias("birthdate"), upper(col("gender")).alias("Gender")).withColumn("Age", F.round(F.datediff(F.current_date(), F.to_date(col("birthdate"))) / 365, 1))
#patientDataFrame = patientDataFrame.withColumn("gender", upper(col("gender")))
#patientDataFrame = patientDataFrame.withColumnRenamed("maritalstatus_trunc", "maritalstatus")

In [6]:
# select only Patient references and manipulate the column values so they are ready for joining with other tables

referenceDataFrame = rawReference.where(rawReference.reference.like("Patient%")).select(regexp_replace(rawReference.id, "#hidden", "").alias("id"), regexp_replace(rawReference.reference, "Patient/", "").alias("reference"))

In [7]:
referenceDataFrame.count()

321587

In [8]:
referenceDataFrame.cache()

DataFrame[id: string, reference: string]

In [9]:
# rename some observation columns AND join with reference table to resolve Patient ID references properly

observationDataFrame = rawObservation.select(col("id").alias("ObservationId"), col("code"), col("component"), col("valuequantity"), col("subject").alias("PatientReferenceId"))
observationDataFrame = referenceDataFrame.join(observationDataFrame, referenceDataFrame.id == observationDataFrame.PatientReferenceId) \
                  .withColumn("id", col("id").cast(StringType())).withColumn("reference", col("reference").cast(StringType())) \
                  .withColumnRenamed('reference', "PatientId") \
                  .drop(col("id")) \
                  .drop(col("PatientReferenceId"))




In [10]:
# Filter and select LOINC information from observation dataframe

observationDataFrame_loinc = observationDataFrame \
    .select(col("PatientId").alias("Subject"), \
            col("ObservationId"), \
            col("code").coding[0].code.alias("LoincCode"), \
            col("valuequantity").value.alias("ValueQuantity"))

loinc_code_list = ["8480-6", "8462-4","29463-7","8302-2","33914-3","2571-8","2085-9","18262-6","2093-3","39156-5","55284-4", "195967001", "233678006"]


observationDataFrame_loinc = observationDataFrame_loinc \
    .filter(col("LoincCode").isin(loinc_code_list))\
    .na.drop()

In [11]:
observationDataFrame.count()

130760

In [12]:
observationDataFrame_loinc.cache()

DataFrame[Subject: string, ObservationId: string, LoincCode: string, ValueQuantity: decimal(38,18)]

In [13]:
# Select only Body Weight observations from pre-filtered observation data frame

body_weight_df = observationDataFrame_loinc \
    .filter(col("LoincCode").like("%29463-7%")) \
    .withColumnRenamed("ValueQuantity", "Body Weight") \
    .select(col("Subject"), col("Body Weight")) \
    .na.drop() \
    .dropDuplicates()

In [14]:
# Calculate Systolic, Dystolic and total BP, then avg per patient

blood_pressure_df = observationDataFrame.select(col("PatientId").alias("Subject"),
                          "ObservationId",
                          col("component")[0].code.coding[0].code.alias("DBPCode"), \
                          col("component")[0].valuequantity.value.alias("Diastolic Blood Pressure"), \
                          col("component")[1].code.coding[0].code.alias("SBPCode"), \
                          col("component")[1].valuequantity.value.alias("Systolic Blood Pressure")).na.drop()



blood_pressure_df = blood_pressure_df.withColumn("Blood Pressure", \
      F.round(col("Diastolic Blood Pressure")+(col("Systolic Blood Pressure") - col("Diastolic Blood Pressure"))/3))

blood_pressure_avg_df = blood_pressure_df.groupBy(col("Subject")).agg(F.round(F.avg("Diastolic Blood Pressure"), 3).alias("Diastolic BP"), \
                                           F.round(F.avg("Systolic Blood Pressure"), 3).alias("Systolic BP"), \
                                           F.round(F.avg("Blood Pressure"), 3).alias("BP") \
                                          )


In [15]:
# Select all encounters that show as Asthma diagnoses

encounterDataFrame = rawEncounter.filter(F.size(col("reasoncode")) > 0).select(col("subject"), col("reasoncode")[0].coding[0]["code"].alias("Asthma")) \
                                 .withColumn("Asthma", F.when(col("Asthma").isin(["195967001","233678006"]), F.lit(1)).otherwise(F.lit(0)))

encounterDataFrame = encounterDataFrame.join(referenceDataFrame, encounterDataFrame.subject == referenceDataFrame.id) \
                                               .drop("subject", "id") \
                                               .withColumnRenamed("reference", "subject")

encounterDataFrame_asthma = encounterDataFrame.groupBy("subject").agg(F.max(col("Asthma")).alias("Asthma"))

In [16]:
# Calculate avg body weight per patient and join with patient data frame

patient_calc_df = body_weight_df.groupBy("Subject").agg(F.round(F.avg("Body Weight"), 3).alias("Body Weight")) \
                                .join(patientDataFrame, body_weight_df.Subject == patientDataFrame.PatientId) \
                                .drop("Subject")

# Add literal demo info to patients
patient_calc_df = patient_calc_df.dropDuplicates() \
                                 .withColumn("Disease", F.array(F.lit("0"))) \
                                 .withColumn("PostalCode", F.array(F.lit("0")))

# Join BP info by patient
patient_calc_df = blood_pressure_avg_df.join(patient_calc_df, \
                                             blood_pressure_avg_df.Subject == patient_calc_df.PatientId) \
                                       .drop("Subject")

# Join asthma info by patient
patient_calc_df = encounterDataFrame_asthma.join(patient_calc_df, encounterDataFrame_asthma.subject == patient_calc_df.PatientId) \
                   .dropDuplicates() \
                   .drop(col("subject"))

# Calculate and join avg triglycerides by patient
triglycerides_df = getAvgForLoinc("%2571-8%", "Triglycerides", observationDataFrame_loinc)

patient_calc_df = patient_calc_df.join(triglycerides_df, patient_calc_df.PatientId == triglycerides_df.Subject) \
                                 .drop(col("Subject"))

# Calculate and join average EGFR by patient
egfrLoincCode = ["88294-4", "33914-3"]
agg_column = "Estimated Glomerular Filtration Rate"

egfr_df = observationDataFrame_loinc.select(col("Subject"), col("LoincCode"), col("ValueQuantity")) \
                                 .filter(col("LoincCode").isin(egfrLoincCode)) \
                                 .withColumn(agg_column, col("ValueQuantity")) \
                                 .groupBy("Subject").agg(F.round(F.avg(col(agg_column))).alias(agg_column))

patient_calc_df = patient_calc_df.join(egfr_df, patient_calc_df.PatientId == egfr_df.Subject) \
                                 .drop(col("Subject"))

# Calculate and join avg LDL by patient
ldl_df = getAvgForLoinc("%18262-6%", "Low Density Lipoprotein", observationDataFrame_loinc)

patient_calc_df = patient_calc_df.join(ldl_df, patient_calc_df.PatientId == ldl_df.Subject) \
                   .dropDuplicates().drop(col("Subject"))

# Calculate and join average LDL by patient
hdl_df = getAvgForLoinc("%2085-9%", "High Density Lipoprotein Cholesterol", observationDataFrame_loinc)

patient_calc_df = patient_calc_df.join(hdl_df, patient_calc_df.PatientId == hdl_df.Subject) \
                   .drop(col("Subject"))

# Calculate and join average height by patient
height_df = getAvgForLoinc("%8302-2%", "Body Height", observationDataFrame_loinc)

patient_calc_df = patient_calc_df.join(height_df, patient_calc_df.PatientId == height_df.Subject) \
                   .drop(col("Subject"))

#Calculate and join average BMI by patient
bmi_df = getAvgForLoinc("%39156-5%", "BMI", observationDataFrame_loinc)
    
patient_calc_df = patient_calc_df.join(bmi_df, patient_calc_df.PatientId == bmi_df.Subject) \
                   .drop(col("Subject"))

#Calculate and join average cholesterol by patient
cholesterol_df = getAvgForLoinc("%2093-3%", "Total Cholesterol", observationDataFrame_loinc)

patient_calc_df = patient_calc_df.join(cholesterol_df, patient_calc_df.PatientId == cholesterol_df.Subject) \
                   .drop(col("Subject"))


In [44]:
asthma_dataset = patient_calc_df.drop("PatientId", "Disease", "PostalCode", "birthdate", "Diagnosed Date").na.drop()

In [45]:
asthma_dataset.schema.names

['Asthma',
 'Diastolic BP',
 'Systolic BP',
 'BP',
 'Body Weight',
 'Marital Status',
 'Gender',
 'Age',
 'Triglycerides',
 'Estimated Glomerular Filtration Rate',
 'Low Density Lipoprotein',
 'High Density Lipoprotein Cholesterol',
 'Body Height',
 'BMI',
 'Total Cholesterol']

In [19]:
asthma_dataset.cache()

DataFrame[Asthma: int, Diastolic BP: decimal(38,3), Systolic BP: decimal(38,3), BP: decimal(38,3), Body Weight: decimal(38,3), Marital Status: string, birthdate: date, Gender: string, Age: double, Triglycerides: decimal(38,3), Estimated Glomerular Filtration Rate: decimal(38,0), Low Density Lipoprotein: decimal(38,3), High Density Lipoprotein Cholesterol: decimal(38,3), Body Height: decimal(38,3), BMI: decimal(38,3), Total Cholesterol: decimal(38,3)]

In [20]:
asthma_dataset.show()

+------+------------+-----------+-------+-----------+--------------+----------+------+-----+-------------+------------------------------------+-----------------------+------------------------------------+-----------+------+-----------------+
|Asthma|Diastolic BP|Systolic BP|     BP|Body Weight|Marital Status| birthdate|Gender|  Age|Triglycerides|Estimated Glomerular Filtration Rate|Low Density Lipoprotein|High Density Lipoprotein Cholesterol|Body Height|   BMI|Total Cholesterol|
+------+------------+-----------+-------+-----------+--------------+----------+------+-----+-------------+------------------------------------+-----------------------+------------------------------------+-----------+------+-----------------+
|     1|      73.693|    119.458| 89.000|     77.300|             M|1943-07-09|  MALE| 77.6|      126.036|                                  32|                112.191|                              67.250|    166.300|27.950|          201.869|
|     0|      80.417|    112.417

In [21]:
asthma_dataset.count()

125

In [23]:
asthma_dataset.filter(col("Asthma") == 1).count()

5

In [46]:
trainDF, testDF = asthma_dataset.randomSplit([.8, .2], seed = 42)
print(f"""There are {trainDF.count()} rows in the training set and {testDF.count()} in the test set""")

There are 103 rows in the training set and 22 in the test set


In [47]:
categoricalCols = [field for (field, dataType) in trainDF.dtypes if dataType == "string"]
indexOutputCols = [x + "Index" for x in categoricalCols]
oheOutputCols = [x + "OHE" for x in categoricalCols]

stringIndexer = StringIndexer(inputCols = categoricalCols, outputCols = indexOutputCols, handleInvalid = "skip")
oheEncoder = OneHotEncoder(inputCols = indexOutputCols, outputCols = oheOutputCols)

numericCols = [field for (field, dataType) in trainDF.dtypes
                if dataType != "string" and field != "Asthma"]

assemblerInputs = oheOutputCols + numericCols

vecAssembler = VectorAssembler(inputCols = assemblerInputs, outputCol = "features")

In [48]:
lr = LinearRegression(labelCol = "Asthma", featuresCol = "features")
pipeline = Pipeline(stages = [stringIndexer, oheEncoder, vecAssembler, lr])

pipelineModel = pipeline.fit(trainDF)

In [53]:
predDF = pipelineModel.transform(testDF)
predDF.select("features", "Asthma", "prediction").show(22)

+--------------------+------+--------------------+
|            features|Asthma|          prediction|
+--------------------+------+--------------------+
|[1.0,0.0,80.417,1...|     0|0.029171343606225975|
|[1.0,1.0,80.429,1...|     0|-0.04980751139079653|
|[1.0,1.0,76.778,1...|     0| 0.11803101024322626|
|[1.0,1.0,71.433,1...|     0| 0.10857709677112193|
|[1.0,1.0,79.938,1...|     0| 0.07748964013421866|
|[1.0,1.0,78.688,1...|     0| 0.15204345040755762|
|[0.0,0.0,87.5,120...|     0|-0.21552369335368415|
|[1.0,1.0,71.05,10...|     0| 0.13885747464551357|
|[0.0,0.0,80.727,1...|     0| -0.1648409311161081|
|[1.0,0.0,100.034,...|     0|  0.0767006289489165|
|[1.0,0.0,57.444,8...|     0|-0.00500094701028...|
|[1.0,0.0,90.23,13...|     0| 0.12223990978552046|
|[0.0,1.0,78.652,1...|     0| 0.03548234089319435|
|[1.0,0.0,75.787,1...|     0|  0.0683026717273798|
|[1.0,0.0,74.667,1...|     0| 0.14782819626331922|
|[1.0,1.0,82.417,1...|     0| 0.10487360760423137|
|[1.0,0.0,77.977,1...|     0| 0

In [54]:
categoricalCols

['Marital Status', 'Gender']