##### Project: Opioid Exposed Infant Covariates
##### Investigator: Stephen Patrick, Sarah Loch
##### Programmers: Sander Su, Chris Guardo
##### Date Created: 01/17/23
##### Last Modified: 09/30/25
##### Notes:
THIS NOTEBOOK REQUIRES ALTERATIONS. Specifically, the section that defines the location of the data in the schema will need to be completed by the user. Reference tables are available in the GitHub Repository

###### This file includes MPRINT related libraries:

- Change Column name case (lower or upper)
- Dataframe inspection (unique patient and total records; local and global)
- EGA calcuation to get standaraized EGA days
- Read CSV file
- Remove widget
- Get view information
- Aggration Summary (total_count, min, max, mean and median)
- Get mom_baby records with code list
- Get phenotype_table
- Combine two column values into a list
- Union dataframes
- Cohort definitions (phenotype,baby and mom)
- Output df to parquet
- Get drug concept id and name
- DB Summary
- Get search term list
- Rename column
- Time Travel for different table versions

In [0]:
import warnings #ignore warnings
warnings.filterwarnings('ignore')
import pandas as pd # for saving results from SQL queries to pandas DFs
import numpy as np # for math operations
import re # for regular expression
import os
import shutil
import datetime as dt
import matplotlib.pyplot as plt
import matplotlib.ticker as tick
import functools

from datetime import datetime, timedelta
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
import pyspark.sql.functions as func
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import year
from pyspark.sql.functions import to_date
from pyspark.sql.types import *
from pyspark.sql.functions import countDistinct

spark.conf.set("spark.sql.shuffle.partitions", "auto")

inspect_df_flag = 'count'
output_overwrite = True
# spark.conf.get("spark.databricks.io.cache.enabled")

In [0]:
###### Table definitions
OMOP_database_location = "***FILL THESE IN***"
refernce_table_database = "***FILL THESE IN***"

In [0]:
proc_table=f"{OMOP_database_location}.procedure_occurrence"
cond_table=f"{OMOP_database_location}.condition_occurrence"
obs_table=f"{OMOP_database_location}.observation"
person_table=f"{OMOP_database_location}.person"
drug_exp_table=f"{OMOP_database_location}.drug_exposure"
drug_exp_table_extra=f"{OMOP_database_location}.x_drug_exposure"
meas_table=f"{OMOP_database_location}.measurement"
note_table=f"{OMOP_database_location}.note"
visit_table=f"{OMOP_database_location}.visit_occurrence"
death_table=f"{OMOP_database_location}.death"
concept_table=f"{OMOP_database_location}.concept"
location_table = f"{OMOP_database_location}.location"
location_history_table = f"{OMOP_database_location}.location_history"
clin_doc_form_table = f"{OMOP_database_location}.x_clin_doc_form_fields"

fact_real_table=f"{OMOP_database_location}.fact_relationship"
insurance_table=f"{OMOP_database_location}.x_enc_insurance"
smoking_status_table=f"{OMOP_database_location}.x_smoking_status"


flowsheet_table=f"{refernce_table_database}.hci_pat_result"

In [0]:
##### Change Column name case (lower or upper)


In [0]:
def change_colname_case(df,case_type):
   for col in df.columns:
     if (case_type=='upper'):   
      df = df.withColumnRenamed(col, col.upper())
     elif (case_type=='lower'):   
      df = df.withColumnRenamed(col, col.lower()) 
   return df

In [0]:
##### Dataframe inspection (unique patient and total records; local and global)


In [0]:
def df_inspection(df_name,sub_group):
    
    if (sub_group=='all'):
           sql = f"select count(*) AS total, count(distinct mom_person_id) AS unique_mom, count(distinct baby_person_id) AS unique_baby from {df_name};" 
    else:
           sql = f"select count(*) AS total ,count(distinct {sub_group}_person_id) AS unique_{sub_group} from {df_name};" 
            
    if (inspect_df_flag=='df' or inspect_df_flag=='both'): 
       df= spark.sql(f"select * from {df_name} limit 10;")
       display(df) 
        
    if (inspect_df_flag=='count' or inspect_df_flag=='both'): 
       df_info= spark.sql(sql)
       display(df_info) 
       

In [0]:
##### EGA calcuation to get standaraized EGA days

In [0]:
def cal_ega(SparkDF,ega_col_name):
   ega_df = SparkDF.toPandas() ### use pandas function to process (faster)
   ega_weeks=[]
   ega_days=[]
   standardized_ega=[]
   total_ega_days=[]

    
   for index, row in ega_df.iterrows(): ### processed all records
       ega=str(row[ega_col_name])
       ega_week=0
       ega_day=0
       found = re.findall('(\d+\.?\d*)', ega) ## 39.6 weeks, 39 weeks, 39 wks 4days. 39 weeks 1/7 -> [39,1,7]
       ### normal cases
       if (len(found)>=2 ):
          ega_week=float(found[0])
          ega_day=float(found[1])
       elif (len(found)==1 ):
         if ("." in found[0]): ### 39.6 weeks -> 39 weeks and 6 days
          ega_week=int(found[0].split(".")[0])
          if (found[0].split(".")[1].isdigit()):
             ega_day=int(found[0].split(".")[1])
         else: 
          ega_week=float(found[0])
    
       ### some exceptions
       if (ega_week>99): #303-> 30 weeks and 3 days. 4017-> 40 weeks 1 day. 40 1/7-> 4017. 
           fstr = repr(ega_week).split('.')[0]
           ega_week = float(fstr[:2])
           ega_day = float(fstr[2:3])
       elif (ega_day>7): # 33-36 weeks -> 33 weeks and 6 days ### Henry: pick mid-point
           ega_day=6
        
       ega_weeks.append(ega_week)
       ega_days.append(ega_day)
       total_ega_days.append((ega_week*7)+ega_day)
       if ega_week > 0 and ega_day > 0:
          standardized_ega.append(str(int(ega_week))+"w "+str(int(ega_day))+"d")
       elif ega_week > 0 and ega_day == 0: 
          standardized_ega.append(str(int(ega_week))+"w ")
       elif ega_week == 0 and ega_day == 0:
          standardized_ega.append(None)

   ega_df['ega_week']=ega_weeks
   ega_df['ega_day']=ega_days
   ega_df['total_ega_days']=total_ega_days
   ega_df['standardized_ega']=standardized_ega
   return spark.createDataFrame(ega_df)

In [0]:
##### Read CSV file


In [0]:
def read_csv_file(DataFilePath):
   
   file_df= (spark.read
     .format("csv")
     .option("header", "true")
     .option("inferSchema", "true")        
     .load(DataFilePath)
   )
   return file_df

In [0]:
##### Read CPT/ICD LBC codes


In [0]:
# def lbc_codes():
#    lbc=read_csv_file("dbfs:/FileStore/LBC_list_10_20_22.csv")
#    lbc_CPT=lbc[lbc.VOCABULARY_ID=='CPT4'] ### CPT
#    lbc_ICD= lbc[lbc.VOCABULARY_ID!='CPT4'] ### ICD (ALL)

#    lbc_CPT_list=lbc_CPT.select('CONCEPT_CODE').distinct()
#    lbc_ICD_list=lbc_ICD.select('CONCEPT_CODE').distinct()
#    print("Total CPT code count:",lbc_CPT_list.count())
#    print("Total ICD code count:",lbc_ICD_list.count())

#    lbc_CPT_str=lbc_CPT.agg(F.concat_ws(",",F.collect_list(F.concat(F.lit('"'),F.col('CONCEPT_CODE'),F.lit('"'))))).first()[0]    
#    lbc_ICD_str=lbc_ICD.agg(F.concat_ws(",",F.collect_list(F.concat(F.lit('"'),F.col('CONCEPT_CODE'),F.lit('"'))))).first()[0]
#    return lbc_CPT_str,lbc_ICD_str

In [0]:
##### Get view information

In [0]:
def get_global_temp_output():
   sql = "SHOW VIEWS FROM global_temp;";
   view_names= spark.sql(sql).select('viewName').rdd.flatMap(lambda x: x).collect()
   view_names.append('N')
   return view_names

def temp_output(df_name):
    sql = f"SELECT * FROM global_temp.{df_name};"
    tmp_output= spark.sql(sql)
    return tmp_output

In [0]:
##### Aggration Summary (total_count, min, max, mean and median)

In [0]:
def df_agg_summary(sum_df,agg_cols,target_col):
   summary=sum_df.groupBy(agg_cols) \
    .agg(count(target_col).alias("TOTAL_RECORD"),
         min(target_col).alias("MIN_VAL"),
         round(avg(target_col),2).alias("AVG_VAL"),
         round(func.percentile_approx(target_col, 0.5),2).alias("MEDIAN_VAL"),
         max(target_col).alias("MAX_VAL")
     )\
    .sort(agg_cols,ascending=True) \
    .display()
    
def df_agg_summary_simple(sum_df,target_col):
    summary=sum_df.select(max(target_col).alias("MAX_VAL"), 
          min(target_col).alias("MIN_VAL"),
          round(mean(target_col),2).alias("MEAN_VAL"),                 
          round(func.percentile_approx(target_col, 0.5),2).alias("MEDIAN_VAL")                  
    ).display()
     

In [0]:
##### Get mom_baby records with code list

In [0]:
def get_table_records(table_name,col_name,code_list, wildcard_sel):
    if (wildcard_sel==1):
        code_str=' or '.join(map(lambda x: col_name+" like '" + x + "'", code_list))
        sql=f"select * from {table_name} where {code_str};"
    else:
        sql=f"select * from {table_name} where {col_name} in ({code_list});"
     
    table_output = spark.sql(sql)
    print("Total record:",table_output.count())
    return table_output

In [0]:
def mon_baby_records(term,lbc_df_name,subject_id):
    if (term=="condition"):
        date_str="condition_start_date"
        datetime_str="condition_start_date"
    elif (term=="procedure"):
        date_str="procedure_date"
        datetime_str="procedure_datetime"
    elif (term=="observation"):
        date_str="observation_date"
        datetime_str="observation_date"
    
    sql=f"""
         select FACT_ID_1 as mom_person_id,FACT_ID_2 as baby_person_id,
         person_source_value as baby_person_source_value,
         birth_datetime,gender_source_value as baby_gender,
         race_source_value as baby_race, a.person_id as code_person_id,
         {term}_source_value as code,{date_str} as code_date,{datetime_str} as code_datetime from
         {lbc_df_name} a inner join global_temp.mom_baby_step1 b
         
         on a.person_id = b.{subject_id}
         and date(birth_datetime) - 30 < {date_str}
         and date(birth_datetime) + 30 > {date_str}
       """
    
    df = spark.sql(sql)
    print("term:",term,", record count:",df.count())
    return df

In [0]:
def mon_baby_records_code_list(term,subject_id,code_list):

    code_list_str=' , '.join(map(lambda x: "'" + x + "'", code_list))
    if (term=="condition"):
        date_str="condition_start_date"
        datetime_str="condition_start_date"
        table_source=cond_table
    elif (term=="procedure"):
        date_str="procedure_date"
        datetime_str="procedure_datetime"
        table_source=proc_table
    elif (term=="observation"):
        date_str="observation_date"
        datetime_str="observation_date"
        table_source=obs_table
        
    sql=f"""
         select FACT_ID_1 as mom_person_id,FACT_ID_2 as baby_person_id,
         person_source_value as baby_person_source_value,
         birth_datetime,gender_source_value as baby_gender,
         race_source_value as baby_race, a.person_id as code_person_id,
         {term}_source_value as code,{date_str} as code_date,{datetime_str} as code_datetime from
         
         (select * from {table_source} where {term}_source_value in ({code_list_str})) a 
         
         inner join global_temp.mom_baby_step1 b
         on a.person_id = b.{subject_id}
        """
    df = spark.sql(sql)
    print("term:",term,", record count:",df.count())
    return df

In [0]:
##### Get phenotype_table

In [0]:
def get_table(version):
    if (version=='V1_05232022'):
      table_name = "workspace_rdmprintp2.phenotyping.mom_baby_2010_mombabypair_all_version_1_05232022";
    elif (version=='V2_11042022'):
      table_name = "workspace_rdmprintp2.phenotyping.mom_baby_2010_mombabypair_all_version_1_11042022_0523202";
    elif (version=='Current_version'):
      table_name = "global_temp.mom_baby_2010_mombabypair_all_current"; 
    
    return table_name

In [0]:
##### Combine two column values into a list

In [0]:
def combine_column_value(sheet_name,col_name1,col_name2):
   
   sql=f"select * from {sheet_name};"
   df = spark.sql(sql)
   
   name_list=df.select([lower(col_name1),lower(col_name2)]).rdd.flatMap(lambda x: x).distinct().collect()
   while(None in name_list):
    name_list.remove(None)
   print("total record count in list:",len(name_list))
   
   return name_list,df

In [0]:
##### Union dataframes

In [0]:
def union_dataframes(df_list):
   # create merged dataframe
   df_complete = functools.reduce(DataFrame.unionAll, df_list)
   return df_complete.distinct()

In [0]:
##### Cohort definitions (phenotype,baby and mom)

In [0]:
def get_whole_phenotype_cohort(phenotype_table):
  
   sql=f"""
        select * from {phenotype_table};
      """
   phenotype_cohort= spark.sql(sql)
   
   return phenotype_cohort

def get_phenotype_cohort(phenotype_table):
  
   sql=f"""
        select * from {phenotype_table}
        where gestational_age_w33_or_uncertain = 1 and live_birth_code=1 and critical_illness_4cpt = 0 and 
        respiratory_procedure_code = 0 and fetal_anomalies_code =0;
      """
   phenotype_cohort= spark.sql(sql)
   
   return phenotype_cohort

def get_baby_exposed_opioid(phenotype_table):
    sql=f"""
       select * from {phenotype_table} where gestational_age_w33_or_uncertain = 1 and live_birth_code=1 and 
       critical_illness_4cpt = 0 and respiratory_procedure_code = 0 and fetal_anomalies_code =0 
       and (nows_baby_code =1 or infant_tox_lab =1);
    """
    exposed_opioid_infant = spark.sql(sql)
    return exposed_opioid_infant

def get_baby_bh_cohort(phenotype_table,exposed_opioid):
    
   if (exposed_opioid=='opioid_exposed'):
      infant_cohort=get_baby_exposed_opioid(phenotype_table)
   else:     
      infant_cohort=get_phenotype_cohort(phenotype_table)
        
   infant_cohort.createOrReplaceTempView("infant_cohort")

   sql="""
         select * from global_temp.mom_baby_step1_baby1stvisit where baby_person_id in (select baby_person_id from infant_cohort);
       """
   baby_bh_cohort= spark.sql(sql)
   return baby_bh_cohort

def get_mom_cohort(phenotype_table):
    sql=f"""
          select *
          from {phenotype_table}
          where gestational_age_w33_or_uncertain = 1 and live_birth_code=1 and critical_illness_4cpt = 0 
          and respiratory_procedure_code = 0 and fetal_anomalies_code =0 
          and (mom_oud_inpatient = 1 or mom_oud_outpatient=1 or mom_drug = 1 or mom_drug_in_note = 1);
        """

    mom_exposed_opioid = spark.sql(sql)
    return mom_exposed_opioid

def get_baby_brestfeed_cohort(phenotype_table):
    
    sql=f"""
          select MOM_PERSON_SOURCE_VALUE,AGE_AT_DELIVERY,MOM_BIRTH_DATETIME,MOM_GENDER,MOM_RACE,BABY_PERSON_ID,
          BABY_PERSON_SOURCE_VALUE,BABY_BIRTH_DATETIME,BABY_GENDER,BABY_RACE,LIVE_BIRTH_CODE,PREGNANCY_CODE,
          GESTATIONAL_AGE_W33_OR_UNCERTAIN,CRITICAL_ILLNESS_4CPT,RESPIRATORY_PROCEDURE_CODE,
          FETAL_ANOMALIES_CODE,GESTATIONAL_AGE_UNCERTAIN,NOWS_BABY_CODE,INFANT_TOX_LAB,MOM_OUD,
          MOM_OUD_INPATIENT,MOM_OUD_OUTPATIENT,MOM_DRUG,MOM_DRUG_IN_NOTE,MOM_OPIOID_TOX from {phenotype_table} where
       
          gestational_age_w33_or_uncertain = 1 and live_birth_code=1 and critical_illness_4cpt = 0 and 
          respiratory_procedure_code = 0 and fetal_anomalies_code =0 and (nows_baby_code = 1 or 
          infant_tox_lab =1 or mom_oud = 1 or mom_oud_inpatient = 1 or mom_oud_outpatient = 1 or mom_drug=1 or mom_drug_in_note =1)
         """

    baby_brestfeed_cohort = spark.sql(sql)
    return baby_brestfeed_cohort

In [0]:
##### Output df to parquet

In [0]:
 def file_exists(path):
  try:
    dbutils.fs.ls(path)
    return True
  except Exception as e:
    if 'java.io.FileNotFoundException' in str(e):
      return False
    else:
      raise

In [0]:
#delete parquet file
# shutil.rmtree('/dbfs/FileStore/tables/mom_drug_info_tmp.parquet')

def register_parquet_global_view(df):
    output_name=df.name
    parquet_file="dbfs:/FileStore/tables/"+output_name+".parquet"
    
    if (output_overwrite or (not file_exists(parquet_file))):
       print("Start producing/replacing parquet file/view...") 
       df.write.mode('overwrite').parquet(parquet_file)
       print("Done!")
    else:
       print("Parquet file is ready.")
    
    df_parquet = spark.read.parquet(parquet_file)
    df_parquet.createOrReplaceGlobalTempView(df.name)
    print("Parquet view is ready!")

In [0]:
##### Get drug concept id and name

In [0]:
def get_drug_concept_id_name(sheet_name,col_name):
    sql=f"select lower({col_name}) as {col_name} from {sheet_name} order by {col_name} asc;"
    drug_search_term = spark.sql(sql)
    drug_str=drug_search_term.agg(F.concat_ws(" or ",F.collect_list(F.concat(F.lit('lower(concept_name) LIKE "%'),F.col(col_name),F.lit('%"'))))).first()[0] 
    
    sql=f"select concept_id,concept_name from {concept_table} where {drug_str};"
    
    concept_id_name_tmp=spark.sql(sql)
    return concept_id_name_tmp

In [0]:
##### DB Summary

In [0]:
def db_summary():
  final_list = []
  dbList = spark.sql("show databases").select("databaseName").rdd.flatMap(lambda x: x).collect()
  for databaseName in dbList:
    spark.sql("use {}".format(databaseName))
    tableList = spark.sql("show tables from {}".format(databaseName)).select("tableName").rdd.flatMap(lambda x: x).collect()
    for tableName in tableList:
      tableCount = spark.sql("select count(*) as tableCount from {}".format(tableName)).collect()[0][0]
      final_list.append(list([databaseName,tableName,tableCount]))
  column_names = list(['DatabaseName','TableName','TableCount'])
  df = spark.createDataFrame(final_list,column_names)
  df=df.sort(df.DatabaseName.asc(),df.TableCount.desc())
  
  display(df)

In [0]:
#db_summary()

In [0]:
##### Get searah term list

In [0]:
def get_search_term(sheet_name):
    
    sql=f"""
       select lab,search_term from (
          select lab,lower(short_name) as search_term from {sheet_name}
          union
          select lab,lower(long_name) as search_term from {sheet_name}
       ) where search_term is not null
    """
    search_term_list=spark.sql(sql)
    return search_term_list

In [0]:
##### Rename column

In [0]:
 def rename_column(source_df,name_dict):
    for key, value in name_dict.items():
      source_df= source_df.withColumnRenamed(key,value)
    return source_df

In [0]:
##### Time Travel for different table versions

In [0]:
def version_diff(tracking_table):
    sql=f"""
        select version from (
           DESCRIBE HISTORY {tracking_table}
        ) limit 2;
        """
    version_info=spark.sql(sql)
    if (version_info.count()==1):
        print("This is the only version")
    else:
        row_list = version_info.collect()
        new_version=row_list[0][0]
        old_version=row_list[1][0]
        
        sql=f"""
              select * from {tracking_table}@v{new_version}
              except all
              select * from
              {tracking_table}@v{old_version}
        """
        
        
        diff_df=spark.sql(sql)
        print("Total diff record count:",diff_df.count())
        return diff_df