In [468]:
import findspark
findspark.init()

In [469]:
import pathlib

# Can remove this import since it's no longer used
# import pandas as pd
import pyspark

from pyspark.sql import functions as sf, types, Window


In [470]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [471]:
CODE_GROUPS_SCHEMA = types.StructType([
    types.StructField("ICD_CODE",types.StringType(),nullable=False),
    types.StructField("GROUP",types.StringType(),nullable=False),
])

DEMOGRAPHICS_SCHEMA = types.StructType([
    types.StructField("CLINIC",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=False),
    types.StructField("FIRST_NAME",types.StringType(),nullable=True),
    types.StructField("LAST_NAME",types.StringType(),nullable=True),
    types.StructField("DATE_OF_BIRTH",types.DateType(),nullable=False),
])

ENCOUNTERS_SCHEMA = types.StructType([
    types.StructField("DATE",types.DateType(),nullable=True),
    types.StructField("ENC_ID",types.StringType(),nullable=True),
    types.StructField("MRN",types.StringType(),nullable=False),
    types.StructField("ICD_CODE",types.StringType(),nullable=True),
])

In [472]:
def csv_to_parquet(path):
    # Spark requires path as a str
    path = str(path)

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )
            
    parquet_path = path.split(".")[0]+".parquet"

    df.write.mode("overwrite").parquet(parquet_path)

In [473]:
def read_csv(path):
    path = str(path)

    filename = path.split("/")[-1].split(".")[0]

    if filename.casefold() == "code_groups":
        schema = CODE_GROUPS_SCHEMA
    elif filename.casefold() == "demographics":
        schema = DEMOGRAPHICS_SCHEMA
    elif filename.casefold() == "encounters":
        schema = ENCOUNTERS_SCHEMA

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .schema(schema)
        .csv(path)
    )

    return df


In [474]:

# This can be rewritten to use Spark-native functions.
# This removes a dependency we don't use otherwise
def read_parquet(path):
    path = str(path)

    df = spark.read.parquet(path)
    
    return df

In [475]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [476]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [477]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [478]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_3')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

#for csv in [code_groups, demographics, encounters]:
#    csv_to_parquet(csv)

code_groups = pathlib.Path('code_groups.parquet')
demographics = data_folder.joinpath('demographics.parquet')
encounters = data_folder.joinpath('encounters.parquet')

In [479]:
code_groups_sdf = read_parquet(code_groups)
demo_sdf = read_parquet(demographics)
enc_sdf = read_parquet(encounters)

In [480]:
# Data quality checks

def check_dupes(df):
    return (
        df.groupBy(df.columns)
        .count()
        .where(sf.col("count") > 1)
        .select(sf.coalesce(sf.sum("count"), sf.lit(0)))
        .withColumnRenamed("coalesce(sum(count), 0)", f"Duplicate Count")
        .show()
    )

def check_nulls(df):
    print("Null Counts")
    return df.select([sf.count(sf.when(sf.col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

def run_single_df_checks(df):
    check_dupes(df)
    check_nulls(df)
    df.describe().show()

def run_all_checks():
    print("-----CODE GROUPS-----")
    run_single_df_checks(code_groups_sdf)
    print("-----DEMO-----")
    run_single_df_checks(demo_sdf)
    print("-----ENC-----")
    run_single_df_checks(enc_sdf)

run_all_checks()

-----CODE GROUPS-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+--------+-----+
|ICD_CODE|GROUP|
+--------+-----+
|       0|    0|
+--------+-----+

+-------+--------+------------------+
|summary|ICD_CODE|             GROUP|
+-------+--------+------------------+
|  count|    4995|              4995|
|   mean|    null| 5.531131131131131|
| stddev|    null|2.8903991264796347|
|    min|   A07.0|                 1|
|    max| Z98.818|                 9|
+-------+--------+------------------+

-----DEMO-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+------+---+----------+---------+-------------+
|CLINIC|MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+------+---+----------+---------+-------------+
|     0|  0|         0|        0|            0|
+------+---+----------+---------+-------------+

+-------+------+--------------------+----------+---------+-------------+
|summary|CLINIC|                 MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+-------+------+--------------------+----------+---------+-------------+
|  count| 10000|               10000|     10000|    10000|        10000|
|   mean|  null|     5.07491028763E7|      null|     null|         null|
| stddev|  null|2.8727825630081754E7|      null|     null|         null|
|    min|     A|            01015674|     Aaron|   Abbott|   1859-11-23|
|    max|     C|            99998883|       Zoe|�the Kid�|   2022-04-06|
+-------+------+--------------------+----------+---------+-------------+

                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+----+------+---+--------+
|DATE|ENC_ID|MRN|ICD_CODE|
+----+------+---+--------+
|   0|     0|  0|       0|
+----+------+---+--------+



[Stage 952:>                                                        (0 + 1) / 1]

+-------+----------+------------------+---------+--------+
|summary|      DATE|            ENC_ID|      MRN|ICD_CODE|
+-------+----------+------------------+---------+--------+
|  count|    131447|            131447|   131447|  131447|
|   mean|      null|4.98692226390104E7|     null|    null|
| stddev|      null|2.87755318476725E7|     null|    null|
|    min|1881-07-14|          10000905|A01126333|   A07.0|
|    max|2022-04-11|          99999335|C99975199| Z98.818|
+-------+----------+------------------+---------+--------+



                                                                                

In [481]:
def clean_data(df, type):
    # Pulled from https://www.johndcook.com/blog/2019/05/05/regex_icd_codes/
    N = r"\d{3}\.?\d{0,2}"
    E = r"E\d{3}\.?\d?"
    V = r"V\d{2}\.?\d{0,2}"
    icd9_regex = "|".join([N, E, V])
    icd10_regex = r"[A-TV-Z][0-9][0-9AB]\.?[0-9A-TV-Z]{0,4}"
    name_regex = r"[A-Za-z]+"

    if type == "code_groups":
        df = (df.where(
            (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))
    elif type == "demographics":
        df = (df.where(
            (sf.col("DATE_OF_BIRTH") > sf.date_sub(sf.current_date(), (120 * 365))) &
            (sf.col("FIRST_NAME").rlike(name_regex)) &
            (sf.col("LAST_NAME").rlike(name_regex))
        ))
    elif type == "encounters":
        df = (df.where(
            (sf.col("DATE") > sf.date_sub(sf.current_date(), (15 * 365))) &
            (sf.col("ICD_CODE").rlike(icd9_regex) | sf.col("ICD_CODE").rlike(icd10_regex))
        ))

    return df

code_groups_sdf = clean_data(code_groups_sdf, "code_groups")
demo_sdf = clean_data(demo_sdf, "demographics")
enc_sdf = clean_data(enc_sdf, "encounters")


In [482]:
run_all_checks()

-----CODE GROUPS-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+--------+-----+
|ICD_CODE|GROUP|
+--------+-----+
|       0|    0|
+--------+-----+

+-------+--------+------------------+
|summary|ICD_CODE|             GROUP|
+-------+--------+------------------+
|  count|    4995|              4995|
|   mean|    null| 5.531131131131131|
| stddev|    null|2.8903991264796347|
|    min|   A07.0|                 1|
|    max| Z98.818|                 9|
+-------+--------+------------------+

-----DEMO-----


                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+------+---+----------+---------+-------------+
|CLINIC|MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+------+---+----------+---------+-------------+
|     0|  0|         0|        0|            0|
+------+---+----------+---------+-------------+

+-------+------+-------------------+----------+---------+-------------+
|summary|CLINIC|                MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH|
+-------+------+-------------------+----------+---------+-------------+
|  count|  9999|               9999|      9999|     9999|         9999|
|   mean|  null|5.075196392889289E7|      null|     null|         null|
| stddev|  null|2.872783748094048E7|      null|     null|         null|
|    min|     A|           01015674|     Aaron|   Abbott|   1906-04-17|
|    max|     C|           99998883|       Zoe|   Zuniga|   2022-04-06|
+-------+------+-------------------+----------+---------+-------------+

-----EN

                                                                                

+---------------+
|Duplicate Count|
+---------------+
|              0|
+---------------+

Null Counts
+----+------+---+--------+
|DATE|ENC_ID|MRN|ICD_CODE|
+----+------+---+--------+
|   0|     0|  0|       0|
+----+------+---+--------+



[Stage 973:>                                                        (0 + 1) / 1]

+-------+----------+-------------------+---------+--------+
|summary|      DATE|             ENC_ID|      MRN|ICD_CODE|
+-------+----------+-------------------+---------+--------+
|  count|    131445|             131445|   131445|  131445|
|   mean|      null|4.986884302184183E7|     null|    null|
| stddev|      null|2.877558619176992E7|     null|    null|
|    min|2017-01-01|           10000905|A01126333|   A07.0|
|    max|2022-04-11|           99999335|C99975199| Z98.818|
+-------+----------+-------------------+---------+--------+



                                                                                

In [483]:
enc_sdf = transform_encounters(enc_sdf)

                                                                                

In [484]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

In [485]:
enc_sdf_added_cols = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .withColumn('ICD_CODE', sf.expr("stack(4, ICD_CODE_1, ICD_CODE_2, ICD_CODE_3, ICD_CODE_4)"))
)
results = (
    enc_sdf_added_cols
    .join(code_groups_sdf, enc_sdf_added_cols.ICD_CODE==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [486]:
results.show()



+----------+-----+-----+
|AGE_BUCKET|GROUP|count|
+----------+-----+-----+
|       100|   10| 1448|
|        40|   10| 1087|
|        50|   10| 1222|
|        60|   10| 1139|
|        70|   10| 1066|
|        90|   10| 2205|
+----------+-----+-----+



                                                                                