##### Project: Opioid Exposed Infant Covariates
##### Investigator: Stephen Patrick, Sarah Loch
##### Programmers: Sander Su, Chris Guardo
##### Date Created: 01/17/23
##### Last Modified: 09/30/25


##### Summary: 

- Variable: 10. Mental health diagnoses (anxiety, depression, bipolar, schizophrenia, psychosis, mood disorders, personality disorder)
- timeframe: 1) prenatal or 2) postpartum
- cohort: including maternal (ip/op)
- ICD list: Sheet 13 mat mental health dx

In [0]:
%run "../Project_modules"

In [0]:
spark_session = SparkSession.builder.appName(
    'Spark_Session').getOrCreate()    


In [0]:
phenotype_table_location = " ***Insert file location*** "
phenotype_table=spark.sql(f"SELECT * FROM {phenotype_table_location}")
sheet_name="phenotyping.sheet_13_mat_mental_health_dx"

phenotype_cohort=get_phenotype_cohort(phenotype_table)
phenotype_cohort.createOrReplaceTempView("phenotype_cohort")
phenotype_cohort=change_colname_case(phenotype_cohort,"lower")
df_inspection("phenotype_cohort","all")

##### In the condition table, to catch up the ICDs

In [0]:
sql=f"""
     select trim(DX_group) as DX_group, regexp_replace(disease_code,'xx','') as disease_code from {sheet_name} where trim(disease_code) like '%xx';
     """
disease_partial = spark.sql(sql)
disease_partial_str=disease_partial.agg(F.concat_ws(" or ",F.collect_list(F.concat(F.lit('condition_source_value like "'),F.col('disease_code'),F.lit('%"'))))).first()[0]
disease_partial.createOrReplaceTempView("disease_partial")

sql=f"select trim(DX_group) as DX_group,trim(disease_code) as disease_code from {sheet_name} where trim(disease_code)  not like '%xx';"
disease_exact = spark.sql(sql)
disease_exact.createOrReplaceTempView("disease_exact")

In [0]:
def update_df(source_df): #change column names and reorder columns
    
    name_dict = {'condition_start_date': 'code_date', 'condition_source_value': 'icd_code','condition_start_datetime':'code_datetime'}
    source_df=rename_column(source_df,name_dict) 
    source_df=source_df.select('mom_person_id','baby_person_id','DX_group','icd_code',
                               'code_date','code_datetime','visit_occurrence_id','visit_concept_id')
    return source_df

##### Timeframe: prenatal

In [0]:
sql=f"""
    
    select * from 
     (
     select a.*, b.concept_name from
     (select * from {cond_table} where (condition_source_value in (select disease_code from disease_exact) or ({disease_partial_str}) ) ) a
     left join {concept_table} b
     on a.condition_source_concept_id = b.concept_id
     ) c
   
     inner join  
     (select mom_person_id, baby_person_id,baby_dob,start_gestation_date,length_of_gestation from 
     global_temp.ega_w33_or_uncertain_gestation_date) d
   
     on c.person_id = d.mom_person_id 
     where condition_start_date >= start_gestation_date and condition_start_date <  date(baby_dob); --baby_dob=delivery date

    """

mental_health_dx_prenatal= spark.sql(sql)
mental_health_dx_prenatal=change_colname_case(mental_health_dx_prenatal,"lower")
mental_health_dx_prenatal.createOrReplaceTempView("mental_health_dx_prenatal")

##### Used to get matchecd mental dx cohorts and phenotype maternal cohorts

In [0]:
def get_matched_mentaldx_phenotype_cohorts(mental_health_dx_cohort):
    
   ### join with phenotype cohort
   tmp_cohort=mental_health_dx_cohort.join(phenotype_cohort,(mental_health_dx_cohort.person_id == phenotype_cohort.mom_person_id) 
                                             &(mental_health_dx_cohort.baby_person_id== phenotype_cohort.baby_person_id))\
                                          .drop(phenotype_cohort.mom_person_id).drop(phenotype_cohort.baby_person_id)

   ### exact code match
   mental_health_dx_phenotype_cohort_exact=tmp_cohort.join(disease_exact,tmp_cohort.condition_source_value 
                                                              ==disease_exact.disease_code).withColumn("mom_mrn",tmp_cohort.mom_person_source_value)

   ### partial code match
   mental_health_dx_phenotype_cohort_partial=tmp_cohort.join(disease_partial,tmp_cohort.condition_source_value.contains(disease_partial.disease_code))\
                                            .withColumn("mom_mrn",tmp_cohort.mom_person_source_value)

   ### combine both cohorts                                                      
   mental_health_dx_phenotype_cohort=union_dataframes([mental_health_dx_phenotype_cohort_exact,mental_health_dx_phenotype_cohort_partial])
    
   return mental_health_dx_phenotype_cohort



In [0]:
mental_health_dx_prenatal_and_cohort_tmp=get_matched_mentaldx_phenotype_cohorts(mental_health_dx_prenatal)
### get visit information
visit_df=spark.sql(f"select visit_occurrence_id,visit_concept_id,visit_start_datetime from {visit_table}")
mental_health_dx_prenatal_and_cohort = mental_health_dx_prenatal_and_cohort_tmp.join(visit_df, on=['visit_occurrence_id'] , how = 'left')

mental_health_dx_prenatal_and_cohort=update_df(mental_health_dx_prenatal_and_cohort)
mental_health_dx_prenatal_and_cohort=mental_health_dx_prenatal_and_cohort.withColumn("period", lit('prenatal')) 
mental_health_dx_prenatal_and_cohort.name='mental_health_dx_prenatal_and_cohort'
register_parquet_global_view(mental_health_dx_prenatal_and_cohort)

##### Validation

In [0]:
df_inspection("global_temp.mental_health_dx_prenatal_and_cohort","all")

##### Timeframe: postpartum

In [0]:
sql=f"""

     select mom_person_source_value as mom_mrn,* from
     (
      select a.*, b.concept_name from
      (select * from {cond_table} where (condition_source_value in (select disease_code from disease_exact) or ({disease_partial_str}))) a
      left join {concept_table} b
      on a.condition_source_concept_id = b.concept_id
     ) c
     
     inner join 
     
     (select * from phenotype_cohort) d
     on c.person_id = d.mom_person_id 
     
     where condition_start_date >= date(baby_birth_datetime) and condition_start_date <=  date_add(baby_birth_datetime, 365) ;
     
    """


tmp_cohort= spark.sql(sql)

### partial code match
mental_health_dx_postpartum_and_cohort_partial=tmp_cohort.join(disease_partial,tmp_cohort.condition_source_value.contains(disease_partial.disease_code))

### exact code match
mental_health_dx_postpartum_and_cohort_exact=tmp_cohort.join(disease_exact,tmp_cohort.condition_source_value==disease_exact.disease_code)

### combine both cohorts                                                      
mental_health_dx_postpartum_and_cohort_tmp=union_dataframes([mental_health_dx_postpartum_and_cohort_exact,mental_health_dx_postpartum_and_cohort_partial])
### get visit information
mental_health_dx_postpartum_and_cohort = mental_health_dx_postpartum_and_cohort_tmp.join(visit_df, on=['visit_occurrence_id'] , how = 'left')


mental_health_dx_postpartum_and_cohort = update_df(mental_health_dx_postpartum_and_cohort)
mental_health_dx_postpartum_and_cohort =mental_health_dx_postpartum_and_cohort.withColumn("period", lit('postpartum')) 
mental_health_dx_postpartum_and_cohort.name='mental_health_dx_postpartum_and_cohort'
register_parquet_global_view(mental_health_dx_postpartum_and_cohort)

In [0]:
%sql
select * from global_temp.mental_health_dx_postpartum_and_cohort

##### Validation

In [0]:
# %sql
# select * from global_temp.mental_health_dx_postpartum_and_cohort;

In [0]:
df_inspection("global_temp.mental_health_dx_postpartum_and_cohort","all")

##### Combine prenatll and postpartum data

In [0]:
sql="""
        select * from global_temp.mental_health_dx_prenatal_and_cohort
        union
        select * from global_temp.mental_health_dx_postpartum_and_cohort
    """
combined_cohort=spark.sql(sql)
combined_cohort.createOrReplaceTempView("combined_cohort")
display(combined_cohort)

In [0]:
df_inspection("combined_cohort","all")

##### Used to get IP(9201) / OP(9202) cohorts

In [0]:
def ip_op_cohort(cohort_name,ip_op_code):
   if (ip_op_code=='9201'):
    count=0
   elif (ip_op_code=='9202'):
    count=1
   
   sql=f"""
        select mom_person_id,baby_person_id,DX_group from (
        
         select mom_person_id,baby_person_id,DX_group,count(*) from (
          select * from {cohort_name}  where visit_concept_id = {ip_op_code}
         )
         
         group by mom_person_id,baby_person_id,DX_group
         having count(*) > {count}
         order by mom_person_id,baby_person_id
        
        )
       """
   
   ip_op_df=spark.sql(sql)
   return ip_op_df 

In [0]:
def ip_op_cohort_detail(cohort_name,ip_op_code,mom_baby_df_name):
    
   sql=f"""
        select mom_person_id,baby_person_id,DX_group,icd_code,code_datetime,
        case when visit_concept_id= 9201 then 'IP' else 'OP' end as IP_OP,period
        from {cohort_name} a 
        where visit_concept_id = {ip_op_code}
        and CONCAT(a.mom_person_id,a.baby_person_id,a.DX_group) in
        (
          select CONCAT(b.mom_person_id,b.baby_person_id,b.DX_group) from {mom_baby_df_name} b
        ) 
       """
   
   ip_op_detail_df=spark.sql(sql)
   return ip_op_detail_df

##### select mom and baby pairs from previous steps, group by id and having count(*) 
##### 1 ICD code count for inpatient (> 0) 
##### 2 ICD code count for outpatient ( > 1)

In [0]:
mom_mental_prenatal_ip1=ip_op_cohort('global_temp.mental_health_dx_prenatal_and_cohort','9201')
mom_mental_prenatal_op2=ip_op_cohort('global_temp.mental_health_dx_prenatal_and_cohort','9202')

### 1 inpatient OR (any) 2 outpatient ICD-9/ICD-10 codes in prenatal period
mom_mental_prenatal_ip1_op2=union_dataframes([mom_mental_prenatal_ip1,mom_mental_prenatal_op2]) 
mom_mental_prenatal_ip1_op2.createOrReplaceTempView("mom_mental_prenatal_ip1_op2")

mom_mental_postpartum_ip1=ip_op_cohort('global_temp.mental_health_dx_postpartum_and_cohort','9201')
mom_mental_postpartum_op2=ip_op_cohort('global_temp.mental_health_dx_postpartum_and_cohort','9202')

### 1 inpatient OR (any) 2 outpatient ICD-9/ICD-10 codes in postpartum period
mom_mental_postpartum_ip1_op2=union_dataframes([mom_mental_postpartum_ip1,mom_mental_postpartum_op2])
mom_mental_postpartum_ip1_op2.createOrReplaceTempView("mom_mental_postpartum_ip1_op2")

##### 1 inpatient OR (any) 2 outpatient ICD-9/ICD-10 codes in the prenatal or 1-year postpartum period 

In [0]:
mom_mental_ip1_op2=union_dataframes([mom_mental_prenatal_ip1_op2,mom_mental_postpartum_ip1_op2])
mom_mental_ip1_op2.createOrReplaceTempView("mom_mental_ip1_op2")
df_inspection("mom_mental_ip1_op2","all")


##### Get the ID that without visit_occurrence_id and the ID did not in the ip1/op2 groups

In [0]:
%sql
select * from combined_cohort

In [0]:
sql="""
       select a.mom_person_id,a.baby_person_id from (
         select mom_person_id,baby_person_id from combined_cohort where visit_occurrence_id is null
       ) a 
       
       where CONCAT(a.mom_person_id,a.baby_person_id) not in
       (
         select CONCAT(b.mom_person_id,b.baby_person_id) from mom_mental_ip1_op2 b
       ) 
    """

mom_mental_no_visit_id_pairs= spark.sql(sql)
mom_mental_no_visit_id_pairs.createOrReplaceTempView("mom_mental_no_visit_id_pairs")
#df_inspection("mom_mental_no_visit_id_pairs","all")

##### Find records in condition tables but without visit_occurrence_id info. Then use other information (e.g., condition_start_date,visit_start_date,visit_end_date and visit_concept_id.) to find supplement records around time periods

##### create view that there is no visit_occurrence_id information

In [0]:
sql="""
       select a.* from combined_cohort a
       inner join mom_mental_no_visit_id_pairs b
       on a.mom_person_id = b.mom_person_id and 
       a.baby_person_id = b.baby_person_id
       where visit_occurrence_id is null
       order by mom_person_id,baby_person_id;
      
    """
mom_mental_no_visitoccurrence_id= spark.sql(sql)
mom_mental_no_visitoccurrence_id.createOrReplaceTempView("mom_mental_no_visitoccurrence_id")
display(mom_mental_no_visitoccurrence_id)


##### condition1:  had visit_end_date


In [0]:
sql=f"""
       select distinct a.mom_person_id,a.baby_person_id,a.DX_group,a.icd_code,
       a.code_date,a.code_datetime,a.period,b.visit_occurrence_id,b.visit_concept_id
       
       from mom_mental_no_visitoccurrence_id a
       
       inner join {visit_table} b 
       on a.mom_person_id = b.person_id and visit_end_date is not null
       and visit_start_date <= code_date and code_date < visit_end_date
       order by mom_person_id,baby_person_id;

    """
mom_mental_no_visitoccurrence_id_cond1= spark.sql(sql)
mom_mental_no_visitoccurrence_id_cond1.createOrReplaceTempView("mom_mental_no_visitoccurrence_id_cond1")
display(mom_mental_no_visitoccurrence_id_cond1)

In [0]:
%sql
select * from rd_omop_prod.visit_occurrence where visit_occurrence_id=118040762

##### condition2:  did not have visit_end_date (use visit_start_date+30), and visit concept is '9201' (inpatient)

In [0]:
sql=f"""
       select distinct a.mom_person_id,a.baby_person_id,a.DX_group,a.icd_code,
       a.code_date,a.code_datetime,a.period,b.visit_occurrence_id,b.visit_concept_id
       
       from mom_mental_no_visitoccurrence_id a
       
       left join {visit_table} b 
       on a.mom_person_id = b.person_id and visit_end_date is null
       where b.visit_concept_id = '9201'
       and visit_start_date <= code_date and code_date < date_add(visit_start_date, 30)
       order by mom_person_id,baby_person_id;

    """
mom_mental_no_visitoccurrence_id_cond2= spark.sql(sql)
mom_mental_no_visitoccurrence_id_cond2.createOrReplaceTempView("mom_mental_no_visitoccurrence_id_cond2")
display(mom_mental_no_visitoccurrence_id_cond2)

In [0]:
%sql
select * from rd_omop_prod.visit_occurrence where visit_occurrence_id=239111954

##### condition3: did not have visit_end_date (use visit_start_date+3) , and visit concept is '9202' (outpatient)

In [0]:
sql=f"""
       select distinct a.mom_person_id,a.baby_person_id,a.DX_group,a.icd_code,
       a.code_date,a.code_datetime,a.period,b.visit_occurrence_id,b.visit_concept_id
       
       from mom_mental_no_visitoccurrence_id a
       
       left join {visit_table} b 
       on a.mom_person_id = b.person_id and visit_end_date is null
       where b.visit_concept_id = '9202'
       and visit_start_date <= code_date and code_date < date_add(visit_start_date, 3)
       order by mom_person_id,baby_person_id;

    """
mom_mental_no_visitoccurrence_id_cond3= spark.sql(sql)
mom_mental_no_visitoccurrence_id_cond3.createOrReplaceTempView("mom_mental_no_visitoccurrence_id_cond3")
display(mom_mental_no_visitoccurrence_id_cond3)

In [0]:
%sql
select * from rd_omop_prod.visit_occurrence where visit_occurrence_id=248501435

##### Merge combined_cohort and cohorts from all 3 conditions

In [0]:
#select columns for merge

sql="""
       select mom_person_id,baby_person_id,DX_group,icd_code,code_date,code_datetime,visit_occurrence_id,visit_concept_id,
       period from combined_cohort
       union 
       select mom_person_id,baby_person_id,DX_group,icd_code,code_date,code_datetime,visit_occurrence_id,visit_concept_id,period 
       from mom_mental_no_visitoccurrence_id_cond1
       union 
       select mom_person_id,baby_person_id,DX_group,icd_code,code_date,code_datetime,visit_occurrence_id,visit_concept_id,period 
       from mom_mental_no_visitoccurrence_id_cond2
       union 
       select mom_person_id,baby_person_id,DX_group,icd_code,code_date,code_datetime,visit_occurrence_id,visit_concept_id,period 
       from mom_mental_no_visitoccurrence_id_cond3;

    """
mom_mental_update= spark.sql(sql)
mom_mental_update.createOrReplaceTempView("mom_mental_update")
display(mom_mental_update)

In [0]:
df_inspection("mom_mental_update","all")

##### Redefine 1x inpatient and 2x outptient

In [0]:
updated_prenatal_df=spark.sql("select * from mom_mental_update where period='prenatal'")
updated_prenatal_df.createOrReplaceTempView("updated_prenatal_df")
                      
mom_mental_prenatal_ip1_update=ip_op_cohort('updated_prenatal_df','9201')
mom_mental_prenatal_op2_update=ip_op_cohort('updated_prenatal_df','9202')

mom_mental_prenatal_ip1_update.createOrReplaceTempView("mom_mental_prenatal_ip1_update")
mom_mental_prenatal_op2_update.createOrReplaceTempView("mom_mental_prenatal_op2_update")

### used to get raw data in prenatal period
mom_mental_prenatal_ip1_update_detail=ip_op_cohort_detail('updated_prenatal_df','9201',"mom_mental_prenatal_ip1_update")
mom_mental_prenatal_op2_update_detail=ip_op_cohort_detail('updated_prenatal_df','9202',"mom_mental_prenatal_op2_update")
mom_mental_prenatal_ip1_op2_update_detail=union_dataframes([mom_mental_prenatal_ip1_update_detail,mom_mental_prenatal_op2_update_detail])
mom_mental_prenatal_ip1_op2_update_detail.createOrReplaceTempView("mom_mental_prenatal_ip1_op2_update_detail")

### 1 inpatient OR (any) 2 distinct outpatient ICD-9/ICD-10 codes in prenatal period
mom_mental_prenatal_ip1_op2_update=union_dataframes([mom_mental_prenatal_ip1_update,mom_mental_prenatal_op2_update]) 
mom_mental_prenatal_ip1_op2_update.createOrReplaceTempView("mom_mental_prenatal_ip1_op2_update")

updated_postpartum_df=spark.sql("select * from mom_mental_update where period='postpartum'")
updated_postpartum_df.createOrReplaceTempView("updated_postpartum_df")
                                
mom_mental_postpartum_ip1_update=ip_op_cohort('updated_postpartum_df','9201')
mom_mental_postpartum_op2_update=ip_op_cohort('updated_postpartum_df','9202')

### 1 inpatient OR (any) 2 distinct outpatient ICD-9/ICD-10 codes in postpartum period
mom_mental_postpartum_ip1_op2_update=union_dataframes([mom_mental_postpartum_ip1_update,mom_mental_postpartum_op2_update])
mom_mental_postpartum_ip1_op2_update.createOrReplaceTempView("mom_mental_postpartum_ip1_op2_update")

### used to get raw data in postpartum period
mom_mental_postpartum_ip1_update.createOrReplaceTempView("mom_mental_postpartum_ip1_update")
mom_mental_postpartum_op2_update.createOrReplaceTempView("mom_mental_postpartum_op2_update")
mom_mental_postpartum_ip1_update_detail=ip_op_cohort_detail('updated_postpartum_df','9201',"mom_mental_postpartum_ip1_update")
mom_mental_postpartum_op2_update_detail=ip_op_cohort_detail('updated_postpartum_df','9202',"mom_mental_postpartum_op2_update")


mom_mental_postpartum_ip1_op2_update_detail=union_dataframes([mom_mental_postpartum_ip1_update_detail,mom_mental_postpartum_op2_update_detail])
mom_mental_postpartum_ip1_op2_update_detail.createOrReplaceTempView("mom_mental_postpartum_ip1_op2_update_detail")

### combine raw data from both periods
mom_mental_ip1_op2_update_detail=union_dataframes([mom_mental_prenatal_ip1_op2_update_detail,mom_mental_postpartum_ip1_op2_update_detail])
mom_mental_ip1_op2_update_detail.createOrReplaceTempView("mom_mental_ip1_op2_update_detail")

mom_mental_ip1_op2_update_detail.name='mom_mental_ip1_op2_update_detail'
register_parquet_global_view(mom_mental_ip1_op2_update_detail)

In [0]:
sql="""
       select mom_person_source_value as MOM_MRN,baby_person_source_value as BABY_MRN,
       DX_group,icd_code,code_datetime,IP_OP,period 
       from phenotype_cohort a inner join mom_mental_ip1_op2_update_detail b
       on a.mom_person_id=b.mom_person_id and a.baby_person_id=b.baby_person_id
       order by MOM_MRN,BABY_MRN;
    """
mat_mentl_raw_output=spark.sql(sql)
mat_mentl_raw_output.name='mat_mentl_raw_output'
register_parquet_global_view(mat_mentl_raw_output)
display(mat_mentl_raw_output)
#mat_mentl_raw_output.write.mode("overwrite").saveAsTable("phenotype_output.cover_mat_mental_raw_data")

### Save Output for future use

In [0]:
mat_mentl_raw_output.write.mode("overwrite").saveAsTable("covariate_output.cover_mat_mental_raw_data")