## Initiate session and imports

In [None]:
spark

In [None]:
from pyspark.sql import SparkSession

# Create a SparkSession with increased maxResultSize
spark = SparkSession.builder \
    .appName("YourAppName") \
    .config("spark.driver.maxResultSize", "4g").getOrCreate()

In [None]:
# set max columns, rows, column width in pandas so doesn't truncate
import pandas as pd
pd.set_option('display.max_colwidth',250) # or -1
pd.set_option('display.max_columns', None) # or 500
pd.set_option('display.max_rows', None) # or 500

# sets the cell width to 100% respective to the screen size
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:92% !important; }</style>"))

##Imports
from pyspark.sql.functions import when, col
from pyspark.sql.functions import col, to_date, datediff, current_date

## MA Cohort Selection

In [None]:
spark.sql("use real_world_data_jun_2023")

In [None]:
#Select personid from condition table if in list of codes 
#from MA (includes ICD 9/10 and Snomet-CT)

condition_sdf = spark.sql("""
    select personid, conditioncode.standard.id as stdid,
    conditioncode.standard.primaryDisplay as display
    from condition
    where conditioncode.standard.id in ('Q51.0','752.31', '204844007', '783231007', '783230008', '17142008', 
	'Q51.1','Q51.10', 'Q51.11','752.2', '360422007', '10835661000119100', '21346009', '722431007', 
	'Q51.2', '752.35',  '1230025003','22504001', 'Q51.3', '752.34', '31401003', '237223005', '237224004', '237225003', '237221007', '237220008', '237219002',
	'Q51.818', '752.39','Q51.811', '752.32', '253832006', 'Q51.4', '752.33', '1372004', 'Q51.810', '752.36', '38437003', 
	'Q52.11', '752.46', '142191000119104', 'Q52.12', 'Q52.120', 'Q52.120','Q52.121', 'Q52.122', 'Q52.123', 'Q52.124', 'Q52.129', '752.47', '142201000119101')
""")

condition_sdf.cache()

print(condition_sdf.select('personid').distinct().count())

In [None]:
#Select personid from problemlist table if in list of codes 
#from MA (includes ICD 9/10 and Snomet-CT)

problem_sdf = spark.sql("""
    select personid, problemlistcode.standard.id as stdid,
    problemlistcode.standard.primaryDisplay as display
    from problem_list
    where problemlistcode.standard.id in ('Q51.0','752.31', '204844007', '783231007', '783230008', '17142008', 
	'Q51.1','Q51.10', 'Q51.11','752.2', '360422007', '10835661000119100', '21346009', '722431007', 
	'Q51.2', '752.35',  '1230025003','22504001', 'Q51.3', '752.34', '31401003', '237223005', '237224004', '237225003', '237221007', '237220008', '237219002',
	'Q51.818', '752.39','Q51.811', '752.32', '253832006', 'Q51.4', '752.33', '1372004', 'Q51.810', '752.36', '38437003', 
	'Q52.11', '752.46', '142191000119104', 'Q52.12', 'Q52.120', 'Q52.121', 'Q52.122', 'Q52.123', 'Q52.124', 'Q52.129', '752.47', '142201000119101')
""")

problem_sdf.cache()

print(problem_sdf.select('personid').distinct().count())

In [None]:
##Concatenate the personids from condition and problemlist into one
##ensure all distinct/ no duplicates

all_codes=condition_sdf.union(problem_sdf)

concat_count=all_codes.select('personid').distinct().count()

print(concat_count)

all_codes.limit(10).toPandas()

CUA_patients_codes=all_codes.distinct()

CUA_personlist= list(CUA_patients_codes.select('personid').distinct().toPandas()['personid'])
print(len(CUA_personlist))

In [None]:
#persist table with all personids collected

personid_table=all_codes.select('personid').distinct()

personid_table.limit(5).toPandas()

print(personid_table.count())


In [None]:
##save tables to the pre-existing database (CUA_db)
#personid_table = just 1 column with personid codes
#CUA_patients_codes = personid + MA diagnostic code

personid_table.write.saveAsTable('CUA_db.personid_table')

CUA_patients_codes.write.saveAsTable('CUA_db.CUA_patient_codes')

## Create Table for MA cohort with ALL condition codes

### Temporary pull tables with ALL condition codes/ problem_lists

In [None]:
condition= spark.sql("""
    select personid, conditioncode.standard.id as stdid, conditioncode.standard.primaryDisplay as display
    from condition
""")
condition.cache()

In [None]:
problem=spark.sql("""
    select personid, problemlistcode.standard.id as stdid,
    problemlistcode.standard.primaryDisplay as display
    from problem_list
""")
problem.cache()

### Filter condition and problem tables for MA population and combine

In [None]:
CUA_prob=problem.filter(col("personid").isin(CUA_personlist))
print(CUA_prob.select('personid').distinct().count())

In [None]:
CUA_cond=condition.filter(col("personid").isin(CUA_personlist))
print(CUA_cond.select('personid').distinct().count())

In [None]:
CUA_concat=CUA_cond.union(CUA_prob)
CUA_concat_count=CUA_concat.select('personid').distinct().count()
print(CUA_concat_count)
CUA_concat.limit(20).toPandas()

In [None]:
CUA_concat.write.saveAsTable('CUA_db.CUA_concat')

## Add Procedure Table

In [None]:
procedure= spark.sql("""
    SELECT 
    personid,
    procedurecode.standard.id as procedure_code,
    procedurecode.standard.primaryDisplay as procedure_display,
    servicestartdate as startdate,
    serviceenddate as enddate
    FROM procedure
""")
procedure

In [None]:
CUA_procedure=procedure.filter(col("personid").isin(CUAperson_list))

In [None]:
CUA_procedure.limit(5).toPandas()

In [None]:
CUA_procedure.select('personid').distinct().count()

In [None]:
CUA_procedure.write.saveAsTable('CUA_db.procedure_table')

## Add demographics

In [None]:
demographics_sdf=spark.sql("""
    SELECT personid, birthdate, deceased, dateofdeath, races.standard.id as race_id, 
    races.standard.primaryDisplay as race_display, 
    ethnicities.standard.id as ethnic_id, 
    ethnicities.standard.codingSystemId as ethnic_id1,
    ethnicities.standard.primaryDisplay as ethnic_display, testpatientflag, 
    zipcodes.zipcode1 as zipcode,
    zipcodes.begineffectiveyear as zip_year,
    gender.standard.primaryDisplay as gender
    FROM demographics
    WHERE gender.standard.primaryDisplay== "Female"
    """)
demographics_sdf

In [None]:
#Get age of patients (as of Sept 25, 2023)
dem_age=demographics_sdf.select('personid','birthdate', 'deceased', 'dateofdeath', 'race_id', 'race_display', 'ethnic_id', 'ethnic_display', 'testpatientflag','zipcode', 'zip_year', 'gender')\
            .withColumn('age',datediff(current_date(),to_date(col('birthdate')))/365.25)\
            .drop('birthdate')
dem_age

In [None]:
#Filter test patients
true_patients=dem_age.where(col('testpatientflag')=="False")
true_patients.limit(5).toPandas()

print(true_patients.select('personid').distinct().count())

In [None]:
demo_CUA=true_patients.filter(col("personid").isin(CUA_personlist))
demo_CUA.select('personid').distinct().count()

In [None]:
CUA_demo.write.saveAsTable('CUA_db.demo_CUA')

## Incorporate BMI

In [None]:
BMI_pull= spark.sql("""
    SELECT personid, measurementcode.standard.id as id, measurementcode.standard.primaryDisplay as display, 
    typedvalue.numericValue.value as BMI_ratio, servicedate, typedvalue.dateValue.date 
    FROM measurement
    WHERE measurementcode.standard.id=='39156-5'
""")
BMI_pull

In [None]:
BMI=BMI_pull.drop("id", "display", "servicedate", "date")
BMI.limit(10).toPandas()

In [None]:
BMI_numeric = BMI.withColumn("BMI_ratio", col("BMI_ratio").cast(FloatType()))

In [None]:
BMI_CUA=BMI_numeric.filter(col("personid").isin(CUA_personlist))
print(BMI_CUA.select('personid').distinct().count())

In [None]:
max_bmi = BMI_CUA.agg(max("BMI_ratio").alias("max_BMI_ratio"))
max_bmi.limit(20).toPandas()

In [None]:
min_bmi = BMI_CUA.agg(min("BMI_ratio").alias("min_BMI_ratio"))
min_bmi.show()

In [None]:
filtered_BMI = BMI_CUA.filter((BMI_CUA["BMI_ratio"] >= 10) & (BMI_CUA["BMI_ratio"] <= 204))

In [None]:
BMI_max=filtered_BMI.groupBy('personid').max('BMI_ratio')

In [None]:
BMI_max1 = BMI_max.withColumnRenamed("max(BMI_ratio)", "max_BMI_ratio")

## Combine demo & BMI tables

In [None]:
spark.sql("use CUA_db")

In [None]:
CUA_demo1 = spark.sql("""
    select personid
    from CUA_demo
    """)
CUA_demo1

In [None]:
demo_BMI=CUA_demo1.join(BMI_max1,['personid'],how='left')
demo_BMI.limit(15).toPandas()
print(demo_BMI.select('personid').distinct().count())

In [None]:
demo_BMI.write.saveAsTable('CUA_db.demo_BMI')

## Check for newly created tables

In [None]:
spark.sql("use CUA_db")

In [None]:
##check for newly created tables
spark.sql("show tables").toPandas()