## Imports and Notebook Runs

In [None]:
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
from datetime import datetime
from pyspark.sql import DataFrame, SparkSession, Column
from pyspark.sql import functions as F
from pandas import DataFrame as pd_DataFrame
# Common imports double imported for clarity and ease of use
from pyspark.sql.functions import col, lit 

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 3, Finished, Available, Finished)

## **Core Test Suite Logic**

### Helper Classes and Models

In [2]:
class TestResult:
    """Immutable test result object"""
    def __init__(self, test_name: str, table_name: str, passed: bool,
                 error_count: int = 0, total_count: int = 0,
                 details: str = "", execution_time: float = 0.0,
                 sample_mode: bool = False, timestamp: Optional[datetime] = None, example_failures: Optional[pd_DataFrame] = None):
        self.test_name = test_name
        self.table_name = table_name
        self.passed = passed
        self.error_count = error_count
        self.total_count = total_count
        self.success_rate = (total_count - error_count) / total_count if total_count > 0 else 1.0
        self.details = details
        self.execution_time_seconds = execution_time
        self.timestamp = timestamp or datetime.now()
        self.sample_mode = sample_mode
        self._example_failures = example_failures

    def to_dict(self) -> Dict:
        return {
            'test_name': self.test_name,
            'table_name': self.table_name,
            'passed': self.passed,
            'error_count': self.error_count,
            'total_count': self.total_count,
            'success_rate': self.success_rate,
            'details': self.details,
            'execution_time_seconds': self.execution_time_seconds,
            'timestamp': self.timestamp,
            'sample_mode': self.sample_mode,
            # serialize to records for Spark/JSON friendliness
            "failures": (
                self._example_failures.to_dict(orient="records")
                if self._example_failures is not None else None
            ),
        }

class TestResultBuilder:
    """Builder for creating TestResult objects incrementally"""
    
    def __init__(self):
        self._test_name: Optional[str] = None
        self._table_name: Optional[str] = None
        self._passed: Optional[bool] = None
        self._error_count: int = 0
        self._total_count: int = 0
        self._details: str = ""
        self._execution_time: float = 0.0
        self._sample_mode: bool = False
        self._start_time: Optional[datetime] = None
        self._example_failures: Optional[pd_DataFrame] = None
    
    def test_name(self, name: str) -> 'TestResultBuilder':
        """Set the test name"""
        self._test_name = name
        return self
    
    def table_name(self, name: str) -> 'TestResultBuilder':
        """Set the table name"""
        self._table_name = name
        return self
    
    def start_timing(self) -> 'TestResultBuilder':
        """Start timing the test execution"""
        self._start_time = datetime.now()
        return self
    
    def end_timing(self) -> 'TestResultBuilder':
        """End timing and calculate execution time"""
        if self._start_time:
            self._execution_time = (datetime.now() - self._start_time).total_seconds()
        return self
    
    def counts(self, total: int, errors: int = 0) -> 'TestResultBuilder':
        """Set total and error counts"""
        self._total_count = total
        self._error_count = errors
        return self
    
    def passed(self, is_passed: bool) -> 'TestResultBuilder':
        """Set whether the test passed"""
        self._passed = is_passed
        return self
    
    def details(self, details: str) -> 'TestResultBuilder':
        """Set test details/description"""
        self._details = details
        return self
    
    def sample_mode(self, is_sample: bool) -> 'TestResultBuilder':
        """Set whether test was run in sample mode"""
        self._sample_mode = is_sample
        return self
    
    def cache_failure(
        self,
        failures_df: DataFrame,
        limit_rows: int = 5,
        sample_cols: Optional[list[str]] = None,
        max_bytes: int = 256_000,  # ~250 KB guard
    ) -> "TestResultBuilder":
        """Materialize a tiny, driver-safe sample of failing rows as pandas."""
        df = failures_df
        if sample_cols:
            # only keep requested columns if present
            keep = [c for c in sample_cols if c in df.columns]
            if keep:
                df = df.select(*keep)
        pdf = df.limit(limit_rows).toPandas()

        # crude size guard
        if pdf.memory_usage(index=True, deep=True).sum() > max_bytes:
            # trim to first N columns to stay under the cap
            max_cols = max(1, int(len(pdf.columns) / 2))
            pdf = pdf.iloc[:, :max_cols]

        self._example_failures = pdf
        return self
    def success(self, total: int, details: str = "") -> 'TestResultBuilder':
        """Convenience method for successful test"""
        return self.counts(total, 0).passed(True).details(details)
    
    def failure(self, total: int, errors: int, details: str = "") -> 'TestResultBuilder':
        """Convenience method for failed test"""
        return self.counts(total, errors).passed(False).details(details)
    
    def exception(self, error: Exception, details: str = "") -> 'TestResultBuilder':
        """Convenience method for test that threw exception"""
        error_details = f"{details} Exception: {str(error)}" if details else f"Exception: {str(error)}"
        return self.counts(-1, -1).passed(False).details(error_details)
    

    def auto_pass_fail(self, total: int, errors: int, details: str = "") -> 'TestResultBuilder':
        """Automatically set pass/fail based on error count"""
        is_passed = errors == 0
        return self.counts(total, errors).passed(is_passed).details(details)
    
    def auto_build(self, total: int, errors: int, details: str = "") -> TestResult:
        """Automatically set pass/fail and build test result"""
        return self.auto_pass_fail(total, errors, details).end_timing().build()
    @classmethod
    def for_test(cls, test_name: str, table_name: str, sample_mode: bool = False) -> 'TestResultBuilder':
        """Factory method to create a builder with common initial setup"""
        return (cls()
                .test_name(test_name)
                .table_name(table_name)
                .sample_mode(sample_mode)
                .start_timing())
    
    def build(self) -> TestResult:
        """Build the final test result object"""
        if self._test_name is None:
            raise ValueError("Test name is required")
        if self._table_name is None:
            raise ValueError("Table name is required")
        if self._passed is None:
            raise ValueError("Test result (passed/failed) is required")
        
        return TestResult(
            test_name=self._test_name,
            table_name=self._table_name,
            passed=self._passed,
            error_count=self._error_count,
            total_count=self._total_count,
            details=self._details,
            execution_time=self._execution_time,
            sample_mode=self._sample_mode,
            example_failures = self._example_failures
        )

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 4, Finished, Available, Finished)

In [3]:
# TODO: Document the interaction between TestStrategy TestBuilder and TestResult
class TestStrategy(ABC):
    """Abstract base class for all test strategies"""

    @abstractmethod
    def execute(self, df: DataFrame, *args, **kwargs) -> TestResult:
        """Execute the test and return a TestResult"""
        pass
    
    @property
    @abstractmethod
    def test_name(self) -> str:
        "Returns the name of this test"
        pass

class DataProvider:
    """Handles data access and filtering logic"""

    def __init__(self, spark: SparkSession, lakehouse_name: str):
        self.spark = spark
        self.lakehouse_name = lakehouse_name
        self.sample_mode = False
        self.sample_properties = []
    
    def enable_sample_mode(self, sample_size: int = 10, random_seed: Optional[int] = None) -> List[str]:
        """Enable sample mode by randomly selecting properties"""
        print(f'Enabling sample mode with {sample_size} properties...')
        
        # Get total count first
        total_properties = self.spark.sql(f'SELECT COUNT(*) as cnt FROM {self.lakehouse_name}.DIM_PROPERTY').collect()[0]['cnt']
        
        if sample_size >= total_properties:
            print(f'Requested sample size ({sample_size}) >= total properties ({total_properties}). Using all properties.')
            sampled_properties_df = self.spark.sql(f'SELECT Property_UUID FROM {self.lakehouse_name}.DIM_PROPERTY')
            self.sample_properties = [row['Property_UUID'] for row in sampled_properties_df.collect()]
        else:
            # Use orderBy(rand()) and limit for exact sample size
            if random_seed is not None:
                sampled_properties_df = (self.spark.sql(f'SELECT Property_UUID FROM {self.lakehouse_name}.DIM_PROPERTY')
                                        .orderBy(F.rand(random_seed))
                                        .limit(sample_size))
            else:
                sampled_properties_df = (self.spark.sql(f'SELECT Property_UUID FROM {self.lakehouse_name}.DIM_PROPERTY')
                                        .orderBy(F.rand())
                                        .limit(sample_size))
            
            self.sample_properties = [row['Property_UUID'] for row in sampled_properties_df.collect()]
        
        self.sample_mode = True
        actual_count = len(self.sample_properties)
        
        print(f'Sample mode enabled with {actual_count} properties')
        return self.sample_properties
    
    def disable_sample_mode(self):
        """Disables sample mode"""
        self.sample_mode = False
        self.sample_properties = []
        print("Sample mode disabled - tests will run on full dataset")
    
    def get_dataframe(self, table_name: str, filter_column: str = "PROPERTY_UUID", disable_sampling: bool = False) -> DataFrame:
        df = self.spark.table(f"{self.lakehouse_name}.{table_name}")
        if disable_sampling or not self.sample_mode or not self.sample_properties:
            return df
        if filter_column not in df.columns:
            print(f"⚠️  Sampling disabled for {table_name}: no '{filter_column}' column.")
            return df
        return df.filter(F.col(filter_column).isin(self.sample_properties))

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 5, Finished, Available, Finished)

### `TestRunner` object: Manages the test queue and pretty prints

In [None]:
class TestRunner:
    """Orchestrates test execution and result collection"""
    
    def __init__(self, data_provider: DataProvider):
        self.data_provider = data_provider
        self.test_results: List[TestResult] = []
        self.queued_tests: List[Dict[str, Any]] = []
    
    def add_test(self, test_strategy: TestStrategy, table_name: str, 
                 filter_column: str = "PROPERTY_UUID", disable_sampling: bool = False, **test_kwargs):
        """Add a test to the execution queue"""

        # If sample_mode is unspecified, use the data provider's default
        
        test_config = {
            'strategy': test_strategy,
            'table_name': table_name,
            'filter_column': filter_column,
            'disable_sampling': disable_sampling,
            'kwargs': test_kwargs
        }
        self.queued_tests.append(test_config)
        print(f'Added test: {test_strategy.test_name} for {table_name} to queue')
    
    def clear_queue(self):
        """Clear the test queue"""
        self.queued_tests = []
        self.test_results = []
        print('Test queue cleared')
    
    def run_all_tests(self) -> List[TestResult]:
        """Execute all queued tests"""
        print("="*60)
        print("STARTING LAKEHOUSE DATA TESTING SUITE")
        print(f"Mode: {'SAMPLE' if self.data_provider.sample_mode else 'FULL DATASET'}")
        if self.data_provider.sample_mode:
            print(f"Sample Properties: {len(self.data_provider.sample_properties)}")
        print(f"Queued Tests: {len(self.queued_tests)}")
        print("="*60)
        
        self.test_results = []
        
        for test_config in self.queued_tests:
            try:
                strategy = test_config['strategy']
                table_name = test_config['table_name']
                filter_column = test_config['filter_column']
                disable_sampling = test_config['disable_sampling']
                test_kwargs = test_config['kwargs']
                
                print(f"\n🔄 Running {strategy.test_name} on {table_name}...")
                
                df = self.data_provider.get_dataframe(table_name, filter_column, disable_sampling)
                result = strategy.execute(
                    df, 
                    table_name=table_name, 
                    sample_mode=self.data_provider.sample_mode,
                    **test_kwargs
                )
                self.test_results.append(result)
                
                status = "✅ PASSED" if result.passed else "❌ FAILED"
                print(f"{status} - {result.details}")
                
            except Exception as e:
                print(f"❌ Critical error running {strategy.test_name}: {str(e)}")
                error_result = TestResult(
                    strategy.test_name, table_name, False, -1, -1,
                    f"Critical error: {str(e)}", 0.0, self.data_provider.sample_mode
                )
                self.test_results.append(error_result)
        
        self._print_summary()
        return self.test_results
    
    def _print_summary(self):
        """Print test execution summary"""
        print("\n" + "="*60)
        print("TEST SUMMARY")
        print("="*60)
        
        if not self.test_results:
            print("No tests were run.")
            return
        
        total_tests = len(self.test_results)
        passed_tests = sum(1 for result in self.test_results if result.passed)
        failed_tests = total_tests - passed_tests
        
        print(f"Total Tests: {total_tests}")
        print(f"Passed: {passed_tests}")
        print(f"Failed: {failed_tests}")
        print(f"Success Rate: {(passed_tests/total_tests)*100:.1f}%")
        print()
        
        for result in self.test_results:
            status = "✅ PASS" if result.passed else "❌ FAIL"
            print(f"{status:8} | {result.execution_time_seconds:6.2f}s | {result.test_name:40} | {result.details}")
        
        if failed_tests > 0:
            print(f"\n⚠️  {failed_tests} test(s) failed. Review details above.")
        else:
            print("\n🎉 All tests passed!")

def create_test_suite(spark_session: SparkSession, lakehouse_name) -> TestRunner:
    data_provider = DataProvider(spark, lakehouse_name)
    test_runner = TestRunner(data_provider)
    return test_runner

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 6, Finished, Available, Finished)

### **Modifiable**: Individual Test Strategy Definitions

In [9]:
def require_columns(df: DataFrame, cols: list[str | Column], where: str):
    """
    Verify that all required columns exist in df.
    
    Args:
        df: Spark DataFrame
        cols: list of column names or Column objects
        where: string describing where the check is happening (for error context)
    
    Raises:
        ValueError if any required columns are missing
    """
    # Normalize str|Column into names
    normalized = []
    for c in cols:
        if isinstance(c, str):
            normalized.append(c)
        elif isinstance(c, Column):
            # Spark Column often has an internal ._jc.toString()
            # safer to let the caller alias() before passing
            name = c._jc.toString() if hasattr(c, "_jc") else str(c)
            normalized.append(name)
        else:
            raise TypeError(f"Unsupported column type {type(c)} in require_columns")

    missing = [c for c in normalized if c not in df.columns]
    if missing:
        raise ValueError(
            f"{where}: missing columns {missing}. "
            f"Available: {df.columns}"
        )

class EffectiveDateValidityTest(TestStrategy):
    """Test that a given start date is less than or equal to a given end date
    Used when intending to explode a sequence of date intervals"""

    def __init__(self, id_col: str, start_col: str, end_col: str):
        self.id_col = id_col
        self.start_col = start_col
        self.end_col = end_col
    
    @property
    def test_name(self) -> str:
        return f"Date Validity ({self.start_col} <= {self.end_col})"
    
    def execute(self, df: DataFrame, table_name: str, **kwargs) -> TestResult:
        """Execute an effective date validity test

        kwargs:
            table_name (str): The name of the table in the lakehouse
            sampe_mode (bool): Whether or not to use random sampling to increase performance
        """ 
        sample_mode = kwargs.get('sample_mode', False)
        require_columns(df, [self.id_col, self.start_col, self.end_col], where = self.test_name)

        builder = TestResultBuilder.for_test(self.test_name, table_name, sample_mode)

        try:
            total_count = df.count()
            invalid_df = df.filter(F.col(self.start_col) > F.col(self.end_col))
            error_count = invalid_df.count()
            
            details = f'Found {error_count} records with start_date > end_date out of {total_count} total records'
            
            if error_count > 0:
                print("Sample of invalid records:")
                invalid_df.select(self.id_col, self.start_col, self.end_col).show(5)
            
            return builder.auto_pass_fail(total_count, error_count, details).end_timing().build()
            
        except Exception as e:
            return builder.exception(e, "Failed to validate effective dates").end_timing().build()

class UniquenessTest(TestStrategy):
    """Test whether each combination of id_cols is unique"""
    def __init__(self, id_cols: list[str]):
        self.id_cols = id_cols
    
    @property
    def test_name(self) -> str:
        return f"Unique Columns, Subset: {self.id_cols}"
    
    def execute(self, df: DataFrame, table_name: str, **kwargs):
        sample_mode = kwargs.get('sample_mode', False)
        require_columns(df, self.id_cols, where = self.test_name)

        builder = TestResultBuilder.for_test(self.test_name, table_name, sample_mode)

        try:
            total_count = df.count()
            duplicates = df.groupBy(self.id_cols).count().filter(col('count') > 1)
            error_count = duplicates.count()

            details = f'Found {error_count} duplicates of {total_count} total rows'

            if error_count > 0:
                print(f'Samples with duplicate {self.id_cols}')
                duplicates.show(5)
                builder.cache_failure(duplicates)
            
            
            return builder.auto_pass_fail(total_count, error_count, details).end_timing().build()

        except Exception as e:
            return builder.exception(e, f"Failed to test {self.id_cols} uniqueness").end_timing().build()

class DateContinuityTest(TestStrategy):
    """Tests that dates are contiguous with no gaps"""
    def __init__(self, id_col: str | Column, date_col: str):
        self.id_col = id_col
        self.date_col = date_col
    
    @property
    def test_name(self) -> str:
        return f"Date Column Continuity {self.date_col} by {self.id_col}"
    
    def execute(self, df: DataFrame, table_name: str, **kwargs) -> TestResult:
        sample_mode = kwargs.get('sample_mode')
        # ID column isn't required in the df, since we may choose to group by a literal
        require_columns(df, [self.date_col], where = self.test_name)

        builder = TestResultBuilder.for_test(self.test_name, table_name, sample_mode)

        try:
            id_alias = '_id'
            id_column = (F.col(self.id_col) if isinstance(self.id_col, str) else self.id_col).alias(id_alias)
            date_column = F.col(self.date_col) if isinstance(self.date_col, str) else self.date_col

            total_by_id = df.select(id_column).distinct().count()
            
            unit_date_ranges = (df.groupBy(id_column)
                               .agg(
                                   F.min(date_column).alias('min_date'),
                                   F.max(date_column).alias('max_date'),
                                   F.count(date_column).alias('actual_count')
                               ))

            
            unit_expected = unit_date_ranges.withColumn(
                "expected_count",
                F.datediff(F.col('max_date'), F.col('min_date')) + 1
            )
            
            gaps_df = unit_expected.filter(F.col('actual_count') < F.col('expected_count'))
            error_count = gaps_df.count()
            
            details = f'Found {error_count} {id_alias}s with date gaps out of {total_by_id} {self.id_col}s'
            
            if error_count > 0:
                print(f"Sample {self.id_col}s with gaps:")
                
            return builder.auto_build(total_by_id, error_count, details)
        except Exception as e:
            return builder.exception(e, f"Failed to test date continuity for {self.date_col}").end_timing().build()


StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 11, Finished, Available, Finished)

## **Usage Examples**

In [None]:
# Spark will be defined in the fabric environment.
test_runner = create_test_suite(spark, 'TRINITY_BI')
test_prop_ids = test_runner.data_provider.enable_sample_mode(10)

# Define the individual test strategies
effective_date_valid = EffectiveDateValidityTest(
    id_col = 'UNIT_SPACE_UUID',
    start_col = 'EFFECTIVE_START_DATE',
    end_col = 'EFFECTIVE_END_DATE'
)

ued_date_uniquess = UniquenessTest(id_cols = ['UNIT_SPACE_UUID', 'DATE'])
ued_date_continuity = DateContinuityTest(id_col = 'UNIT_SPACE_UUID', date_col = 'DATE')
dim_date_continuity = DateContinuityTest(id_col = lit(1), date_col = 'DATE_KEY') #lit(1) used for no col dependency

# Queue the individual test cases
test_runner.add_test(effective_date_valid, 'FACT_UNIT_EVENT')
test_runner.add_test(ued_date_uniquess, 'FACT_UNIT_EVENT_DAILY')
test_runner.add_test(ued_date_continuity, 'FACT_UNIT_EVENT_DAILY')
test_runner.add_test(dim_date_continuity, 'DIM_DATE', disable_sampling = True)

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 12, Finished, Available, Finished)

Enabling sample mode with 10 properties...
Sample mode enabled with 10 properties
Added test: Date Validity (EFFECTIVE_START_DATE <= EFFECTIVE_END_DATE) for FACT_UNIT_EVENT to queue
Added test: Unique Columns, Subset: ['UNIT_SPACE_UUID', 'DATE'] for FACT_UNIT_EVENT_DAILY to queue
Added test: Date Column Continuity DATE by UNIT_SPACE_UUID for FACT_UNIT_EVENT_DAILY to queue
Added test: Date Column Continuity DATE_KEY by Column<'1'> for DIM_DATE to queue


In [11]:
# NOTE: All tests are only queued, so they can be ran in other notebooks
# to run all queued tests 
# results = test_runner.run_all_tests()
#resultsDf = spark.createDataFrame([r.to_dict() for r in results])

StatementMeta(, 099b5ced-d358-4fcb-b3a5-5180b5324e70, 13, Finished, Available, Finished)

Test queue cleared
Added test: Date Column Continuity DATE_KEY by Column<'1'> for DIM_DATE to queue
STARTING LAKEHOUSE DATA TESTING SUITE
Mode: SAMPLE
Sample Properties: 10
Queued Tests: 1

🔄 Running Date Column Continuity DATE_KEY by Column<'1'> on DIM_DATE...
⚠️  Sampling disabled for DIM_DATE: no 'PROPERTY_UUID' column.
✅ PASSED - Found 0 _ids with date gaps out of 1 Column<'1'>s

TEST SUMMARY
Total Tests: 1
Passed: 1
Failed: 0
Success Rate: 100.0%

✅ PASS   |   3.63s | Date Column Continuity DATE_KEY by Column<'1'> | Found 0 _ids with date gaps out of 1 Column<'1'>s

🎉 All tests passed!
