In [117]:
import findspark
findspark.init()

In [118]:
from enum import Enum
import pathlib

from pyspark.sql import (
                        DataFrame,
                        functions as sf,
                        SparkSession,
                        types,
                        Window
)


In [119]:
OutputMode = Enum("OutputMode", "results_only stats verbose")
ReadMode = Enum("ReadMode", "csv_only parquet_only csv_to_parquet")

# Config constants, changing these
# will change the flow, runtime, and
# output of the notebook
DATA_FOLDER = "data_test"
OUTPUT_MODE = OutputMode.results_only
READ_MODE = ReadMode.parquet_only


In [120]:
spark = SparkSession.builder.master("local").getOrCreate()

In [121]:
# Limited number of schemas with few columns,
# so let's define them here. Makes data types more
# reliable and speeds up DF creation, since we skip
# initial pass through the file to infer schema.

CODE_GROUPS_SCHEMA = types.StructType([
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
    types.StructField("GROUP",types.StringType(),nullable=True),
])

DEMOGRAPHICS_SCHEMA = types.StructType([
    types.StructField("CLINIC",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=True),
    types.StructField("FIRST_NAME",types.StringType(),nullable=True),
    types.StructField("LAST_NAME",types.StringType(),nullable=True),
    types.StructField("DATE_OF_BIRTH",types.DateType(),nullable=True),
])

ENCOUNTERS_SCHEMA = types.StructType([
    types.StructField("DATE",types.DateType(),nullable=True),
    types.StructField("ENC_ID",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=True),
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
])

In [122]:
def csv_to_parquet(path: str) -> None:
    """
    Export csv file as a parquet file
    with the appropriate schema.

    :param path: File path to the csv
    """

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )
            
    parquet_path = path.split(".")[0]+".parquet"

    df.write.mode("overwrite").parquet(parquet_path)

In [123]:
def read_file(path: str) -> DataFrame:
    """
    Read file and convert to a Spark DataFrame.
    Supported filetypes are csv & parquet.

    :param path: File path of the file to process
    """

    filename, extension = path.split("/")[-1].split(".")

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    if extension == "csv":
        df = (
            spark.read
            .option("delimiter", "|")
            .option("header", "true")
            .schema(schema)
            .csv(path)
        )

    elif extension == "parquet":
         df = (
            spark.read
            .schema(schema)
            .parquet(path)
         )
    else:
        raise TypeError(f"{path} is not a csv or parquet file!")

    return df


In [124]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [125]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 80
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [126]:
# Make any changes in constants above, not here
data_folder = pathlib.Path(DATA_FOLDER)

# spark.read requires path to be a str,
# converting them all here in one place
# instead of handling inside funcs
if READ_MODE == Mode.csv_only or READ_MODE == Mode.csv_to_parquet:
    code_groups = str(pathlib.Path('code_groups.csv'))
    demographics = str(data_folder.joinpath('demographics.csv'))
    encounters = str(data_folder.joinpath('encounters.csv'))

if READ_MODE == Mode.csv_to_parquet:
    for csv in [code_groups, demographics, encounters]:
        csv_to_parquet(csv)

if READ_MODE == Mode.parquet_only or READ_MODE == csv_to_parquet:
    code_groups = str(pathlib.Path('code_groups.parquet'))
    demographics = str(data_folder.joinpath('demographics.parquet'))
    encounters = str(data_folder.joinpath('encounters.parquet'))

code_groups_sdf = read_file(code_groups)
demo_sdf = read_file(demographics)
enc_sdf = read_file(encounters)

In [127]:
# Data quality checks

def check_dupes(df):
    return (
        df.groupBy(df.columns)
        .count()
        .where(sf.col("count") > 1)
        .select(sf.coalesce(sf.sum("count"), sf.lit(0)))
        .withColumnRenamed("coalesce(sum(count), 0)", f"Duplicate Count")
        .show()
    )

def check_nulls(df):
    print("Null Counts")
    return df.select([sf.count(sf.when(sf.col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

def run_single_df_checks(df):
    check_dupes(df)
    check_nulls(df)
    df.describe().show()

def run_all_checks():
    print("-----CODE GROUPS-----")
    run_single_df_checks(code_groups_sdf)
    print("-----DEMO-----")
    run_single_df_checks(demo_sdf)
    print("-----ENC-----")
    run_single_df_checks(enc_sdf)

if OUTPUT_MODE != OutputMode.results_only:
    run_all_checks()

-----CODE GROUPS-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+--------+-----+
|ICD_CODE|GROUP|
+--------+-----+
|       0|    0|
+--------+-----+

+-------+--------+------------------+
|summary|ICD_CODE|             GROUP|
+-------+--------+------------------+
|  count|    4995|              4995|
|   mean|    null| 5.531131131131131|
| stddev|    null|2.8903991264796347|
|    min|   A07.0|                 1|
|    max| Z98.818|                 9|
+-------+--------+------------------+

-----DEMO-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+------+---+----------+---------+-------------+
|CLINIC|MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+------+---+----------+---------+-------------+
|     0|  0|         0|        0|            2|
+------+---+----------+---------+-------------+

+-------+------+--------------------+----------+---------+
|summary|CLINIC|                 MRN|FIRST_NAME|LAST_NAME|
+-------+------+--------------------+----------+---------+
|  count|    11|                  11|        11|       11|
|   mean|  null| 4.454543181818182E7|      null|     null|
| stddev|  null|3.2153053946735825E7|      null|     null|
|    min|     A|            11100000|   Another|        A|
|    max|     C|            88776655|        迅|       鲁|
+-------+------+--------------------+----------+---------+

-----ENC-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              2|
+---------------+

Null Counts
+----+------+---+--------+
|DATE|ENC_ID|MRN|ICD_CODE|
+----+------+---+--------+
|   0|     0|  0|       0|
+----+------+---+--------+

+-------+------------------+---------+--------+
|summary|            ENC_ID|      MRN|ICD_CODE|
+-------+------------------+---------+--------+
|  count|                21|       21|      21|
|   mean|121298.57142857143|     null|    null|
| stddev| 9075.238038593967|     null|    null|
|    min|            102542|A12345678| H25.813|
|    max|            133273|C11111111|W94.32XD|
+-------+------------------+---------+--------+



In [128]:
def clean_data(df, type):
    # Pulled from https://www.johndcook.com/blog/2019/05/05/regex_icd_codes/
    N = "\d{3}\.?\d{0,2}"
    E = "E\d{3}\.?\d?"
    V = "V\d{2}\.?\d{0,2}"
    icd9_regex = "|".join([N, E, V])
    icd10_regex = "[A-TV-Z][0-9][0-9AB]\.?[0-9A-TV-Z]{0,4}"

    # Allows weird dates like 0000-99-99 but good enough for our purposes
    date_regex = "[0-9]{4}-[0-9]{2}-[0-9]{2}"
    #name_regex = r"[A-Za-z]+"

    if type == "code_groups":
        df = (df.where(
            (sf.col("ICD_CODE").isNotNull())
        &   (sf.col("GROUP").isNotNull())
        &   (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))
    elif type == "demographics":
        df = (df.where(
            (sf.col("DATE_OF_BIRTH").isNotNull())
        &   (sf.col("MRN").isNotNull())
        &   (sf.col("DATE_OF_BIRTH").rlike(date_regex))
        &   (sf.col("DATE_OF_BIRTH") > sf.date_sub(sf.current_date(), (120 * 365)))
        #&  (sf.col("FIRST_NAME").rlike(name_regex))
        #&  (sf.col("LAST_NAME").rlike(name_regex))
        ))
    elif type == "encounters":
        df = (df.where(
            (sf.col("DATE").isNotNull())
        &   (sf.col("MRN").isNotNull())
        &   (sf.col("ICD_CODE").isNotNull())
        &   (sf.col("DATE").rlike(date_regex))
        &   (sf.col("DATE") > sf.date_sub(sf.current_date(), (15 * 365)))
        &   (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))

    df = df.dropDuplicates()

    return df


In [129]:
code_groups_sdf = clean_data(code_groups_sdf, "code_groups")
demo_sdf = clean_data(demo_sdf, "demographics")
enc_sdf = clean_data(enc_sdf, "encounters")

if OUTPUT_MODE != OutputMode.results_only:
    run_all_checks()

-----CODE GROUPS-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts


                                                                                

+--------+-----+
|ICD_CODE|GROUP|
+--------+-----+
|       0|    0|
+--------+-----+



                                                                                

+-------+--------+------------------+
|summary|ICD_CODE|             GROUP|
+-------+--------+------------------+
|  count|    4995|              4995|
|   mean|    null| 5.531131131131131|
| stddev|    null|2.8903991264796374|
|    min|   A07.0|                 1|
|    max| Z98.818|                 9|
+-------+--------+------------------+

-----DEMO-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts


                                                                                

+------+---+----------+---------+-------------+
|CLINIC|MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+------+---+----------+---------+-------------+
|     0|  0|         0|        0|            0|
+------+---+----------+---------+-------------+



                                                                                

+-------+------+--------------------+----------+---------+
|summary|CLINIC|                 MRN|FIRST_NAME|LAST_NAME|
+-------+------+--------------------+----------+---------+
|  count|     7|                   7|         7|        7|
|   mean|  null| 4.285553928571428E7|      null|     null|
| stddev|  null|3.2499443156393424E7|      null|     null|
|    min|     A|            11100000|   Another|        A|
|    max|     C|            87654321|     Sonic|   Wiggum|
+-------+------+--------------------+----------+---------+

-----ENC-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts


                                                                                

+----+------+---+--------+
|DATE|ENC_ID|MRN|ICD_CODE|
+----+------+---+--------+
|   0|     0|  0|       0|
+----+------+---+--------+





+-------+-----------------+---------+--------+
|summary|           ENC_ID|      MRN|ICD_CODE|
+-------+-----------------+---------+--------+
|  count|               20|       20|      20|
|   mean|         122236.4|     null|    null|
| stddev|8200.768541648045|     null|    null|
|    min|           102542|A12345678| H25.813|
|    max|           133273|C11111111|W94.32XD|
+-------+-----------------+---------+--------+



                                                                                

In [130]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [131]:
enc_sdf_transformed = transform_encounters(enc_sdf)

if OUTPUT_MODE == OutputMode.verbose:
    enc_sdf.show()
    enc_sdf_transformed.show()

                                                                                

+----------+------+---------+--------+
|      DATE|ENC_ID|      MRN|ICD_CODE|
+----------+------+---------+--------+
|2022-02-27|126308|C11111111| H25.813|
|2021-07-08|133273|A44444333|M84.332A|
|2022-02-27|126308|C11111111|S40.822D|
|2020-07-22|116422|A22222222|S72.043G|
|2020-11-13|102542|A12345678|V90.27XS|
|2018-06-06|118129|B66666666|S52.551A|
|2021-07-08|133273|A44444333|T53.5X3A|
|2021-05-27|114503|C11111111|T83.021S|
|2018-06-06|118129|B66666666|S53.31XS|
|2020-07-22|116422|A22222222|T37.2X1A|
|2020-05-23|130887|C11100000|O36.8233|
|2021-07-21|114493|B87654321|W94.32XD|
|2018-06-06|118129|B66666666|S61.219D|
|2021-03-06|120022|A12345678|S12.01XD|
|2018-07-14|124619|A12345678|S53.004A|
|2022-02-27|126308|C11111111|  O08.89|
|2021-07-08|133273|A44444333|S72.023Q|
|2020-05-23|130887|C11100000|T46.7X4D|
|2021-07-21|114493|B87654321| O36.891|
|2022-02-27|126308|C11111111| M87.337|
+----------+------+---------+--------+



                                                                                

+------+----------+---------+----------+----------+----------+----------+
|ENC_ID|      DATE|      MRN|ICD_CODE_1|ICD_CODE_2|ICD_CODE_3|ICD_CODE_4|
+------+----------+---------+----------+----------+----------+----------+
|120022|2021-03-06|A12345678|  S12.01XD|      null|      null|      null|
|116422|2020-07-22|A22222222|  S72.043G|  T37.2X1A|      null|      null|
|114503|2021-05-27|C11111111|  T83.021S|      null|      null|      null|
|133273|2021-07-08|A44444333|  M84.332A|  T53.5X3A|  S72.023Q|      null|
|124619|2018-07-14|A12345678|  S53.004A|      null|      null|      null|
|126308|2022-02-27|C11111111|   H25.813|  S40.822D|    O08.89|   M87.337|
|114493|2021-07-21|B87654321|  W94.32XD|   O36.891|      null|      null|
|118129|2018-06-06|B66666666|  S52.551A|  S53.31XS|  S61.219D|      null|
|130887|2020-05-23|C11100000|  O36.8233|  T46.7X4D|      null|      null|
|102542|2020-11-13|A12345678|  V90.27XS|      null|      null|      null|
+------+----------+---------+---------

                                                                                

In [132]:
demo_sdf_mrn = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

if OUTPUT_MODE == OutputMode.verbose:
    demo_sdf.show()
    demo_sdf_mrn.show()

+------+--------+-----------------+-------------+-------------+
|CLINIC|     MRN|       FIRST_NAME|    LAST_NAME|DATE_OF_BIRTH|
+------+--------+-----------------+-------------+-------------+
|     A|44444333|        Duowuliao|            A|   1970-04-04|
|     C|11111111|            Sonic|The Hedgehog!|   1990-05-15|
|     B|87654321|            Ralph|       Wiggum|   1950-11-13|
|     A|12345678|             Bill|      Edwards|   1989-01-20|
|     B|77777777|          Another|         Bene|   1955-03-03|
|     A|55555555|             Some|  Beneficiary|   1951-01-01|
|     C|11100000|Nizhenzaikanzhege|           Ma|   1990-05-05|
+------+--------+-----------------+-------------+-------------+

+------+--------+-----------------+-------------+-------------+---------+
|CLINIC|     MRN|       FIRST_NAME|    LAST_NAME|DATE_OF_BIRTH| FULL_MRN|
+------+--------+-----------------+-------------+-------------+---------+
|     A|44444333|        Duowuliao|            A|   1970-04-04|A44444333|

In [133]:
enc_sdf_added_cols = (
    enc_sdf_transformed
    .join(demo_sdf_mrn, enc_sdf_transformed.MRN == demo_sdf_mrn.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .withColumn('ICD_CODE', sf.expr("stack(4, ICD_CODE_1, ICD_CODE_2, ICD_CODE_3, ICD_CODE_4)"))
)

if OUTPUT_MODE == OutputMode.verbose:
    enc_sdf_added_cols.show()

results = (
    enc_sdf_added_cols
    .join(code_groups_sdf, enc_sdf_added_cols.ICD_CODE==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .groupBy('AGE_BUCKET', 'count').agg(sf.collect_set('GROUP'))
    .withColumnRenamed('collect_set(GROUP)', 'GROUPS')
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

                                                                                

+------+----------+---------+----------+----------+----------+----------+------+--------+----------+-------------+-------------+---------+------------------+----------+--------+
|ENC_ID|      DATE|      MRN|ICD_CODE_1|ICD_CODE_2|ICD_CODE_3|ICD_CODE_4|CLINIC|     MRN|FIRST_NAME|    LAST_NAME|DATE_OF_BIRTH| FULL_MRN|               AGE|AGE_BUCKET|ICD_CODE|
+------+----------+---------+----------+----------+----------+----------+------+--------+----------+-------------+-------------+---------+------------------+----------+--------+
|120022|2021-03-06|A12345678|  S12.01XD|      null|      null|      null|     A|12345678|      Bill|      Edwards|   1989-01-20|A12345678| 32.14520547945205|         0|S12.01XD|
|120022|2021-03-06|A12345678|  S12.01XD|      null|      null|      null|     A|12345678|      Bill|      Edwards|   1989-01-20|A12345678| 32.14520547945205|         0|    null|
|120022|2021-03-06|A12345678|  S12.01XD|      null|      null|      null|     A|12345678|      Bill|      Edwa

In [134]:
results.show()



+----------+-----+---------+
|AGE_BUCKET|count|   GROUPS|
+----------+-----+---------+
|        50|    1|[1, 2, 7]|
|        70|    1|   [2, 9]|
+----------+-----+---------+



                                                                                