In [38]:
import findspark
findspark.init()

In [41]:
import pathlib

# Can remove this import since it's no longer used
# import pandas as pd
import pyspark

from pyspark.sql import functions as sf
from pyspark.sql import Window

In [42]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [118]:

# This can be rewritten to use Spark-native functions.
# This removes a dependency we don't use otherwise
# and avoids loading the whole df into memory
def read_csv(path, date_cols=None):
    # Spark requires path as a str
    path = str(path)

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .csv(path)
    )

    if date_cols:
        for date_col in date_cols:
            df.withColumn(date_col, sf.col(date_col).cast("Date"))
            
    return df

In [119]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [120]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [121]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [122]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_1')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

In [123]:
code_groups_sdf = read_csv(code_groups)
demo_sdf = read_csv(demographics, date_cols=['DATE_OF_BIRTH'])
enc_sdf = read_csv(encounters)

In [124]:
enc_sdf = transform_encounters(enc_sdf)

                                                                                

In [125]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

In [126]:
results = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .join(code_groups_sdf, enc_sdf.ICD_CODE_1==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [127]:
results.show()

                                                                                

+----------+-----+-----+
|AGE_BUCKET|GROUP|count|
+----------+-----+-----+
|       100|   10|    7|
|        40|    3|   10|
|        50|    3|    1|
|        60|    4|    8|
|        70|    7|    4|
|        90|   10|   18|
+----------+-----+-----+

