In [None]:
# from pyspark.sql import SparkSession
# spark = SparkSession.builder()
# spark = SparkSession.newSession()

In [1]:
import pandas as pd
import json
import urllib.request as urlreq
pd.set_option('display.max_rows', None)

In [None]:
# where clause for identifing pregnancy events using ICD codes
def dx_where_clause(icd_lst:list,alias=''):
    '''
    sql where clause for cohort selection based on ICD code list
    '''
    if alias is not '':
        alias += '.'
    icd_quote = []
    for code in icd_lst:
        icd_quote.append("'"+ code +"'")
    icd_quote_str = ",".join(icd_quote)
    sql_str = "substring_index(upper("+ alias +"conditioncode.standard.id),'.',1) in (" + icd_quote_str +")"
    return sql_str 

preg_icd = [str(x) for x in list(range(630,679,1))]
preg_icd.extend(['V22','V23','V28','Z33','Z34','Z35','Z36','Z37','Z38','Z3A','O9A'])
preg_icd.extend(['O0' + str(x) for x in list(range(0,9,1))])
preg_icd.extend(['O' + str(x) for x in list(range(10,99,1))])
preg_where_clause = dx_where_clause(preg_icd,"cond") 
preg_where_clause

In [None]:
spark.sql("USE real_world_data_sep_2022")
preg_pt_init = spark.sql('''
    select distinct
           cond.personid,
           to_date(demo.birthdate) as dob,
           to_date(demo.dateofdeath) as dod,
           coalesce(demo.deceased,false) as deceased,
           demo.gender.standard.id as gender,
           coalesce(demo.races[0].standard.primaryDisplay,'NI') as race,
           coalesce(demo.ethnicities[0].standard.primaryDisplay,'NI') as ethnicity,
           cond.encounterid,
           cond.conditioncode,
           to_date(cond.effectivedate) as dxdate,
           cast(round(datediff(to_date(cond.effectivedate),to_date(demo.birthdate))/365.25) as int) as ageatdx,
           coalesce(enc.classification.standard.id,'NI') as enctype,
           to_date(enc.actualarrivaldate) as admitdate,
           date_add(to_date(enc.actualarrivaldate),365) as admitdate_add1yr,
           row_number() over (partition by cond.personid order by enc.actualarrivaldate,cond.effectivedate) as rn,
           cast(round(datediff(to_date(enc.actualarrivaldate),to_date(demo.birthdate))/365.25) as int) as ageatenc,
           enc.readmission,
           to_date(enc.dischargedate) as dischargedate,
           tnt.zip_code as zipcode,
           tnt.bed_size as bedsize,
           tnt.speciality,
           tnt.segment
    from condition cond
    join encounter enc on cond.encounterid = enc.encounterid
    join demographics demo on cond.personid = demo.personid and demo.gender.standard.id = 'F'
    left join tenant_attributes tnt on enc.tenant = tnt.tenant 
    where ''' + preg_where_clause
).cache()
preg_pt_init.createOrReplaceTempView("preg_pt_init")
preg_pt_init.head()

In [None]:
spark.sql("USE real_world_data_sep_2022")
preg_pt_12to55 = spark.sql('''
    select a.personid,
           a.deceased,
           a.race,
           a.ethnicity,
           case when a.ageatdx >= 12 and a.ageatdx < 20 then 'agegrp1'
                when a.ageatdx >= 20 and a.ageatdx < 25 then 'agegrp2'
                when a.ageatdx >= 25 and a.ageatdx < 30 then 'agegrp3'
                when a.ageatdx >= 30 and a.ageatdx < 35 then 'agegrp4'
                when a.ageatdx >= 35 and a.ageatdx < 40 then 'agegrp5'
                when a.ageatdx >= 40 and a.ageatdx < 45 then 'agegrp6'
                when a.ageatdx >= 45 and a.ageatdx < 50 then 'agegrp7'
                when a.ageatdx >= 50 and a.ageatdx <= 55 then 'agegrp8'
           end as ageatdxgrp,
           extract(year from a.dod) as dodcalyr,
           a.encounterid,
           coalesce(a.zipcode,'NI') as zipcode,
           coalesce(a.bedsize,'NI') as bedsize,
           coalesce(a.speciality,'NI') as speciality,
           coalesce(a.segment,'NI') as segment,
           a.conditioncode,
           a.enctype,
           a.admitdate,
           a.admitdate_add1yr,
           a.rn
    from preg_pt_init a
    where coalesce(a.dxdate,a.admitdate) between to_date('2010-01-01') and to_date('2022-12-31') and
          coalesce(a.ageatdx,a.ageatenc) between 12 and 55
''').cache()
preg_pt_12to55.createOrReplaceTempView("preg_pt_12to55")
preg_pt_12to55.head()

In [None]:
import itertools
def pt_freq_qry(df,stratified_by,n_way=1):
    '''
    generate total patient counts for each stratified variables
    '''
    sql_str_lst = []
    # overall count
    sql_str_lst.append("select 'total' as summ_var,'N' as summ_cat, count(distinct personid) as pat_cnt from " + df)
    # 1-way summary
    for var_str in stratified_by:
        sql_str_lst.append(
            "select '" + var_str +"' as summ_var," 
            + "cast(" + var_str +" as string) as summ_cat,"
            + "count(distinct personid) as pat_cnt "
            + "from "+ df + " group by "+ var_str
        )
        
    # n-way summary
    if n_way > 1:
        for L in range(2,n_way+1,1):
            for var_str_comb in itertools.combinations(stratified_by, L):
                var_str_concat_by_underline = "_".join(var_str_comb)
                var_str_concat_by_pipe = "|| '|' ||".join(var_str_comb)
                var_str_concat_by_comma = ",".join(var_str_comb)
                sql_str_lst.append(
                    "select '" + var_str_concat_by_underline +"' as summ_var," 
                    + "cast(" + var_str_concat_by_pipe +" as string) as summ_cat,"
                    + "count(distinct personid) as pat_cnt "
                    + "from "+ df + " group by "+ var_str_concat_by_comma
                )
    
    # union everything
    sql_str = " union ".join(sql_str_lst)
    return(sql_str)

stratified_by = [
    'race',
    'ethnicity',
    'segment',
    'speciality',
    'bedsize',
    'zipcode',
#     'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry('preg_pt_12to55',stratified_by,n_way=3)
get_pt_summ

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'segment',
    'speciality',
    'bedsize',
    'zipcode',
#     'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=2
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55.csv',index=False)
summ_stat_long

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
    'speciality',
    'bedsize',
    'zipcode',
    'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=2
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_2way.csv',index=False)
summ_stat_long

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
#     'bedsize',
#     'zipcode',
    'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_segment.csv',index=False)
summ_stat_long

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
#     'bedsize',
    'zipcode',
#     'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_segment_zip.csv',index=False)
summ_stat_long

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
    'bedsize',
#     'zipcode',
#     'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_segment_bedsize.csv',index=False)
summ_stat_long

In [None]:
def sqlgen_phe_spark(
    src_json_url,      # url where source phenotype mapping can be found, e.g. https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json
    which_phenotype,   # specify codeset for which phenotype, currently only allow single string
    cd_field_map,      # actual code field mapping
    cdtype_field_map,  # code type field mapping
    cdtype_value_map,  # code type valueset mapping
    alias=''           # if an alias , especially embeded in a "join" statement
):
    '''
    sql where clause for cohort selection based on ICD code list
    '''
    # load json file from source
    with urlreq.urlopen(src_json_url) as url:
        phemap = json.loads(url.read().decode())

    # loop over code types
    sql_str_lst = []
    for cdt in phemap[which_phenotype]:
        for prec,cdlst in phemap[which_phenotype][cdt].items():
            sql_str = ''
            if not cdlst: continue
            # codetype
            cdtype_quote = []
            for type in cdtype_value_map[cdt]:
                cdtype_quote.append("'"+ type +"'")
            cdtype_quote_str = ",".join(cdtype_quote)
            sql_str += alias + cdtype_field_map[cdt] + " in (" + cdtype_quote_str + ") AND "
            # expand range
            if prec == "range":
                for rg in cdlst: 
                    rg_pos = rg.split('-') 
                    cdlst_new = range(int(rg_pos[0]),int(rg_pos[1])+1)
                cdlst = cdlst_new
            # construct code list string
            cd_quote = []
            for code in cdlst:
                cd_quote.append("'"+ str(code) +"'")
            cd_quote_str = ",".join(cd_quote)
            # complete search sentence
            if prec == 'lev0':
                sql_str += "substring_index(upper("+ alias + cd_field_map[cdt] +"),'.',1) in (" + cd_quote_str +")" 
            elif prec == 'lev1':
                str_len = len(cdlst[0])
                sql_str += "substring(upper("+ alias + cd_field_map[cdt] +"),1,"+ str(str_len+1) +") in (" + cd_quote_str +")" 
            else: # lev2, range, exact
                sql_str += "upper("+ alias + cd_field_map[cdt] +") in (" + cd_quote_str +")" 
            sql_str_lst.append(sql_str)

    # concatenate them together
    sql_str_master = "(" + ") OR (".join(sql_str_lst) + ")"
    return(sql_str_master)

cd_field_map = {
    'icd9-cm':'conditioncode.standard.id',
    'icd10-cm':'conditioncode.standard.id',
    'icd9-pr':'procedurecode.standard.id',
    'icd10-pcs':'procedurecode.standard.id',
    'hcpcs':'procedurecode.standard.id',
    'drg':'conditioncode.standard.id'
}
cdtype_field_map = {
    'icd9-cm':'conditioncode.standard.codingSystemId',
    'icd10-cm':'conditioncode.standard.codingSystemId',
    'icd9-pr':'procedurecode.standard.codingSystemId',
    'icd10-pcs':'procedurecode.standard.codingSystemId',
    'hcpcs':'procedurecode.standard.codingSystemId',
    'drg':'conditioncode.standard.codingSystemId'
}
cdtype_value_map = {
    'icd9-cm':['2.16.840.1.113883.6.103'],
    'icd10-cm':['2.16.840.1.113883.6.90'],
    'icd9-pr':['2.16.840.1.113883.6.104'],
    'icd10-pcs':['2.16.840.1.113883.6.4'],
    'hcpcs':['2.16.840.1.113883.6.14'],
    'drg':['urn:cerner:codingsystem:drg:aprdrg',
           'urn:cerner:codingsystem:drg:apdrg',
           'urn:cerner:codingsystem:drg:irdrg',
           'urn:cerner:codingsystem:drg:msdrg']
}

In [None]:
where_delivery_vaginal = sqlgen_phe_spark(
    src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
    which_phenotype = 'delivery-vaginal',
    cd_field_map = cd_field_map,
    cdtype_field_map = cdtype_field_map,
    cdtype_value_map = cdtype_value_map,
    alias = 'a.'
)
where_delivery_vaginal

In [None]:
where_delivery_csection = sqlgen_phe_spark(
    src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
    which_phenotype = 'delivery-csection',
    cd_field_map = cd_field_map,
    cdtype_field_map = cdtype_field_map,
    cdtype_value_map = cdtype_value_map,
    alias = 'a.'
)
where_delivery_csection

In [None]:
where_delivery_mischarriage = sqlgen_phe_spark(
    src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
    which_phenotype = 'delivery-mischarriage',
    cd_field_map = cd_field_map,
    cdtype_field_map = cdtype_field_map,
    cdtype_value_map = cdtype_value_map,
    alias = 'a.'
)
where_delivery_mischarriage

In [None]:
where_delivery_ectopic = sqlgen_phe_spark(
    src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
    which_phenotype = 'delivery-ectopic',
    cd_field_map = cd_field_map,
    cdtype_field_map = cdtype_field_map,
    cdtype_value_map = cdtype_value_map,
    alias = 'a.'
)
where_delivery_ectopic

In [None]:
where_delivery_molar = sqlgen_phe_spark(
    src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
    which_phenotype = 'delivery-molar',
    cd_field_map = cd_field_map,
    cdtype_field_map = cdtype_field_map,
    cdtype_value_map = cdtype_value_map,
    alias = 'a.'
)
where_delivery_molar

In [None]:
spark.sql("USE real_world_data_sep_2022")
preg_pt_init_delivery = spark.sql('''
    with delivery_stk AS (
        select 'vaginal' AS delivery_type, a.personid, a.encounterid, a.dodcalyr, a.deceased, a.race, a.ethnicity, a.ageatdxgrp,
            a.admitdate, a.admitdate_add1yr, a.rn, a.zipcode, a.bedsize,a.speciality, a.segment 
            from preg_pt_12to55 a where ''' + where_delivery_vaginal + ''' and a.enctype = 'I'
        union
        select 'csection' AS delivery_type, a.personid, a.encounterid, a.dodcalyr, a.deceased, a.race, a.ethnicity, a.ageatdxgrp, 
            a.admitdate, a.admitdate_add1yr, a.rn, a.zipcode, a.bedsize,a.speciality, a.segment 
            from preg_pt_12to55 a where ''' + where_delivery_csection + ''' and a.enctype = 'I'
        union
        select 'miscarriage' AS delivery_type, a.personid, a.encounterid, a.dodcalyr, a.deceased, a.race, a.ethnicity, a.ageatdxgrp, 
            a.admitdate, a.admitdate_add1yr, a.rn, a.zipcode, a.bedsize,a.speciality, a.segment 
            from preg_pt_12to55 a where ''' + where_delivery_mischarriage + ''' and a.enctype = 'I'
        union
        select 'ectopic' AS delivery_type, a.personid, a.encounterid, a.dodcalyr, a.deceased, a.race, a.ethnicity, a.ageatdxgrp, 
            a.admitdate, a.admitdate_add1yr, a.rn, a.zipcode,a.bedsize,a.speciality, a.segment 
            from preg_pt_12to55 a where ''' + where_delivery_ectopic + ''' and a.enctype = 'I'
        union
        select 'molar' AS delivery_type, a.personid, a.encounterid, a.dodcalyr, a.deceased, a.race, a.ethnicity, a.ageatdxgrp, 
            a.admitdate, a.admitdate_add1yr, a.rn, a.zipcode,a.bedsize,a.speciality, a.segment 
            from preg_pt_12to55 a where ''' + where_delivery_molar + ''' and a.enctype = 'I'
    )
    select distinct
           personid,
           delivery_type,
           dodcalyr, 
           deceased,
           race, 
           ethnicity, 
           ageatdxgrp,
           encounterid,
           admitdate,
           zipcode,
           bedsize,
           speciality,
           segment
    from delivery_stk
''').cache()
preg_pt_init_delivery.createOrReplaceTempView("preg_pt_12to55_delivery")
preg_pt_init_delivery.head()

In [None]:
stratified_by = [
    'delivery_type',
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
#     'bedsize',
#     'zipcode',
#     'dodcalyr',
    'deceased'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_12to55_delivery',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_delivery_segment.csv',index=False)
summ_stat_long

In [None]:
smm_dict_cm = {
    'acute-myocardial-infarction':'ami',
    'aneurysm':'ane',
    'acute-renal-failure':'arf',
    'adult-respiratory-distress-syndrome':'ards',
    'amniotic-fluid-embolism':'afe',
    'cardiac-arrest-ventricular-fibrillation':'cavf',
    'disseminated-intravascular-coagulation':'dic',
    'eclampsia':'ecl',
    'heart-failure-arrest-during-surgery':'hf',
    'puerperal-cerebrovascular-disorders':'pcd',
    'pulmonary-edema-acute-heart-failure':'pe',
    'severe-anesthesia-complications':'sac',
    'sepsis':'sep',
    'shock':'sho',
    'air-and-thrombotic-embolism':'ate',
    'sickle-cell-disease':'scd'
}
smm_dict_pr = {
    'conversion-cardiac-rhythm':'ccr',
    'blood-products-transfusion':'bpt',
    'hysterectomy':'hyst',
    'temporary-tracheostomy':'tt',
    'ventilation':'vent'
}

In [None]:
spark.sql("USE real_world_data_sep_2022")
sql_union_statmts = []
# defined by diagnosis codes
for smm,smm_short in smm_dict_cm.items():
    where_smm = sqlgen_phe_spark(
        src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
        which_phenotype = smm,
        cd_field_map = cd_field_map,
        cdtype_field_map = cdtype_field_map,
        cdtype_value_map = cdtype_value_map,
        alias = 'cond.'
    )
    
    sql_union_statmts.append('''
        select distinct a.*, '''+ "'" + smm + "'" +''' AS smm_grp 
        from preg_pt_12to55_delivery a
        join condition cond on a.personid = cond.personid 
        join encounter enc on cond.encounterid = enc.encounterid
        where coalesce(cond.effectivedate,enc.actualarrivaldate) between a.admitdate and date_add(a.admitdate,42) and 
              '''+ where_smm +'''
    ''')
    
# defined by procedure codes
for smm,smm_short in smm_dict_pr.items():
    where_smm = sqlgen_phe_spark(
        src_json_url = 'https://raw.githubusercontent.com/RWD2E/phecdm/main/res/valueset_curated/vs-mmm.json',
        which_phenotype = smm,
        cd_field_map = cd_field_map,
        cdtype_field_map = cdtype_field_map,
        cdtype_value_map = cdtype_value_map,
        alias = 'pr.'
    )
    
    sql_union_statmts.append('''
        select distinct a.*, '''+ "'" + smm + "'" +''' AS smm_grp 
        from preg_pt_12to55_delivery a
        join procedure pr on a.personid = pr.personid 
        join encounter enc on pr.encounterid = enc.encounterid
        where coalesce(pr.servicestartdate,enc.actualarrivaldate) between a.admitdate and date_add(a.admitdate,42) and 
              '''+ where_smm +'''
    ''')

sql_union_statmt = ' union '.join(sql_union_statmts)
# print(sql_union_statmt)
preg_pt_init_delivery_smm = spark.sql(sql_union_statmt).cache()
preg_pt_init_delivery_smm.createOrReplaceTempView("preg_pt_init_delivery_smm")
preg_pt_init_delivery_smm.head()

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
#     'bedsize',
#     'zipcode',
    'dodcalyr',
    'smm_grp'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_init_delivery_smm',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_delivery_smm_seg.csv',index=False)
summ_stat_long

In [None]:
stratified_by = [
    'race',
    'ethnicity',
    'ageatdxgrp',
    'segment',
#     'speciality',
#     'bedsize',
    'zipcode',
#     'dodcalyr',
    'smm_grp'
]
get_pt_summ = pt_freq_qry(
    'preg_pt_init_delivery_smm',stratified_by,n_way=3
)
summ_stat_long = spark.sql(get_pt_summ).toPandas()
summ_stat_long.to_csv('./data/summ_stat_12to55_delivery_smm_seg_zip.csv',index=False)
summ_stat_long