## Function for generating medication DataFrame based on RxCUI, possible routes, drug name, and days prior

In [0]:
%run
"/Users/jennifer.hadlock2@providence.org/Jenn - Hadlock Lab Shared/Clinical_Concepts_New/Load_all_dictionaries"

In [0]:
# With window functions to add row_number for each a selected datatime column based on pat_id
# Use: the row_number helps to get needed row(s) to create next sub_df
from pyspark.sql.window import Window
from pyspark.sql.functions import col, lower, mean, bround, when, unix_timestamp, datediff, months_between, to_date, lit, dense_rank, regexp_replace, split, isnan, desc, row_number

import numpy as np
import pandas as pd
from pyspark.sql import SparkSession, Column, DataFrameNaFunctions, DataFrameStatFunctions, Row, GroupedData
from pyspark.sql import functions as F  
from pyspark.sql.types import *
from pyspark.ml.feature import Imputer
import os
import sys

# pd.set_option('max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)

##################################################################
## Previous omop table versions:
## 2020-08-05, 2022-02-11
## Input:
## the code stored in the dictionary based on the current keyword
## Output:
## A dataframe of codes and their corresponding descendant codes
##################################################################

## Notice!! Please always confirm the version of the OMOP table you are using, check the data panel to see if there is a newer version

print("Please always confirm the version of the OMOP table you are using, check the data panel to see if there is a newer version!!!")

def get_all_descendant_RxNorm_codes(code, omop_table_version):
  descendant_codes_df = spark.sql(
  """
  SELECT
    oc1.concept_code as ancestor_code,
    oc1.concept_name as ancestor_concept_name,
    oc2.concept_code as descendant_code,
    oc2.concept_name as descendant_concept_name
  FROM (
    SELECT * FROM rdp_phi_sandbox.omop_concept_{version}
    WHERE
      concept_code = {concept_code} AND
      vocabulary_id = 'RxNorm') as oc1
  JOIN rdp_phi_sandbox.omop_concept_ancestor_{version} as oca
  ON oc1.concept_id = oca.ancestor_concept_id
  JOIN rdp_phi_sandbox.omop_concept_{version} as oc2
  ON oca.descendant_concept_id = oc2.concept_id
  ORDER BY min_levels_of_separation, oc2.concept_name
  """.format(concept_code = code, version = omop_table_version))
  return descendant_codes_df

In [0]:
## Load rdp_phi.medication orders
medorder_field = "pat_id, instance, pat_enc_csn_id, medication_id, order_med_id, orderdescription, orderclass, ordermode, orderstatus, orderingdatetime, start_date, end_date, sig, dosage, quantity, route, department_id"
medorder_query = "SELECT %s FROM rdp_phi.medicationorders" %medorder_field
## If there is issue with the medicationorders table now, currently use the zz_old_medicationorders table instead
# medorder_query = "SELECT %s FROM rdp_phi.zz_old_medicationorders WHERE instance ='1000'" %medorder_field

df_medorder = spark.sql(medorder_query)

###################################################################################################
## left join with the encounter table using both pat_id and pat_enc_csn_id to get the contact_date
encounter_df = spark.sql(""" select distinct pat_id, instance, PAT_ENC_CSN_ID, CONTACT_DATE from rdp_phi.encounter""")
encounter_df = encounter_df.withColumnRenamed('pat_id', 'pat_id2').withColumnRenamed('PAT_ENC_CSN_ID', 'PAT_ENC_CSN_ID2').withColumnRenamed('instance', 'instance2')
cond = [df_medorder.pat_id == encounter_df.pat_id2, 
        df_medorder.pat_enc_csn_id == encounter_df.PAT_ENC_CSN_ID2, 
        df_medorder.instance == encounter_df.instance2]

df_medorder = df_medorder.join(encounter_df, cond, "left")
## Drop not need columns
df_medorder = df_medorder.drop("pat_id2").drop("instance2").drop("PAT_ENC_CSN_ID2").dropDuplicates()

#####################################################################################
## Use contact_date to fill in as many null rows as possible in the start_date column
## The logic is to: 
## 1. use window to find the earliest contact date for a combination of pat_id, dx_id, and instance
## 2. use contact date to fill in noted date, if resolved date is also unknown or >= contact_date
##
## Notes: partitionBy can take a list of cols as input; orderBy is ascending be default, so the first row will have the smallest value
partition_cols = ["pat_id", "instance", "pat_enc_csn_id", "medication_id", "order_med_id"]
w2 = Window.partitionBy(partition_cols).orderBy(col("contact_date"))

df_medorder = df_medorder.withColumn("row",row_number().over(w2)) \
                          .filter(col("row") == 1).drop("row")

df_medorder = df_medorder.withColumn('new_start_date', when( (df_medorder.start_date.isNull() & (df_medorder.end_date.isNull() | (df_medorder.end_date >= df_medorder.CONTACT_DATE) ) ), col('CONTACT_DATE') )\
                                                        .otherwise(col('start_date')) )

#remove NULL
# df_medorder = df_medorder.where(F.col("orderingdatetime").isNotNull()) 

# select required columns only
df_medorder = df_medorder.select(
  col("pat_id").alias("medord_pat_id"), col("medication_id").alias("medord_medication_id"), col("order_med_id").alias("medord_order_med_id"), col("instance"),
  col("orderdescription").alias("medord_description"), col("orderclass"), col("ordermode"), col("orderstatus"),
  col("orderingdatetime").alias("medord_date"), col("new_start_date").alias("medord_startdate"), col("end_date").alias("medord_endate"),
  col("sig"), col("dosage"), col("quantity"), col("route"))

# rdp_phi.medication
med_field = "medication_id, instance, name, shortname"
med_query = "SELECT %s FROM rdp_phi.medication" %med_field
df_med = spark.sql(med_query) 

## select required columns only
df_med = df_med.select(col("medication_id").alias("med_medication_id"), col("instance"), col("name").alias("medication_name"))

meds_join1 = df_medorder.join(df_med, 
                           (df_medorder.medord_medication_id == df_med.med_medication_id) & 
                           (df_medorder.instance == df_med.instance), how = "left").drop(df_med.med_medication_id).drop(df_med.instance)

## Load rdp_phi.medication rxnorm
medrxnorm_field = "medication_id, instance, rxnormcode, codelevel, termtype"
medrxnorm_query = "SELECT %s FROM rdp_phi.medicationrxnorm" %medrxnorm_field
df_medrxnorm = spark.sql(medrxnorm_query)  

df_medrxnorm = df_medrxnorm.select(col("medication_id").alias("medrx_medication_id"), col("instance"), col("rxnormcode"), col("codelevel"), col("termtype"))

meds_join2 = meds_join1.join(df_medrxnorm, 
                           (meds_join1.medord_medication_id == df_medrxnorm.medrx_medication_id) & 
                           (meds_join1.instance == df_medrxnorm.instance), how = "left").drop(df_medrxnorm.medrx_medication_id).drop(df_medrxnorm.instance)

## Load rdp_phi.medication administration
medadmins_field = "pat_id, order_med_id, instance, administrationdatetime, department_id"
## WHERE instance = '1000'
medadmins_query = "SELECT %s FROM rdp_phi.medicationadministration" %medadmins_field
df_medadmins = spark.sql(medadmins_query)

df_medadmins  = df_medadmins .select(col("pat_id").alias("medadmin_pat_id"), col("order_med_id").alias("medadmin_order_med_id"), col("instance"), col("administrationdatetime").alias("medadmin_date"))

## Join with meds_join2
meds_join3 = meds_join2.join(df_medadmins, 
                           (meds_join2.medord_order_med_id == df_medadmins.medadmin_order_med_id) & 
                           (meds_join2.instance == df_medadmins.instance), how = "left").drop(df_medadmins.medadmin_order_med_id).drop(df_medadmins.instance)

meds_join3 = meds_join3.select(col("medord_pat_id").alias("pat_id"), col("instance"), col("medadmin_date"), col("medord_date"), col("medord_startdate"), col("medord_endate"), col("rxnormcode"), col("medication_name"), col("orderclass"), col("ordermode"), col("orderstatus"), col("route"), col("sig"), col("quantity"), col("dosage"), col("codelevel"), col("termtype"))
## Need to think twice, whether to include those without actual dates
# meds_join3 = meds_join3.where(col("medord_startdate").isNotNull()) # remove NaN values
# meds_join3 = meds_join3.where(col("medord_endate").isNotNull()) # remove NaN values

meds_join3 = meds_join3.dropDuplicates()
# print('Total records: {}'.format(meds_join3.count()))

In [0]:
meds_join3.createOrReplaceTempView('qw_med_join_table')

In [0]:
######################################################################################################################################################################################################
## Function for searching and storing medication information of prior usage
######################################################################################################################################################################################################
## Inputs:
## pts_df: data frame for the target patient cohort 
##         Notice: need to assign the date used as reference with name "decided_index_date"
## keyword_list: keyword_list for finding medication parent rxnorm codes
##               Check details of accepted keywords in this dictionary: https://adb-3942176072937821.1.azuredatabricks.net/?o=3942176072937821#notebook/1150766708331471/command/1150766708331472             
## possible_routes: add corresponding possible routes
## med_name: add the name of the current medication, this will only be used in the name of column
## number_days_prior: add how many days prior will the medication being considered still effective
## only_instance_1k: a boolean, to control whether limit the diagnosis table to only have instance = 1000 records
## omop_table_version: the version of the OMOP table, check the data panel to see if there is a newer version
#######################################################################################################################################################################################################
## Output:
## pts_df: data frame for the target patient cohort with each medications added as a binary feature
#######################################################################################################################################################################################################
def add_med_prior_usage_with_drugNames_possibleRoutes_medName(pts_df, med_pts_df, keyword_list, possible_routes, med_name, number_days_prior, only_instance_1k, omop_table_version):
    
    from datetime import datetime
    
    ## Acquire current time, and convert to the right format
    now1 = datetime.now()
    start_time = now1.strftime("%H:%M:%S")
    
    ## Define a full snomed list to store all descendants found
    full_RxNorm_list = []
    for item in keyword_list:
      codes = medications_cc_registry[item]
      for code in codes:
        temp_df = get_all_descendant_RxNorm_codes(code, omop_table_version)
        # use toPandas to convert a pyspark dataframe's col into a list (faster than using flatmap)
        cur_RxNorm_list = list(temp_df.select('descendant_code').toPandas()['descendant_code'])
        full_RxNorm_list += cur_RxNorm_list
      
      if item in patches_cc_registry:
        list_to_ignore = patches_cc_registry[item]
        ## Following line only used to debug
        #print(list_to_ignore)
        ## Exclude codes from the list_to_ignore
        full_RxNorm_list = [ele for ele in full_RxNorm_list if ele not in list_to_ignore]
        print("The special treatment to remove unwanted codes from {} is working!".format(item))
        #print("Here is the list of snomed codes found: {}".format(snomed_codes))
      else:
        print("No special treatment to remove codes from {} now, please contact Jenn or the doctor you worked with to confirm.".format(item))
      
    ## Convert to set to make sure no duplicates    
    RxNorm_codes = list(set(full_RxNorm_list))
    ## Temp comment out to speed up
    ## test only!!
    #print("Here is the list of RxNorm codes found: {}".format(RxNorm_codes))
    
    ## Remove all strange strings start with OMOP
    def checkPrefix(x):
      prefix = "OMOP"
      if x.startswith(prefix):
        return False
      else: return True 
    new_RxNorm_codes = list(filter(checkPrefix, RxNorm_codes))
        
    ## 1. filter the needed drugs from source df using RX_CUI values and assign to med_df
    med_df = med_pts_df.where(lower(col('rxnormcode')).isin(new_RxNorm_codes))
    
    ## 2. convert timestamp columns to to_date
    med_df =  med_df.withColumn("index_date", to_date(col("decided_index_date"), "yyyy-MM-dd"))\
    .withColumn("{}_medord_startdate".format(med_name), to_date(col("medord_startdate"), "yyyy-MM-dd")).drop(col("medord_startdate"))\
    .withColumn("{}_medord_endate".format(med_name), to_date(col("medord_endate"), "yyyy-MM-dd")).drop(col("medord_endate"))\
    .dropDuplicates()
    
    ## 3. create data frame with key columns to explore
    cols = ['pat_id','instance','rxnormcode','medication_name','orderclass','ordermode','route','index_date', '{}_medord_startdate'.format(med_name), '{}_medord_endate'.format(med_name)]
    med_df = med_df.select(*cols)
    
    if only_instance_1k:
      print("Filter to include only instance = 1000 records.")
      med_df = med_df.filter(med_df.instance == 1000)
    else:
      print("Include all instance numbers.")
    
    ## 5. FILTERING PROCESS TO GET REQUIRED DATASET
    ## Since this IMIDs project only focuses on the prior medications from index date, let us remove row that contain contact date post to index date
    ## get subdf that contains only dates that are less than the index_date
    med_df_sub = med_df.dropDuplicates()
    med_df_sub1 = med_df_sub.filter(F.col("{}_medord_startdate".format(med_name)) < med_df_sub.index_date).dropDuplicates()
    med_df_sub2 = med_df_sub1.filter(F.col("{}_medord_endate".format(med_name)) < med_df_sub1.index_date).dropDuplicates()
    
    
    ########################################################################################################
    # deltadays between index date and start dates minus index_date
    # number_days_prior: define the days (positive number) prior is considered still in system
    med_df_sub3 = med_df_sub2.withColumn("DDays_medstartDt_indexDt",
                                                 F.bround(datediff(col("{}_medord_startdate".format(med_name)), col("index_date"))))

    med_df_sub3 = med_df_sub3.withColumn('prior_{0}_days_{1}_w_medstartDt'.format(number_days_prior, med_name), 
                                     when((F.col('DDays_medstartDt_indexDt') > -1 * number_days_prior), 'yes').otherwise('no'))
    ########################################################################################################
    ## deltadays between index date and end dates minus index_date
    med_df_sub3 = med_df_sub3.withColumn("DDays_medendDt_indexDt",
                                             F.bround(datediff(col("{}_medord_endate".format(med_name)), col("index_date"))))

    med_df_sub3 = med_df_sub3.withColumn('prior_{0}_days_{1}_w_medendDt'.format(number_days_prior, med_name), 
                                 when((F.col('DDays_medendDt_indexDt') > -1 * number_days_prior), 'yes').otherwise('no'))
    
    ## 7. Aggregated dataframe and Quality-check process
    ## change the input possible routes to lowercase
    possible_routes = [each.lower() for each in possible_routes]
    
    med_df_agg = med_df_sub3.select("pat_id","instance","rxnormcode","medication_name","orderclass","ordermode","route","index_date","{}_medord_startdate".format(med_name), "{}_medord_endate".format(med_name), "prior_{0}_days_{1}_w_medstartDt".format(number_days_prior, med_name), "prior_{0}_days_{1}_w_medendDt".format(number_days_prior, med_name), "DDays_medstartDt_indexDt", "DDays_medendDt_indexDt")\
    .groupBy("pat_id", "instance", "index_date")\
    .agg(F.last("orderclass", ignorenulls=True).alias("last_{}_orderclass".format(med_name)),
         F.last("route", ignorenulls=True).alias("last_{}_route".format(med_name)),
         F.max("{}_medord_startdate".format(med_name)).alias("latest_{}_mo_startDt".format(med_name)),
         F.max("DDays_medstartDt_indexDt"),
         F.max("{}_medord_endate".format(med_name)).alias("latest_{}_mo_endDt".format(med_name)),
         F.max("DDays_medendDt_indexDt"))\
    .withColumn("prior_{0}_days_{1}".format(number_days_prior, med_name), when((F.col('max(DDays_medstartDt_indexDt)') > -1 * number_days_prior)\
                                                                               |(F.col('max(DDays_medendDt_indexDt)') > -1 * number_days_prior), F.lit(1) ).otherwise( F.lit(0) ))\
    .withColumn("prior_{0}_days_{1}_logic".format(number_days_prior, med_name), when((F.col("prior_{0}_days_{1}".format(number_days_prior, med_name)) == 1)\
                                                                                     & (F.col('last_{}_route'.format(med_name)).isNull()\
                                                                                       |lower(F.col('last_{}_route'.format(med_name))).isin(possible_routes)), F.lit(1)).otherwise( F.lit(0) ))
    
    ##8. select columns and save data frame of a certain drug to sandbox
    select_cols = ("pat_id", "instance" , "prior_{0}_days_{1}_logic".format(number_days_prior, med_name))
    tmp = med_df_agg.select(*select_cols)
    ## Uncomment to show the # of patients, could take a long time to run
    #  print("Found {0} patients with the medication: {1}".format(pts_df.count(), med_name))
  
    ## Left join the usage of a certain medication back in the full df
    ## fill 0 for all NULL values during the left join, indicating that the medication is not used
    ## .drop(tmp.pat_id)
    pts_df = pts_df.join(tmp, on=['pat_id', 'instance'], how='left').fillna(0)
    
    ## coalesce for optimization
    pts_df = pts_df.coalesce(10)
    
    ## Acquire current time, and convert to the right format
    now2 = datetime.now()
    td = now2 - now1
    td_mins = int(round(td.total_seconds() / 60))
    
    print("{0} finished, used approx. {1} minutes".format(med_name, td_mins))
    
    ## Write the table
#     spark.sql("""DROP TABLE IF EXISTS rdp_phi_sandbox.qw_{0}_{1}_days_prior_history""".format(med_name, number_days_prior))
#     table_name = "rdp_phi_sandbox.qw_{0}_{1}_days_prior_history_v1".format(med_name, number_days_prior) # all tests
#     output_df.write.saveAsTable(table_name)
    
    return pts_df

In [0]:
######################################################################################################################################################################################################
## Function for searching and storing medication information of afterafterprior usage
######################################################################################################################################################################################################
## Inputs:
## pts_df: data frame for the target patient cohort 
##         Notice: need to assign the date used as reference with name "decided_index_date"
## keyword_list: keyword_list for finding medication parent rxnorm codes
##               Check details of accepted keywords in this dictionary: https://adb-3942176072937821.1.azuredatabricks.net/?o=3942176072937821#notebook/1150766708331471/command/1150766708331472             
## possible_routes: add corresponding possible routes
## med_name: add the name of the current medication, this will only be used in the name of column
## number_days_after: add how many days after will the medication being considered still effective
## only_instance_1k: a boolean, to control whether limit the diagnosis table to only have instance = 1000 records
## omop_table_version: the version of the OMOP table, check the data panel to see if there is a newer version
#######################################################################################################################################################################################################
## Output:
## pts_df: data frame for the target patient cohort with each medications added as a binary feature
#######################################################################################################################################################################################################
def add_med_after_usage_with_drugNames_possibleRoutes_medName(pts_df, med_pts_df, keyword_list, possible_routes, med_name, number_days_after, only_instance_1k, omop_table_version):
    
    from datetime import datetime
    
    ## Acquire current time, and convert to the right format
    now1 = datetime.now()
    start_time = now1.strftime("%H:%M:%S")
    
    ## Define a full snomed list to store all descendants found
    full_RxNorm_list = []
    for item in keyword_list:
      codes = medications_cc_registry[item]
      for code in codes:
        temp_df = get_all_descendant_RxNorm_codes(code, omop_table_version)
        # use toPandas to convert a pyspark dataframe's col into a list (faster than using flatmap)
        cur_RxNorm_list = list(temp_df.select('descendant_code').toPandas()['descendant_code'])
        full_RxNorm_list += cur_RxNorm_list
      
      if item in patches_cc_registry:
        list_to_ignore = patches_cc_registry[item]
        ## Following line only used to debug
        #print(list_to_ignore)
        ## Exclude codes from the list_to_ignore
        full_RxNorm_list = [ele for ele in full_RxNorm_list if ele not in list_to_ignore]
        print("The special treatment to remove unwanted codes from {} is working!".format(item))
        #print("Here is the list of snomed codes found: {}".format(snomed_codes))
      else:
        print("No special treatment to remove codes from {} now, please contact Jenn or the doctor you worked with to confirm.".format(item))
      
    ## Convert to set to make sure no duplicates    
    RxNorm_codes = list(set(full_RxNorm_list))
    ## Temp comment out to speed up
    ## test only!!
    #print("Here is the list of RxNorm codes found: {}".format(RxNorm_codes))
    
    ## Remove all strange strings start with OMOP
    def checkPrefix(x):
      prefix = "OMOP"
      if x.startswith(prefix):
        return False
      else: return True 
    new_RxNorm_codes = list(filter(checkPrefix, RxNorm_codes))
    
#     ## 0. load the meds_pts_df from the temp view
#     tmp =  spark.sql("""select pat_id, instance, rxnormcode, medication_name, orderclass, ordermode, route, rxnormcode, medord_startdate, medord_endate from qw_med_join_table""")
#     tmp = tmp.withColumnRenamed('pat_id', 'pat_id2').withColumnRenamed('instance', 'instance2')
    
#     ### Directly join with the previous condition dataframe
#     cond = [pts_df.pat_id == tmp.pat_id2, pts_df.instance == tmp.instance2]
#     med_pts_df = pts_df.select("pat_id", "instance", "decided_index_date").dropDuplicates().join(tmp, 
#                         cond, how = "inner").drop(tmp.pat_id2).drop(tmp.instance2)
    
    ## 1. filter the needed drugs from source df using RX_CUI values and assign to med_df
    med_df = med_pts_df.where(lower(col('rxnormcode')).isin(new_RxNorm_codes))
    
    ## 2. convert timestamp columns to to_date
    med_df =  med_df.withColumn("index_date", to_date(col("decided_index_date"), "yyyy-MM-dd"))\
    .withColumn("{}_medord_startdate".format(med_name), to_date(col("medord_startdate"), "yyyy-MM-dd")).drop(col("medord_startdate"))\
    .withColumn("{}_medord_endate".format(med_name), to_date(col("medord_endate"), "yyyy-MM-dd")).drop(col("medord_endate"))\
    .dropDuplicates()
    
    ## 3. create data frame with key columns to explore
    cols = ['pat_id','instance','rxnormcode','medication_name','orderclass','ordermode','route','index_date', '{}_medord_startdate'.format(med_name), '{}_medord_endate'.format(med_name)]
    med_df = med_df.select(*cols)
    
    if only_instance_1k:
      print("Filter to include only instance = 1000 records.")
      med_df = med_df.filter(med_df.instance == 1000)
    else:
      print("Include all instance numbers.")
    
    ## 5. FILTERING PROCESS TO GET REQUIRED DATASET
    ## Since this IMIDs project only focuses on the after medications from index date, let us remove row that contain contact date post to index date
    ## get subdf that contains only dates that are less than the index_date
    med_df_sub = med_df.dropDuplicates()
    med_df_sub1 = med_df_sub.filter(F.col("{}_medord_startdate".format(med_name)) >= med_df_sub.index_date).dropDuplicates()
    med_df_sub2 = med_df_sub1.filter(F.col("{}_medord_endate".format(med_name)) >= med_df_sub1.index_date).dropDuplicates()
    
    ########################################################################################################
    # deltadays between index date and start dates minus index_date
    # number_days_after: define the days (positive number) after is considered still in system
    med_df_sub3 = med_df_sub2.withColumn("DDays_medstartDt_indexDt",
                                                 F.bround(datediff(col("{}_medord_startdate".format(med_name)), col("index_date"))))

    med_df_sub3 = med_df_sub3.withColumn('after_{0}_days_{1}_w_medstartDt'.format(number_days_after, med_name), 
                                     when((F.col('DDays_medstartDt_indexDt') <= 1 * number_days_after), 'yes').otherwise('no'))
    ########################################################################################################
    ## deltadays between index date and end dates minus index_date
    med_df_sub3 = med_df_sub3.withColumn("DDays_medendDt_indexDt",
                                             F.bround(datediff(col("{}_medord_endate".format(med_name)), col("index_date"))))

    med_df_sub3 = med_df_sub3.withColumn('after_{0}_days_{1}_w_medendDt'.format(number_days_after, med_name), 
                                 when((F.col('DDays_medendDt_indexDt') <= 1 * number_days_after), 'yes').otherwise('no'))
    
    ## 7. Aggregated dataframe and Quality-check process
    ## change the input possible routes to lowercase
    possible_routes = [each.lower() for each in possible_routes]
    
    med_df_agg = med_df_sub3.select("pat_id","instance","rxnormcode","medication_name","orderclass","ordermode","route","index_date","{}_medord_startdate".format(med_name), "{}_medord_endate".format(med_name), "after_{0}_days_{1}_w_medstartDt".format(number_days_after, med_name), "after_{0}_days_{1}_w_medendDt".format(number_days_after, med_name), "DDays_medstartDt_indexDt", "DDays_medendDt_indexDt")\
    .groupBy("pat_id", "instance", "index_date")\
    .agg(F.first("orderclass", ignorenulls=True).alias("closest_{}_orderclass".format(med_name)), ## Notice: need to be first, since we are considering the first medication order's class & routes after index date
         F.first("route", ignorenulls=True).alias("closest_{}_route".format(med_name)),
         ## Notice: need to be F.min, since now we are considering the medication orders after the index date
         F.min("{}_medord_startdate".format(med_name)).alias("latest_{}_mo_startDt".format(med_name)),
         F.min("DDays_medstartDt_indexDt"),
         F.min("{}_medord_endate".format(med_name)).alias("latest_{}_mo_endDt".format(med_name)),
         F.min("DDays_medendDt_indexDt"))\
    .withColumn("after_{0}_days_{1}".format(number_days_after, med_name), when((F.col('max(DDays_medstartDt_indexDt)') <= 1 * number_days_after)\
                                                                               |(F.col('max(DDays_medendDt_indexDt)') <= 1 * number_days_after), F.lit(1) ).otherwise( F.lit(0) ))\
    .withColumn("after_{0}_days_{1}_logic".format(number_days_after, med_name), when((F.col("after_{0}_days_{1}".format(number_days_after, med_name)) == 1)\
                                                                                     & (F.col('closest_{}_route'.format(med_name)).isNull()\
                                                                                       |lower(F.col('closest_{}_route'.format(med_name))).isin(possible_routes)), F.lit(1)).otherwise( F.lit(0) ))
    
    ##8. select columns and save data frame of a certain drug to sandbox
    select_cols = ("pat_id", "instance" , "after_{0}_days_{1}_logic".format(number_days_after, med_name))
    tmp = med_df_agg.select(*select_cols)
    ## Uncomment to show the # of patients, could take a long time to run
    #  print("Found {0} patients with the medication: {1}".format(pts_df.count(), med_name))
  
    ## Left join the usage of a certain medication back in the full df
    ## fill 0 for all NULL values during the left join, indicating that the medication is not used
    ## .drop(tmp.pat_id)
    pts_df = pts_df.join(tmp, on=['pat_id', 'instance'], how='left').fillna(0)
    
    ## coalesce for optimization
    pts_df = pts_df.coalesce(10)
    
    ## Acquire current time, and convert to the right format
    now2 = datetime.now()
    td = now2 - now1
    td_mins = int(round(td.total_seconds() / 60))
    
    print("{0} finished, used approx. {1} minutes".format(med_name, td_mins))
    
    ## Write the table
#     spark.sql("""DROP TABLE IF EXISTS rdp_phi_sandbox.qw_{0}_{1}_days_prior_history""".format(med_name, number_days_prior))
#     table_name = "rdp_phi_sandbox.qw_{0}_{1}_days_prior_history_v1".format(med_name, number_days_prior) # all tests
#     output_df.write.saveAsTable(table_name)
    
    return pts_df