## Function for generating condition DataFrame based on snomed, and decided index dates

In [0]:
%run
"/Users/jennifer.hadlock2@providence.org/Jenn - Hadlock Lab Shared/Clinical_Concepts_New/Load_all_dictionaries"

In [0]:
## Import necessary packages and functions
from pyspark.sql.window import Window
from pyspark.sql.functions import lower, col, lit, when, unix_timestamp, months_between, expr, countDistinct, count, row_number, concat
from datetime import datetime
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, FloatType

##################################################################
## Previous omop table versions:
## 2020-08-05, 2022-02-11
## Input:
## the code stored in the dictionary based on the current keyword
## Output:
## A dataframe of codes and their corresponding descendant codes
##################################################################

## Notice!! Please always confirm the version of the OMOP table you are using, check the data panel to see if there is a newer version

print("Please always confirm the version of the OMOP table you are using, check the data panel to see if there is a newer version!!!")

def get_all_descendant_snomed_codes(code, omop_table_version):
  descendant_snomed_codes_df = spark.sql(
  """
  SELECT
    oc1.concept_code as ancestor_snomed_code,
    oc1.concept_name as ancestor_concept_name,
    oc2.concept_code as descendant_snomed_code,
    oc2.concept_name as descendant_concept_name
  FROM (
    SELECT * FROM rdp_phi_sandbox.omop_concept_{version}
    WHERE
      concept_code = {snomed_code} AND
      vocabulary_id = 'SNOMED') as oc1
  JOIN rdp_phi_sandbox.omop_concept_ancestor_{version} as oca
  ON oc1.concept_id = oca.ancestor_concept_id
  JOIN rdp_phi_sandbox.omop_concept_{version} as oc2
  ON oca.descendant_concept_id = oc2.concept_id
  ORDER BY min_levels_of_separation, oc2.concept_name
  """.format(snomed_code=code, version = omop_table_version))
  return descendant_snomed_codes_df

In [0]:
####################################################################################################################################################
## Get needed cols from problemlist table and fill null values in noted_date with closest date_of_entry info
problemlist_df = spark.sql(""" select distinct pat_id, dx_id, instance, NOTED_DATE, RESOLVED_DATE, DATE_OF_ENTRY from rdp_phi.problemlist""")

## Use date_of_entry to fill in as many null rows as possible in the noted_date column
## The logic is to use date of entry to fill in noted date, if resolved date is also unknown or >= date of entry
problemlist_df = problemlist_df.withColumn('new_noted_date', when( (problemlist_df.NOTED_DATE.isNull() & (problemlist_df.RESOLVED_DATE.isNull() | (problemlist_df.RESOLVED_DATE >= problemlist_df.DATE_OF_ENTRY) ) ), col('DATE_OF_ENTRY') )\
                                                        .otherwise(col('NOTED_DATE')) )

## Select only needed columns and then rename the column back to "noted_date" to keep following codes working
select_cols = ("pat_id", "dx_id", "instance", "RESOLVED_DATE", "new_noted_date")
problemlist_df = problemlist_df.select(*select_cols).dropDuplicates()
problemlist_df = problemlist_df.withColumnRenamed('new_noted_date', 'noted_date')

####################################################################################################################################################
## Get needed cols from encounter diagnosis table
encounterdiagnosis_df = spark.sql(""" select distinct pat_id, instance, dx_id, PAT_ENC_CSN_ID, DIAGNOSISNAME from rdp_phi.encounterdiagnosis""")

################################################################################
## full outer join to get a more detailed diagnosis table
encounterdiagnosis_df = encounterdiagnosis_df.withColumnRenamed('pat_id', 'pat_id2').withColumnRenamed('dx_id', 'dx_id2').withColumnRenamed('instance', 'instance2')
cond = [problemlist_df.pat_id == encounterdiagnosis_df.pat_id2, problemlist_df.dx_id == encounterdiagnosis_df.dx_id2, problemlist_df.instance == encounterdiagnosis_df.instance2]
prob_enc_diagnosis_df = problemlist_df.join(encounterdiagnosis_df, cond, "fullouter")

## merge pat_id2 and dx_id2 back to pat_id and dx_id columns
prob_enc_diagnosis_df = prob_enc_diagnosis_df.withColumn('pat_id_merge', when(prob_enc_diagnosis_df.pat_id.isNull(), col('pat_id2') )\
                                                        .otherwise(col('pat_id')) )\
                                              .withColumn('dx_id_merge', when(prob_enc_diagnosis_df.dx_id.isNull(), col('dx_id2') )\
                                                        .otherwise(col('dx_id')) )\
                                              .withColumn('instance_merge', when(prob_enc_diagnosis_df.instance.isNull(), col('instance2') )\
                                                        .otherwise(col('instance')) )

## Drop not needed columns
select_cols = ("pat_id_merge", "dx_id_merge", "instance_merge", "PAT_ENC_CSN_ID", "NOTED_DATE", "RESOLVED_DATE", "DIAGNOSISNAME")
prob_enc_diagnosis_df = prob_enc_diagnosis_df.select(*select_cols).dropDuplicates()

###################################################################################################
## left join with the encounter table using both pat_id and pat_enc_csn_id to get the contact_date
encounter_df = spark.sql(""" select distinct pat_id, instance, PAT_ENC_CSN_ID, CONTACT_DATE from rdp_phi.encounter""")
encounter_df = encounter_df.withColumnRenamed('pat_id', 'pat_id2').withColumnRenamed('PAT_ENC_CSN_ID', 'PAT_ENC_CSN_ID2')
cond = [prob_enc_diagnosis_df.pat_id_merge == encounter_df.pat_id2, 
        prob_enc_diagnosis_df.PAT_ENC_CSN_ID == encounter_df.PAT_ENC_CSN_ID2, 
        prob_enc_diagnosis_df.instance_merge == encounter_df.instance]

prob_enc_diag_enc_df = prob_enc_diagnosis_df.join(encounter_df, cond, "left")

## Drop not needed columns
select_cols = ("pat_id_merge", "dx_id_merge", "instance_merge", "CONTACT_DATE", "NOTED_DATE", "RESOLVED_DATE", "DIAGNOSISNAME")
prob_enc_diag_enc_df = prob_enc_diag_enc_df.select(*select_cols).dropDuplicates()

#########################################################################################################
## left join with external concept mapping table to get the corresponding snomed codes for diagnosis
externalmapping_df = spark.sql(""" select distinct value, name, instance, concept from rdp_phi.externalconceptmapping""")
cond = [prob_enc_diag_enc_df.dx_id_merge == externalmapping_df.value, 
        prob_enc_diag_enc_df.instance_merge == externalmapping_df.instance]

prob_enc_diag_enc_concept_df = prob_enc_diag_enc_df.join(externalmapping_df, cond, "left")

## Drop not needed columns
select_cols = ("pat_id_merge", "dx_id_merge", "instance_merge", "CONTACT_DATE", "NOTED_DATE", "RESOLVED_DATE", "DIAGNOSISNAME", "name", "concept")
prob_enc_diag_enc_concept_df = prob_enc_diag_enc_concept_df.select(*select_cols).dropDuplicates()

####################################################################################
## concanate the diagnosisname and name columns
df_full_diagnosis_name = prob_enc_diag_enc_concept_df.withColumn('full_diagnosis_name', when(col('diagnosisname').contains(col('name')),  col('diagnosisname'))\
                                                                 .otherwise( concat(col('diagnosisname'), lit('_'), col('name')) )
)

## Drop diagnosisname and name col
df_full_diagnosis_name = df_full_diagnosis_name.drop( *('diagnosisname', 'name') )

#####################################################################################
## Use contact_date to fill in as many null rows as possible in the noted_date column
## The logic is to: 
## 1. use window to find the earliest contact date for a combination of pat_id, dx_id, and instance
## 2. use contact date to fill in noted date, if resolved date is also unknown or >= contact_date
##
## Notes: partitionBy can take a list of cols as input; orderBy is ascending be default, so the first row will have the smallest value
partition_cols = ["pat_id_merge", "dx_id_merge", "instance_merge", "full_diagnosis_name"]
w2 = Window.partitionBy(partition_cols).orderBy(col("contact_date"))

df_full_diagnosis_name = df_full_diagnosis_name.withColumn("row",row_number().over(w2)) \
                          .filter(col("row") == 1).drop("row")

df_full_diagnosis_name = df_full_diagnosis_name.withColumn('new_noted_date', when( (df_full_diagnosis_name.NOTED_DATE.isNull() & (df_full_diagnosis_name.RESOLVED_DATE.isNull() | (df_full_diagnosis_name.RESOLVED_DATE >= df_full_diagnosis_name.CONTACT_DATE) ) ), col('CONTACT_DATE') )\
                                                        .otherwise(col('NOTED_DATE')) )

###############################################################################################
## Rename pat_id_merge, dx_id_merge, instance_merge, full_diagnosis_name for future codes
df_full_diagnosis_name = df_full_diagnosis_name.withColumnRenamed('pat_id_merge', 'pat_id').withColumnRenamed('dx_id_merge', 'dx_id').withColumnRenamed('instance_merge', 'instance')\
                                                .withColumnRenamed('full_diagnosis_name', 'diagnosis_name')

## Select only needed columns and then rename the column back to "noted_date" to keep following codes working
select_cols = ("pat_id", "dx_id", "instance", "diagnosis_name", "new_noted_date", "RESOLVED_DATE", "concept")
df_full_diagnosis_name = df_full_diagnosis_name.select(*select_cols).dropDuplicates()
diag_snomed_df = df_full_diagnosis_name.withColumnRenamed('new_noted_date', 'noted_date')

In [0]:
diag_snomed_df = diag_snomed_df.where((lower(diag_snomed_df.concept).contains('snomed')))
###remove prefix
diag_snomed_df1 = diag_snomed_df.withColumn("SNOMED_ids", expr("substring(concept, 8, length(concept))"))

## rename RESOLVED_DATE to lowercase resolved_date
diag_snomed_df1 = diag_snomed_df1.withColumnRenamed("RESOLVED_DATE", "resolved_date")

diag_snomed_df1.createOrReplaceTempView('qw_diag_snomed')

In [0]:
######################################################################################################################################################################################################
## Function to find the presence of risk factors based on the SNOMED codes.
######################################################################################################################################################################################################
## Inputs:
## df: data frame for the target patient cohort 
##         Notice: need to assign the date used as reference with name "decided_index_date"
## list_of_risk_factors: keyword_list for finding risk factors parent snomed codes    
##               Check details of accepted keywords in this dictionary: https://adb-3942176072937821.1.azuredatabricks.net/?o=3942176072937821#notebook/1150766708331471
## only_instance_1k: a boolean, to control whether limit the diagnosis table to only have instance = 1000 records
## omop_table_version: the version of the OMOP table, check the data panel to see if there is a newer version
#######################################################################################################################################################################################################
## Output:
## df: data frame for the target patient cohort with each risk factors added as a binary feature
##     The updated dataframe with given risk facors
#######################################################################################################################################################################################################

def add_risk_factors_active_at_decided_index_date(df, list_of_risk_factors, only_instance_1k, omop_table_version):
  
  for risk_factor in list_of_risk_factors:
    ## Acquire current time, and convert to the right format
    now1 = datetime.now()
    start_time = now1.strftime("%H:%M:%S")
    
    print("The current risk factor in processing is {0}, start at {1}".format(risk_factor, start_time))
    codes = conditions_cc_registry[risk_factor]
    diagnosis_id_list, diagnosis_noDuplicate_id_list = [], []
    for code in codes:
      temp_df = get_all_descendant_snomed_codes(code, omop_table_version)
      # use toPandas to convert a pyspark dataframe's col into a list (faster than using flatmap)
      cur_snomed_list = list(temp_df.select('descendant_snomed_code').toPandas()['descendant_snomed_code'])
      diagnosis_id_list += cur_snomed_list
    
    ## Convert to set to make sure no duplicates    
    diagnosis_noDuplicate_id_list = list(set(diagnosis_id_list))
    
    if risk_factor in patches_cc_registry:
      list_to_ignore = patches_cc_registry[risk_factor]
      ## Following line only used to debug
      #print(list_to_ignore)
      ## Exclude snomed codes from the list_to_ignore
      diagnosis_noDuplicate_id_list = [ele for ele in diagnosis_noDuplicate_id_list if ele not in list_to_ignore]
      print("The special treatment to remove unwanted codes from {} is working!".format(risk_factor))
      ## Test only!!
#       print("Here is the list of snomed codes found: {}".format(diagnosis_noDuplicate_id_list))
    else:
      print("No special treatment to remove codes from {} now, please contact Jenn or the doctor you worked with to confirm.".format(risk_factor))
     
    ## Remove all strange strings start with OMOP
    def checkPrefix(x):
      prefix = "OMOP"
      if x.startswith(prefix):
        return False
      else: return True 
    new_diagnosis_codes = list(filter(checkPrefix, diagnosis_noDuplicate_id_list))
    
    diagnosis_ids = "','".join(new_diagnosis_codes)
    
    tmp =  spark.sql(
    """
    SELECT DISTINCT qw_diag_snomed.pat_id,qw_diag_snomed.instance, qw_diag_snomed.noted_date, qw_diag_snomed.resolved_date, IF(COUNT(diagnosis.dx_id) > 0, 1, 0) AS """ + risk_factor + """ 
    FROM ((qw_diag_snomed
      INNER JOIN rdp_phi.diagnosismapping ON qw_diag_snomed.dx_id = diagnosismapping.dx_id)
      INNER JOIN rdp_phi.diagnosis ON rdp_phi.diagnosismapping.dx_id = rdp_phi.diagnosis.dx_id)
    WHERE qw_diag_snomed.SNOMED_ids in ('""" + diagnosis_ids + """')
    GROUP BY qw_diag_snomed.pat_id,qw_diag_snomed.instance, qw_diag_snomed.noted_date, qw_diag_snomed.resolved_date
    """
    )
    
    if only_instance_1k:
      print("Filter to include only instance = 1000 records.")
      tmp = tmp.filter(tmp.instance == 1000)
    else:
      print("Include all instance numbers.")
    
    cond = (df.pat_id == tmp.pat_id) & (df.instance == tmp.instance) &\
           ( ( (df.decided_index_date >= tmp.noted_date) | (tmp.noted_date.isNull()) ) &\
           ( (df.decided_index_date <= tmp.resolved_date) | (tmp.resolved_date.isNull()) ) )

    df = df.join(tmp, cond, how='left').drop(tmp.pat_id).drop(tmp.instance).drop(tmp.noted_date).drop(tmp.resolved_date).fillna({risk_factor: 0})
    
    ## Get the number of patients found
    #num_pts = df.agg(F.sum(risk_factor)).collect()[0][0]
    
    ## Acquire current time, and convert to the right format
    now2 = datetime.now()
    td = now2 - now1
    td_mins = int(round(td.total_seconds() / 60))
    
    print("{0} finished, used approx. {1} minutes".format(risk_factor, td_mins))
  return df

In [0]:
# # pts_concat_spark_df = pts_concat_spark_df.withColumn("decided_index_date", pts_concat_spark_df.PsC_date)

# # ## Add back instance column, if you already included instance then skip this part
# # pts_concat_spark_df = pts_concat_spark_df.withColumn('instance', lit(1000))

# only_instance_1k = True
# ##omop_table_version
# omop_table_version = "2022_06_27"

# list_of_risk_factors = [
#                         ## Comorbidities
#                         'diabetes_type1and2']

# condition_df = add_risk_factors_active_at_decided_index_date(pts_concat_spark_df, list_of_risk_factors, only_instance_1k, omop_table_version)