In [217]:
import findspark
findspark.init()

In [218]:
import pathlib

# Can remove this import since it's no longer used
# import pandas as pd
import pyspark

from pyspark.sql import functions as sf, types, Window


In [219]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [220]:
CODE_GROUPS_SCHEMA = types.StructType([
    types.StructField("ICD_CODE",types.StringType(),nullable=False),
    types.StructField("GROUP",types.StringType(),nullable=False),
])

DEMOGRAPHICS_SCHEMA = types.StructType([
    types.StructField("CLINIC",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=False),
    types.StructField("FIRST_NAME",types.StringType(),nullable=True),
    types.StructField("LAST_NAME",types.StringType(),nullable=True),
    types.StructField("DATE_OF_BIRTH",types.DateType(),nullable=False),
])

ENCOUNTERS_SCHEMA = types.StructType([
    types.StructField("DATE",types.DateType(),nullable=True),
    types.StructField("ENC_ID",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=False),
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
])

In [221]:
def csv_to_parquet(path):
    # Spark requires path as a str
    path = str(path)

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )
            
    parquet_path = path.split(".")[0]+".parquet"

    df.write.mode("overwrite").parquet(parquet_path)

In [222]:
def read_csv(path):
    path = str(path)

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )

    return df


In [223]:

# This can be rewritten to use Spark-native functions.
# This removes a dependency we don't use otherwise
def read_parquet(path):
    path = str(path)

    df = spark.read.parquet(path)
    
    return df

In [224]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [225]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [226]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [227]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_4')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

#for csv in [code_groups, demographics, encounters]:
#    csv_to_parquet(csv)

code_groups = pathlib.Path('code_groups.parquet')
demographics = data_folder.joinpath('demographics.parquet')
encounters = data_folder.joinpath('encounters.parquet')

In [228]:
code_groups_sdf = read_parquet(code_groups)
demo_sdf = read_parquet(demographics)
enc_sdf = read_parquet(encounters)

In [229]:
# Data quality checks

def check_dupes(df):
    return (
        df.groupBy(df.columns)
        .count()
        .where(sf.col("count") > 1)
        .select(sf.coalesce(sf.sum("count"), sf.lit(0)))
        .withColumnRenamed("coalesce(sum(count), 0)", f"Duplicate Count")
        .show()
    )

def check_nulls(df):
    print("Null Counts")
    return df.select([sf.count(sf.when(sf.col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

def check_all(df):
    check_dupes(df)
    check_nulls(df)
    #df.describe().show()

print("DEMO")
check_all(demo_sdf)
print("ENC")
check_all(enc_sdf)

DEMO


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts


                                                                                

+------+---+----------+---------+-------------+
|CLINIC|MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+------+---+----------+---------+-------------+
|     0|  0|         0|        0|            0|
+------+---+----------+---------+-------------+

ENC


22/04/18 21:25:00 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/04/18 21:25:00 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/04/18 21:25:11 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/04/18 21:25:11 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/04/18 21:25:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/04/18 21:25:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts




+----+------+---+--------+
|DATE|ENC_ID|MRN|ICD_CODE|
+----+------+---+--------+
|   0|     0|  0|       0|
+----+------+---+--------+



                                                                                

In [230]:
enc_sdf = transform_encounters(enc_sdf)

                                                                                

In [231]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

In [232]:
enc_sdf_added_cols = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .withColumn('ICD_CODE', sf.expr("stack(4, ICD_CODE_1, ICD_CODE_2, ICD_CODE_3, ICD_CODE_4)"))
)
results = (
    enc_sdf_added_cols
    .join(code_groups_sdf, enc_sdf_added_cols.ICD_CODE==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [233]:
results.show()



+----------+-----+------+
|AGE_BUCKET|GROUP| count|
+----------+-----+------+
|       100|   10|149125|
|        40|   10|111182|
|        50|   10|111101|
|        60|   10|111529|
|        70|   10|111641|
|        90|   10|221787|
+----------+-----+------+



                                                                                