# ML Task - Predicting LOS using window and PySpark

Start with the necessary imports...

In [6]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import lag, col, avg
import os

Define necessary PySpark env variables.

In [2]:
#os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-17-openjdk-amd64'
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-21-openjdk-21.0.3.0.9-1.fc40.x86_64/'
os.environ['SPARK_LOCAL_IP'] = '127.0.0.1'
os.environ['SPARK_MASTER_HOST'] = 'localhost'

Building the spark session...

In [3]:
spark = SparkSession.builder \
    .appName("Setup") \
    .master("local[*]") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.host","127.0.0.1") \
    .config("spark.driver.bindAddress","127.0.0.1") \
    .getOrCreate()
sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/16 11:40:43 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/06/16 11:40:43 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


Read the previously processed data.

In [4]:
file_path = "dataset/DS.parquet"

ds = spark.read.format("parquet") \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .load(file_path)

ds.printSchema()
ds.show(5)

root
 |-- HADM_ID: integer (nullable = true)
 |-- SUBJECT_ID: integer (nullable = true)
 |-- ICUSTAY_ID: integer (nullable = true)
 |-- LOS: double (nullable = true)
 |-- SEQ_NUM: integer (nullable = true)
 |-- total_events: long (nullable = true)
 |-- GENDER: string (nullable = true)
 |-- MULTIPLE_ADMISSIONS: boolean (nullable = true)
 |-- MULTIPLE_ICU_STAYS: boolean (nullable = true)
 |-- ADMISSION_TYPE: string (nullable = true)
 |-- AGE_ATE_ADMISSION: integer (nullable = true)
 |-- ADMITTIME: timestamp (nullable = true)

+-------+----------+----------+------+-------+------------+------+-------------------+------------------+--------------+-----------------+-------------------+
|HADM_ID|SUBJECT_ID|ICUSTAY_ID|   LOS|SEQ_NUM|total_events|GENDER|MULTIPLE_ADMISSIONS|MULTIPLE_ICU_STAYS|ADMISSION_TYPE|AGE_ATE_ADMISSION|          ADMITTIME|
+-------+----------+----------+------+-------+------------+------+-------------------+------------------+--------------+-----------------+--------------

Lets define a window, and we´ll order the window by admission time.

In [7]:
window_spec = Window.orderBy(col('ADMITTIME')).rowsBetween(-7, 0)

We'll use the lag function to look at previous records of LOS, as this can potentially provide some context.

In [14]:
#WIP - 
avg_los_past_3 = lag('LOS', 2).over(window_spec).cast('double').alias('avg_los_past_3')

print(avg_los_past_3)


Column<'CAST(lag(LOS, 2, NULL) OVER (ORDER BY ADMITTIME ASC NULLS FIRST ROWS BETWEEN -7 FOLLOWING AND CURRENT ROW) AS DOUBLE) AS avg_los_past_3'>
