# Project: Performance of phenotype algorithms for the identification of opioid-exposed infants, Andrew D. Wiese et al. Hospital Pediatrics 2024
# Title: Project Set Up
# Summary: 
## This is called by other notebooks to import necessary libraries, define table names, and define common functions
# List of defined functions:
- Change Column name case (lower or upper)
- Dataframe inspection (unique patient and total records; local and global)
- EGA calcuation to get standaraized EGA days
- Get mom_baby records with code list
- Union dataframes
- Output df to parquet

In [0]:
pip install openpyxl

In [0]:
import warnings #ignore warnings
warnings.filterwarnings('ignore')
import pandas as pd # for saving results from SQL queries to pandas DFs
import numpy as np # for math operations
import re # for regular expression
import os
import shutil
import datetime as dt
import matplotlib.pyplot as plt
import matplotlib.ticker as tick


from datetime import datetime, timedelta
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
import pyspark.sql.functions as func
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import year
from pyspark.sql.functions import to_date
from pyspark.sql.types import *
from pyspark.sql.functions import countDistinct

spark.conf.set("spark.sql.shuffle.partitions", "auto")

inspect_df_flag = 'count'
output_overwrite = True

# Table definitions

In [0]:
### OMOP related tables
proc_table="rd_omop_prod.procedure_occurrence"
cond_table="rd_omop_prod.condition_occurrence"
obs_table="rd_omop_prod.observation"
person_table="rd_omop_prod.person"
drug_exp_table="rd_omop_prod.drug_exposure"
drug_exp_table_extra="rd_omop_prod.x_drug_exposure" # This is a VUMC specific table w/ additional drug exposure information
meas_table="rd_omop_prod.measurement"
note_table="rd_omop_prod.note"
visit_table="rd_omop_prod.visit_occurrence"
concept_table="rd_omop_prod.concept"
fact_real_table="rd_omop_prod.fact_relationship"

### Own defined table for drug exposure
drug_exp_table_extra="rd_omop_prod.x_drug_exposure"

### code lists
lbc_code_list = "phenotyping.mprint_live_birth_code_sheet_1"
critical4cpt_code_list = ['99468', '99469', '99291', '99292']
fetal_anomalies_code_list = "phenotyping.mprint_fetal_anomalies_code"
respiratory_code_list = ['96.04','96.70', '96.71', '96.72','93.90','0BH17EZ', '0BH18EZ','5A1935Z', '5A1945Z', 
   '5A1955Z','5A09357', '5A09457', '5A09557']
infant_tox_lab_list = "phenotyping.mprint_infant_tox_lab_name_sheet_4"
mom_drug_search_term_list = "phenotyping.mprint_mom_drug_search_term_v2"
nows_code_list = ['779.5','P96.1','760.72','P04.14']
mom_oud_code_list = "phenotyping.mprint_mom_oud_code_sheet_5"
opioid_toxicology_code_list = "phenotyping.mprint_sheet_14_mat_opioid_toxicology"

# Function Definitions

## Change Column name case (lower or upper)


In [0]:
def change_colname_case(df,case_type):
   for col in df.columns:
     if (case_type=='upper'):   
      df = df.withColumnRenamed(col, col.upper())
     elif (case_type=='lower'):   
      df = df.withColumnRenamed(col, col.lower()) 
   return df

## Dataframe inspection (counting unique patient and total records; can be used on local and global dataframes)

In [0]:
def df_inspection(df_name,sub_group):
    
    if (sub_group=='all'):
           sql = f"select count(*) AS total, count(distinct mom_person_id) AS unique_mom, count(distinct baby_person_id) AS unique_baby from {df_name};" 
    else:
           sql = f"select count(*) AS total ,count(distinct {sub_group}_person_id) AS unique_{sub_group} from {df_name};" 
            
    if (inspect_df_flag=='df' or inspect_df_flag=='both'): 
       df= spark.sql(f"select * from {df_name} limit 10;")
       display(df) 
        
    if (inspect_df_flag=='count' or inspect_df_flag=='both'): 
       df_info= spark.sql(sql)
       display(df_info) 
       

## EGA calcuation to get standaraized EGA date informattion

In [0]:
def cal_ega(SparkDF,ega_col_name):
   ega_df = SparkDF.toPandas() ### use pandas function to process (faster)
   ega_weeks=[]
   ega_days=[]
   standardized_ega=[]
   total_ega_days=[]

    
   for index, row in ega_df.iterrows(): ### processed all records
       ega=str(row[ega_col_name])
       ega_week=0
       ega_day=0
       found = re.findall('(\d+\.?\d*)', ega) ## 39.6 weeks, 39 weeks, 39 wks 4days. 39 weeks 1/7 -> [39,1,7]
       ### normal cases
       if (len(found)>=2 ):
          ega_week=float(found[0])
          ega_day=float(found[1])
       elif (len(found)==1 ):
         if ("." in found[0]): ### 39.6 weeks -> 39 weeks and 6 days
          ega_week=int(found[0].split(".")[0])
          if (found[0].split(".")[1].isdigit()):
             ega_day=int(found[0].split(".")[1])
         else: 
          ega_week=float(found[0])
    
       ### some exceptions
       if (ega_week>99): #303-> 30 weeks and 3 days. 4017-> 40 weeks 1 day. 40 1/7-> 4017. 
           fstr = repr(ega_week).split('.')[0]
           ega_week = float(fstr[:2])
           ega_day = float(fstr[2:3])
       elif (ega_day>7): # 33-36 weeks -> 33 weeks and 6 days ### Henry: pick mid-point
           ega_day=6
        
       ega_weeks.append(ega_week)
       ega_days.append(ega_day)
       total_ega_days.append((ega_week*7)+ega_day)
       if ega_week > 0 and ega_day > 0:
          standardized_ega.append(str(int(ega_week))+"w "+str(int(ega_day))+"d")
       elif ega_week > 0 and ega_day == 0: 
          standardized_ega.append(str(int(ega_week))+"w ")
       elif ega_week == 0 and ega_day == 0:
          standardized_ega.append(None)

   ega_df['ega_week']=ega_weeks
   ega_df['ega_day']=ega_days
   ega_df['total_ega_days']=total_ega_days
   ega_df['standardized_ega']=standardized_ega
   return spark.createDataFrame(ega_df)

## Get mom_baby records with code list

In [0]:
def get_table_records(table_name,code_list, wildcard_sel):
    if (table_name == proc_table):
        col_name = "procedure_source_value"
    elif (table_name == cond_table):
        col_name = "condition_source_value"
    elif (table_name == obs_table):
        col_name = "observation_source_value"

    if (wildcard_sel==1):
        code_str=' or '.join(map(lambda x: col_name+" like '" + x + "'", code_list))
        sql=f"select * from {table_name} where {code_str};"
    else:
        sql=f"select * from {table_name} where {col_name} in ({code_list});"
     
    table_output = spark.sql(sql)
    print("Total record:",table_output.count())
    return table_output

In [0]:
def mon_baby_records(term,lbc_df_name,subject_id):

    if (term=="condition"):
        date_str="condition_start_date"
        datetime_str="condition_start_date"
    elif (term=="procedure"):
        date_str="procedure_date"
        datetime_str="procedure_datetime"
    elif (term=="observation"):
        date_str="observation_date"
        datetime_str="observation_date"
    
    sql=f"""
         select FACT_ID_1 as mom_person_id,FACT_ID_2 as baby_person_id,
         person_source_value as baby_person_source_value,
         birth_datetime,gender_source_value as baby_gender,
         race_source_value as baby_race, a.person_id as code_person_id,
         {term}_source_value as code,{date_str} as code_date,{datetime_str} as code_datetime from
         {lbc_df_name} a inner join global_temp.mom_baby_step1 b
         
         on a.person_id = b.{subject_id}
         and date(birth_datetime) - 30 < {date_str}
         and date(birth_datetime) + 30 > {date_str}
       """
    
    df = spark.sql(sql)
    print("term:",term,",  record count:",df.count())
    return df

In [0]:
def mon_baby_records_code_list(term,subject_id,code_list):

    code_list_str=' , '.join(map(lambda x: "'" + x + "'", code_list))
    if (term=="condition"):
        date_str="condition_start_date"
        datetime_str="condition_start_date"
        table_source=cond_table
    elif (term=="procedure"):
        date_str="procedure_date"
        datetime_str="procedure_datetime"
        table_source=proc_table
    elif (term=="observation"):
        date_str="observation_date"
        datetime_str="observation_date"
        table_source=obs_table
        
    sql=f"""
         select FACT_ID_1 as mom_person_id,FACT_ID_2 as baby_person_id,
         person_source_value as baby_person_source_value,
         birth_datetime,gender_source_value as baby_gender,
         race_source_value as baby_race, a.person_id as code_person_id,
         {term}_source_value as code,{date_str} as code_date,{datetime_str} as code_datetime from
         
         (select * from {table_source} where {term}_source_value in ({code_list_str})) a 
         
         inner join global_temp.mom_baby_step1 b
         on a.person_id = b.{subject_id}
        """
    df = spark.sql(sql)
    print("term:",term,", record count:",df.count())
    return df

## Union dataframes

In [0]:
def union_dataframes(df_list):
   # create merged dataframe
   df_complete = reduce(DataFrame.unionAll, df_list)
   return df_complete.distinct()

In [0]:
##### Output df to parquet

In [0]:
 def file_exists(path):
  try:
    dbutils.fs.ls(path)
    return True
  except Exception as e:
    if 'java.io.FileNotFoundException' in str(e):
      return False
    else:
      raise

In [0]:
#delete parquet file
#shutil.rmtree('/dbfs/FileStore/tables/mom_drug_info_tmp.parquet')

def register_parquet_global_view(df):
    output_name = df.name
    parquet_file=f"dbfs:/FileStore/tables/{output_name}.parquet"
    
    if (output_overwrite or (not file_exists(parquet_file))):
       print("Start producing/replacing parquet file/view...") 
       df.write.mode('overwrite').parquet(parquet_file)
       print("Done!")
    else:
       print("Parquet file is ready.")
    
    df_parquet = spark.read.parquet(parquet_file)
    df_parquet.createOrReplaceGlobalTempView(df.name)
    print("Parquet view is ready!")