In [None]:
import findspark
findspark.init()

In [None]:
import pathlib

import pandas as pd
import pyspark

from pyspark.sql import functions as sf
from pyspark.sql import Window

In [None]:
def read_csv(path, date_cols=None, **kwargs):
    kwargs = {'sep': '|', 'parse_dates': date_cols, **kwargs}
    df = pd.read_csv(path, **kwargs)
    return spark.createDataFrame(df)

In [None]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [None]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [None]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [None]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [None]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_1')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

In [None]:
code_groups_sdf = read_csv(code_groups)
demo_sdf = read_csv(demographics, date_cols=['DATE_OF_BIRTH'])
enc_sdf = read_csv(encounters)

In [None]:
enc_sdf = transform_encounters(enc_sdf)

In [None]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

In [None]:
results = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .join(code_groups_sdf, enc_sdf.ICD_CODE_1==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [None]:
results.show()