In [247]:
import findspark
findspark.init()

In [248]:
import pathlib

# Can remove this import since it's no longer used
# import pandas as pd
import pyspark

from pyspark.sql import functions as sf
from pyspark.sql import Window

import pyarrow.parquet as pq

In [249]:
spark = pyspark.sql.SparkSession.builder.master("local").getOrCreate()

In [250]:

# This can be rewritten to use Spark-native functions.
# This removes a dependency we don't use otherwise
# and avoids loading the whole df into memory
def read_csv(path, date_cols=None):
    # Spark requires path as a str
    path = str(path)

    df = (
        spark.read
        .option("delimiter", "|")
        .option("header", "true")
        .csv(path)
    )

    if date_cols:
        for date_col in date_cols:
            df.withColumn(date_col, sf.col(date_col).cast("Date"))
            
    parquet_path = str(path.split(".")[0]+".parquet")

    df.write.mode("overwrite").parquet(parquet_path)

    df = spark.read.parquet(parquet_path)
    
    return df

In [251]:
def rename_columns(sdf, name_map):
    for from_col, to_col in name_map.items():
        sdf = sdf.withColumnRenamed(from_col, to_col)
    return sdf

In [252]:
def transform_encounters(enc_sdf):
    sdf2 = (
        enc_sdf
        .withColumn('ROW_ID', sf.monotonically_increasing_id())
        .withColumn('PRIORITY', sf.row_number().over(Window.partitionBy('ENC_ID').orderBy('ROW_ID')))
        .drop('ROW_ID')
        .groupBy(['ENC_ID', 'DATE', 'MRN'])
        .pivot('PRIORITY')
        .agg(sf.first('ICD_CODE'))
    )
    return rename_columns(sdf2, {i: f'ICD_CODE_{i}' for i in sdf2.columns[3:]})

In [253]:
def age_bucket(age):
    if age > 100:
        return 100
    elif age > 90:
        return 90
    elif age > 80:
        return 90
    elif age > 70:
        return 70
    elif age > 60:
        return 60
    elif age > 50:
        return 50
    elif age > 40:
        return 40
    else:
        # ignore all patients under 40
        return 0

In [261]:
# change this to data_2, data_3 or data_4 to use a different dataset
data_folder = pathlib.Path('data_1')

code_groups = pathlib.Path('code_groups.csv')
demographics = data_folder.joinpath('demographics.csv')
encounters = data_folder.joinpath('encounters.csv')

In [262]:
code_groups_sdf = read_csv(code_groups)
demo_sdf = read_csv(demographics, date_cols=['DATE_OF_BIRTH'])
enc_sdf = read_csv(encounters)

22/04/13 21:47:19 ERROR RetryingBlockFetcher: Exception while beginning fetch of 1 outstanding blocks 
java.io.IOException: Connecting to /10.1.2.151:37803 timed out (120000 ms)
	at org.apache.spark.network.client.TransportClientFactory.createClient(TransportClientFactory.java:251)
	at org.apache.spark.network.client.TransportClientFactory.createClient(TransportClientFactory.java:195)
	at org.apache.spark.network.netty.NettyBlockTransferService$$anon$2.createAndStart(NettyBlockTransferService.scala:122)
	at org.apache.spark.network.shuffle.RetryingBlockFetcher.fetchAllOutstanding(RetryingBlockFetcher.java:141)
	at org.apache.spark.network.shuffle.RetryingBlockFetcher.start(RetryingBlockFetcher.java:121)
	at org.apache.spark.network.netty.NettyBlockTransferService.fetchBlocks(NettyBlockTransferService.scala:143)
	at org.apache.spark.network.BlockTransferService.fetchBlockSync(BlockTransferService.scala:103)
	at org.apache.spark.storage.BlockManager.fetchRemoteManagedBuffer(BlockManager.

KeyboardInterrupt: 

In [None]:
enc_sdf = transform_encounters(enc_sdf)

                                                                                

In [None]:
demo_sdf = demo_sdf.withColumn("FULL_MRN", sf.concat("CLINIC", "MRN"))

In [None]:
results = (
    enc_sdf
    .join(demo_sdf, enc_sdf.MRN == demo_sdf.FULL_MRN)
    .withColumn('AGE', sf.datediff('DATE', 'DATE_OF_BIRTH') / 365)
    .withColumn('AGE_BUCKET', sf.udf(age_bucket)(sf.col('AGE')))
    .join(code_groups_sdf, enc_sdf.ICD_CODE_1==code_groups_sdf.ICD_CODE)
    .groupBy('AGE_BUCKET', 'GROUP').count()
    .withColumn(
        '_row',
        sf.row_number().over(Window().partitionBy(['AGE_BUCKET']).orderBy(sf.desc('count'))))
    .filter(sf.col('_row') == 1)
    .drop('_row')
    .filter(sf.col('AGE_BUCKET') != 0)
    .orderBy('AGE_BUCKET')
)

In [None]:
results.show(5)

22/04/13 21:41:02 ERROR BroadcastExchangeExec: Could not execute broadcast in 300 secs.
java.util.concurrent.TimeoutException
	at java.base/java.util.concurrent.FutureTask.get(FutureTask.java:204)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:195)
	at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:515)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeBroadcast$1(SparkPlan.scala:193)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
	at org.apache.spark.sql.execution.SparkPlan.executeBroadcast(SparkPlan.scala:189)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.prepareBroadcast(BroadcastHashJoinExec.scala:116)
	at org.apache.spark.sql.execution.joins.Br

KeyboardInterrupt: 

22/04/13 21:42:46 ERROR BroadcastExchangeExec: Could not execute broadcast in 300 secs.
java.util.concurrent.TimeoutException
	at java.base/java.util.concurrent.FutureTask.get(FutureTask.java:204)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:195)
	at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:515)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeBroadcast$1(SparkPlan.scala:193)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
	at org.apache.spark.sql.execution.SparkPlan.executeBroadcast(SparkPlan.scala:189)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.prepareBroadcast(BroadcastHashJoinExec.scala:116)
	at org.apache.spark.sql.execution.joins.Br