In [1]:
from pyspark.sql import SparkSession, functions as F
import pandas as pd
from pathlib import PurePosixPath
import os

In [2]:
# ---------- Notebook display helpers ----------
N_SHOW = 10  # how many rows to preview at each step

def peek(df, name, n=N_SHOW):
    print(f"\n===== {name}: schema =====")
    df.printSchema()
    try:
        cnt = df.count()
    except Exception:
        cnt = None
    print(f"===== {name}: count ===== {cnt if cnt is not None else 'N/A'}")
    print(f"===== {name}: sample {n} rows =====")
    df.show(n, truncate=80)

In [3]:
# ---------- Spark session ----------
spark = (SparkSession.builder
         .appName("ENM_LOS")
         .config("spark.sql.broadcastTimeout", "360000")
         .enableHiveSupport()
         .getOrCreate())
sc = spark.sparkContext

In [4]:
# ---------- Paths ----------
FACT_PATH = "/mldev/unrestricted/ml_oi/data/pca_analytics/factline_enm_los"


In [5]:
# Read ICD Descriptions from the SAME directory where this notebook runs:
# final local path: ./pooja.kabadi/ICD Descriptions.xlsx
LOCAL_DIR       = os.getcwd()                          # notebook working dir
ICD_DIR_LOCAL   = os.path.join(LOCAL_DIR, "pooja.kabadi")
ICD_XLSX_LOCAL  = os.path.join(ICD_DIR_LOCAL, "ICD Descriptions.xlsx")
ICD_CSV_LOCAL   = os.path.join(ICD_DIR_LOCAL, "ICD Descriptions.csv")  # optional CSV fallback


In [6]:
# ---------- Recalibration thresholds (tune as needed) ----------
MIN_CLAIMS       = 200       # min cohort volume
MIN_PROVIDERS    = 20        # min cohort breadth
MIN_APV_PCT      = -0.05     # <= -3% implies strong negative adjustments vs GPV
MIN_ADJ_RATE     = 0.20      # >= 20% edited lines get adjusted
MIN_BGT_RATE     = 0.60      # >= 60% of claims billed > allowed

In [7]:
# 1) Load fact table
# ============================================================
df = spark.read.parquet(FACT_PATH)
peek(df, "STEP 1 – Loaded factline_enm_los")



===== STEP 1 – Loaded factline_enm_los: schema =====
root
 |-- payer_short: string (nullable = true)
 |-- t_received_date: string (nullable = true)
 |-- paid: double (nullable = true)
 |-- dp_key: integer (nullable = true)
 |-- dp_desc: string (nullable = true)
 |-- solution: string (nullable = true)
 |-- fl_sub_rule_key: integer (nullable = true)
 |-- t_provider_id: string (nullable = true)
 |-- t_taxgroup_id: string (nullable = true)
 |-- t_group_id: string (nullable = true)
 |-- t_subspec_id: string (nullable = true)
 |-- t_subspec_code: string (nullable = true)
 |-- t_insurance_key: integer (nullable = true)
 |-- insurance_desc: string (nullable = true)
 |-- t_pos_code: string (nullable = true)
 |-- icd_cd1: string (nullable = true)
 |-- icd_cd1_version: string (nullable = true)
 |-- icd_cd2: string (nullable = true)
 |-- icd_cd2_version: string (nullable = true)
 |-- icd_cd3: string (nullable = true)
 |-- icd_cd3_version: string (nullable = true)
 |-- icd_cd4: string (nullable = 

In [8]:
# 2) Load ICD Descriptions from local notebook folder and prep for join
# ============================================================
df_icd = None
try:
    if os.path.exists(ICD_XLSX_LOCAL):
        # Prefer .xlsx with openpyxl
        try:
            icd_pdf = pd.read_excel(ICD_XLSX_LOCAL, sheet_name=0, engine="openpyxl")
        except Exception:
            # If openpyxl not available, try default engine (may still work)
            icd_pdf = pd.read_excel(ICD_XLSX_LOCAL, sheet_name=0)
        print(f"STEP 2 – Loaded ICD Excel from: {ICD_XLSX_LOCAL}")
    elif os.path.exists(ICD_CSV_LOCAL):
        # Fallback to CSV if provided
        icd_pdf = pd.read_csv(ICD_CSV_LOCAL)
        print(f"STEP 2 – Loaded ICD CSV from: {ICD_CSV_LOCAL}")
    else:
        raise FileNotFoundError(f"Neither Excel nor CSV found at:\n  {ICD_XLSX_LOCAL}\n  {ICD_CSV_LOCAL}")

    icd_pdf.columns = [c.strip() for c in icd_pdf.columns]
    required = {"ICD_code","ICD_Description"}
    if not required.issubset(set(icd_pdf.columns)):
        missing = required - set(icd_pdf.columns)
        raise ValueError(f"ICD file missing columns: {missing}")

    df_icd = (spark.createDataFrame(icd_pdf)
                .withColumn("ICD_code_norm", F.upper(F.trim(F.col("ICD_code"))))
                .dropDuplicates(["ICD_code_norm"])
                .select("ICD_code_norm","ICD_Description"))
    peek(df_icd, "STEP 2 – ICD Descriptions (local)")
except Exception as e:
    print(f"STEP 2 – ICD Descriptions: READ FAILED (skipping join). Reason: {e}")


STEP 2 – ICD Descriptions: READ FAILED (skipping join). Reason: Neither Excel nor CSV found at:
  /mltools/jupyterhub/notebooks/pooja.kabadi/pooja.kabadi/ICD Descriptions.xlsx
  /mltools/jupyterhub/notebooks/pooja.kabadi/pooja.kabadi/ICD Descriptions.csv


In [9]:
# 3) Core derivations: Primary_ICD, ICD_root, numeric LOS, flags
# ============================================================
df = (df
      .withColumn("Primary_ICD", F.coalesce(F.col("icd_cd1"), F.col("icd1")))
      .withColumn("Claims", F.col("Primary_ICD"))
      .withColumn(
          "ICD_root",
          F.upper(
              F.regexp_replace(
                  F.element_at(F.split(F.col("Primary_ICD"), "\\."), 1),  # token before '.'
                  r"^0+",""
              )
          )
      )
      .withColumn("ICD_root", F.when(F.length("ICD_root")==0, None).otherwise(F.col("ICD_root")))
      .withColumn("allowed_los_int", F.col("allowed_los").cast("int"))
      .withColumn("billed_los_int",  F.col("billed_los").cast("int"))
      .withColumn("los_agree",        (F.col("billed_los_int")==F.col("allowed_los_int")).cast("int"))
      .withColumn("billed_gt_allowed",(F.col("billed_los_int")>F.col("allowed_los_int")).cast("int"))
      .withColumn("billed_lt_allowed",(F.col("billed_los_int")<F.col("allowed_los_int")).cast("int"))
)

# Attach ICD Description if available
if df_icd is not None:
    df = (df.join(df_icd,
                  F.upper(F.trim(F.col("Primary_ICD"))) == F.col("ICD_code_norm"),
                  "left")
            .drop("ICD_code_norm"))

peek(df.select("Primary_ICD","ICD_root","allowed_los","billed_los",
               "allowed_los_int","billed_los_int","los_agree",
               "billed_gt_allowed","billed_lt_allowed",
               *(["ICD_Description"] if "ICD_Description" in df.columns else [])),
     "STEP 3 – After core derivations (preview key fields)")





===== STEP 3 – After core derivations (preview key fields): schema =====
root
 |-- Primary_ICD: string (nullable = true)
 |-- ICD_root: string (nullable = true)
 |-- allowed_los: string (nullable = true)
 |-- billed_los: string (nullable = true)
 |-- allowed_los_int: integer (nullable = true)
 |-- billed_los_int: integer (nullable = true)
 |-- los_agree: integer (nullable = true)
 |-- billed_gt_allowed: integer (nullable = true)
 |-- billed_lt_allowed: integer (nullable = true)

===== STEP 3 – After core derivations (preview key fields): count ===== 7094526
===== STEP 3 – After core derivations (preview key fields): sample 10 rows =====
+-----------+--------+-----------+----------+---------------+--------------+---------+-----------------+-----------------+
|Primary_ICD|ICD_root|allowed_los|billed_los|allowed_los_int|billed_los_int|los_agree|billed_gt_allowed|billed_lt_allowed|
+-----------+--------+-----------+----------+---------------+--------------+---------+-----------------+----

In [10]:
# 4) Null-safety for finance/line counters + drop columns
# ============================================================
for c in ["apv","gpv","paid","adjusted_lines","edited_lines","total_lines"]:
    if c in df.columns:
        df = df.withColumn(c, F.coalesce(F.col(c), F.lit(0.0)))

cols_to_drop = [
    "t_group_id",
    "t_sub_mod1","t_sub_mod2","t_sub_mod3","t_sub_mod4",
    "t_other_reduction",
    "initial_mid_rule1_key","initial_rule1_version","initial_medical1_policy",
    "initial_mid_rule2_key","initial_rule2_version","initial_medical2_policy",
    "initial_mid_rule3_key","initial_rule3_version","initial_medical3_policy",
    "initial_mid_rule4_key","initial_rule4_version","initial_medical4_policy",
    "initial_mid_rule5_key","initial_rule5_version","initial_medical5_policy",
    "latest_mid_rule2_key","latest_rule2_version","latest_medical2_policy",
    "latest_mid_rule3_key","latest_rule3_version","latest_medical3_policy",
    "latest_mid_rule4_key","latest_rule4_version","latest_medical4_policy",
    "latest_mid_rule5_key","latest_rule5_version","latest_medical5_policy",
    "capitated_flag",
    "cdf_text_1","cdf_text_2","cdf_text_3",
    "pccv_indicator",
    "pos_bucket","spec_bucket"  # drop if present
]
df = df.drop(*[c for c in cols_to_drop if c in df.columns])

df = df.withColumn("APV_percent", F.when(F.col("gpv") != 0, F.col("apv")/F.col("gpv")))
peek(df.select("Primary_ICD","ICD_root","cpt_group","allowed_los_int","billed_los_int",
               "apv","gpv","APV_percent","adjusted_lines","edited_lines"),
     "STEP 4 – After drops + APV_percent")


===== STEP 4 – After drops + APV_percent: schema =====
root
 |-- Primary_ICD: string (nullable = true)
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- allowed_los_int: integer (nullable = true)
 |-- billed_los_int: integer (nullable = true)
 |-- apv: double (nullable = false)
 |-- gpv: double (nullable = false)
 |-- APV_percent: double (nullable = true)
 |-- adjusted_lines: double (nullable = false)
 |-- edited_lines: double (nullable = false)

===== STEP 4 – After drops + APV_percent: count ===== 7094526
===== STEP 4 – After drops + APV_percent: sample 10 rows =====
+-----------+--------+---------+---------------+--------------+---+------------------+-----------+--------------+------------+
|Primary_ICD|ICD_root|cpt_group|allowed_los_int|billed_los_int|apv|               gpv|APV_percent|adjusted_lines|edited_lines|
+-----------+--------+---------+---------------+--------------+---+------------------+-----------+--------------+------------+
|     

In [11]:
# 5) Cohort metrics (ICD_root × cpt_group) for recalibration
# ============================================================
for k in [1,2,3,4,5]:
    df = df.withColumn(f"billed_n{k}",  (F.col("billed_los_int")==F.lit(k)).cast("int"))
    df = df.withColumn(f"allowed_n{k}", (F.col("allowed_los_int")==F.lit(k)).cast("int"))

cohort = (df.groupBy("ICD_root","cpt_group")
            .agg(
                F.count("*").alias("claims"),
                F.countDistinct("t_provider_id").alias("providers"),
                F.sum("paid").alias("paid_sum"),
                F.sum("apv").alias("APV_sum"),
                F.sum("gpv").alias("GPV_sum"),
                F.sum("adjusted_lines").alias("adjusted_lines_sum"),
                F.sum("edited_lines").alias("edited_lines_sum"),
                *[F.sum(f"billed_n{k}").alias(f"billed_n{k}") for k in [1,2,3,4,5]],
                *[F.sum(f"allowed_n{k}").alias(f"allowed_n{k}") for k in [1,2,3,4,5]],
                F.sum("los_agree").alias("los_agree_count"),
                F.sum("billed_gt_allowed").alias("billed_gt_allowed_count"),
                F.sum("billed_lt_allowed").alias("billed_lt_allowed_count")
            )
            .withColumn("APV_percent", F.when(F.col("GPV_sum")!=0, F.col("APV_sum")/F.col("GPV_sum")))
            .withColumn("adj_rate", F.when(F.col("edited_lines_sum")!=0, F.col("adjusted_lines_sum")/F.col("edited_lines_sum")))
            .withColumn("los_agreement_rate", F.col("los_agree_count")/F.col("claims"))
            .withColumn("billed_gt_allowed_rate", F.col("billed_gt_allowed_count")/F.col("claims"))
            .withColumn("billed_lt_allowed_rate", F.col("billed_lt_allowed_count")/F.col("claims"))
         )

peek(cohort.select("ICD_root","cpt_group","claims","providers","APV_sum","GPV_sum",
                   "APV_percent","adj_rate","billed_gt_allowed_rate","los_agreement_rate"),
     "STEP 5 – Cohort metrics (preview)")



===== STEP 5 – Cohort metrics (preview): schema =====
root
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- claims: long (nullable = false)
 |-- providers: long (nullable = false)
 |-- APV_sum: double (nullable = true)
 |-- GPV_sum: double (nullable = true)
 |-- APV_percent: double (nullable = true)
 |-- adj_rate: double (nullable = true)
 |-- billed_gt_allowed_rate: double (nullable = true)
 |-- los_agreement_rate: double (nullable = true)

===== STEP 5 – Cohort metrics (preview): count ===== 7379
===== STEP 5 – Cohort metrics (preview): sample 10 rows =====
+--------+---------+------+---------+-------------------+------------------+----------------------+--------------------+----------------------+------------------+
|ICD_root|cpt_group|claims|providers|            APV_sum|           GPV_sum|           APV_percent|            adj_rate|billed_gt_allowed_rate|los_agreement_rate|
+--------+---------+------+---------+-------------------+-------------

In [12]:
# Cohort modes (billed/allowed)
billed_arr  = F.array(*[F.struct(F.col(f"billed_n{k}").alias("cnt"),  F.lit(k).alias("los")) for k in [1,2,3,4,5]])
allowed_arr = F.array(*[F.struct(F.col(f"allowed_n{k}").alias("cnt"), F.lit(k).alias("los")) for k in [1,2,3,4,5]])

cohort = (cohort
          .withColumn("billed_mode_struct",  F.element_at(F.array_sort(billed_arr),  -1))
          .withColumn("allowed_mode_struct", F.element_at(F.array_sort(allowed_arr), -1))
          .withColumn("billed_mode_los",  F.col("billed_mode_struct.los"))
          .withColumn("allowed_mode_los", F.col("allowed_mode_struct.los"))
          .drop("billed_mode_struct","allowed_mode_struct"))

peek(cohort.select("ICD_root","cpt_group","billed_mode_los","allowed_mode_los"),
     "STEP 5b – Cohort modes (billed & allowed)")


===== STEP 5b – Cohort modes (billed & allowed): schema =====
root
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- billed_mode_los: integer (nullable = true)
 |-- allowed_mode_los: integer (nullable = true)

===== STEP 5b – Cohort modes (billed & allowed): count ===== 7379
===== STEP 5b – Cohort modes (billed & allowed): sample 10 rows =====
+--------+---------+---------------+----------------+
|ICD_root|cpt_group|billed_mode_los|allowed_mode_los|
+--------+---------+---------------+----------------+
|     L85|    ESTOV|              4|               3|
|     L20|    IPERV|              4|               3|
|     S86|    ESTOV|              4|               3|
|     R46|    IPERV|              5|               3|
|     J13|    ESTOV|              4|               3|
|     S63|    NEWOV|              4|               3|
|     G43|    EYEES|              4|               2|
|     R51|    IPFUV|              3|               2|
|     N52|    NEWOV|   

In [13]:
# Recalibration rule → proposed_los (weak label)
cohort = (cohort
    .withColumn(
        "proposed_los",
        F.when(
            (F.col("claims")      >= F.lit(MIN_CLAIMS)) &
            (F.col("providers")   >= F.lit(MIN_PROVIDERS)) &
            (F.col("APV_percent") <= F.lit(MIN_APV_PCT)) &        # strong negative adjustments
            (F.col("adj_rate")    >= F.lit(MIN_ADJ_RATE)) &       # high adjusted-lines burden
            (F.col("billed_gt_allowed_rate") >= F.lit(MIN_BGT_RATE)),
            F.col("billed_mode_los")                              # adopt billed mode
        ).otherwise(F.col("allowed_mode_los"))                    # else keep baseline
    )
)
peek(cohort.select("ICD_root","cpt_group","claims","APV_percent","adj_rate",
                   "billed_gt_allowed_rate","allowed_mode_los","billed_mode_los","proposed_los"),
     "STEP 5c – Proposed LOS (recalibrated)")


===== STEP 5c – Proposed LOS (recalibrated): schema =====
root
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- claims: long (nullable = false)
 |-- APV_percent: double (nullable = true)
 |-- adj_rate: double (nullable = true)
 |-- billed_gt_allowed_rate: double (nullable = true)
 |-- allowed_mode_los: integer (nullable = true)
 |-- billed_mode_los: integer (nullable = true)
 |-- proposed_los: integer (nullable = true)

===== STEP 5c – Proposed LOS (recalibrated): count ===== 7379
===== STEP 5c – Proposed LOS (recalibrated): sample 10 rows =====
+--------+---------+------+----------------------+--------------------+----------------------+----------------+---------------+------------+
|ICD_root|cpt_group|claims|           APV_percent|            adj_rate|billed_gt_allowed_rate|allowed_mode_los|billed_mode_los|proposed_los|
+--------+---------+------+----------------------+--------------------+----------------------+----------------+---------------+-

In [14]:
# 6) Join proposed_los back to claims (build final DR dataset)
# ============================================================
df_out = (df.join(
            cohort.select("ICD_root","cpt_group","proposed_los"),
            ["ICD_root","cpt_group"],
            "left")
          .withColumn("target_los_allowed", F.col("allowed_los_int").cast("int"))
          .withColumn("target_los_billed",  F.col("billed_los_int").cast("int"))
          .withColumn("target_los_recalib", F.col("proposed_los").cast("int"))
)

peek(df_out.select("Primary_ICD","ICD_root","cpt_group",
                   "allowed_los_int","billed_los_int",
                   "target_los_allowed","target_los_billed","target_los_recalib",
                   "APV_percent","apv","gpv","adjusted_lines","edited_lines").limit(50),
     "STEP 6 – Final claim-level dataset (preview key targets)")

# Keep rows with at least one valid target 1..5
df_out = df_out.filter(
    (F.col("target_los_allowed").between(1,5)) |
    (F.col("target_los_billed").between(1,5))  |
    (F.col("target_los_recalib").between(1,5))
)
peek(df_out, "STEP 6b – Filtered to valid targets")


===== STEP 6 – Final claim-level dataset (preview key targets): schema =====
root
 |-- Primary_ICD: string (nullable = true)
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- allowed_los_int: integer (nullable = true)
 |-- billed_los_int: integer (nullable = true)
 |-- target_los_allowed: integer (nullable = true)
 |-- target_los_billed: integer (nullable = true)
 |-- target_los_recalib: integer (nullable = true)
 |-- APV_percent: double (nullable = true)
 |-- apv: double (nullable = false)
 |-- gpv: double (nullable = false)
 |-- adjusted_lines: double (nullable = false)
 |-- edited_lines: double (nullable = false)

===== STEP 6 – Final claim-level dataset (preview key targets): count ===== 50
===== STEP 6 – Final claim-level dataset (preview key targets): sample 10 rows =====
+-----------+--------+---------+---------------+--------------+------------------+-----------------+------------------+-----------+------+------+--------------+------------+


In [15]:
# ============================================================
# 7) Save ONE CSV to /mldev/restricted/AWSS3/ml_oi_aa/
# ============================================================
OUT_PARENT = "/mldev/restricted/AWSS3/ml_oi_aa"
TMP_DIR    = f"{OUT_PARENT}/_tmp_enm_los_datarobot_csv"
FINAL_CSV  = f"{OUT_PARENT}/enm_los_datarobot_ready.csv"

def save_single_csv(df_export, out_dir, final_csv_path):
    # write to a temp folder as a single part
    (df_export.coalesce(1)
              .write
              .mode("overwrite")
              .option("header", True)
              .csv(out_dir))

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    Path = sc._jvm.org.apache.hadoop.fs.Path

    # ensure parent exists (mkdir -p)
    fs.mkdirs(Path(OUT_PARENT))

    outPath  = Path(out_dir)
    files    = fs.listStatus(outPath)

    # find the part file Spark wrote
    part = None
    for f in files:
        name = f.getPath().getName()
        if name.startswith("part-") and (name.endswith(".csv") or "." not in name):
            part = f.getPath()
            break
    if part is None:
        raise RuntimeError(f"No part-* file found in {out_dir}")

    finalPath = Path(final_csv_path)
    # overwrite if exists
    if fs.exists(finalPath):
        fs.delete(finalPath, True)
    fs.rename(part, finalPath)

    # cleanup temp dir
    for f in fs.listStatus(outPath):
        fs.delete(f.getPath(), True)
    fs.delete(outPath, True)

save_single_csv(df_out, TMP_DIR, FINAL_CSV)
print(f"\n✅ Wrote DataRobot CSV: {FINAL_CSV}")

# Optional: quick read-back sanity check
csv_head = spark.read.option("header", True).csv(FINAL_CSV)
peek(csv_head, "STEP 7 – Read-back of final CSV (sanity check)", n=5)



✅ Wrote DataRobot CSV: /mldev/restricted/AWSS3/ml_oi_aa/enm_los_datarobot_ready.csv

===== STEP 7 – Read-back of final CSV (sanity check): schema =====
root
 |-- ICD_root: string (nullable = true)
 |-- cpt_group: string (nullable = true)
 |-- payer_short: string (nullable = true)
 |-- t_received_date: string (nullable = true)
 |-- paid: string (nullable = true)
 |-- dp_key: string (nullable = true)
 |-- dp_desc: string (nullable = true)
 |-- solution: string (nullable = true)
 |-- fl_sub_rule_key: string (nullable = true)
 |-- t_provider_id: string (nullable = true)
 |-- t_taxgroup_id: string (nullable = true)
 |-- t_subspec_id: string (nullable = true)
 |-- t_subspec_code: string (nullable = true)
 |-- t_insurance_key: string (nullable = true)
 |-- insurance_desc: string (nullable = true)
 |-- t_pos_code: string (nullable = true)
 |-- icd_cd1: string (nullable = true)
 |-- icd_cd1_version: string (nullable = true)
 |-- icd_cd2: string (nullable = true)
 |-- icd_cd2_version: string (n

In [16]:
# ... after you've computed p_final ...

# ---- robust permission helper ----
jvm = sc._jvm
fs  = jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())

def set_perm_safe(hdfs_path_str, mode="644"):
    """
    mode can be '644', '0644', or 'rw-r--r--'.
    Silently skips if the FS disallows chmod.
    """
    try:
        p = jvm.org.apache.hadoop.fs.Path(hdfs_path_str) if isinstance(hdfs_path_str, str) else hdfs_path_str
        FPerm = jvm.org.apache.hadoop.fs.permission.FsPermission  # <-- correct package
        # accept octal or symbolic
        if all(ch.isdigit() for ch in mode):
            if not mode.startswith("0"):
                mode = "0" + mode
            perm = FPerm.valueOf(mode)           # e.g., "0644"
        else:
            perm = FPerm.valueOf(mode)           # e.g., "rw-r--r--"
        fs.setPermission(p, perm)
    except Exception as e:
        print(f"[warn] setPermission skipped: {e}")

# make world-readable
set_perm_safe(p_final, "0644")
print("Single CSV written to:", p_final.toString())


NameError: name 'p_final' is not defined

In [None]:
parent_dir = str(PurePosixPath(FACT_PATH).parent)
TMP_DIR   = f"{parent_dir}/_tmp_enm_los_datarobot_csv"
FINAL_CSV = f"{parent_dir}/enm_los_datarobot_ready"

def save_single_csv(df_export, out_dir, final_csv_path):
    (df_export.coalesce(1)
              .write
              .mode("overwrite")
              .option("header", True)
              .csv(out_dir))

    fs = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration())
    outPath  = sc._jvm.org.apache.hadoop.fs.Path(out_dir)
    files    = fs.listStatus(outPath)

    part = None
    for f in files:
        name = f.getPath().getName()
        if name.startswith("part-") and (name.endswith(".csv") or "." not in name):
            part = f.getPath()
            break
    if part is None:
        raise RuntimeError(f"No part-* file found in {out_dir}")

    finalPath = sc._jvm.org.apache.hadoop.fs.Path(final_csv_path)
    fs.delete(finalPath, True)
    fs.rename(part, finalPath)

    # cleanup temp dir
    for f in fs.listStatus(outPath):
        fs.delete(f.getPath(), True)
    fs.delete(outPath, True)

save_single_csv(df_out, TMP_DIR, FINAL_CSV)
print(f"\n✅ Wrote DataRobot CSV: {FINAL_CSV}")

# Small tail peek (read back the header + a few rows) — optional
csv_head = spark.read.option("header", True).csv(FINAL_CSV)
peek(csv_head, "STEP 7 – Read-back of final CSV (sanity check)", n=5)