In [495]:
import findspark
findspark.init()

In [496]:
import pathlib

# Can remove this import since it's no longer used
# import pandas as pd
import pyspark

from pyspark.sql import functions as sf
from pyspark.sql import Window

import pyarrow.parquet as pq

In [497]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [498]:

# This can be rewritten to use Spark-native functions.
# This removes a dependency we don't use otherwise
# and avoids loading the whole df into memory
def read_csv(path, date_cols=None):
    # Spark requires path as a str
    path = str(path)

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .csv(path)
    )

    if date_cols:
        for date_col in date_cols:
            df.withColumn(date_col, sf.col(date_col).cast("Date"))
            
    parquet_path = str(path.split(".")[0]+".parquet")

    df.write.mode("overwrite").parquet(parquet_path)

    df = spark.read.parquet(parquet_path)
    
    return df

In [499]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [500]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})



In [501]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [502]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_4')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

In [503]:
code_groups_sdf = read_csv(code_groups)
demo_sdf = read_csv(demographics, date_cols=['DATE_OF_BIRTH'])
enc_sdf = read_csv(encounters)
enc_sdf.show()



+----------+------+---------+--------+
|      DATE|ENC_ID|      MRN|ICD_CODE|
+----------+------+---------+--------+
|2019-11-27|100020|C24633824| S93.321|
|2019-11-27|100020|C24633824|S42.102D|
|2019-11-27|100020|C24633824|S72.392D|
|2019-11-27|100020|C24633824|V31.4XXA|
|2020-01-14|100039|B04778727|S74.8X1A|
|2020-01-14|100039|B04778727|  M48.13|
|2020-01-14|100039|B04778727|T81.10XA|
|2020-10-01|100080|B67938940|S22.068G|
|2020-10-01|100080|B67938940|  M33.09|
|2020-10-01|100080|B67938940|  S97.12|
|2019-07-05|100093|B85223500|T54.2X3D|
|2019-07-05|100093|B85223500|S12.110B|
|2019-07-05|100093|B85223500| T71.194|
|2019-07-05|100093|B85223500| M67.261|
|2018-01-09|100116|C48169493|T82.128D|
|2020-12-11|100118|C24687879|S83.004S|
|2020-12-11|100118|C24687879|S07.8XXA|
|2022-04-03|100124|C48946153|S06.9X1D|
|2022-04-03|100124|C48946153|S92.236K|
|2022-04-03|100124|C48946153|S82.001K|
+----------+------+---------+--------+
only showing top 20 rows



                                                                                

In [504]:
enc_sdf = transform_encounters(enc_sdf)
enc_sdf.show()



+--------+----------+---------+----------+----------+----------+----------+
|  ENC_ID|      DATE|      MRN|ICD_CODE_1|ICD_CODE_2|ICD_CODE_3|ICD_CODE_4|
+--------+----------+---------+----------+----------+----------+----------+
|10005493|2017-10-19|A07177112|  T49.2X6D|   M93.949|       M00|      null|
|10005659|2020-04-03|B96471583|       B30|   Z01.818|      null|      null|
|10006023|2017-04-30|A10685652|   M20.029|  M84.444P|       C88|  S82.231E|
|10019218|2018-07-29|A73021675|  V22.5XXS|  S60.465A|   S21.012|   M24.611|
|10019957|2018-04-01|C41104627|    V96.05|     V87.1|   S06.1X7|    S37.92|
|  100263|2018-06-12|A77568628|  S61.304A|  S62.601P|      null|      null|
|10026422|2018-08-19|A83158064|    H18.54|    A92.39|  S63.111D|      null|
| 1002883|2018-09-09|B32970140|  O31.31X9|      null|      null|      null|
| 1003202|2019-10-10|C55076900|  S89.299S|       M35|      null|      null|
|10037766|2019-02-10|C99145363|  S43.111A|  S65.299S|       R19|      null|
|10042652|20

                                                                                

In [505]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))
demo_sdf.show()

+------+--------+----------+---------+-------------+---------+
|CLINIC|     MRN|FIRST_NAME|LAST_NAME|DATE_OF_BIRTH| FULL_MRN|
+------+--------+----------+---------+-------------+---------+
|     A|63450226|   Rebecca|   Parker|   1918-08-25|A63450226|
|     B|62213719|      Ryan|Robertson|   1914-05-10|B62213719|
|     A|19731324|  Danielle|    Lopez|   1922-07-25|A19731324|
|     C|44649328|    Ashley|   Morgan|   1932-11-17|C44649328|
|     C|33848268|      Juan|   Romero|   1950-11-22|C33848268|
|     C|27893296|   Maurice|   Atkins|   1945-07-06|C27893296|
|     B|63621992|     Paige| Hamilton|   2014-04-03|B63621992|
|     A|44486390|     Kelly|   Morris|   1948-08-28|A44486390|
|     A|59817580|      Mark| Thompson|   1973-02-23|A59817580|
|     A|52765016|    Amanda|    Woods|   1996-11-23|A52765016|
|     C|33480965|     Erica|  Swanson|   1976-12-31|C33480965|
|     A|48204670|  Lawrence|       Yu|   1927-11-10|A48204670|
|     B|43088820|    Brandy|    Ramos|   1968-09-03|B43

In [510]:
enc_sdf_added_cols = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .withColumn('ICD_CODE', sf.expr("stack(4, ICD_CODE_1, ICD_CODE_2, ICD_CODE_3, ICD_CODE_4)"))
)
results = (
    enc_sdf_added_cols
    .join(code_groups_sdf, enc_sdf_added_cols.ICD_CODE==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [511]:
results.show()



+----------+-----+------+
|AGE_BUCKET|GROUP| count|
+----------+-----+------+
|       100|    8|143884|
|        40|    8|107114|
|        50|    8|106948|
|        60|    8|106857|
|        70|    8|107118|
|        90|    8|213930|
+----------+-----+------+



                                                                                