In [0]:
import dlt
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import re
from functools import reduce

# --------------------------------------------------------------------
# CONFIG
# --------------------------------------------------------------------
RAW_PATH = "/Volumes/workspace/food_inspection/raw"

# --------------------------------------------------------------------
# HELPER: CLEAN COLUMN NAMES (for all bronze tables)
# --------------------------------------------------------------------
def clean_column_names(df):
    """
    Make column names Delta-friendly:
    - lowercase
    - spaces -> _
    - remove #, /, \, (, ), ., -, etc.
    """
    for old in df.columns:
        new = (
            old.lower()
               .replace(" ", "_")
               .replace("#", "")
               .replace("/", "_")
               .replace("\\", "_")
               .replace("-", "_")
               .replace("(", "")
               .replace(")", "")
               .replace(".", "")
               .replace(",", "")
        )
        if new != old:
            df = df.withColumnRenamed(old, new)
    return df


# --------------------------------------------------------------------
# BRONZE LAYER
# --------------------------------------------------------------------

@dlt.table(
    name="bronze_chicago",
    comment="Raw Chicago food inspections from TSV files, with cleaned column names"
)
def bronze_chicago():
    df = (
        spark.read.format("csv")
        .option("header", "true")
        .option("sep", "\t")
        .load(f"{RAW_PATH}/Chicago_*.tsv")
        .withColumn("city_code", F.lit("CHI"))
    )
    return clean_column_names(df)


@dlt.table(
    name="bronze_dallas",
    comment="Raw Dallas food inspections from TSV files, with cleaned column names and inspection_id"
)
def bronze_dallas():
    # read raw TSV
    df = (
        spark.read.format("csv")
        .option("header", "true")
        .option("sep", "\t")
        .load(f"{RAW_PATH}/Dallas_*.tsv")
    )

    # clean raw column names first
    df = clean_column_names(df)

    # add city_code and a stable inspection_id for Dallas
    df = (
        df
        .withColumn("city_code", F.lit("DAL"))
        .withColumn("inspection_id", F.monotonically_increasing_id().cast("string"))
    )

    return df


# --------------------------------------------------------------------
# HELPER TRANSFORMS: BRONZE -> SILVER SCHEMA
# (use the CLEANED column names)
# --------------------------------------------------------------------
def transform_chicago_to_silver(df):
    """
    Input: bronze_chicago with cleaned column names:
      inspection_id, dba_name, aka_name, license_, facility_type, risk,
      address, city, state, zip, inspection_date, inspection_type,
      results, latitude, longitude, location, city_code
    Output: unified silver schema for inspections.
    """
    df = (
        df
        .withColumn("inspection_id", F.col("inspection_id"))
        .withColumn("business_name", F.upper(F.col("dba_name")))
        .withColumn("aka_name", F.upper(F.col("aka_name")))
        .withColumn("license_number", F.col("license_").cast("string"))  # <-- fixed here
        .withColumn("street_address", F.col("address"))
        .withColumn("city", F.col("city"))
        .withColumn("state", F.col("state"))
        .withColumn("zip", F.lpad(F.col("zip").cast("string"), 5, "0"))
        .withColumn("inspection_date", F.to_date("inspection_date"))
        .withColumn("inspection_year", F.year("inspection_date"))
        .withColumn("inspection_month", F.month("inspection_date"))
        .withColumn("inspection_type", F.col("inspection_type"))
        .withColumn("inspection_result", F.col("results"))
        .withColumn(
            "inspection_score",
            F.when(F.col("results") == "Pass", F.lit(90))
             .when(F.col("results") == "Pass w/ Conditions", F.lit(80))
             .when(F.col("results") == "Fail", F.lit(70))
             .when(F.col("results") == "No Entry", F.lit(0))
             .otherwise(F.lit(None).cast("int"))
        )
        .withColumn("facility_type", F.col("facility_type"))
        .withColumn("risk_category", F.col("risk"))
        .withColumn("latitude", F.expr("try_cast(latitude as double)"))
        .withColumn("longitude", F.expr("try_cast(longitude as double)"))
    )

    return df.select(
        "inspection_id",
        "city_code",
        "business_name",
        "aka_name",
        "license_number",
        "street_address",
        "city",
        "state",
        "zip",
        "inspection_date",
        "inspection_year",
        "inspection_month",
        "inspection_type",
        "inspection_result",
        "inspection_score",
        "facility_type",
        "risk_category",
        "latitude",
        "longitude",
    )

def transform_dallas_to_silver(df):
    """
    Input: bronze_dallas with cleaned column names, including inspection_id
    Output: unified silver schema for inspections.
    """
    df = (
        df
        # REUSE the inspection_id from bronze_dallas
        .withColumn("inspection_id", F.col("inspection_id"))
        .withColumn("business_name", F.upper(F.col("restaurant_name")))
        .withColumn("aka_name", F.lit(None).cast("string"))
        .withColumn("license_number", F.lit(None).cast("string"))
        .withColumn("street_address", F.col("street_address"))
        .withColumn("city", F.lit("DALLAS"))
        .withColumn("state", F.lit("TX"))
        .withColumn("zip", F.regexp_extract(F.col("zip_code"), r"(\d{5})", 1))
        .withColumn("inspection_date", F.to_date("inspection_date"))
        .withColumn("inspection_year", F.year("inspection_date"))
        .withColumn("inspection_month", F.month("inspection_date"))
        .withColumn("inspection_type", F.col("inspection_type"))
        .withColumn(
            "inspection_result",
            F.when(F.col("inspection_score") >= 90, F.lit("Pass"))
             .when(F.col("inspection_score") >= 80, F.lit("Pass w/ Conditions"))
             .otherwise(F.lit("Fail"))
        )
        .withColumn("inspection_score", F.col("inspection_score").cast("int"))
        .withColumn("facility_type", F.lit(None).cast("string"))
        .withColumn("risk_category", F.lit(None).cast("string"))
        .withColumn(
            "latitude",
            F.expr("try_cast(get(split(lat_long_location, ','), 0) as double)")
        )
        .withColumn(
            "longitude",
            F.expr("try_cast(get(split(lat_long_location, ','), 1) as double)")
        )
    )

    return df.select(
        "inspection_id",
        "city_code",
        "business_name",
        "aka_name",
        "license_number",
        "street_address",
        "city",
        "state",
        "zip",
        "inspection_date",
        "inspection_year",
        "inspection_month",
        "inspection_type",
        "inspection_result",
        "inspection_score",
        "facility_type",
        "risk_category",
        "latitude",
        "longitude",
    )




# --------------------------------------------------------------------
# SILVER LAYER (city-level)
# --------------------------------------------------------------------

@dlt.table(
    name="silver_chicago_clean",
    comment="Chicago inspections in unified silver schema"
)
def silver_chicago_clean():
    df = transform_chicago_to_silver(dlt.read("bronze_chicago"))
    # basic NOT NULL rules
    return df.filter(
        F.col("business_name").isNotNull()
        & F.col("inspection_date").isNotNull()
        & F.col("inspection_type").isNotNull()
        & F.col("zip").isNotNull()
    )


@dlt.table(
    name="silver_dallas_clean",
    comment="Dallas inspections in unified silver schema"
)
def silver_dallas_clean():
    df = transform_dallas_to_silver(dlt.read("bronze_dallas"))
    # basic NOT NULL rules
    return df.filter(
        F.col("business_name").isNotNull()
        & F.col("inspection_date").isNotNull()
        & F.col("inspection_type").isNotNull()
        & F.col("zip").isNotNull()
    )


# --------------------------------------------------------------------
# SILVER LAYER (unified inspections)
# --------------------------------------------------------------------

@dlt.table(
    name="silver_inspection",
    comment="Unified inspections from Chicago and Dallas"
)
def silver_inspection():
    chi = dlt.read("silver_chicago_clean")
    dal = dlt.read("silver_dallas_clean")
    return chi.unionByName(dal)


In [0]:
@dlt.table(
    comment="Normalized violations from Chicago and Dallas at inspection level"
)
def silver_violation():
    # ----------------- CHICAGO -----------------
    chi_bronze = dlt.read("bronze_chicago")

    chi = (
        chi_bronze
        .select("inspection_id", "city_code", "violations")
        .where(F.col("violations").isNotNull())
    )

    chi_exploded = (
        chi
        # split on "|" between violations
        .withColumn("violation_entry", F.explode(F.split(F.col("violations"), r"\|\s*")))
        .withColumn("violation_entry", F.trim("violation_entry"))
        .where(F.col("violation_entry") != "")
        # code like "10." at the start
        .withColumn("violation_code", F.regexp_extract("violation_entry", r"^(\d+)", 1))
        # description after the "NN. "
        .withColumn(
            "violation_description",
            F.regexp_replace("violation_entry", r"^\s*\d+\.\s*", "")
        )
        .drop("violations")
        .dropDuplicates(["inspection_id", "violation_code", "violation_description"])
        .withColumn("violation_seq", F.lit(None).cast("int"))
        .withColumn("violation_points", F.lit(None).cast("int"))
        .withColumn("violation_detail", F.lit(None).cast("string"))
        .withColumn("violation_memo", F.lit(None).cast("string"))
    )

    # ----------------- DALLAS -----------------
    dal_bronze = dlt.read("bronze_dallas")

    # dynamically find Dallas violation columns after cleaning
    desc_cols   = [c for c in dal_bronze.columns if c.startswith("violation_description")]
    points_cols = [c for c in dal_bronze.columns if c.startswith("violation_points")]
    detail_cols = [c for c in dal_bronze.columns if c.startswith("violation_detail")]
    memo_cols   = [c for c in dal_bronze.columns if c.startswith("violation_memo")]

    dallas_violation_dfs = []

    for desc_col in desc_cols:
        # get the index at the end, e.g. "..._1", "..._2", etc.
        m = re.search(r"(\d+)$", desc_col)
        idx = m.group(1) if m else None

        pts_col    = next((c for c in points_cols if c.endswith(idx)), None)
        detail_col = next((c for c in detail_cols if c.endswith(idx)), None)
        memo_col   = next((c for c in memo_cols if c.endswith(idx)), None)

        tmp = (
            dal_bronze
            .select(
                "inspection_id",
                "city_code",
                F.lit(int(idx) if idx else None).alias("violation_seq"),
                F.col(desc_col).alias("violation_description"),
                (F.col(pts_col) if pts_col else F.lit(None)).cast("int").alias("violation_points"),
                (F.col(detail_col) if detail_col else F.lit(None)).cast("string").alias("violation_detail"),
                (F.col(memo_col) if memo_col else F.lit(None)).cast("string").alias("violation_memo"),
            )
            .where(F.col(desc_col).isNotNull())
        )
        dallas_violation_dfs.append(tmp)

    if dallas_violation_dfs:
        dal_union = reduce(lambda a, b: a.unionByName(b), dallas_violation_dfs)
    else:
        dal_union = dal_bronze.selectExpr(
            "inspection_id", "city_code",
            "cast(null as int) as violation_seq",
            "cast(null as string) as violation_description",
            "cast(null as int) as violation_points",
            "cast(null as string) as violation_detail",
            "cast(null as string) as violation_memo",
        ).limit(0)

    dal_final = (
        dal_union
        .withColumn("violation_code", F.lit(None).cast("string"))
        .dropDuplicates(["inspection_id", "violation_seq", "violation_description"])
    )

    # ----------------- UNION BOTH CITIES -----------------
    chi_final = chi_exploded.select(
        "inspection_id",
        "city_code",
        "violation_seq",
        "violation_code",
        "violation_description",
        "violation_points",
        "violation_detail",
        "violation_memo",
    )

    dal_final = dal_final.select(
        "inspection_id",
        "city_code",
        "violation_seq",
        "violation_code",
        "violation_description",
        "violation_points",
        "violation_detail",
        "violation_memo",
    )

    return chi_final.unionByName(dal_final)


In [0]:
@dlt.table(
    comment="Date dimension for inspections"
)
def dim_date():
    df = dlt.read("silver_inspection")
    return (
        df
        .select("inspection_date")
        .where("inspection_date IS NOT NULL")
        .distinct()
        .withColumn("date_key", F.date_format("inspection_date", "yyyyMMdd").cast("int"))
        .withColumn("year", F.year("inspection_date"))
        .withColumn("month", F.month("inspection_date"))
        .withColumn("day", F.dayofmonth("inspection_date"))
        .withColumn("weekday", F.date_format("inspection_date", "EEEE"))
    )


In [0]:
@dlt.table(
    comment="Location dimension"
)
def dim_location():
    df = dlt.read("silver_inspection")

    dim_loc = (
        df
        .select("zip", "city", "state", "latitude", "longitude")
        .where("zip IS NOT NULL")
        .distinct()
        .withColumn("location_key", F.monotonically_increasing_id())
    )

    return dim_loc.select(
        "location_key",
        "zip",
        "city",
        "state",
        "latitude",
        "longitude",
    )


In [0]:
@dlt.table(
    comment="Restaurant dimension (snapshot)"
)
def dim_restaurant_history():
    df = dlt.read("silver_inspection")

    dim_rest = (
        df
        .select(
            "business_name",
            "aka_name",
            "license_number",
            "street_address",
            "city",
            "state",
            "zip",
        )
        .where("business_name IS NOT NULL")
        .distinct()
        .withColumn("restaurant_key", F.monotonically_increasing_id())
    )

    return dim_rest.select(
        "restaurant_key",
        "business_name",
        "aka_name",
        "license_number",
        "street_address",
        "city",
        "state",
        "zip",
    )


In [0]:
@dlt.table(
    comment="Violation dimension (unique descriptions)"
)
def dim_violation():
    df = dlt.read("silver_violation")

    dim_v = (
        df
        .select(
            "violation_code",
            "violation_description",
        )
        .where("violation_description IS NOT NULL")
        .distinct()
        .withColumn("violation_key", F.monotonically_increasing_id())
    )

    return dim_v.select(
        "violation_key",
        "violation_code",
        "violation_description",
    )


In [0]:
@dlt.table(
    comment="Fact table for inspections"
)
def fact_inspection():
    fact = dlt.read("silver_inspection")
    dim_d = dlt.read("dim_date")
    dim_l = dlt.read("dim_location")
    dim_r = dlt.read("dim_restaurant_history")

    fact = (
        fact
        .join(dim_d, "inspection_date", "left")
        .join(dim_l, ["zip", "city", "state", "latitude", "longitude"], "left")
        .join(dim_r, ["business_name", "aka_name", "street_address", "zip"], "left")
    )

    return fact.select(
        "inspection_id",
        "city_code",
        "date_key",
        "location_key",
        "restaurant_key",
        "inspection_type",
        "inspection_result",
        "inspection_score",
        "risk_category",
        "facility_type",
    )


In [0]:
@dlt.table(
    comment="Fact table linking inspections and violations"
)
def fact_inspection_violation():
    v = dlt.read("silver_violation")
    dim_v = dlt.read("dim_violation")
    dim_d = dlt.read("dim_date")

    # Get date for each inspection
    ins = dlt.read("silver_inspection").select("inspection_id", "inspection_date")

    fact = (
        v.join(ins, "inspection_id", "left")
         # join to violation dim (for violation_key)
         .join(dim_v, ["violation_code", "violation_description"], "left")
         # join to date dim (for date_key)
         .join(dim_d, "inspection_date", "left")
    )

    # 🔑 Ensure exactly one row per violation per inspection
    fact = fact.dropDuplicates(
        ["inspection_id", "violation_seq", "violation_description"]
    )

    return fact.select(
        "inspection_id",
        "city_code",
        "violation_key",
        "date_key",
        "violation_seq",
        "violation_code",
        "violation_points",
        "violation_description",
    )
