**Table definitions and load packages**

In [None]:
drug_table = "victr_sd.sd_omop_prod.drug_exposure"
cond_table = "victr_sd.sd_omop_prod.condition_occurrence"
visit_table = "victr_sd.sd_omop_prod.visit_occurrence"
person_table ="victr_sd.sd_omop_prod.person"
icd_table = "workspace_sdphenotypecore.statin_mental_conditions.icd_condition"
control_cohort = "workspace_sdphenotypecore.statin_mental_conditions.control_patients_1"
case_cohort = "workspace_sdphenotypecore.statin_mental_conditions.matched_cases_1"
statin_list = ["statin","atorvastatin","simvastatin","rosuvastatin","pitavastatin","fluvastatin","lovastatin","pravastatin","lipitor","zocor","crestor","livalo","lescol","mevacor","pravachol"]

In [None]:
#import python packages
from pyspark.sql.functions import expr
from pyspark.sql.functions import when, col
from pyspark.sql.functions import min, max
import numpy as np
import pandas as pd

In [None]:
%r
install.packages("MatchIt")

In [None]:
%r
#import R packages
library(SparkR)
library(sparklyr)
library(dplyr)
library(MatchIt)

# 1. Create database

In [None]:
%sql
--create database statin_mental_conditions--

# 2. Search for patients exposed statins in drug_exposure table

In [None]:

sql_drug = f'''
SELECT 
    drug.person_id,
    MIN(drug.drug_exposure_start_date) AS drug_date,
    person.gender_source_value AS gender,
    person.race_source_value AS race,
    DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) AS age,
    DATEDIFF(year, MIN(visit.visit_start_date), MAX(visit.visit_start_date)) AS ehr_length
FROM  {drug_table} AS drug 
INNER JOIN  {visit_table} AS visit ON visit.person_id = drug.person_id
INNER JOIN {person_table} AS person ON person.person_id = drug.person_id
WHERE LOWER(drug_source_value) IN ({", ".join(["'"+s.lower()+"'" for s in statin_list])}) AND 
      person.gender_source_value IN ("F","M") AND person.race_source_value IN ("W","B")
GROUP BY  
    drug.person_id, person.gender_source_value, person.race_source_value, person.birth_datetime
HAVING 
    MIN(drug.drug_exposure_start_date) = (
        SELECT MIN(drug_exposure_start_date)
        FROM {drug_table}
        WHERE person_id = drug.person_id
        AND LOWER(drug_source_value) IN ({", ".join(["'"+s.lower()+"'" for s in statin_list])})
    )
    AND DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) >= 18    
'''
df_drug = spark.sql(sql_drug)
df_drug.display()

In [None]:
#the number of patients exposed to statin
distinct_statin_count = df_drug.select("person_id").distinct().count()
print("Number of distinct depressed person_id :", distinct_statin_count)

In [None]:
#the number of record for statin exposure
row_count = df_drug.count()

print(f"Number of rows in df_drug: {row_count}")

In [None]:
min_max_age = df_drug.agg(
    min("age").alias("min_age"),
    max("age").alias("max_age")
)

# Show the result
min_max_age.show()

# 3. Search for ICD codes in condition_occurence table

In [None]:
#icd code from condition table appeared in icd_condition


sql_icd = f'''
SELECT 
    ICD.person_id,
    ICD.ICD_date,
    person.gender_source_value AS gender,
    person.race_source_value AS race,
    DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) AS age,
    DATEDIFF(year, MIN(visit.visit_start_date), MAX(visit.visit_start_date)) AS ehr_length
FROM  ( SELECT person_id, MIN(ICD_date) AS ICD_date
        FROM (
            SELECT 
                condition.person_id AS person_id, 
                condition.condition_source_value AS ICD_code, 
                MIN(condition.condition_start_date) AS ICD_date
            FROM 
                {cond_table} AS condition
            INNER JOIN 
                {icd_table} AS icd
            ON 
                condition.condition_source_value = icd.code
            GROUP BY
                condition.person_id, condition.condition_source_value
           ) AS subquery
       GROUP BY person_id) AS ICD
INNER JOIN  {visit_table} AS visit ON visit.person_id = ICD.person_id
INNER JOIN {person_table} AS person ON person.person_id = ICD.person_id
WHERE person.gender_source_value IN ("F","M") AND person.race_source_value IN ("W","B")
GROUP BY  
    ICD.person_id, ICD.ICD_date, person.gender_source_value, person.race_source_value, person.birth_datetime
HAVING DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) >= 18
'''

df_icd = spark.sql(sql_icd)
df_icd.display()

In [None]:
distinct_icd_count = df_icd.select("person_id").distinct().count()
print("Number of distinct ICD person_id :", distinct_icd_count)

In [None]:
#the number of record for statin exposure
row_count = df_icd.count()

print(f"Number of rows in df_drug: {row_count}")

In [None]:
min_max_age = df_icd.agg(
    min("age").alias("min_age"),
    max("age").alias("max_age")
)

# Show the result
min_max_age.show()

# 4. find the case population for patients 

In [None]:
df_depress = df_drug.join(df_icd, df_drug.person_id == df_icd.person_id, "inner")

columns = ["person_id","gender","race","age","ehr_length"]

result_df = df_drug.select(*columns)

result_df.display()

In [None]:
#case population
distinct_person_count = result_df.select("person_id").distinct().count()

print("Number of distinct person_id :", distinct_person_count)

In [None]:
#the number of record in case
row_count = result_df.count()

print(f"Number of rows in result_df: {row_count}")

In [None]:
df_depress.display()

In [None]:
#the number of depressed patients in case population
distinct_depress_count = df_depress.select("drug.person_id").distinct().count()
print("Number of distinct depressed person_id :", distinct_depress_count)

In [None]:
#the number of record in case
row_count = df_depress.count()

print(f"Number of rows in df_drug: {row_count}")

In [None]:
#depress rate
depress_rate = distinct_depress_count/distinct_person_count

print(f"Depress rate in Case: {depress_rate}")

# 5. write result_df to table

In [None]:
result_df.write.saveAsTable(f"statin_mental_conditions.case_cohort", format="parquet", mode="overwrite")

# 6. quality check

In [None]:
%sql
---see the minimum age---
SELECT min(age), max(age)
FROM statin_mental_conditions.case_cohort


In [None]:
%sql
use catalog `hive_metastore`; select * from `statin_mental_conditions`.`case_cohort` 

In [None]:
%sql
---see the gender distribution---
SELECT gender, count(gender)
FROM statin_mental_conditions.case_cohort
GROUP BY gender

In [None]:
%sql
---see the race distribution---
SELECT race, count(race)
FROM statin_mental_conditions.case_cohort
GROUP BY race
ORDER BY count(race) DESC

# 7. Get cohort of patients who are not exposed to statin drugs (to sample control)

In [None]:
sql_cohort = f'''
    SELECT 
        visit.person_id,
        person.gender_source_value AS gender,
        person.race_source_value AS race,
        DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) AS age,
        DATEDIFF(year, MIN(visit.visit_start_date), MAX(visit.visit_start_date)) AS ehr_length
    FROM  {visit_table} AS visit 
    INNER JOIN {person_table} AS person ON person.person_id = visit.person_id
    WHERE visit.person_id not in (select person_id from `statin_mental_conditions`.`case_cohort`)
          AND person.gender_source_value IN ("F","M") AND person.race_source_value IN ("W","B")
    GROUP BY
        visit.person_id, person.gender_source_value, person.race_source_value, person.birth_datetime
    HAVING 
        DATEDIFF(year, person.birth_datetime, MAX(visit.visit_start_date)) >= 18
'''
df_cohort = spark.sql(sql_cohort)
df_cohort.display()

In [None]:
#the number of depressed patients in control cohort
distinct_cohort_count = df_cohort.select("person_id").distinct().count()
print("Number of distinct cohort person_id :", distinct_cohort_count)


In [None]:
#the number of record in case
row_count = df_cohort.count()

print(f"Number of rows in df_drug: {row_count}")

In [None]:
min_max_age = df_drug.agg(
    min("age").alias("min_age"),
    max("age").alias("max_age")
)

# Show the result
min_max_age.show()

In [None]:
df_cohort.write.saveAsTable(f"statin_mental_conditions.control_cohort", format="parquet", mode="overwrite")

In [None]:
%sql
---see the gender distribution---
SELECT gender, count(gender)
FROM statin_mental_conditions.control_cohort
GROUP BY gender

In [None]:
%sql
---see the race distribution---
SELECT race, count(race)
FROM statin_mental_conditions.control_cohort
GROUP BY race
ORDER BY count(race) DESC

# 8. Using MatchIt to sample control based on case

In [None]:
%r
sc <- spark_connect(method = "databricks")

In [None]:
%r
case <- as.data.frame(collect(sdf_sql(sc, "SELECT * FROM statin_mental_conditions.case_cohort")))
case

In [None]:
%r
cohort <- as.data.frame(collect(sdf_sql(sc, "SELECT * FROM statin_mental_conditions.control_cohort")))
cohort

In [None]:
%r
case$treatment <- 1
cohort$treatment <- 0

In [None]:
%r
combined_data <- rbind(case, cohort)

In [None]:
%r
# Define the formula for the covariates you want to use for matching
formula <- as.formula("treatment ~ age + gender + race + ehr_length")

# Perform nearest neighbor matching with a 1:1 match ratio
matched_data <- matchit(formula, data = combined_data, method = "nearest", ratio = 1)

# Extract the matched data
matched_data <- match.data(matched_data)


In [None]:
%r
matched_data

In [None]:
%r
nrow(matched_data)

In [None]:
%r
summary(matched_data)

In [None]:
%r
matched_controls <- matched_data %>%
  filter(treatment == 0) # Select controls

nrow(matched_controls)

In [None]:
%r
matched_controls

In [None]:
%r
matched_cases <- matched_data %>%
  filter(treatment == 1) # Select controls

nrow(matched_cases)

In [None]:
%r
matched_cases

In [None]:
%r
any(duplicated(matched_controls$person_id))

In [None]:
%r
# Assuming df is your dataframe
matched_controls <- matched_controls[, c("person_id", "gender", "race", "age", "ehr_length")]
matched_controls

In [None]:
%r
sdf <- copy_to(sc, matched_controls)
spark_write_table(sdf, name = "statin_mental_conditions.control_patients_1", mode = "overwrite")

In [None]:
%r
matched_cases <- matched_cases[, c("person_id", "gender", "race", "age", "ehr_length")]
matched_cases

In [None]:
%r
sdf_case <- copy_to(sc, matched_cases)
spark_write_table(sdf_case, name = "statin_mental_conditions.matched_cases_1", mode = "overwrite")

In [None]:
%sql
SELECT count(*), count(DISTINCT control.person_id)
FROM statin_mental_conditions.control_patients_1 as control
INNER JOIN statin_mental_conditions.control_cohort as cohort
ON cohort.person_id = control.person_id

In [None]:
sql_control_depress = f'''
  SELECT control.person_id, control.gender, control.race,  control.age,  control.ehr_length
  FROM (
        SELECT 
            condition.person_id AS person_id, 
            condition.condition_source_value AS ICD_code
        FROM 
            {cond_table} AS condition
        INNER JOIN 
            {icd_table} AS icd
        ON 
            condition.condition_source_value = icd.code) As icd
  INNER JOIN {control_cohort} As control
  ON control.person_id = icd.person_id
'''
df_control_depress = spark.sql(sql_control_depress)
df_control_depress.display()

In [None]:
#the number of depressed patients in control 
distinct_con_depress_count = df_control_depress.select("person_id").distinct().count()
print("Number of distinct control depress person_id :", distinct_con_depress_count)

In [None]:
#the number of record in case
row_count = df_control_depress.count()

print(f"Number of rows in df_control_depress: {row_count}")

In [None]:
sql_control = f'''
  SELECT *
  FROM {control_cohort}
'''
df_control = spark.sql(sql_control)
df_control.display()

In [None]:
#the number patients in control 
distinct_con_count = df_control.select("person_id").distinct().count()
print("Number of distinct control person_id :", distinct_con_count)

In [None]:
#the number of record in case
row_count = df_control.count()

print(f"Number of rows in df_control: {row_count}")

In [None]:
#depress rate
depress_rate_control = distinct_con_depress_count/distinct_con_count

print(f"Depress rate in control: {depress_rate_control}")

In [None]:
sql_case_depress = f'''
  SELECT case.person_id, case.gender, case.race, case.age, case.ehr_length
  FROM (
        SELECT 
            condition.person_id AS person_id, 
            condition.condition_source_value AS ICD_code
        FROM 
            {cond_table} AS condition
        INNER JOIN 
            {icd_table} AS icd
        ON 
            condition.condition_source_value = icd.code) As icd
  INNER JOIN {case_cohort} As case
  ON case.person_id = icd.person_id
'''
df_case_depress = spark.sql(sql_case_depress)
df_case_depress.display()

In [None]:
sql_case_depress_before = f'''
SELECT 
    case.person_id, 
    case.gender, 
    case.race, 
    case.age, 
    case.ehr_length,
    depress.ICD_date,
    statin.drug_date
FROM (
    SELECT 
        ICD.person_id,
        ICD.ICD_date
    FROM (
        SELECT 
            condition.person_id AS person_id, 
            MIN(condition.condition_start_date) AS ICD_date
        FROM 
            {cond_table} AS condition
        INNER JOIN 
            {icd_table} AS icd
            ON condition.condition_source_value = icd.code
        GROUP BY
            condition.person_id
        ) AS ICD
) AS depress
INNER JOIN (
    SELECT 
        drug.person_id,
        MIN(drug.drug_exposure_start_date) AS drug_date
    FROM {drug_table} AS drug 
    WHERE LOWER(drug.drug_source_value) IN ({", ".join(["'"+s.lower()+"'" for s in statin_list])})
    GROUP BY  
        drug.person_id
    HAVING 
        MIN(drug.drug_exposure_start_date) = (
            SELECT MIN(drug_exposure_start_date)
            FROM {drug_table}
            WHERE person_id = drug.person_id
            AND LOWER(drug_source_value) IN ({", ".join(["'"+s.lower()+"'" for s in statin_list])}))
) AS statin
ON statin.person_id = depress.person_id
INNER JOIN {case_cohort} AS case
ON case.person_id = depress.person_id
WHERE depress.ICD_date > statin.drug_date
'''
df_case_depress_before = spark.sql(sql_case_depress_before)
df_case_depress_before.display()

In [None]:
#the number of depressed patients in matched case
distinct_case_depress_count = df_case_depress.select("person_id").distinct().count()
print("Number of distinct case depress person_id :", distinct_case_depress_count)

In [None]:
#the number of record in case
row_count = df_case_depress.count()

print(f"Number of rows in df_control_depress: {row_count}")

In [None]:
#the number of depressed patients in matched case
distinct_case_before_count = df_case_depress_before.select("person_id").distinct().count()
print("Number of distinct case depress person_id :", distinct_case_before_count)

In [None]:
#the number of record in case
row_count = df_case_depress_before.count()

print(f"Number of rows in df_control_depress: {row_count}")

In [None]:
sql_case = f'''
  SELECT *
  FROM statin_mental_conditions.matched_cases_1
'''
df_case = spark.sql(sql_case)
df_case.display()

In [None]:
#the number patients in case
distinct_case_count = df_case.select("person_id").distinct().count()
print("Number of distinct case person_id :", distinct_case_count)

In [None]:
#depress rate
depress_rate_case = distinct_case_depress_count/distinct_case_count

print(f"Depress rate in case: {depress_rate_case}")

In [None]:
#depress rate before
depress_rate_case_before = distinct_case_before_count/distinct_case_count

print(f"Depress rate in case: {depress_rate_case_before}")