In [None]:
# from pyspark.sql import SparkSession
# spark = SparkSession.builder()
# spark = SparkSession.newSession()

In [1]:
import pandas as pd
pd.set_option('display.max_rows', None)

In [None]:
# where clause for identifing pregnancy events using ICD codes
def cd_where_clause(cd_type, cd_lst:list, alias=''):
    '''
    sql where clause for cohort selection based on ICD code list
    '''
    cd_field_map = {
        'dx':'conditioncode',
        'px':'procedurecode'
    }
    if alias is not '':
        alias += '.'
    cd_quote = []
    for code in cd_lst:
        cd_quote.append("'"+ code +"'")
    cd_quote_str = ",".join(cd_quote)
    sql_str = "substring_index(upper("+ alias + cd_field_map[cd_type] +".standard.id),'.',1) in (" + cd_quote_str +")"
    return sql_str 

# where clause for identifing pregnancy events using ICD codes
preg_icd = [str(x) for x in list(range(630,679,1))]
preg_icd.extend(['V22','V23','V28','Z33','Z34','Z35','Z36','Z37','Z38','Z3A','O9A'])
preg_icd.extend(['O0' + str(x) for x in list(range(0,9,1))])
preg_icd.extend(['O' + str(x) for x in list(range(10,99,1))])
preg_where_clause = cd_where_clause(
    cd_type = 'dx',
    cd_lst = preg_icd,
    alias = "cond"
) 
preg_where_clause

In [None]:
spark.sql("USE real_world_data_sep_2022")
preg_pt_init = spark.sql('''
    select distinct
           cond.personid,
           to_date(demo.birthdate) as dob,
           to_date(demo.dateofdeath) as dod,
           coalesce(demo.deceased,false) as deceased,
           demo.gender.standard.id as gender,
           coalesce(demo.races[0].standard.primaryDisplay,'NI') as race,
           coalesce(demo.ethnicities[0].standard.primaryDisplay,'NI') as ethnicity,
           cond.encounterid,
           cond.conditioncode,
           to_date(cond.effectivedate) as dxdate,
           cast(round(datediff(to_date(cond.effectivedate),to_date(demo.birthdate))/365.25) as int) as ageatdx,
           coalesce(enc.classification.standard.id,'NI') as enctype,
           to_date(enc.actualarrivaldate) as admitdate,
           date_add(to_date(enc.actualarrivaldate),365) as admitdate_add1yr,
           row_number() over (partition by cond.personid order by enc.actualarrivaldate,cond.effectivedate) as rn,
           cast(round(datediff(to_date(enc.actualarrivaldate),to_date(demo.birthdate))/365.25) as int) as ageatenc,
           enc.readmission,
           to_date(enc.dischargedate) as dischargedate
    from condition cond
    join encounter enc on cond.encounterid = enc.encounterid and enc.classification.standard.id = 'I'
    join demographics demo on cond.personid = demo.personid and demo.gender.standard.id = 'F'
    where coalesce(to_date(cond.effectivedate),to_date(enc.actualarrivaldate)) between to_date('2010-01-01') and to_date('2022-12-31') and
          cast(round(datediff(to_date(enc.actualarrivaldate),to_date(demo.birthdate))/365.25) as int) between 12 and 55
'''
).cache()
preg_pt_init.createOrReplaceTempView("preg_pt_init")
preg_pt_init.head()

In [None]:
spark.sql("USE real_world_data_jun_2022")
preg_pt_init = spark.sql('''
    select distinct
           cond.personid,
           to_date(demo.birthdate) as dob,
           to_date(demo.dateofdeath) as dod,
           demo.deceased as deceased,
           demo.gender.standard.id as gender,
           demo.races[0].standard.primaryDisplay as race,
           demo.ethnicities[0].standard.primaryDisplay as ethnicity,
           cond.encounterid,
           to_date(cond.effectivedate) as dxdate,
           cast(round(datediff(to_date(cond.effectivedate),to_date(demo.birthdate))/365.25) as int) as ageatdx,
           enc.classification.standard.id as enctype,
           to_date(enc.actualarrivaldate) as admitdate,
           cast(round(datediff(to_date(enc.actualarrivaldate),to_date(demo.birthdate))/365.25) as int) as ageatenc,
           enc.readmission,
           to_date(enc.dischargedate) as dischargedate,
           enc.tenant as enctenant,
           tnt.bed_size as bedsize,
           tnt.speciality,
           tnt.segment
    from condition cond
    join encounter enc on cond.encounterid = enc.encounterid
    join demographics demo on cond.personid = demo.personid
    left join tenant_attributes tnt on enc.tenant = tnt.tenant 
    where ''' + preg_where_clause
).cache()
preg_pt_init.createOrReplaceTempView("preg_pt_init")
preg_pt_init.head()

In [None]:
import itertools
def pt_freq_qry(df,stratified_by,n_way=1):
    '''
    generate total patient counts for each stratified variables
    '''
    sql_str_lst = []
    # overall count
    sql_str_lst.append("select 'total' as summ_var,'N' as summ_cat, count(distinct personid) as pat_cnt from " + df)
    # 1-way summary
    for var_str in stratified_by:
        sql_str_lst.append(
            "select '" + var_str +"' as summ_var," 
            + "cast(" + var_str +" as string) as summ_cat,"
            + "count(distinct personid) as pat_cnt "
            + "from "+ df + " group by "+ var_str
        )
        
    # n-way summary
    if n_way > 1:
        for L in range(2,n_way+1,1):
            for var_str_comb in itertools.combinations(stratified_by, L):
                var_str_concat_by_underline = "_".join(var_str_comb)
                var_str_concat_by_pipe = "|| '|' ||".join(var_str_comb)
                var_str_concat_by_comma = ",".join(var_str_comb)
                sql_str_lst.append(
                    "select 'by_" + var_str_concat_by_underline +"' as summ_var," 
                    + "cast(" + var_str_concat_by_pipe +" as string) as summ_cat,"
                    + "count(distinct personid) as pat_cnt "
                    + "from "+ df + " group by "+ var_str_concat_by_comma
                )
    
    # union everything
    sql_str = " union ".join(sql_str_lst)
    return(sql_str)

stratified_by = [
    'gender',
    'race',
    'ethnicity',
    'segment',
    'speciality',
    'bedsize',
    'deceased_ind'
]
get_pt_summ = pt_freq_qry('preg_pt_init',stratified_by,n_way=2)
get_pt_summ

In [None]:
get_pt_summ = pt_freq_qry(
    'preg_pt_init',stratified_by,n_way=2
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat.csv',index=False)
summ_stat_long

In [None]:
spark.sql("USE real_world_data_jun_2022")
preg_pt_12to55 = spark.sql('''
    select a.*, extract(year from a.dod) as dod_calyr
    from preg_pt_init a
    where coalesce(a.dxdate,a.admitdate) between to_date('2010-01-01') and to_date('2022-12-31') and
          coalesce(a.ageatdx,a.ageatenc) between 12 and 55
''').cache()
preg_pt_12to55.createOrReplaceTempView("preg_pt_12to55")
preg_pt_12to55.head()

In [None]:
stratified_by.append('dod_calyr')
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=2
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55.csv',index=False)
summ_stat_long