In [11]:
table_name = 'games'
run_id = None
stage = 'lh_silver'

StatementMeta(, 85777a49-84bd-4518-9c9c-e0cadac89343, 15, Finished, Available, Finished)

In [12]:
from typing import Tuple, List
from datetime import datetime
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T

class RawTableProfiler:
    """
    Column-wise profiler that works in Microsoft Fabric (PySpark).
    Produces:
      - metrics_df: one row per column with completeness, distincts, stats, etc.
      - topk_df:    top-k most frequent values per column (optional)
    """

    def __init__(self, spark, run_id: str, table_name: str):
        self.spark = spark
        self.run_id = run_id
        self.table_name = table_name
        self.profiled_ts = datetime.utcnow()


    @staticmethod
    def _is_numeric(dt: T.DataType) -> bool:
        return isinstance(dt, (T.ByteType, T.ShortType, T.IntegerType, T.LongType,
                               T.FloatType, T.DoubleType, T.DecimalType))

    @staticmethod
    def _is_string(dt: T.DataType) -> bool:
        return isinstance(dt, T.StringType)

    @staticmethod
    def _is_bool(dt: T.DataType) -> bool:
        return isinstance(dt, T.BooleanType)

    @staticmethod
    def _is_datetime(dt: T.DataType) -> bool:
        return isinstance(dt, (T.TimestampType, T.DateType))

    @staticmethod
    def _safe_alias(expr, name):
        return expr.alias(name)

    # ---------- main

    def profile(self, df: DataFrame, topk: int = 10) -> Tuple[DataFrame, DataFrame]:
        """
        Compute metrics for each column in df.
        Returns (metrics_df, topk_df).
        """
        total_rows = df.count()

        metrics_rows: List[DataFrame] = []
        topk_rows: List[DataFrame] = []

        for c in df.columns:
            dt = dict(df.dtypes)[c]  # string dtype name
            dtype_obj = df.schema[c].dataType

            col = F.col(c)

            non_null = F.count(col)
            nulls = F.lit(total_rows) - non_null
            completeness = (non_null / F.lit(total_rows)).cast("double")

            approx_distinct = F.approx_count_distinct(col).cast("long")
            distinct_ratio = (approx_distinct / F.lit(total_rows)).cast("double")

            min_value = F.lit(None).cast("string")
            max_value = F.lit(None).cast("string")
            mean_value = F.lit(None).cast("double")
            stddev_value = F.lit(None).cast("double")
            p05 = F.lit(None).cast("double")
            p50 = F.lit(None).cast("double")
            p95 = F.lit(None).cast("double")
            min_len = F.lit(None).cast("int")
            avg_len = F.lit(None).cast("double")
            max_len = F.lit(None).cast("int")
            empty_strings = F.lit(None).cast("long")
            true_count = F.lit(None).cast("long")
            false_count = F.lit(None).cast("long")

            inferred_semantic = "unknown"

            if self._is_numeric(dtype_obj):
                inferred_semantic = "numeric"
                min_value = F.min(col).cast("string")
                max_value = F.max(col).cast("string")
                mean_value = F.avg(col).cast("double")
                stddev_value = F.stddev_samp(col).cast("double")
                pcts = F.percentile_approx(col, [0.05, 0.5, 0.95], 10000)
                # percentile_approx returns array<double>
                p05 = F.element_at(pcts, 1).cast("double")
                p50 = F.element_at(pcts, 2).cast("double")
                p95 = F.element_at(pcts, 3).cast("double")

            elif self._is_datetime(dtype_obj):
                inferred_semantic = "datetime"
                min_value = F.min(col.cast("timestamp")).cast("string")
                max_value = F.max(col.cast("timestamp")).cast("string")
                # Percentiles for datetime: convert to long epoch seconds
                epoch_col = F.col(c).cast("timestamp").cast("long")
                pcts = F.percentile_approx(epoch_col, [0.05, 0.5, 0.95], 10000)
                p05 = F.element_at(pcts, 1).cast("double")
                p50 = F.element_at(pcts, 2).cast("double")
                p95 = F.element_at(pcts, 3).cast("double")

            elif self._is_bool(dtype_obj):
                inferred_semantic = "boolean"
                true_count = F.sum(F.when(col.isNotNull() & (col.cast("boolean") == True), 1).otherwise(0)).cast("long")
                false_count = F.sum(F.when(col.isNotNull() & (col.cast("boolean") == False), 1).otherwise(0)).cast("long")
                min_value = F.min(col.cast("string"))
                max_value = F.max(col.cast("string"))

            elif self._is_string(dtype_obj):
                inferred_semantic = "string"
                length_col = F.length(col)
                min_len = F.min(length_col).cast("int")
                avg_len = F.avg(length_col).cast("double")
                max_len = F.max(length_col).cast("int")
                empty_strings = F.sum(F.when(col == "", 1).otherwise(0)).cast("long")
                # For strings, keep min/max lexicographically
                min_value = F.min(col).cast("string")
                max_value = F.max(col).cast("string")

            # Build a single-row aggregation for this column
            agg_exprs = [
                self._safe_alias(F.lit(self.run_id), "run_id"),
                self._safe_alias(F.lit(self.table_name), "table_name"),
                self._safe_alias(F.lit(c), "column_name"),
                self._safe_alias(F.lit(dt), "data_type"),
                self._safe_alias(F.lit(total_rows).cast("long"), "total_rows"),
                self._safe_alias(non_null.cast("long"), "non_null_rows"),
                self._safe_alias(nulls.cast("long"), "null_rows"),
                self._safe_alias(completeness, "completeness_ratio"),
                self._safe_alias(approx_distinct, "distinct_count"),
                self._safe_alias(distinct_ratio, "distinct_ratio"),
                self._safe_alias(min_value, "min_value"),
                self._safe_alias(max_value, "max_value"),
                self._safe_alias(mean_value, "mean"),
                self._safe_alias(stddev_value, "stddev"),
                self._safe_alias(p05, "p05"),
                self._safe_alias(p50, "p50"),
                self._safe_alias(p95, "p95"),
                self._safe_alias(min_len, "min_length"),
                self._safe_alias(avg_len, "avg_length"),
                self._safe_alias(max_len, "max_length"),
                self._safe_alias(empty_strings, "empty_string_count"),
                self._safe_alias(true_count, "true_count"),
                self._safe_alias(false_count, "false_count"),
                self._safe_alias(F.lit(inferred_semantic), "inferred_semantic"),
                self._safe_alias(F.lit(self.profiled_ts), "profiled_ts"),
            ]

            metrics_row = df.agg(*agg_exprs)
            metrics_rows.append(metrics_row)

            # Top-K frequency for this column (stringify values for consistency)
            topk_df = (
                df.groupBy(col.cast("string").alias("value"))
                  .agg(F.count(F.lit(1)).alias("cnt"))
                  .orderBy(F.desc("cnt"))
                  .limit(topk)
                  .withColumn("run_id", F.lit(self.run_id))
                  .withColumn("table_name", F.lit(self.table_name))
                  .withColumn("column_name", F.lit(c))
                  .withColumn("profiled_ts", F.lit(self.profiled_ts))
                  .select("run_id", "table_name", "column_name", "value", "cnt", "profiled_ts")
            )
            topk_rows.append(topk_df)

        metrics_df = metrics_rows[0]
        for r in metrics_rows[1:]:
            metrics_df = metrics_df.unionByName(r)

        topk_df_all = topk_rows[0]
        for r in topk_rows[1:]:
            topk_df_all = topk_df_all.unionByName(r)

        return metrics_df, topk_df_all


    def ensure_tables(self,
                      metrics_table: str = "monitoring_profile_metrics",
                      topk_table: str = "monitoring_profile_topk") -> None:
        self.spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {metrics_table} (
            run_id STRING,
            table_name STRING,
            column_name STRING,
            data_type STRING,
            total_rows BIGINT,
            non_null_rows BIGINT,
            null_rows BIGINT,
            completeness_ratio DOUBLE,
            distinct_count BIGINT,
            distinct_ratio DOUBLE,
            min_value STRING,
            max_value STRING,
            mean DOUBLE,
            stddev DOUBLE,
            p05 DOUBLE,
            p50 DOUBLE,
            p95 DOUBLE,
            min_length INT,
            avg_length DOUBLE,
            max_length INT,
            empty_string_count BIGINT,
            true_count BIGINT,
            false_count BIGINT,
            inferred_semantic STRING,
            profiled_ts TIMESTAMP
        ) USING DELTA
        """)

        self.spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {topk_table} (
            run_id STRING,
            table_name STRING,
            column_name STRING,
            value STRING,
            cnt BIGINT,
            profiled_ts TIMESTAMP
        ) USING DELTA
        """)

    def write(self,
              metrics_df: DataFrame,
              topk_df: DataFrame,
              metrics_table: str = "monitoring_profile_metrics",
              topk_table: str = "monitoring_profile_topk") -> None:
        metrics_df.write.mode("append").saveAsTable(metrics_table)
        topk_df.write.mode("append").saveAsTable(topk_table)

StatementMeta(, 85777a49-84bd-4518-9c9c-e0cadac89343, 16, Finished, Available, Finished)

In [13]:
df = spark.sql(f"SELECT * FROM {stage}.{table_name}")

profiler = RawTableProfiler(
    spark=spark,
    run_id=run_id,  
    table_name=f"{stage}.{table_name}"
)

metrics_df, topk_df = profiler.profile(df, topk=10)

profiler.ensure_tables()
profiler.write(metrics_df, topk_df)

StatementMeta(, 85777a49-84bd-4518-9c9c-e0cadac89343, 17, Finished, Available, Finished)