In [506]:
import findspark
findspark.init()

In [507]:
from enum import Enum
import pathlib

from pyspark.sql import (
                        DataFrame,
                        functions as sf,
                        SparkSession,
                        types,
                        Window
)


In [508]:
OutputMode = Enum("OutputMode", "results_only stats verbose")
ReadMode = Enum("ReadMode", "csv_only parquet_only csv_to_parquet")

# Config constants, changing these
# will change the flow, runtime, and
# output of the notebook
DATA_FOLDER = "data_test"
OUTPUT_MODE = OutputMode.results_only
READ_MODE = ReadMode.parquet_only


In [509]:
spark = SparkSession.builder.master("local").getOrCreate()

In [510]:
# Limited number of schemas with few columns,
# so let's define them here. Makes data types more
# reliable and speeds up DF creation, since we skip
# initial pass through the file to infer schema.

CODE_GROUPS_SCHEMA = types.StructType([
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
    types.StructField("GROUP",types.StringType(),nullable=True),
])

DEMOGRAPHICS_SCHEMA = types.StructType([
    types.StructField("CLINIC",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=True),
    types.StructField("FIRST_NAME",types.StringType(),nullable=True),
    types.StructField("LAST_NAME",types.StringType(),nullable=True),
    types.StructField("DATE_OF_BIRTH",types.DateType(),nullable=True),
])

ENCOUNTERS_SCHEMA = types.StructType([
    types.StructField("DATE",types.DateType(),nullable=True),
    types.StructField("ENC_ID",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=True),
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
])

In [511]:
def csv_to_parquet(path: str) -> None:
    """
    Export csv file as a parquet file
    with the appropriate schema.

    :param path: File path to the csv
    """

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )
            
    parquet_path = path.split(".")[0]+".parquet"

    df.write.mode("overwrite").parquet(parquet_path)

In [512]:
def read_file(path: str) -> DataFrame:
    """
    Read file and convert to a Spark DataFrame.
    Supported filetypes are csv & parquet.

    :param path: File path of the file to process
    """

    filename, extension = path.split("/")[-1].split(".")

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    if extension == "csv":
        df = (
            spark.read
            .option("delimiter", "|")
            .option("header", "true")
            .schema(schema)
            .csv(path)
        )

    elif extension == "parquet":
         df = (
            spark.read
            .schema(schema)
            .parquet(path)
         )
    else:
        raise TypeError(f"{path} is not a csv or parquet file!")

    return df


In [513]:
def rename_columns(sdf: DataFrame, name_map: dict) -> DataFrame:
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [514]:
def age_bucket(age: int) -> int:
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 80
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [515]:
# Make any changes in constants above, not here
data_folder = pathlib.Path(DATA_FOLDER)

# spark.read requires path to be a str,
# converting them all here in one place
# instead of handling inside funcs
if READ_MODE == ReadMode.csv_only or READ_MODE == ReadMode.csv_to_parquet:
    code_groups = str(pathlib.Path('code_groups.csv'))
    demographics = str(data_folder.joinpath('demographics.csv'))
    encounters = str(data_folder.joinpath('encounters.csv'))

if READ_MODE == ReadMode.csv_to_parquet:
    for csv in [code_groups, demographics, encounters]:
        csv_to_parquet(csv)

if READ_MODE == ReadMode.parquet_only or READ_MODE == ReadMode.csv_to_parquet:
    code_groups = str(pathlib.Path('code_groups.parquet'))
    demographics = str(data_folder.joinpath('demographics.parquet'))
    encounters = str(data_folder.joinpath('encounters.parquet'))

code_groups_sdf = read_file(code_groups)
demo_sdf = read_file(demographics)
enc_sdf = read_file(encounters)

In [516]:
# Data quality checks

def check_dupes(df: DataFrame) -> None:
    return (
        df.groupBy(df.columns)
        .count()
        .where(sf.col("count") > 1)
        .select(sf.coalesce(sf.sum("count"), sf.lit(0)))
        .withColumnRenamed("coalesce(sum(count), 0)", f"Duplicate Count")
        .show()
    )

def check_nulls(df: DataFrame) -> None:
    print("Null Counts")
    return df.select([sf.count(sf.when(sf.col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

def run_single_df_checks(df: DataFrame) -> None:
    check_dupes(df)
    check_nulls(df)
    df.describe().show()

def run_all_checks():
    print("-----CODE GROUPS-----")
    run_single_df_checks(code_groups_sdf)
    print("-----DEMO-----")
    run_single_df_checks(demo_sdf)
    print("-----ENC-----")
    run_single_df_checks(enc_sdf)

if OUTPUT_MODE != OutputMode.results_only:
    run_all_checks()

In [517]:
ResourceType = Enum("ResourceType", "code_groups demographics encounters")

def clean_data(df: DataFrame, type: ResourceType) -> DataFrame:
    # Pulled from https://www.johndcook.com/blog/2019/05/05/regex_icd_codes/
    N = "\d{3}\.?\d{0,2}"
    E = "E\d{3}\.?\d?"
    V = "V\d{2}\.?\d{0,2}"
    icd9_regex = "|".join([N, E, V])
    icd10_regex = "[A-TV-Z][0-9][0-9AB]\.?[0-9A-TV-Z]{0,4}"

    # Allows weird dates like 0000-99-99 but good enough for our purposes
    date_regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}"
    #name_regex = r"[A-Za-z]+"

    if type == ResourceType.code_groups:
        df = (df.where(
            (sf.col("ICD_CODE").isNotNull())
        &   (sf.col("GROUP").isNotNull())
        &   (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))
    elif type == ResourceType.demographics:
        df = (df.where(
            (sf.col("DATE_OF_BIRTH").isNotNull())
        &   (sf.col("MRN").isNotNull())
        &   (sf.col("DATE_OF_BIRTH").rlike(date_regex))
        &   (sf.col("DATE_OF_BIRTH") > sf.date_sub(sf.current_date(), (120 * 365)))
        #&  (sf.col("FIRST_NAME").rlike(name_regex))
        #&  (sf.col("LAST_NAME").rlike(name_regex))
        ))
    elif type == ResourceType.encounters:
        df = (df.where(
            (sf.col("DATE").isNotNull())
        &   (sf.col("MRN").isNotNull())
        &   (sf.col("ICD_CODE").isNotNull())
        &   (sf.col("DATE").rlike(date_regex))
        &   (sf.col("DATE") > sf.date_sub(sf.current_date(), (15 * 365)))
        &   (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))

    df = df.dropDuplicates()

    return df


In [518]:
code_groups_sdf = clean_data(code_groups_sdf, ResourceType.code_groups)
demo_sdf = clean_data(demo_sdf, ResourceType.demographics)
enc_sdf = clean_data(enc_sdf, ResourceType.encounters)

if OUTPUT_MODE != OutputMode.results_only:
    run_all_checks()

In [519]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [520]:
enc_sdf_transformed = transform_encounters(enc_sdf)

if OUTPUT_MODE == OutputMode.verbose:
    enc_sdf.show()
    enc_sdf_transformed.show()

                                                                                

In [521]:
demo_sdf_mrn = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

if OUTPUT_MODE == OutputMode.verbose:
    demo_sdf.show()
    demo_sdf_mrn.show()

In [522]:
enc_sdf_added_cols = (
    enc_sdf_transformed
    .join(demo_sdf_mrn, enc_sdf_transformed.MRN == demo_sdf_mrn.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .withColumn('ICD_CODE', sf.expr("stack(4, ICD_CODE_1, ICD_CODE_2, ICD_CODE_3, ICD_CODE_4)"))
)

joined = (
    enc_sdf_added_cols
    .join(code_groups_sdf, enc_sdf_added_cols.ICD_CODE==code_groups_sdf.ICD_CODE)
)

if OUTPUT_MODE == OutputMode.verbose:
    joined.show()

results = (
    joined
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .groupBy('AGE_BUCKET', 'count').agg(sf.collect_set('GROUP'))
    .withColumnRenamed('collect_set(GROUP)', 'GROUPS')
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [523]:
results.show()



+----------+-----+---------+
|AGE_BUCKET|count|   GROUPS|
+----------+-----+---------+
|        50|    1|[1, 2, 7]|
|        70|    1|   [2, 9]|
+----------+-----+---------+



                                                                                