In [0]:
from pyspark.sql import functions as F
from delta.tables import DeltaTable

In [0]:
%run /Workspace/Users/evansavo@gmail.com/Population_Health_&_Readmission_Risk/1.setup/utilities

In [0]:
# setup dbutils
dbutils.widgets.text('catalog', 'phr', 'Catalog')
dbutils.widgets.text('data_source', '', 'Data Source')

In [0]:
catalog = dbutils.widgets.get('catalog')
data_source = dbutils.widgets.get('data_source')

base_path = f'databricks_cms_synthetic_public_use_files_synpuf.cms_synpuf_ext.{data_source}'


print(base_path)

# Bronze level

In [0]:
# Load table
df = spark.read.table(base_path).withColumn('read_timestamp', F.current_timestamp())

display(df.limit(2))

In [0]:
df.count()

In [0]:
# Save data
(df.write
 .format('delta')
 .option('delta.enableChangeDataFeed', 'true')
 .mode('overwrite')
 .saveAsTable(f'{catalog}.{bronze_schema}.{data_source}')
)

# Silver Level

In [0]:
silver_df = spark.sql(f'select * from {catalog}.{bronze_schema}.{data_source}')
# display(silver_df.limit(5))

In [0]:
# count nulls in each column
total_rows = silver_df.count()

for col_name in silver_df.columns:
    null_count = silver_df.filter(F.col(col_name).isNull()).count()
    null_ratio = null_count * 100 / total_rows
    print(f"Nulls in {col_name}: {null_count}, {null_ratio}")

In [0]:
# PHYSICIANS
silver_df = (
    silver_df
    .withColumn(
        "OPERATED_ON",
        F.when(F.col("OP_PHYSN_NPI").isNotNull(), 1).otherwise(0)
    )
    .withColumn(
        "SEEN_OTHER_PHYSN",
        F.when(F.col("OT_PHYSN_NPI").isNotNull(), 1).otherwise(0)
    )
    .drop("OP_PHYSN_NPI", "OT_PHYSN_NPI")
)

In [0]:
# Drop columns in list
hcpcs_drop = [f'HCPCS_CD_{x}' for x in range(1,46)]
silver_df = silver_df.drop(*hcpcs_drop)

In [0]:
# Create column ICD9_DGNS_CD
diag_cols = [f"ICD9_DGNS_CD_{i}" for i in range(1, 11)]

silver_df = (
    silver_df
    .withColumn("ICD9_DGNS_CD",
    F.array(*diag_cols))
    .withColumn("ICD9_DGNS_CD",
    F.expr("filter(ICD9_DGNS_CD, x -> x is not null)"))
    .drop(*diag_cols)
)

# Create column ICD9_PRCDR_CD
diag_cols_prd = [f"ICD9_PRCDR_CD_{i}" for i in range(1, 7)]

silver_df = (
    silver_df
    .withColumn("ICD9_PRCDR_CD",
    F.array(*diag_cols_prd))
    .withColumn("ICD9_PRCDR_CD",
    F.expr("filter(ICD9_PRCDR_CD, x -> x is not null)"))
    .drop(*diag_cols_prd)
)


In [0]:
# Change Dtype (Date)
silver_df = (
    silver_df
    .withColumn('CLM_FROM_DT', F.to_date(F.col('CLM_FROM_DT'), 'yyyyMMdd'))
    .withColumn('CLM_THRU_DT', F.to_date(F.col('CLM_THRU_DT'), 'yyyyMMdd'))
    .withColumn('CLM_ADMSN_DT', F.to_date(F.col('CLM_ADMSN_DT'), 'yyyyMMdd'))
    .withColumn('NCH_BENE_DSCHRG_DT', F.to_date(F.col('NCH_BENE_DSCHRG_DT'), 'yyyyMMdd'))
)

# Change Dtype (Amount)
silver_df = (
    silver_df
    .withColumn('CLM_PMT_AMT', F.col('CLM_PMT_AMT').cast('double'))
    .withColumn('NCH_PRMRY_PYR_CLM_PD_AMT', F.col('NCH_PRMRY_PYR_CLM_PD_AMT').cast('double'))
    .withColumn('CLM_PASS_THRU_PER_DIEM_AMT', F.col('CLM_PASS_THRU_PER_DIEM_AMT').cast('double'))
    .withColumn('NCH_BENE_IP_DDCTBL_AMT', F.col('NCH_BENE_IP_DDCTBL_AMT').cast('double'))
    .withColumn('NCH_BENE_PTA_COINSRNC_LBLTY_AM', F.col('NCH_BENE_PTA_COINSRNC_LBLTY_AM').cast('double'))
    .withColumn('NCH_BENE_BLOOD_DDCTBL_LBLTY_AM', F.col('NCH_BENE_BLOOD_DDCTBL_LBLTY_AM').cast('double'))
)

# Change Dtype (Months, Counts)
silver_df = (
    silver_df
    .withColumn('CLM_UTLZTN_DAY_CNT', F.col('CLM_UTLZTN_DAY_CNT').cast('int'))
)

In [0]:
# Create column CLAIM_DURATION
silver_df = (silver_df
            .withColumn('CLAIM_DURATION', F.datediff('CLM_THRU_DT', 'CLM_FROM_DT'))
            .withColumn('ADMISSION_DURATION', F.datediff('NCH_BENE_DSCHRG_DT', 'CLM_ADMSN_DT'))
)

In [0]:
# Drop missing
silver_df = silver_df.dropna()

In [0]:
# Drop read_timestamp
silver_df = silver_df.drop("read_timestamp")

In [0]:
# Load ben_sum
ben_sum = spark.read.table('phr.`02_silver`.ben_sum')

# Drop read_timestamp
ben_sum = ben_sum.drop("read_timestamp")

# Join with ben_sum
silver_df = silver_df.join(ben_sum, on='DESYNPUF_ID', how='inner')

In [0]:
# Write to silver layer
(silver_df.write
 .format('delta')
 .mode('overwrite')
 .option('enableChangeDataFeed', 'true')
 .saveAsTable(f'{catalog}.{silver_schema}.{data_source}_enriched')
)

In [0]:
# # count nulls in each column
# total_rows = silver_df.count()

# for col_name in silver_df.columns:
#     null_count = silver_df.filter(F.col(col_name).isNull()).count()
#     null_ratio = null_count * 100 / total_rows
#     print(f"Nulls in {col_name}: {null_count}, {null_ratio}")