In [4]:
table_name = 'promotions'

StatementMeta(, b27a3f6d-e08b-438c-8dcf-579032102320, 3, Finished, Available, Finished)

In [2]:
from pathlib import Path
import json, logging
from typing import Optional, Dict, Any
from pyspark.sql import SparkSession, DataFrame, Window
from pyspark.sql import functions as F
from pyspark.sql.types import *
from pyspark.sql.functions import col,to_date, lower, trim, regexp_replace, explode, split, current_date, lit, udf,round as spark_round
import re
from datetime import datetime

StatementMeta(, b27a3f6d-e08b-438c-8dcf-579032102320, 4, Finished, Available, Finished)

In [6]:

def parse_int_safe(val):
    try:
        return int(float(val)) if val is not None else None
    except:
        return None

def parse_float_safe(val):
    try:
        return float(val) if val is not None else None
    except:
        return None

int_udf = udf(parse_int_safe, LongType())
float_udf = udf(parse_float_safe, DoubleType())


class DataCleaningSpark:

    def __init__(self, spark: SparkSession, source_db: str, table_name: str, meta_data_table: str, target_db: str = "lh_transformation"):
        self.spark = spark
        self.table_name = table_name
        self.target_db = target_db
        self.df = spark.sql(f"SELECT * FROM {source_db}.{table_name}")
        meta_raw = spark.sql(f"SELECT * FROM main.{meta_data_table}")
        self.meta = (
            meta_raw.filter(F.col("source_table") == table_name)
                    .withColumn("column", F.lower(F.regexp_replace(F.col("column"), r"\s+", "_")))
        )
        self.normalized_tables = {}

        if self.meta.count() == 0:
            raise ValueError(f"No metadata for table '{table_name}' in {meta_data_table}")

    def _get_cols(self, flag: str, value: int = 1):
        return (
            self.meta.filter(F.col(flag) == value)
                     .select("column")
                     .rdd.flatMap(lambda r: r)
                     .collect()
        )



    def _clean_promotions(self):
        
        promos = self.df
        a = promos.alias("a")
        b = promos.alias("b")

        pairs = (
            a.join(
                b,
                (F.col("a.game_id") == F.col("b.game_id")) &
                (F.col("a.promotion_id") != F.col("b.promotion_id")) &
                (F.col("a.start_date") <= F.col("b.end_date")) &      # date ranges intersect
                (F.col("a.end_date")   >= F.col("b.start_date")) &
                (F.col("a.start_date") <  F.col("b.start_date"))      # a is the older promo
            )
        )

        trim_later = pairs.where(F.col("a.percentage") >= F.col("b.percentage")) \
            .select(
                F.col("b.promotion_id").alias("promotion_id"),
                (F.col("a.end_date") + F.expr("INTERVAL 1 DAY")).alias("new_start")
            )

        trim_older = pairs.where(F.col("a.percentage") < F.col("b.percentage")) \
            .select(
                F.col("a.promotion_id").alias("promotion_id"),
                (F.col("b.start_date") - F.expr("INTERVAL 1 DAY")).alias("new_end")
            )

        new_starts = (trim_later
                    .groupBy("promotion_id")
                    .agg(F.max("new_start").alias("new_start")))     # later promo moves *forward*

        new_ends = (trim_older
                    .groupBy("promotion_id")
                    .agg(F.min("new_end").alias("new_end")))         # older promo moves *backward*

        trimmed = (
            promos
            .join(new_starts, on="promotion_id", how="left")
            .join(new_ends,   on="promotion_id", how="left")
            .withColumn(
                "start_date",
                F.coalesce("new_start", "start_date")
            )
            .withColumn(
                "end_date",
                F.coalesce("new_end", "end_date")
            )
            .drop("new_start", "new_end")
            .filter("start_date <= end_date")  # discard zero/negative length promos
        )

        dedupe_win = (
            Window
            .partitionBy("game_id", "start_date", "end_date")      # exact same interval
            .orderBy(F.col("percentage").desc(), F.col("promotion_id").asc())
        )

        trimmed = (
            trimmed
            .withColumn("rn", F.row_number().over(dedupe_win))     # rank within identical windows
            .filter("rn = 1")                                      # keep the winner
            .drop("rn")
        )

        # Handle duplicate start dates by pushing lower discount start date after higher discount's end date
        start_date_win = (
            Window
            .partitionBy("game_id", "start_date")
            .orderBy(F.col("percentage").desc(), F.col("promotion_id").asc())
        )

        # Identify promotions to adjust
        adjust_starts = (
            trimmed
            .withColumn("rn", F.row_number().over(start_date_win))
            .withColumn(
                "max_discount_end",
                F.max(
                    F.when(F.col("rn") == 1, F.col("end_date")).otherwise(F.lit(None))
                ).over(Window.partitionBy("game_id", "start_date"))
            )
            .filter(F.col("rn") > 1)  # Lower discount promos
            .select(
                F.col("promotion_id"),
                (F.col("max_discount_end") + F.expr("INTERVAL 1 DAY")).alias("new_start")
            )
        )

        # Apply start date adjustments
        trimmed = (
            trimmed
            .join(adjust_starts, on="promotion_id", how="left")
            .withColumn(
                "start_date",
                F.coalesce("new_start", "start_date")
            )
            .drop("new_start")
            .filter("start_date <= end_date")  # Remove zero/negative length promos
        )

        # Reapply deduplication to handle any new overlaps caused by adjustments
        trimmed = (
            trimmed
            .withColumn("rn", F.row_number().over(dedupe_win))
            .filter("rn = 1")
            .drop("rn")
        )

        self.df = trimmed
        return self

    def validate_primary_keys(self):
        pk_cols = self._get_cols("PK")
        if pk_cols:
            self.df = self.df.filter(F.expr(" AND ".join([f"{c} IS NOT NULL" for c in pk_cols])))
            self.df = self.df.dropDuplicates(pk_cols)
        return self

    def validate_non_nulls(self):
        nn_cols = self._get_cols("NON_NULLABLE")
        if nn_cols:
            self.df = self.df.filter(F.expr(" AND ".join([f"{c} IS NOT NULL" for c in nn_cols])))
        return self

    def validate_unique(self):
        uniq_cols = self._get_cols("UNIQUE")
        if not uniq_cols:
            return self
        dedup = (
            self.df.groupBy(uniq_cols)
                   .count()
                   .filter("count = 1")
                   .drop("count")
        )
        self.df = self.df.join(dedup, on=uniq_cols, how="inner")
        return self

    def _parse_dtype(self, dtype_str):
        dtype_str = dtype_str.lower()
        if dtype_str.startswith("int"):
            return LongType()
        elif dtype_str.startswith("float") or dtype_str.startswith("double"):
            return DoubleType()
        elif dtype_str.startswith("date") or dtype_str.startswith("datetime"):
            return DateType()
        else:
            return StringType()

    def validate_datatype(self):
        fields = []
        date_columns = {}
        float_precision = {}
        integer_columns = []

        for row in self.meta.collect():
            col_name = row["column"]
            dtype = str(row["DTYPE"])

            if col_name not in self.df.columns:
                continue

            base_type = dtype.split("|")[0].lower()

            if "date" in base_type:
                date_columns[col_name] = dtype.split("|")[1]
            elif "float" in base_type or "double" in base_type:
                float_precision[col_name] = int(dtype.split("|")[1]) if "|" in dtype else 2
            elif "int" in base_type:
                integer_columns.append(col_name)

            fields.append(StructField(col_name, self._parse_dtype(dtype), True))

        for field in fields:
            name = field.name
            dtype = field.dataType

            if isinstance(dtype, DateType) and name in date_columns:
                self.df = self.df.withColumn(name, to_date(col(name), date_columns[name]))
            elif isinstance(dtype, DoubleType) and name in float_precision:
                self.df = self.df.withColumn(name, spark_round(float_udf(col(name)), float_precision[name]))
            elif isinstance(dtype, LongType) and name in integer_columns:
                self.df = self.df.withColumn(name, int_udf(col(name)))
            else:
                self.df = self.df.withColumn(name, col(name).cast(dtype))

        return self

    def _clean_developers_column(self, col_name: str):
        self.df = (
            self.df
            .withColumn(col_name, F.regexp_replace(col_name, r",?\s*Inc\.?", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r"\s*\(Mac\)", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r"\s*\(Linux\)", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r",?\s*LLC\.?", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r",?\s*Ltd\.?", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r",?\s*LTD\.?", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r",?\s*INC\.?", ""))
            .withColumn(col_name, F.regexp_replace(col_name, r"BANDAI NAMCO", "Bandai Namco"))
            .withColumn(col_name, F.regexp_replace(col_name, r"CAPCOM CO.", "CAPCOM Co."))
            .withColumn(col_name, F.trim(col(col_name)))
        )

        
        return self


    def _clean_publishers_column(self, col_name: str):
        return self._clean_developers_column(col_name)

    def apply_rules(self):
        for row in self.meta.collect():
            col_name = row["column"]
            rules = row["RULES"]

            if not rules or col_name not in self.df.columns:
                continue

            for rule in rules.split(","):
                rule = rule.strip().lower()

                if rule.startswith("range|"):
                    try:
                        _, min_val, max_val = rule.split("|")
                        max_val = 999999999999 if max_val == 'inf' else max_val
                        min_val = -999999999999 if max_val == '-inf' else min_val
                        self.df = self.df.filter((col(col_name) >= float(min_val)) & (col(col_name) <= float(max_val)))
                    except:
                        print(f"❌ Invalid range rule format: {rule}")

                elif rule == "developer_clean":
                    self._clean_developers_column(col_name)

                elif rule == "publisher_clean":
                    self._clean_publishers_column(col_name)

                elif rule == "email":
                    self.df = self.df.filter(col(col_name).rlike(r"^[\w\.-]+@[\w\.-]+\.\w{2,}$"))

                elif rule == "date_in_past":
                    self.df = self.df.filter(col(col_name) < current_date())

                elif rule == "date_in_past_or_today":
                    self.df = self.df.filter(col(col_name) <= current_date())

                elif rule.startswith("date_after[") and rule.endswith("]"):
                    ref_col = rule[len("date_after["):-1]
                    if ref_col in self.df.columns:
                        self.df = self.df.filter(col(col_name) > col(ref_col))
                    else:
                        print(f"❌ Reference column '{ref_col}' not found for rule '{rule}'.")

                elif rule == "normalize_alpha":
                    self.df = (
                        self.df
                        .withColumn(
                            col_name,
                            F.when(
                                F.col(col_name).cast("string").rlike("[0-9]") |  # Contains digits
                                F.col(col_name).cast("string").rlike("^[ ]*$"),  # Contains only spaces
                                F.lit(None)  # Replace with null
                            ).otherwise(
                                F.lower(F.regexp_replace(F.col(col_name), "[^a-zA-Z ]", ""))  # Keep letters and spaces, lowercase
                            )
                        )
                        .filter(
                            F.col(col_name).isNotNull() &  # Remove nulls
                            (F.col(col_name) != "")  # Remove empty strings
                        )
                    )

                elif rule == "promotions_clean":
                    self._clean_promotions()  # Call the new promotions cleaning logic

                else:
                    print(f"⚠️ Unknown rule '{rule}' ignored.")

        return self


    def normalize_2nf_df(self):
            norm_cols = self._get_cols("NORMALIZE2NF")
            key_cols = self._get_cols("PK")  

            for col in norm_cols:
                table_name = f"{self.table_name}_{col}"

                unique_vals = (
                    self.df.select(col)
                        .distinct()
                        .rdd.flatMap(lambda x: x)
                        .collect()
                )

                ids = list(range(len(unique_vals)))
                lookup_data = list(zip(ids, unique_vals))
                lookup_df = self.spark.createDataFrame(lookup_data, schema=["id", col])
                self.normalized_tables[table_name] = lookup_df

                self.df = (
                    self.df.join(lookup_df, on=col, how="left")
                        .drop(col)
                        .withColumnRenamed("id", f"{col}_id")
                )

            return self
    def normalize_3nf_df(self):
        norm_cols = self._get_cols("NORMALIZE3NF")
        key_cols = self._get_cols("PK")

        for col in norm_cols:
            table_name = f"{self.table_name}_{col}"

            exploded_df = (
                self.df.select(*key_cols, col)
                    .withColumn(col, explode(split(col, ",")))
                    .withColumn(col, trim(col))
            )

            unique_vals = (
                exploded_df.select(col)
                           .distinct()
                           .rdd.flatMap(lambda x: x)
                           .collect()
            )

            ids = list(range(len(unique_vals)))
            lookup_data = list(zip(ids, unique_vals))
            lookup_df = self.spark.createDataFrame(lookup_data, schema=["id", col])
            self.normalized_tables[table_name] = lookup_df

            bridge_df = (
                exploded_df.join(lookup_df, on=col, how="left")
                           .select(*key_cols, "id")
                           .withColumnRenamed("id", f"{col}_id")
            )
            self.normalized_tables[f"{self.table_name}_{col}_bridge"] = bridge_df

            self.df = self.df.drop(col)

        return self

    def save(self, mode: str = "overwrite"):
        self.df.write.mode(mode).option("overwriteSchema", "true").format("delta").saveAsTable(self.table_name)
        print(f"✔️  {self.target_db}.{self.table_name} written ({mode})")

        for table_name, df in self.normalized_tables.items():
            df.write.mode(mode).option("overwriteSchema", "true").format("delta").saveAsTable(table_name)
            print(f"✔️  {self.target_db}.{table_name} written ({mode})")

        return f"{self.target_db}.{self.table_name}"


StatementMeta(, b27a3f6d-e08b-438c-8dcf-579032102320, 5, Finished, Available, Finished)

In [8]:
status = 'Succeeded'
try:
    cleaner = DataCleaningSpark(
        spark,
        source_db="lh_bronze",          
        table_name=table_name,
        meta_data_table="meta_data_silver",
        target_db= "lh_transformation"    
    )
    rows_in = cleaner.df.count()
    cleaner = cleaner.validate_datatype()
    print(f"After validate_datatype: {cleaner.df.count()} rows")

    cleaner = cleaner.apply_rules()
    print(f"After applying rules: {cleaner.df.count()} rows")

    cleaner = cleaner.validate_primary_keys()
    print(f"After validate_primary_keys: {cleaner.df.count()} rows")

    cleaner = cleaner.validate_non_nulls()
    print(f"After validate_non_nulls: {cleaner.df.count()} rows")

    cleaner = cleaner.validate_unique()
    print(f"After validate_unique: {cleaner.df.count()} rows")



    cleaner = cleaner.normalize_2nf_df()

    cleaner = cleaner.normalize_3nf_df()

    rows_out = cleaner.df.count()
    cleaner.save(mode="overwrite") 

except:
    status = "Failed"

finally:
    output = {
    "rows_in": rows_in,
    "rows_out": rows_out,
    "status": status,
    "run_ts": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }

    notebookutils.notebook.exit(json.dumps(output))

StatementMeta(, b27a3f6d-e08b-438c-8dcf-579032102320, 6, Finished, Available, Finished)

After validate_datatype: 190389 rows


After applying rules: 161941 rows


After validate_primary_keys: 161941 rows


After validate_non_nulls: 161941 rows


After validate_unique: 161941 rows


✔️  lh_transformation.promotions written (overwrite)
ExitValue: {"rows_in": 190389, "rows_out": 161941, "status": "Succeeded", "run_ts": "2025-07-27 09:29:53"}