In [None]:
import pandas as pd
pd.set_option('display.max_rows', None)

In [None]:
# where clause for identifing pregnancy events using ICD codes
def dx_where_clause(icd_lst:list,alias=''):
    '''
    sql where clause for cohort selection based on ICD code list
    '''
    if alias is not '':
        alias += '.'
    icd_quote = []
    for code in icd_lst:
        icd_quote.append("'"+ code +"'")
    icd_quote_str = ",".join(icd_quote)
    sql_str = "substring_index(upper("+ alias +"conditioncode.standard.id),'.',1) in (" + icd_quote_str +")"
    return sql_str 

preg_icd = [str(x) for x in list(range(630,679,1))]
preg_icd.extend(['V22','V23','V28','Z33','Z34','Z35','Z36','Z37','Z38','Z3A','O9A'])
preg_icd.extend(['O0' + str(x) for x in list(range(0,9,1))])
preg_icd.extend(['O' + str(x) for x in list(range(10,99,1))])
preg_where_clause = dx_where_clause(preg_icd,"cond") 
preg_where_clause

In [None]:
spark.sql("USE real_world_data_jun_2022")
preg_pt_init = spark.sql('''
    select distinct
           cond.personid,
           to_date(demo.birthdate) as dob,
           to_date(demo.dateofdeath) as dod,
           demo.deceased as deceased_ind,
           demo.gender.standard.id as gender,
           demo.races[0].standard.primaryDisplay as race,
           demo.ethnicities[0].standard.primaryDisplay as ethnicity,
           cond.encounterid,
           to_date(cond.effectivedate) as dx_date,
           cast(round(datediff(to_date(cond.effectivedate),to_date(demo.birthdate))/365.25) as int) as age_at_dx,
           enc.classification.standard.id as enc_type,
           to_date(enc.actualarrivaldate) as admit_date,
           cast(round(datediff(to_date(enc.actualarrivaldate),to_date(demo.birthdate))/365.25) as int) as age_at_enc,
           enc.readmission,
           to_date(enc.dischargedate) as discharge_date,
           enc.tenant as enc_tenant,
           tnt.bed_size,
           tnt.speciality,
           tnt.segment
    from condition cond
    join encounter enc on cond.encounterid = enc.encounterid
    join demographics demo on cond.personid = demo.personid
    left join tenant_attributes tnt on enc.tenant = tnt.tenant 
    where ''' + preg_where_clause
).cache()
preg_pt_init.createOrReplaceTempView("preg_pt_init")
preg_pt_init.head()

In [None]:
import itertools
def pt_freq_qry(df,stratified_by,n_way=1):
    '''
    generate total patient counts for each stratified variables
    '''
    sql_str_lst = []
    # overall count
    sql_str_lst.append("select 'total' as summ_var,'N' as summ_cat, count(distinct personid) as pat_cnt from " + df 
                       + where_clause)
    # 1-way summary
    for var_str in stratified_by:
        sql_str_lst.append(
            "select 'by_" + var_str +"' as summ_var," 
            + "cast(" + var_str +" as string) as summ_cat,"
            + "count(distinct personid) as pat_cnt "
            + "from "+ df + " group by "+ var_str
        )
        
    # n-way summary
    if n_way > 1:
        for L in range(2,n_way+1,1):
            for var_str_comb in itertools.combinations(stratified_by, L):
                var_str_concat_by_underline = "_".join(var_str_comb)
                var_str_concat_by_pipe = "||".join(var_str_comb)
                var_str_concat_by_comma = ",".join(var_str_comb)
                sql_str_lst.append(
                    "select 'by_" + var_str_concat_by_underline +"' as summ_var," 
                    + "cast(" + var_str_concat_by_pipe +" as string) as summ_cat,"
                    + "count(distinct personid) as pat_cnt "
                    + "from "+ df + " group by "+ var_str_concat_by_comma
                )
    
    # union everything
    sql_str = " union ".join(sql_str_lst)
    return(sql_str)

stratified_by = [
    'gender',
    'race',
    'ethnicity',
    'segment',
    'speciality',
    'bed_size',
    'deceased_ind'
]
get_pt_summ = pt_freq_qry('preg_pt_init',stratified_by,n_way=2)
get_pt_summ

In [None]:
get_pt_summ = pt_freq_qry(
    'preg_pt_init',stratified_by,n_way=1
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat.csv')
summ_stat_long

In [None]:
spark.sql("USE real_world_data_jun_2022")
preg_pt_12to55 = spark.sql('''
    select * from preg_pt_init
    where coalesce(dx_date,admit_date) between to_date('2010-01-01') and to_date('2022-12-31') and
          coalesce(age_at_dx,age_at_enc) between 12 and 55
''').cache()
preg_pt_12to55.createOrReplaceTempView("preg_pt_12to55")
preg_pt_12to55.head()

In [None]:
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=1
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55.csv')
summ_stat_long