In [0]:
# pipeline.py  -- Delta Live Tables (DLT) pipeline for Healthcare
#
# Notes:
# - Bronze = raw ingestion
# - Silver = deduplication, validation, enrichment
# - Gold = business-ready fraud scoring and alerts
# - Uses DLT quality gates (expect/expect_or_drop) for governance

import dlt
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, TimestampType, BooleanType, DateType
)

# -------------------------
# Configuration
# -------------------------
SOURCE_BASE = "/Volumes/workspace/default/vol_in"
CLAIMS_BATCH = f"{SOURCE_BASE}/claims_batch.csv"
CLAIMS_STREAM = f"{SOURCE_BASE}/claims_stream.json"
MEMBERS_FILE = f"{SOURCE_BASE}/members.csv"
PROVIDERS_FILE = f"{SOURCE_BASE}/providers.json"
DIAG_REF_FILE = f"{SOURCE_BASE}/diagnosis_ref.csv"

# Business thresholds
HIGH_AMOUNT_THRESHOLD = 1000.0   # threshold for "high amount"
FALLBACK_HIGH_RISK_DIAG = ["D123", "X999"]  # fallback risky diagnosis codes

# -------------------------
# Helpers
# -------------------------
def _exists(path: str) -> bool:
    """Check if DBFS/Volume path exists and contains files."""
    try:
        return len(dbutils.fs.ls(path)) > 0
    except Exception:
        return False

def _empty_claims_df():
    """Empty claims DF (fallback if file missing)."""
    schema = StructType([
        StructField("ClaimID", StringType(), True),
        StructField("MemberID", StringType(), True),
        StructField("ProviderID", StringType(), True),
        StructField("Amount", DoubleType(), True),
        StructField("ICD10Codes", StringType(), True),
        StructField("ingest_ts", TimestampType(), True),
    ])
    return spark.createDataFrame([], schema=schema)

def _empty_members_df():
    schema = StructType([
        StructField("MemberID", StringType(), True),
        StructField("Name", StringType(), True),
        StructField("DOB", DateType(), True)
    ])
    return spark.createDataFrame([], schema=schema)

def _empty_providers_df():
    schema = StructType([
        StructField("ProviderID", StringType(), True),
        StructField("Name", StringType(), True),
        StructField("Locations", StringType(), True),
        StructField("IsActive", BooleanType(), True)
    ])
    return spark.createDataFrame([], schema=schema)

def _empty_diag_ref_df():
    schema = StructType([
        StructField("Code", StringType(), True),
        StructField("Description", StringType(), True)
    ])
    return spark.createDataFrame([], schema=schema)

# ------------------------------------------------
# BRONZE: Raw ingestion
# ------------------------------------------------

@dlt.table(
    name="bronze_claims_raw",
    comment="Raw claims ingested from claims_batch.csv"
)
def bronze_claims_raw():
    df = (spark.read.option("header", "true").option("inferSchema", "true").csv(CLAIMS_BATCH))
    return df.withColumn("ingest_ts", F.current_timestamp())

@dlt.table(
    name="bronze_claims_stream",
    comment="Streaming claims ingested from claims_stream.json"
)
def bronze_claims_stream():
    df = (spark.read.option("header", "true").option("inferSchema", "true").json(CLAIMS_STREAM))
    return df.withColumn("ingest_ts", F.current_timestamp())

@dlt.table(
    name="bronze_members_raw",
    comment="Raw members ingested from members.csv"
)
def bronze_members_raw():
    df = (spark.read.option("header", "true").option("inferSchema", "true").csv(MEMBERS_FILE))
    return df.withColumn("ingest_ts", F.current_timestamp())


@dlt.table(
    name="bronze_providers_raw",
    comment="Raw providers ingested from providers.json"
)
def bronze_providers_raw():
    df = spark.read.json(PROVIDERS_FILE)
    return df.withColumn("ingest_ts", F.current_timestamp())


@dlt.table(
    name="bronze_diag_ref_raw",
    comment="Diagnosis reference ingested from diagnosis_ref.csv"
)
def bronze_diag_ref_raw():
    df = (spark.read.option("header", "true").option("inferSchema", "true").csv(DIAG_REF_FILE))
    return df.withColumn("ingest_ts", F.current_timestamp())


# ------------------------------------------------
# SILVER: Deduplication, Validation, Enrichment
# ------------------------------------------------

@dlt.table
@dlt.expect("silver_claim_amount_present", "Amount IS NOT NULL")
def silver_claims_dedup():
    """
    Combine batch + stream claims, deduplicate by ClaimID (keep latest ingest_ts).
    Produces canonical set of claims for enrichment.
    """
    batch = dlt.read("bronze_claims_raw")
    stream = dlt.read("bronze_claims_stream")   # <- assumed created separately
    combined = batch.unionByName(stream, allowMissingColumns=True)

    if "ingest_ts" not in combined.columns:
        combined = combined.withColumn("ingest_ts", F.current_timestamp())

    from pyspark.sql.window import Window
    w = Window.partitionBy("ClaimID").orderBy(F.col("ingest_ts").desc_nulls_last())
    return (combined
            .withColumn("rn", F.row_number().over(w))
            .filter(F.col("rn") == 1)
            .drop("rn"))


@dlt.table
def silver_claims_enriched_v1():
    """
    Enrich deduped claims with members, providers, and diagnosis reference.
    Always ensure diagnosis_high_risk exists.
    """
    claims = dlt.read("silver_claims_dedup")
    members = dlt.read("bronze_members_raw")
    providers = dlt.read("bronze_providers_raw")
    diag_ref = dlt.read("bronze_diag_ref_raw")

    # --- Handle duplicate ingest_ts columns ---
    claims = claims.withColumnRenamed("ingest_ts", "claims_ingest_ts")
    members = members.withColumnRenamed("ingest_ts", "members_ingest_ts")
    providers = providers.withColumnRenamed("ingest_ts", "providers_ingest_ts")

    # --- Extract diagnosis code from ICD10Codes ---
    claims = claims.withColumn(
        "Diagnosis_Code",
        F.when(F.col("ICD10Codes").isNotNull(),
               F.split(F.col("ICD10Codes"), "[,;\\s]+").getItem(0))
    )

    # --- Normalize IDs ---
    def normalize(col):
        return F.upper(F.trim(F.regexp_replace(col, "[^A-Za-z0-9]", "")))

    claims = (claims
              .withColumn("norm_member_id", normalize(F.col("MemberID")))
              .withColumn("norm_provider_id", normalize(F.col("ProviderID"))))

    members = members.withColumn("norm_member_id", normalize(F.col("MemberID")))
    providers = providers.withColumn("norm_provider_id", normalize(F.col("ProviderID")))

    # --- Normalize providers: explode Locations if exists ---
    if "Locations" in providers.columns:
        providers_norm = (
            providers
            .select(
                F.col("ProviderID"),
                F.col("Name").alias("provider_name"),
                F.explode_outer("Locations").alias("location"),
                F.col("Specialties")
            )
            .withColumn("provider_address", F.col("location.Address"))
            .withColumn("provider_city", F.col("location.City"))
            .withColumn("provider_state", F.col("location.State"))
            .drop("location")
        )
    else:
        providers_norm = providers.select(
            F.col("ProviderID"),
            F.col("Name").alias("provider_name")
        )

    # --- Join members + providers ---
    joined = (claims.alias("c")
              .join(members.alias("m"), F.col("c.MemberID") == F.col("m.MemberID"), how="left")
              .join(providers_norm.alias("p"), F.col("c.ProviderID") == F.col("p.ProviderID"), how="left"))

    joined = joined.withColumn("member_exists", F.col("m.MemberID").isNotNull()) \
                   .withColumn("provider_exists", F.col("p.ProviderID").isNotNull())

    # --- Diagnosis reference handling ---
    if "Code" in diag_ref.columns:
        dx_sel = diag_ref.select(
            F.col("Code").alias("dx_code"),
            F.col("Description").alias("diagnosis_description")
        )
        joined = joined.join(dx_sel, F.col("c.Diagnosis_Code") == F.col("dx_code"), how="left")
    else:
        joined = joined.withColumn("dx_code", F.lit(None)) \
                       .withColumn("diagnosis_description", F.lit(None))

    # --- High-risk flag ---
    joined = joined.withColumn(
        "diagnosis_high_risk",
        F.col("Diagnosis_Code").isin(FALLBACK_HIGH_RISK_DIAG)
    )

    # --- Final select ---
    select_cols = [
        F.col("c.ClaimID").alias("ClaimID"),
        F.col("c.MemberID").alias("MemberID"),
        F.col("c.ProviderID").alias("ProviderID"),
        F.col("c.Amount").cast("double").alias("Amount"),
        F.col("Diagnosis_Code"),
        F.col("diagnosis_description"),
        F.col("claims_ingest_ts").alias("ingest_ts"),  # ✅ Only claims ingest_ts kept
        F.col("member_exists"),
        F.col("provider_exists"),
        F.col("diagnosis_high_risk"),
    ]

    if "provider_name" in providers_norm.columns:
        select_cols.append(F.col("provider_name"))
    if "provider_city" in providers_norm.columns:
        select_cols.append(F.col("provider_city"))
    if "provider_state" in providers_norm.columns:
        select_cols.append(F.col("provider_state"))

    enriched = joined.select(*select_cols)

    # --- Add is_valid flag ---
    enriched = enriched.withColumn("is_valid", F.col("member_exists") & F.col("provider_exists"))

    return enriched


@dlt.table
def silver_invalid_claims():
    """
    Capture invalid claims for audit/remediation, with error_reason.
    """
    enriched = dlt.read("silver_claims_enriched_v1")
    invalids = enriched.filter(
        (F.col("is_valid") == False) | F.col("ClaimID").isNull() | F.col("MemberID").isNull()
    )
    return invalids.withColumn(
        "error_reason",
        F.when(F.col("ClaimID").isNull(), "missing_claim_id")
         .when(F.col("MemberID").isNull(), "missing_member_id")
         .when(F.col("is_valid") == False, "fk_not_found")
         .otherwise("other")
    )


# ------------------------------------------------
# GOLD: Fraud Scoring + Alerts
# ------------------------------------------------

@dlt.table(
    name="gold_claims_fraud_scores",
    comment="Business-facing Gold table for fraud scoring",
    table_properties={
        "quality": "gold",
        "pipelines.autoOptimize.zOrderCols": "risk_bucket"
    }
)
@dlt.expect("valid_fraud_score", "fraud_score >= 0")
def gold_claims_fraud_scores():
    """
    Multi-signal fraud scoring:
      - Amount
      - Diagnosis
      - Provider flag
    """
    df = dlt.read("silver_claims_enriched_v1").withColumn("Amount", F.col("Amount").cast("double"))

    # Provider flagged
    if "IsActive" in df.columns:
        df = df.withColumn("provider_flagged", F.when(F.col("IsActive") == False, True).otherwise(False))
    else:
        df = df.withColumn("provider_flagged", F.lit(False))

    # Diagnosis high risk
    if "diagnosis_description" in df.columns:
        df = df.withColumn("diagnosis_high_risk",
                           F.upper(F.col("diagnosis_description")).rlike("CANCER|MALIGNANT|CRITICAL"))
    else:
        df = df.withColumn("diagnosis_high_risk", F.col("Diagnosis_Code").isin(FALLBACK_HIGH_RISK_DIAG))

    # Score computation
    scored = (df
              .withColumn("score_amount", F.when(F.col("Amount") > HIGH_AMOUNT_THRESHOLD, 0.6).otherwise(0.0))
              .withColumn("score_diag", F.when(F.col("diagnosis_high_risk"), 0.3).otherwise(0.0))
              .withColumn("score_provider", F.when(F.col("provider_flagged"), 0.4).otherwise(0.0))
              .withColumn("fraud_score", F.col("score_amount") + F.col("score_diag") + F.col("score_provider"))
              .withColumn("risk_bucket",
                          F.when(F.col("fraud_score") >= 0.7, "high")
                           .when(F.col("fraud_score") >= 0.3, "medium")
                           .otherwise("low"))
             )

    return scored.select("ClaimID", "MemberID", "ProviderID", "Amount", "Diagnosis_Code",
                         "diagnosis_description", "fraud_score", "risk_bucket", "ingest_ts")


@dlt.table(
    name="gold_fraud_alerts",
    comment="High-risk claims alerts with enriched diagnosis details"
)
def gold_fraud_alerts():
    """
    Alerts for high-risk claims.
    Downstream systems can read this frequently for investigations.
    """
    scored = dlt.read("gold_claims_fraud_scores")

    alerts = (scored
              .filter(F.col("risk_bucket") == "high")
              .select(
                  "ClaimID",
                  "MemberID",
                  "ProviderID",
                  "Amount",
                  "Diagnosis_Code",
                  "fraud_score",
                  "risk_bucket",
                  "ingest_ts"
              ))

    return alerts