# Snowpark: SMM & Chronic Flags — Regex ICD Patterns Added

In [None]:

from snowflake.snowpark.context import get_active_session
session = get_active_session()
print("ROLE     :", session.get_current_role())
print("WAREHOUSE:", session.get_current_warehouse())
print("DATABASE :", session.get_current_database())
print("SCHEMA   :", session.get_current_schema())


In [None]:

MM_TABLE  = "EASE_DW_PROD.EASE_DUA_MAP_PMO_MSM.MSM_BIRTHS"
SMM_TABLE = "EASE_DW_PROD.EASE_DUA_MAP_PMO_MSM.SEVERE_MATERNAL_MORBIDITY_VARIABLES"
mm_tbl  = session.table(MM_TABLE)
smm_tbl = session.table(SMM_TABLE)
mm_tbl.limit(1).show()
smm_tbl.limit(1).show()


In [None]:

import re
mm_cols = mm_tbl.schema.names

def numbered_cols(df_cols, prefix: str):
    pat = re.compile(rf"^{re.escape(prefix)}_(\d+)$", re.IGNORECASE)
    matched = []
    for c in df_cols:
        m = pat.match(c)
        if m:
            matched.append((int(m.group(1)), c))
    matched.sort(key=lambda t: t[0])
    return [c for _, c in matched]

code_cols = numbered_cols(mm_cols, "DIAGNOSIS_CODE")
desc_cols = numbered_cols(mm_cols, "DIAGNOSIS_DESCRIPTION_SHORT")

print("Detected DIAG cols:", code_cols)
print("Detected DESC cols:", desc_cols)
assert len(code_cols) > 0, "No DIAGNOSIS_CODE_# columns found."


In [None]:

from snowflake.snowpark.functions import col, lit, call_function

def object_construct_expr(cols):
    args = []
    for c in cols:
        args.append(lit(c)); args.append(col(c))
    return call_function("OBJECT_CONSTRUCT", *args)


In [None]:

from snowflake.snowpark.functions import col, lit, regexp_replace, call_function

codes_obj = object_construct_expr(code_cols)
mm_codes_long = (
    mm_tbl
    .select(col("DIM_MEMBER_KEY"), codes_obj.alias("CODES"))
    .join_table_function(call_function("LATERAL_FLATTEN", col("CODES")))
    .select(col("DIM_MEMBER_KEY"),
            col("KEY").alias("CODE_SLOT"),
            col("VALUE").cast("string").alias("ICD_CODE"))
)
if desc_cols:
    descs_obj = object_construct_expr(desc_cols)
    mm_desc_long = (
        mm_tbl
        .select(col("DIM_MEMBER_KEY"), descs_obj.alias("DESCS"))
        .join_table_function(call_function("LATERAL_FLATTEN", col("DESCS")))
        .select(col("DIM_MEMBER_KEY"),
                col("KEY").alias("DESC_SLOT"),
                col("VALUE").cast("string").alias("ICD_DESC"))
    )
    mm_codes_long = mm_codes_long.with_column("SLOT_NUM", regexp_replace(col("CODE_SLOT"), r"[^0-9]", ""))
    mm_desc_long  = mm_desc_long.with_column("SLOT_NUM", regexp_replace(col("DESC_SLOT"),  r"[^0-9]", ""))
    mm_codes_long = mm_codes_long.join(mm_desc_long, on=["DIM_MEMBER_KEY", "SLOT_NUM"], how="left")
else:
    mm_codes_long = mm_codes_long.with_column("ICD_DESC", lit(None).cast("string"))
mm_codes_long.show(5)


In [None]:

from snowflake.snowpark.functions import upper, trim, coalesce, col, lit
mm_codes_norm = (
    mm_codes_long
    .with_column("ICD_CODE_NORM", upper(trim(coalesce(col("ICD_CODE"), lit("")))))
    .with_column("ICD_DESC_NORM", upper(trim(coalesce(col("ICD_DESC"), lit("")))))
)
mm_codes_norm.show(5)


In [None]:

def resolve(colnames):
    for cand in colnames:
        for actual in smm_tbl.schema.names:
            if actual.upper() == cand.upper():
                return actual
    return None

diag_col = resolve(["DIAGNOSIS", "DIAGNOSIS_PROCEDURE", "DIAGNOSIS_OR_PROCEDURE"])
icd_col  = resolve(["ICD_10", "ICD-10", "ICD"])
ind_col  = resolve(["SEVERE_MATERNAL_MORBIDITY_INDICATOR", "SMM_INDICATOR", "SEVERE_MATERNAL_MORBIDITY"])
if not (diag_col and icd_col and ind_col):
    raise ValueError(f"Could not resolve SMM columns. diag={diag_col} icd={icd_col} indicator={ind_col}")

smm_ref = (
    smm_tbl
    .select(col(diag_col).alias("DIAGNOSIS"),
            col(icd_col).alias("ICD_10"),
            col(ind_col).alias("SMM_INDICATOR_RAW"))
    .with_column("DIAGNOSIS", upper(trim(coalesce(col("DIAGNOSIS"), lit("")))))
    .with_column("ICD_10",   upper(trim(coalesce(col("ICD_10"),   lit("")))))
    .filter(col("DIAGNOSIS") == lit("DX"))
    .filter(col("ICD_10").is_not_null() & (col("ICD_10") != lit("")))
)
smm_ref.show(5)


In [None]:

smm_join = mm_codes_norm.join(smm_ref, on=mm_codes_norm["ICD_CODE_NORM"] == smm_ref["ICD_10"], how="left")
smm_join.show(5)


In [None]:

from snowflake.snowpark.functions import when, col, lit, call_function

icd_regex = {
    "has_diabetes":      r"^E0[8-9]\.|^E1[0-3]\.",
    "has_hypertension":  r"^I1[0-6]\.",
    "has_asthma":        r"^J45\.",
    "has_copd":          r"^J44\.",
    "has_depression":    r"^F3[2-3]\.",
    "has_heart_failure": r"^I50\.|^I11\.0|^I13\.(0|2)",
    "has_ischemic":      r"^I2[0-5]\.",
    "has_stroke":        r"^I6[0-9]\.",
}

desc_keywords = {
    "has_diabetes":      ["DIABETES"],
    "has_hypertension":  ["HYPERTENSION"],
    "has_asthma":        ["ASTHMA"],
    "has_copd":          ["COPD", "CHRONIC OBSTRUCTIVE"],
    "has_depression":    ["DEPRESSION", "DEPRESSIVE"],
    "has_heart_failure": ["HEART FAILURE"],
    "has_ischemic":      ["ISCHEMIC", "ISCHAEMIC", "IHD"],
    "has_stroke":        ["STROKE", "CVA", "CEREBROVASCULAR"],
}

df_flags = smm_join
for flag, pattern in icd_regex.items():
    re_cond = call_function("REGEXP_LIKE", col("ICD_CODE_NORM"), lit(pattern))
    kw_cond = None
    for kw in desc_keywords[flag]:
        cond = col("ICD_DESC_NORM").contains(kw)
        kw_cond = cond if kw_cond is None else (kw_cond | cond)
    any_cond = re_cond | kw_cond
    df_flags = df_flags.with_column(flag, when(any_cond, lit(1)).otherwise(lit(0)))

df_flags.select("DIM_MEMBER_KEY", "ICD_CODE_NORM", "ICD_DESC_NORM", *icd_regex.keys()).show(5)


In [None]:

from snowflake.snowpark.functions import max as sp_max, lower, trim, coalesce, col, lit

df_flags = df_flags.with_column(
    "SMM_INDICATOR_LC",
    lower(trim(coalesce(col("SMM_INDICATOR_RAW"), lit(""))))
).with_column("HAS_ANY_SMM", (col("SMM_INDICATOR_LC") != lit("")).cast("int"))

indicator_values = [r[0] for r in df_flags.select("SMM_INDICATOR_LC").distinct().collect() if r[0]]
if indicator_values:
    smm_pivot = (
        df_flags
        .filter(col("SMM_INDICATOR_LC") != lit(""))
        .group_by(col("DIM_MEMBER_KEY"))
        .pivot(col("SMM_INDICATOR_LC"), indicator_values)
        .agg(sp_max(col("HAS_ANY_SMM")))
    )
    for v in indicator_values:
        if v in smm_pivot.schema.names:
            smm_pivot = smm_pivot.with_column_renamed(v, f"FLAG_SMM_{v.replace(' ', '_')}")
else:
    smm_pivot = df_flags.group_by("DIM_MEMBER_KEY").agg(sp_max(col("HAS_ANY_SMM")).alias("FLAG_SMM_any"))


In [None]:

from snowflake.snowpark.functions import max as sp_max, col, coalesce, lit, call_function

chronic_cols = [c for c in df_flags.schema.names if c.startswith("has_")]
agg_exprs = [sp_max(col(c)).alias(c) for c in chronic_cols] + [sp_max(col("HAS_ANY_SMM")).alias("HAS_ANY_SMM")]
chronic_wide = df_flags.group_by("DIM_MEMBER_KEY").agg(*agg_exprs)

demo_cols_candidates = ["DELIVERY_DATE", "MEMBER_RACE", "MEMBER_AGE", "MEMBER_ZIP_CODE_4", "COUNTY_NAME"]
mm_names = set(mm_tbl.schema.names)
demo_cols = [c for c in demo_cols_candidates if c in mm_names]
demo_exprs = [call_function("MAX", col(c)).alias(c) for c in demo_cols]
member_demo = mm_tbl.group_by(col("DIM_MEMBER_KEY")).agg(*demo_exprs)

final_wide = (
    member_demo
    .join(chronic_wide, on="DIM_MEMBER_KEY", how="left")
    .join(smm_pivot,    on="DIM_MEMBER_KEY", how="left")
)

for c in final_wide.schema.names:
    if c.startswith("has_") or c.startswith("FLAG_SMM_") or c == "HAS_ANY_SMM":
        final_wide = final_wide.with_column(c, coalesce(col(c), lit(0)).cast("int"))

final_wide.limit(10).show()
