# Data Comparison Notebook

This notebook compares two datasets (tables or S3 paths) and generates row-count, null-rate, distinct-count, date-level, and missing-record diagnostics.


## Environment Setup

Use this notebook in either local Spark or Databricks.

### Local Spark
- Install dependencies: `pyspark`, `pandas`, `boto3`, `openpyxl`.
- Run with a local Python/Jupyter kernel.
- Use a local output path such as `./outputs`.

### Databricks or Databricks Connect
- On Databricks clusters, Spark is already available.
- For Databricks Connect, configure credentials first (`databricks-connect configure`).
- For DBFS output, set `output_base_path` to `/dbfs/FileStore/<folder>` or `dbfs:/FileStore/<folder>`.


In [1]:
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple

import boto3
import os
import pandas as pd
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import coalesce, col, count, countDistinct, lit, to_date, when
from pyspark.storagelevel import StorageLevel

try:
    # Available when running with Databricks Connect.
    from databricks.connect.session import DatabricksSession
except Exception:
    DatabricksSession = None


def get_spark_session() -> SparkSession:
    """Return an active Spark session for local Spark or Databricks Connect."""
    if DatabricksSession is not None:
        return DatabricksSession.builder.getOrCreate()
    return SparkSession.builder.getOrCreate()


## Configuration Model

`DataComparisonConfig` controls dataset selection, filtering, authentication, and exports.

### Required fields
- `source_1_table`, `source_2_table`: Source table names or S3 paths.
- `primary_key`: Join key used for missing-record analysis.
- `domain_name`: Label used in logs and output filenames.

### Common optional fields
- `count_key`: Column used for grouped counts (defaults to `primary_key`).
- `datetime_columns`: One or more datetime/date columns for time-based analysis.
- `filter_date_start`, `filter_date_end`: Inclusive date window in `YYYYMMDD` format.
- `partition_column`: Optional S3 partition key for partition-aware reads.
- `aws_*`: Optional explicit AWS credentials (otherwise default credential chain is used).
- `output_base_path`: Output directory for Excel/CSV exports.
- `enable_persist`: Enables Spark persistence where supported.


In [2]:
def _is_databricks_runtime() -> bool:
    """Return True when running inside a Databricks runtime."""
    return bool(os.getenv("DATABRICKS_RUNTIME_VERSION") or os.getenv("DB_HOME"))


def _default_output_base_path(domain_name: str) -> str:
    """Resolve a default output folder based on execution environment."""
    if _is_databricks_runtime() and os.path.isdir("/dbfs"):
        return os.path.join("/dbfs", "FileStore", domain_name)
    return os.path.join(".", "outputs", domain_name)


@dataclass
class DataComparisonConfig:
    """Configuration for comparing two datasets and exporting diagnostics."""

    source_1_table: str
    source_2_table: str
    primary_key: str
    domain_name: str
    count_key: Optional[str] = None

    # Optional date filtering on dataframe columns
    datetime_columns: Optional[List[str]] = field(default_factory=list)
    filter_date_start: Optional[str] = None  # format: YYYYMMDD
    filter_date_end: Optional[str] = None    # format: YYYYMMDD
    date_format: str = "%Y%m%d"
    skip_date_filter: bool = True

    # Optional partition-aware S3 loading
    partition_column: Optional[str] = None

    # Output settings
    output_base_path: Optional[str] = None
    excel_output: Optional[str] = None
    csv_output: Optional[str] = None

    # Persist behavior (serverless-safe)
    enable_persist: bool = True

    # Standard AWS S3 authentication inputs
    aws_region: str = "us-east-1"
    aws_access_key_id: Optional[str] = None
    aws_secret_access_key: Optional[str] = None
    aws_session_token: Optional[str] = None

    def __post_init__(self):
        """Normalize defaults, output paths, and optional datetime settings."""
        if not self.count_key:
            self.count_key = self.primary_key

        if self.datetime_columns is None:
            self.datetime_columns = []
        elif isinstance(self.datetime_columns, str):
            self.datetime_columns = [self.datetime_columns]

        if self.output_base_path is None:
            self.output_base_path = _default_output_base_path(self.domain_name)
        if self.output_base_path.startswith("dbfs:"):
            self.output_base_path = self.output_base_path.replace("dbfs:", "/dbfs", 1)
        if self.excel_output is None:
            self.excel_output = f"{self.domain_name}_comparison.xlsx"
        if self.csv_output is None:
            self.csv_output = f"{self.domain_name}_comparison.csv"

        os.makedirs(self.output_base_path, exist_ok=True)

    @staticmethod
    def infer_datetime_type(column_name: str) -> str:
        column_lower = column_name.lower()
        if "create" in column_lower and ("dt" in column_lower or "date" in column_lower):
            return "iso"
        if any(token in column_lower for token in ["processdate", "process_dt", "eventdate", "event_date"]):
            return "yyyymmdd"
        return "timestamp"

    def get_standardized_date_column_name(self, datetime_column: str) -> str:
        return f"{datetime_column}_date"

    @property
    def excel_path(self) -> str:
        return os.path.join(self.output_base_path, self.excel_output)

    @property
    def csv_path(self) -> str:
        return os.path.join(self.output_base_path, self.csv_output)


def create_configs_from_pairs(config_rows: List[Dict]) -> List[DataComparisonConfig]:
    """Build configs from a list of dictionaries for batch execution."""
    return [DataComparisonConfig(**row) for row in config_rows]


def _normalize_datetime_columns(value) -> List[str]:
    """Normalize datetime column input into a list of column names."""
    if value is None:
        return []
    if isinstance(value, str):
        return [value]
    return list(value)


def create_configs_from_nested_dict(nested_config: Dict) -> List[DataComparisonConfig]:
    """Create configs from grouped source tables.

    Expected shape:
    {
        "group": (
            [source_1_table_desc, ...],
            [source_2_table_desc, ...],
        )
    }

    Each table description can be:
    - str: table/path
    - tuple/list: (table, primary_key, count_key, domain_name, datetime_columns)
    - dict: {table, primary_key, count_key, domain_name, datetime_columns}
    """
    configs: List[DataComparisonConfig] = []

    for _, pair in nested_config.items():
        if len(pair) != 2:
            raise ValueError("Each nested config entry must contain (source_1_tables, source_2_tables)")

        source_1_tables, source_2_tables = pair
        if len(source_1_tables) != len(source_2_tables):
            raise ValueError("source_1/source_2 table list lengths must match")

        for source_1_desc, source_2_desc in zip(source_1_tables, source_2_tables):
            if isinstance(source_1_desc, str):
                source_1_table = source_1_desc
                primary_key = "id"
                count_key = "id"
                domain_name = source_1_table.rstrip("/").split("/")[-1]
                datetime_columns = []
            elif isinstance(source_1_desc, (tuple, list)):
                source_1_table = source_1_desc[0]
                primary_key = source_1_desc[1] if len(source_1_desc) > 1 else "id"
                count_key = source_1_desc[2] if len(source_1_desc) > 2 else primary_key
                domain_name = source_1_desc[3] if len(source_1_desc) > 3 else source_1_table.rstrip("/").split("/")[-1]
                datetime_columns = _normalize_datetime_columns(source_1_desc[4]) if len(source_1_desc) > 4 else []
            elif isinstance(source_1_desc, dict):
                source_1_table = source_1_desc["table"]
                primary_key = source_1_desc.get("primary_key", "id")
                count_key = source_1_desc.get("count_key", primary_key)
                domain_name = source_1_desc.get("domain_name", source_1_table.rstrip("/").split("/")[-1])
                datetime_columns = _normalize_datetime_columns(source_1_desc.get("datetime_columns", []))
            else:
                raise ValueError(f"Unsupported source_1 descriptor type: {type(source_1_desc)}")

            if isinstance(source_2_desc, str):
                source_2_table = source_2_desc
            elif isinstance(source_2_desc, (tuple, list)):
                source_2_table = source_2_desc[0]
            elif isinstance(source_2_desc, dict):
                source_2_table = source_2_desc["table"]
            else:
                raise ValueError(f"Unsupported source_2 descriptor type: {type(source_2_desc)}")

            configs.append(
                DataComparisonConfig(
                    source_1_table=source_1_table,
                    source_2_table=source_2_table,
                    primary_key=primary_key,
                    count_key=count_key,
                    domain_name=domain_name,
                    datetime_columns=datetime_columns,
                )
            )

    return configs


## Analyzer Workflow

`DataComparisonAnalyzer` executes the end-to-end comparison pipeline.

1. Load `source_1` and `source_2` (table or S3 parquet).
2. Apply optional date filters.
3. Compare total counts.
4. Compute per-column null and distinct metrics.
5. Identify key-level missing records in each source.
6. Produce optional date-wise and enhanced schema diagnostics.
7. Export reports to Excel/CSV.


In [3]:
class DataComparisonAnalyzer:
    """Run dataset-level comparison, quality checks, and exports."""

    def __init__(self, config: DataComparisonConfig):
        self.config = config
        self.spark = get_spark_session()

        self.df_source_1: Optional[DataFrame] = None
        self.df_source_2: Optional[DataFrame] = None
        self.results: Dict = {}

        self._s3_configured = False

    def _is_serverless(self) -> bool:
        try:
            return self.spark.conf.get(
                "spark.databricks.clusterUsageTags.serverless", "false"
            ).lower() == "true"
        except Exception:
            return False

    def _safe_persist(self, df: DataFrame) -> DataFrame:
        if not self.config.enable_persist:
            return df
        if self._is_serverless():
            print("Persist skipped: Databricks serverless does not support persist().")
            enable_persist: bool = True
            return df
        try:
            df.persist(StorageLevel.MEMORY_AND_DISK)
        except Exception as exc:
            print(f"Persist skipped: {exc}")
        return df

    def _configure_s3_authentication(self):
        """Configure Spark S3 access using explicit or default AWS credentials."""
        if self._s3_configured:
            return

        access_key = self.config.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
        secret_key = self.config.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
        session_token = self.config.aws_session_token or os.getenv("AWS_SESSION_TOKEN")
        region = self.config.aws_region or os.getenv("AWS_REGION") or "us-east-1"

        hconf = self.spark.sparkContext._jsc.hadoopConfiguration()
        hconf.set("fs.s3a.endpoint.region", region)

        if access_key and secret_key:
            hconf.set("fs.s3a.access.key", access_key)
            hconf.set("fs.s3a.secret.key", secret_key)
            if session_token:
                hconf.set("fs.s3a.session.token", session_token)
                hconf.set(
                    "fs.s3a.aws.credentials.provider",
                    "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider",
                )
            else:
                hconf.set(
                    "fs.s3a.aws.credentials.provider",
                    "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider",
                )
            boto3.Session(
                aws_access_key_id=access_key,
                aws_secret_access_key=secret_key,
                aws_session_token=session_token,
                region_name=region,
            )
            print("S3 authentication: explicit AWS credentials")
        else:
            hconf.set(
                "fs.s3a.aws.credentials.provider",
                "com.amazonaws.auth.DefaultAWSCredentialsProviderChain",
            )
            boto3.Session(region_name=region)
            print("S3 authentication: default AWS credential chain")

        self._s3_configured = True

    @staticmethod
    def _is_s3_path(path: str) -> bool:
        return path.startswith("s3://") or path.startswith("s3a://")

    def _iter_partition_paths(self, base_path: str) -> List[str]:
        if not self.config.partition_column:
            return []
        if not self.config.filter_date_start or not self.config.filter_date_end:
            return []

        start = datetime.strptime(self.config.filter_date_start, self.config.date_format)
        end = datetime.strptime(self.config.filter_date_end, self.config.date_format)

        paths = []
        current = start
        while current <= end:
            date_str = current.strftime(self.config.date_format)
            paths.append(f"{base_path}{self.config.partition_column}={date_str}/")
            current += timedelta(days=1)
        return paths

    def _load_from_s3(self, path: str) -> DataFrame:
        self._configure_s3_authentication()
        base_path = path if path.endswith("/") else f"{path}/"

        partition_paths = self._iter_partition_paths(base_path)
        if partition_paths:
            existing_paths = []
            for p in partition_paths:
                try:
                    probe = self.spark.read.format("parquet").load(p)
                    if probe.head(1):
                        existing_paths.append(p)
                except Exception:
                    pass

            if existing_paths:
                return self.spark.read.format("parquet").option("basePath", base_path).load(*existing_paths)

        return self.spark.read.format("parquet").load(base_path)

    def _apply_datetime_filters(self, df: DataFrame, datetime_col: str) -> DataFrame:
        if self.config.skip_date_filter:
            return df
        if not self.config.filter_date_start or not self.config.filter_date_end:
            return df
        if datetime_col not in df.columns:
            return df

        col_type = self.config.infer_datetime_type(datetime_col)
        start_dt = datetime.strptime(self.config.filter_date_start, self.config.date_format)
        end_dt = datetime.strptime(self.config.filter_date_end, self.config.date_format) + timedelta(days=1)

        if col_type == "timestamp":
            return df.filter((col(datetime_col) >= lit(start_dt)) & (col(datetime_col) < lit(end_dt)))
        if col_type == "iso":
            return df.filter(
                (col(datetime_col) >= lit(start_dt.isoformat())) &
                (col(datetime_col) < lit(end_dt.isoformat()))
            )
        return df.filter(
            (col(datetime_col) >= lit(self.config.filter_date_start)) &
            (col(datetime_col) <= lit(self.config.filter_date_end))
        )

    def _apply_date_filters(self, df: DataFrame) -> DataFrame:
        if not self.config.datetime_columns:
            return df
        if not self.config.filter_date_start or not self.config.filter_date_end:
            return df
        if self.config.skip_date_filter:
            return df

        filtered = df
        for dt_col in self.config.datetime_columns:
            if dt_col in filtered.columns:
                filtered = self._apply_datetime_filters(filtered, dt_col)
        return filtered

    def _load_source(self, table_or_path: str) -> DataFrame:
        """Load a source from S3 parquet path or Spark catalog table, then filter."""
        if self._is_s3_path(table_or_path):
            df = self._load_from_s3(table_or_path)
        else:
            df = self.spark.table(table_or_path)

        df = self._apply_date_filters(df)
        return df

    def load_source_1_data(self) -> DataFrame:
        print(f"Loading source_1: {self.config.source_1_table}")
        self.df_source_1 = self._load_source(self.config.source_1_table)
        if(self.config.enable_persist):
            self.df_source_1 = self._safe_persist(self.df_source_1)
        print(f"source_1 rows: {self.df_source_1.count():,}")
        return self.df_source_1

    def load_source_2_data(self) -> DataFrame:
        print(f"Loading source_2: {self.config.source_2_table}")
        self.df_source_2 = self._load_source(self.config.source_2_table)
        if(self.config.enable_persist):
            self.df_source_2 = self._safe_persist(self.df_source_2)
        print(f"source_2 rows: {self.df_source_2.count():,}")
        return self.df_source_2

    def _require_loaded_data(self) -> Tuple[DataFrame, DataFrame]:
        if self.df_source_1 is None or self.df_source_2 is None:
            raise ValueError("Source data is not loaded. Call load_source_1_data() and load_source_2_data() first.")
        return self.df_source_1, self.df_source_2

    def compare_counts(self) -> Dict[str, int]:
        df_source_1, df_source_2 = self._require_loaded_data()
        source_1_count = df_source_1.count()
        source_2_count = df_source_2.count()
        return {
            "source_1_count": source_1_count,
            "source_2_count": source_2_count,
            "count_difference": source_1_count - source_2_count,
        }

    def analyze_columns(self, df_source_1: DataFrame, df_source_2: DataFrame) -> pd.DataFrame:
        common_cols = sorted(set(df_source_1.columns).intersection(set(df_source_2.columns)))

        records = []
        for column_name in common_cols:
            stats_1 = df_source_1.select(
                count(lit(1)).alias("total_count"),
                count(when(col(column_name).isNull(), 1)).alias("null_count"),
                countDistinct(col(column_name)).alias("distinct_count"),
            ).collect()[0]

            stats_2 = df_source_2.select(
                count(lit(1)).alias("total_count"),
                count(when(col(column_name).isNull(), 1)).alias("null_count"),
                countDistinct(col(column_name)).alias("distinct_count"),
            ).collect()[0]

            total_1 = max(int(stats_1["total_count"]), 1)
            total_2 = max(int(stats_2["total_count"]), 1)

            records.append({
                "column_name": column_name,
                "null_count_source_1": int(stats_1["null_count"]),
                "null_count_source_2": int(stats_2["null_count"]),
                "distinct_count_source_1": int(stats_1["distinct_count"]),
                "distinct_count_source_2": int(stats_2["distinct_count"]),
                "total_count_source_1": int(stats_1["total_count"]),
                "total_count_source_2": int(stats_2["total_count"]),
                "null_pct_source_1": round((int(stats_1["null_count"]) / total_1) * 100, 2),
                "null_pct_source_2": round((int(stats_2["null_count"]) / total_2) * 100, 2),
            })

        return pd.DataFrame(records)

    def find_missing_records(self, key_column: str):
        """Return anti-join results for records missing in each source by key."""
        df_source_1, df_source_2 = self._require_loaded_data()

        if key_column not in df_source_1.columns or key_column not in df_source_2.columns:
            raise ValueError(f"Key column '{key_column}' must exist in both sources")

        missing_in_source_2 = df_source_1.join(
            df_source_2.select(key_column).distinct(),
            on=[key_column],
            how="left_anti",
        )

        missing_in_source_1 = df_source_2.join(
            df_source_1.select(key_column).distinct(),
            on=[key_column],
            how="left_anti",
        )

        return missing_in_source_2, missing_in_source_1

    def _convert_datetime_to_date(self, df: DataFrame, datetime_col: str) -> DataFrame:
        if datetime_col not in df.columns:
            return df

        datetime_type = self.config.infer_datetime_type(datetime_col)
        date_col = self.config.get_standardized_date_column_name(datetime_col)

        if datetime_type == "yyyymmdd":
            return df.withColumn(date_col, to_date(col(datetime_col).cast("string"), "yyyyMMdd"))
        if datetime_type == "iso":
            return df.withColumn(date_col, to_date(col(datetime_col)))
        return df.withColumn(date_col, to_date(col(datetime_col).cast("timestamp")))

    def _filter_by_date_col(self, df: DataFrame, date_col: str) -> DataFrame:
        if self.config.skip_date_filter:
            return df
        if not self.config.filter_date_start or not self.config.filter_date_end:
            return df
        if date_col not in df.columns:
            return df

        start_date = datetime.strptime(self.config.filter_date_start, self.config.date_format).date()
        end_date = datetime.strptime(self.config.filter_date_end, self.config.date_format).date()

        filtered = df.filter(
            (col(date_col).isNotNull()) &
            (col(date_col) >= lit(start_date)) &
            (col(date_col) <= lit(end_date))
        )
        if filtered.count() == 0 and df.count() > 0:
            print(f"Warning: no valid dates found in {date_col}; using unfiltered data.")
            return df
        return filtered

    def analyze_by_date(
        self, df_source_1: DataFrame, df_source_2: DataFrame, key_column: str, datetime_col: Optional[str] = None
    ) -> pd.DataFrame:
        if datetime_col is None:
            if not self.config.datetime_columns:
                return pd.DataFrame(columns=["date", "source_1_count", "source_2_count", "difference"])
            datetime_col = self.config.datetime_columns[0]

        df_source_1 = self._convert_datetime_to_date(df_source_1, datetime_col)
        df_source_2 = self._convert_datetime_to_date(df_source_2, datetime_col)
        date_col = self.config.get_standardized_date_column_name(datetime_col)

        if date_col not in df_source_1.columns or date_col not in df_source_2.columns:
            return pd.DataFrame(columns=["date", "source_1_count", "source_2_count", "difference"])

        df_source_1 = self._filter_by_date_col(df_source_1, date_col)
        df_source_2 = self._filter_by_date_col(df_source_2, date_col)

        source_1_by_date = df_source_1.groupBy(date_col).agg(count(key_column).alias("source_1_count"))
        source_2_by_date = df_source_2.groupBy(date_col).agg(count(key_column).alias("source_2_count"))

        joined = source_1_by_date.join(source_2_by_date, on=[date_col], how="outer").fillna(0)
        joined = joined.withColumn("difference", col("source_1_count") - col("source_2_count"))
        joined = joined.withColumn(date_col, col(date_col).cast("string"))
        return joined.orderBy(date_col).toPandas().rename(columns={date_col: "date"})

    def recheck_missing_records(self, key_column: Optional[str] = None) -> Dict:
        key_column = key_column or self.config.primary_key
        df_source_1, df_source_2 = self._require_loaded_data()

        if "missing_in_source_2" not in self.results or "missing_in_source_1" not in self.results:
            missing_in_source_2, missing_in_source_1 = self.find_missing_records(key_column)
        else:
            missing_in_source_2 = self.results["missing_in_source_2"]
            missing_in_source_1 = self.results["missing_in_source_1"]

        source_2_keys = df_source_2.select(key_column).distinct()
        source_1_keys = df_source_1.select(key_column).distinct()

        still_missing_in_source_2 = missing_in_source_2.join(source_2_keys, on=[key_column], how="left_anti")
        still_missing_in_source_1 = missing_in_source_1.join(source_1_keys, on=[key_column], how="left_anti")

        result = {
            "missing_in_source_2_original": missing_in_source_2.count(),
            "missing_in_source_1_original": missing_in_source_1.count(),
            "still_missing_in_source_2": still_missing_in_source_2.count(),
            "still_missing_in_source_1": still_missing_in_source_1.count(),
            "found_in_source_2_on_recheck": missing_in_source_2.count() - still_missing_in_source_2.count(),
            "found_in_source_1_on_recheck": missing_in_source_1.count() - still_missing_in_source_1.count(),
            "still_missing_source_2_df": still_missing_in_source_2,
            "still_missing_source_1_df": still_missing_in_source_1,
        }
        self.results["recheck_results"] = result
        return result

    def print_recheck_results(self):
        if "recheck_results" not in self.results:
            print("No recheck results found. Run recheck_missing_records() first.")
            return

        r = self.results["recheck_results"]
        print("=" * 80)
        print("Recheck Summary")
        print("=" * 80)
        print(f"Original missing in source_2: {r['missing_in_source_2_original']:,}")
        print(f"Found in source_2 on recheck: {r['found_in_source_2_on_recheck']:,}")
        print(f"Still missing in source_2: {r['still_missing_in_source_2']:,}")
        print(f"Original missing in source_1: {r['missing_in_source_1_original']:,}")
        print(f"Found in source_1 on recheck: {r['found_in_source_1_on_recheck']:,}")
        print(f"Still missing in source_1: {r['still_missing_in_source_1']:,}")

    def _analyze_missing_by_date_impl(self, missing_df: DataFrame, source_label: str, datetime_col: str, key_column: str) -> pd.DataFrame:
        if missing_df is None:
            return pd.DataFrame(columns=["date", "missing_count", "source"])

        df = self._convert_datetime_to_date(missing_df, datetime_col)
        date_col = self.config.get_standardized_date_column_name(datetime_col)
        if date_col not in df.columns:
            return pd.DataFrame(columns=["date", "missing_count", "source"])

        df = self._filter_by_date_col(df, date_col)
        out = df.groupBy(date_col).agg(count(key_column).alias("missing_count")).orderBy(date_col)
        out = out.withColumn(date_col, col(date_col).cast("string"))
        out = out.toPandas().rename(columns={date_col: "date"})
        out["source"] = source_label
        return out

    def analyze_missing_by_date(
        self, key_column: Optional[str] = None, datetime_col: Optional[str] = None
    ) -> Dict[str, pd.DataFrame]:
        key_column = key_column or self.config.primary_key
        if datetime_col is None:
            if not self.config.datetime_columns:
                empty = pd.DataFrame(columns=["date", "missing_count", "source"])
                return {"missing_in_source_2_by_date": empty, "missing_in_source_1_by_date": empty}
            datetime_col = self.config.datetime_columns[0]

        if "missing_in_source_2" not in self.results or "missing_in_source_1" not in self.results:
            self.results["missing_in_source_2"], self.results["missing_in_source_1"] = self.find_missing_records(key_column)

        source_2_pdf = self._analyze_missing_by_date_impl(
            self.results["missing_in_source_2"], "source_2", datetime_col, key_column
        )
        source_1_pdf = self._analyze_missing_by_date_impl(
            self.results["missing_in_source_1"], "source_1", datetime_col, key_column
        )

        self.results["missing_in_source_2_by_date"] = source_2_pdf
        self.results["missing_in_source_1_by_date"] = source_1_pdf
        return {
            "missing_in_source_2_by_date": source_2_pdf,
            "missing_in_source_1_by_date": source_1_pdf,
        }

    def analyze_nulls_by_date(
        self, df: DataFrame, key_column: str, datetime_col: str, source_label: Optional[str] = None
    ) -> pd.DataFrame:
        df = self._convert_datetime_to_date(df, datetime_col)
        date_col = self.config.get_standardized_date_column_name(datetime_col)
        if date_col not in df.columns:
            return pd.DataFrame(columns=["date", "total_count", "key_null_count", "key_null_pct", "source"])

        df = self._filter_by_date_col(df, date_col)
        out = df.groupBy(date_col).agg(
            count(lit(1)).alias("total_count"),
            count(when(col(key_column).isNull(), 1)).alias("key_null_count"),
        )
        out = out.withColumn("key_null_pct", (col("key_null_count") * lit(100.0) / col("total_count")))
        out = out.withColumn(date_col, col(date_col).cast("string"))
        pdf = out.orderBy(date_col).toPandas().rename(columns={date_col: "date"})
        if source_label:
            pdf["source"] = source_label
        return pdf

    def analyze_columns_with_datatypes(self, df: DataFrame, source_name: str) -> pd.DataFrame:
        rows = []
        total_count = df.count()
        for field in df.schema.fields:
            column_name = field.name
            null_count = df.select(count(when(col(column_name).isNull(), 1)).alias("n")).collect()[0]["n"]
            distinct_count = df.select(countDistinct(col(column_name)).alias("d")).collect()[0]["d"]
            rows.append({
                "source": source_name,
                "column_name": column_name,
                "data_type": str(field.dataType),
                "null_count": int(null_count),
                "distinct_count": int(distinct_count),
                "total_count": int(total_count),
                "null_pct": round((int(null_count) / max(total_count, 1)) * 100, 2),
            })
        return pd.DataFrame(rows)

    def compare_column_schemas(self) -> Dict[str, List[str]]:
        df_source_1, df_source_2 = self._require_loaded_data()
        source_1_cols = set(df_source_1.columns)
        source_2_cols = set(df_source_2.columns)
        return {
            "only_in_source_1": sorted(source_1_cols - source_2_cols),
            "only_in_source_2": sorted(source_2_cols - source_1_cols),
            "common": sorted(source_1_cols & source_2_cols),
        }

    def highlight_high_null_columns(self, threshold_pct: float = 80.0) -> Dict[str, pd.DataFrame]:
        if "column_analysis" not in self.results:
            df_source_1, df_source_2 = self._require_loaded_data()
            self.results["column_analysis"] = self.analyze_columns(df_source_1, df_source_2)

        column_analysis = self.results["column_analysis"]
        source_1_high = column_analysis[column_analysis["null_pct_source_1"] >= threshold_pct][[
            "column_name", "null_pct_source_1", "null_count_source_1", "total_count_source_1"
        ]]
        source_2_high = column_analysis[column_analysis["null_pct_source_2"] >= threshold_pct][[
            "column_name", "null_pct_source_2", "null_count_source_2", "total_count_source_2"
        ]]

        result = {
            "source_1_high_nulls": source_1_high,
            "source_2_high_nulls": source_2_high,
        }
        self.results["high_null_analysis"] = result
        return result

    def run_full_analysis(self) -> Dict:
        """Execute baseline comparison workflow and store all generated artifacts."""
        self.load_source_1_data()
        self.load_source_2_data()

        # Baseline count and column-level health metrics.
        self.results.update(self.compare_counts())
        df_source_1, df_source_2 = self._require_loaded_data()
        self.results["column_analysis"] = self.analyze_columns(df_source_1, df_source_2)

        # Date-level record volume comparison for each configured datetime column.
        for datetime_col in self.config.datetime_columns:
            date_analysis = self.analyze_by_date(
                df_source_1, df_source_2, self.config.count_key, datetime_col
            )
            self.results[f"date_analysis_{datetime_col}"] = date_analysis
        if self.config.datetime_columns:
            first_datetime = self.config.datetime_columns[0]
            self.results["date_analysis"] = self.results.get(f"date_analysis_{first_datetime}")

        # Key-level anti-join checks to find records absent in each source.
        missing_in_source_2, missing_in_source_1 = self.find_missing_records(self.config.primary_key)
        self.results["missing_in_source_2"] = missing_in_source_2
        self.results["missing_in_source_1"] = missing_in_source_1

        for datetime_col in self.config.datetime_columns:
            missing_by_date = self.analyze_missing_by_date(self.config.primary_key, datetime_col)
            self.results[f"missing_by_date_{datetime_col}"] = missing_by_date
        if self.config.datetime_columns:
            first_datetime = self.config.datetime_columns[0]
            self.results["missing_by_date"] = self.results.get(f"missing_by_date_{first_datetime}")

        for datetime_col in self.config.datetime_columns:
            nulls_source_1 = self.analyze_nulls_by_date(
                df_source_1, self.config.primary_key, datetime_col, "source_1"
            )
            nulls_source_2 = self.analyze_nulls_by_date(
                df_source_2, self.config.primary_key, datetime_col, "source_2"
            )
            self.results[f"nulls_by_date_{datetime_col}"] = {
                "source_1": nulls_source_1,
                "source_2": nulls_source_2,
            }

        return self.results

    def display_summary(self):
        if not self.results:
            print("No results found. Run full analysis first.")
            return

        print("=" * 80)
        print(f"Summary: {self.config.domain_name}")
        print("=" * 80)
        print(f"source_1 table: {self.config.source_1_table}")
        print(f"source_2 table: {self.config.source_2_table}")
        print(f"source_1 count: {self.results['source_1_count']:,}")
        print(f"source_2 count: {self.results['source_2_count']:,}")
        print(f"difference: {self.results['count_difference']:,}")
        print(f"missing in source_2: {self.results['missing_in_source_2'].count():,}")
        print(f"missing in source_1: {self.results['missing_in_source_1'].count():,}")

    def _cast_datetime_columns_to_string(self, df: DataFrame) -> DataFrame:
        for field in df.schema.fields:
            dtype = str(field.dataType).lower()
            if "date" in dtype or "timestamp" in dtype:
                df = df.withColumn(field.name, col(field.name).cast("string"))
        return df

    def export_results(self):
        if "column_analysis" not in self.results:
            raise ValueError("No analysis results available. Run full analysis first.")

        self.results["column_analysis"].to_excel(self.config.excel_path, index=False)
        self.results["column_analysis"].to_csv(self.config.csv_path, index=False)

        missing_source_2_path = os.path.join(
            self.config.output_base_path,
            f"{self.config.domain_name}_missing_in_source_2.xlsx",
        )
        missing_source_1_path = os.path.join(
            self.config.output_base_path,
            f"{self.config.domain_name}_missing_in_source_1.xlsx",
        )

        missing_source_2 = self._cast_datetime_columns_to_string(self.results["missing_in_source_2"])
        missing_source_1 = self._cast_datetime_columns_to_string(self.results["missing_in_source_1"])
        missing_source_2.toPandas().to_excel(missing_source_2_path, index=False)
        missing_source_1.toPandas().to_excel(missing_source_1_path, index=False)

        print(f"Saved: {self.config.excel_path}")
        print(f"Saved: {self.config.csv_path}")
        print(f"Saved: {missing_source_2_path}")
        print(f"Saved: {missing_source_1_path}")

    def run_enhanced_analysis(self) -> Dict:
        df_source_1, df_source_2 = self._require_loaded_data()

        source_1_enhanced = self.analyze_columns_with_datatypes(df_source_1, "source_1")
        source_2_enhanced = self.analyze_columns_with_datatypes(df_source_2, "source_2")
        schema_comparison = self.compare_column_schemas()
        high_null_analysis = self.highlight_high_null_columns()

        enhanced = {
            "source_1_enhanced_analysis": source_1_enhanced,
            "source_2_enhanced_analysis": source_2_enhanced,
            "schema_comparison": schema_comparison,
            "high_null_analysis": high_null_analysis,
        }
        self.results.update(enhanced)
        return enhanced

    def export_enhanced_to_excel(self, output_path: Optional[str] = None):
        if "schema_comparison" not in self.results or "high_null_analysis" not in self.results:
            raise ValueError("Enhanced analysis not found. Run run_enhanced_analysis() first.")

        output_path = output_path or os.path.join(
            self.config.output_base_path, f"{self.config.domain_name}_enhanced.xlsx"
        )
        with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
            self.results["source_1_enhanced_analysis"].to_excel(
                writer, sheet_name="Source1Columns", index=False
            )
            self.results["source_2_enhanced_analysis"].to_excel(
                writer, sheet_name="Source2Columns", index=False
            )
            pd.DataFrame(self.results["schema_comparison"]["only_in_source_1"], columns=["only_in_source_1"]).to_excel(
                writer, sheet_name="OnlyInSource1", index=False
            )
            pd.DataFrame(self.results["schema_comparison"]["only_in_source_2"], columns=["only_in_source_2"]).to_excel(
                writer, sheet_name="OnlyInSource2", index=False
            )
            pd.DataFrame(self.results["schema_comparison"]["common"], columns=["common"]).to_excel(
                writer, sheet_name="CommonColumns", index=False
            )
            self.results["high_null_analysis"]["source_1_high_nulls"].to_excel(
                writer, sheet_name="HighNullsSource1", index=False
            )
            self.results["high_null_analysis"]["source_2_high_nulls"].to_excel(
                writer, sheet_name="HighNullsSource2", index=False
            )

        print(f"Saved: {output_path}")

    def export_to_excel(self):
        """Export baseline and date-based outputs into a multi-sheet Excel workbook."""
        if "column_analysis" not in self.results:
            raise ValueError("No analysis results available. Run full analysis first.")

        with pd.ExcelWriter(self.config.excel_path, engine="openpyxl") as writer:
            self.results["column_analysis"].to_excel(writer, sheet_name="ColumnAnalysis", index=False)

            for datetime_col in self.config.datetime_columns:
                date_key = f"date_analysis_{datetime_col}"
                if date_key in self.results and not self.results[date_key].empty:
                    sheet = f"Dates_{datetime_col}"[:31]
                    self.results[date_key].to_excel(writer, sheet_name=sheet, index=False)

                missing_key = f"missing_by_date_{datetime_col}"
                if missing_key in self.results:
                    missing_by_date = self.results[missing_key]
                    if not missing_by_date["missing_in_source_2_by_date"].empty:
                        sheet = f"MissingSrc2_{datetime_col}"[:31]
                        missing_by_date["missing_in_source_2_by_date"].to_excel(writer, sheet_name=sheet, index=False)
                    if not missing_by_date["missing_in_source_1_by_date"].empty:
                        sheet = f"MissingSrc1_{datetime_col}"[:31]
                        missing_by_date["missing_in_source_1_by_date"].to_excel(writer, sheet_name=sheet, index=False)

                nulls_key = f"nulls_by_date_{datetime_col}"
                if nulls_key in self.results:
                    nulls_by_date = self.results[nulls_key]
                    if not nulls_by_date["source_1"].empty:
                        sheet = f"NullsSrc1_{datetime_col}"[:31]
                        nulls_by_date["source_1"].to_excel(writer, sheet_name=sheet, index=False)
                    if not nulls_by_date["source_2"].empty:
                        sheet = f"NullsSrc2_{datetime_col}"[:31]
                        nulls_by_date["source_2"].to_excel(writer, sheet_name=sheet, index=False)

        print(f"Saved: {self.config.excel_path}")

    def export_to_csv(self):
        if "column_analysis" not in self.results:
            raise ValueError("No analysis results available. Run full analysis first.")
        self.results["column_analysis"].to_csv(self.config.csv_path, index=False)
        print(f"Saved: {self.config.csv_path}")

    def export_missing_records(self):
        if "missing_in_source_2" not in self.results or "missing_in_source_1" not in self.results:
            raise ValueError("Missing-record results not found. Run full analysis first.")

        missing_source_2_path = os.path.join(
            self.config.output_base_path, f"{self.config.domain_name}_missing_in_source_2.xlsx"
        )
        missing_source_1_path = os.path.join(
            self.config.output_base_path, f"{self.config.domain_name}_missing_in_source_1.xlsx"
        )

        missing_source_2 = self._cast_datetime_columns_to_string(self.results["missing_in_source_2"])
        missing_source_1 = self._cast_datetime_columns_to_string(self.results["missing_in_source_1"])
        missing_source_2.toPandas().to_excel(missing_source_2_path, index=False)
        missing_source_1.toPandas().to_excel(missing_source_1_path, index=False)

        print(f"Saved: {missing_source_2_path}")
        print(f"Saved: {missing_source_1_path}")


## Batch Execution

Use `run_batch_comparison(...)` to process multiple `DataComparisonConfig` objects in one run.

- Continues to the next dataset even if one comparison fails.
- Optionally runs enhanced schema/high-null analysis.
- Optionally rechecks missing records after initial detection.
- Writes per-domain outputs under each config `output_base_path`.


In [4]:
def run_batch_comparison(
    table_configs: List[DataComparisonConfig],
    export_results: bool = True,
    run_enhanced_analysis: bool = False,
    run_recheck: bool = True,
) -> Dict:
    """Run comparisons for multiple configs and return per-domain results."""
    all_results = {}

    print("#" * 80)
    print(f"Running batch for {len(table_configs)} table pairs")
    print("#" * 80)

    for index, config in enumerate(table_configs, start=1):
        print("=" * 80)
        print(f"[{index}/{len(table_configs)}] {config.domain_name}")
        print("=" * 80)

        try:
            analyzer = DataComparisonAnalyzer(config)
            results = analyzer.run_full_analysis()

            if run_recheck:
                analyzer.recheck_missing_records(config.primary_key)

            if run_enhanced_analysis:
                enhanced_results = analyzer.run_enhanced_analysis()
                results['enhanced'] = enhanced_results

            analyzer.display_summary()

            if export_results:
                analyzer.export_to_excel()
                analyzer.export_to_csv()
                analyzer.export_missing_records()
                if run_enhanced_analysis:
                    analyzer.export_enhanced_to_excel()

            all_results[config.domain_name] = {
                'analyzer': analyzer,
                'results': results,
            }
            print(f"Completed: {config.domain_name}")
        except Exception as exc:
            # Keep batch execution resilient: capture error and continue.
            all_results[config.domain_name] = {'error': str(exc)}
            print(f"Failed: {config.domain_name} -> {exc}")

    return all_results


## Example Usage

Use the template below for real sources (S3 or catalog tables).

`run_full_analysis()` computes metrics in memory. To save files under `output_base_path`, call the export methods.

```python
config = DataComparisonConfig(
    source_1_table="s3a://your-bucket/source_1/path/",
    source_2_table="s3a://your-bucket/source_2/path/",
    primary_key="id",
    domain_name="example_dataset",
    count_key="id",
    datetime_columns=["your_datetime_column"],
    filter_date_start="20250101",
    filter_date_end="20250131",
    partition_column=None,
    aws_region="us-east-1",
    aws_access_key_id=None,
    aws_secret_access_key=None,
    aws_session_token=None,
    output_base_path="./outputs",
)

analyzer = DataComparisonAnalyzer(config)
results = analyzer.run_full_analysis()
analyzer.display_summary()
analyzer.export_to_csv()
analyzer.export_to_excel()
analyzer.export_missing_records()
```


In [5]:
# Uncomment and adjust for a real dataset run.
# config = DataComparisonConfig(
#     source_1_table="s3a://your-bucket/source_1/path/",
#     source_2_table="s3a://your-bucket/source_2/path/",
#     primary_key="id",
#     domain_name="example_dataset",
#     count_key="id",
#     datetime_columns=["your_datetime_column"],
#     filter_date_start="20250101",
#     filter_date_end="20250131",
#     partition_column=None,
#     aws_region="us-east-1",
#     aws_access_key_id=None,
#     aws_secret_access_key=None,
#     aws_session_token=None,
#     output_base_path="./outputs",
# )
#
# analyzer = DataComparisonAnalyzer(config)
# results = analyzer.run_full_analysis()
# analyzer.display_summary()
# analyzer.export_to_excel()
# analyzer.export_to_csv()
# analyzer.export_missing_records()


In [8]:
# Local smoke test with in-memory sample data.
# This verifies the analyzer workflow without external dependencies.
spark = (
    SparkSession.builder
    .master("local[*]")
    .appName("TestApp")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .config("spark.connect.enabled", "false")  # Avoid Spark Connect behavior in local mode
    .getOrCreate()
)

source_1_df = spark.createDataFrame([
    (1, "2025-01-10", "A"),
    (2, "2025-01-11", "B"),
    (3, "2025-01-12", None),
    (4, "2025-01-13", "D"),
], ["id", "event_ts", "val"])

source_2_df = spark.createDataFrame([
    (1, "2025-01-10", "A"),
    (2, "2025-01-11", None),
    (5, "2025-01-14", "E"),
], ["id", "event_ts", "val"])

source_1_df.createOrReplaceTempView("source_1_test")
source_2_df.createOrReplaceTempView("source_2_test")

config = DataComparisonConfig(
    source_1_table="source_1_test",
    source_2_table="source_2_test",
    primary_key="id",
    count_key="id",
    domain_name="smoke_test",
    datetime_columns=["event_ts"],
    filter_date_start="20250101",
    filter_date_end="20250131",
    output_base_path="./outputs",
    enable_persist=False,  # Keep disabled in local-only smoke tests
)

analyzer = DataComparisonAnalyzer(config)
results = analyzer.run_full_analysis()
analyzer.display_summary()

# Persist outputs to disk.
analyzer.export_to_csv()
analyzer.export_to_excel()
analyzer.export_missing_records()

results["column_analysis"].head()


Loading source_1: source_1_test
source_1 rows: 4
Loading source_2: source_2_test
source_2 rows: 3
Summary: smoke_test
source_1 table: source_1_test
source_2 table: source_2_test
source_1 count: 4
source_2 count: 3
difference: 1
missing in source_2: 2
missing in source_1: 1
Saved: ./outputs\smoke_test_comparison.csv
Saved: ./outputs\smoke_test_comparison.xlsx
Saved: ./outputs\smoke_test_missing_in_source_2.xlsx
Saved: ./outputs\smoke_test_missing_in_source_1.xlsx


Unnamed: 0,column_name,null_count_source_1,null_count_source_2,distinct_count_source_1,distinct_count_source_2,total_count_source_1,total_count_source_2,null_pct_source_1,null_pct_source_2
0,event_ts,0,0,4,3,4,3,0.0,0.0
1,id,0,0,4,3,4,3,0.0,0.0
2,val,1,1,3,2,4,3,25.0,33.33


In [None]:
%conda install openpyxl ss

Channels:
 - defaults
 - conda-forge
Platform: win-64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: c:\Users\arnav\OneDrive\Desktop\Unimelb\.conda

  added / updated specs:
    - openpyxl


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    et_xmlfile-2.0.0           |  py311haa95532_0          35 KB
    openpyxl-3.1.5             |  py311h827c3e9_1         672 KB
    ------------------------------------------------------------
                                           Total:         707 KB

The following NEW packages will be INSTALLED:

  et_xmlfile         pkgs/main/win-64::et_xmlfile-2.0.0-py311haa95532_0 
  openpyxl           pkgs/main/win-64::openpyxl-3.1.5-py311h827c3e9_1 



Downloading and Extracting Packages: ...working...
openpyxl-3.1.5       | 672 KB    |            |   0% 

et_xmlfile-2.0.0     | 35 KB  