In [None]:
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, current_date

# Initialize SparkSession
def get_spark_session():
    return SparkSession.builder \
        .appName("DataQualityChecks") \
        .getOrCreate()

# Function to validate schema
def validate_schema(df, expected_schema):
    actual_schema = df.dtypes
    for column, expected_type in expected_schema:
        assert (column, expected_type) in actual_schema, f"Column {column} with type {expected_type} is missing or mismatched."
    print("Schema validation passed.")

# Function to check null values in critical columns
def check_nulls(df, columns):
    for col_name in columns:
        null_count = df.filter(df[col_name].isNull()).count()
        assert null_count == 0, f"Null values found in column: {col_name}"
    print("Null check passed for all specified columns.")

# Function to validate data types
def validate_data_types(df, column_data_types):
    for column, expected_type in column_data_types.items():
        actual_type = dict(df.dtypes)[column]
        assert actual_type == expected_type, f"Column {column} has type {actual_type}, expected {expected_type}."
    print("Data type validation passed.")

# Function to check for duplicates in critical columns
def check_duplicates(df, column):
    duplicate_count = df.groupBy(column).count().filter("count > 1").count()
    assert duplicate_count == 0, f"Duplicates found in column: {column}"
    print(f"No duplicates found in column: {column}")

# Function to validate start date (no future dates)
def validate_start_date(df):
    future_count = df.filter(to_date(df["Entity Start Date"], "yyyy-MM-dd") > current_date()).count()
    assert future_count == 0, "Some companies have a start date in the future."
    print("Start date validation passed.")

# Test function for schema validation
def test_schema_validation(spark):
    # Load test data
    abn_df = spark.read.option("header", "true").csv("/path/to/abn_data.csv")
    expected_schema = [("ABN", "string"), ("Entity Name", "string"), ("Entity Type", "string")]
    validate_schema(abn_df, expected_schema)

# Test function for null check
def test_null_check(spark):
    abn_df = spark.read.option("header", "true").csv("/path/to/abn_data.csv")
    critical_columns = ["ABN", "Entity Name", "Entity Type", "Entity Start Date"]
    check_nulls(abn_df, critical_columns)

# Test function for checking duplicates
def test_duplicates(spark):
    abn_df = spark.read.option("header", "true").csv("/path/to/abn_data.csv")
    check_duplicates(abn_df, "ABN")

# Test function for start date validation
def test_start_date(spark):
    abn_df = spark.read.option("header", "true").csv("/path/to/abn_data.csv")
    validate_start_date(abn_df)


# Main pipeline execution function
def run_pipeline(spark):
    # Load the datasets (mock paths, change as needed)
    abn_df = spark.read.option("header", "true").csv("/path/to/abn_data.csv")
    common_crawl_df = spark.read.option("header", "true").csv("/path/to/common_crawl_data.csv")
    
    # Validate schemas and data quality
    expected_schema = [("ABN", "string"), ("Entity Name", "string"), ("Entity Type", "string"), ("Entity Start Date", "string")]
    validate_schema(abn_df, expected_schema)
    check_nulls(abn_df, ["ABN", "Entity Name", "Entity Type", "Entity Start Date"])
    check_duplicates(abn_df, "ABN")
    validate_start_date(abn_df)
    
# Running the pipeline (for testing)
if __name__ == "__main__":
    spark = get_spark_session()
    run_pipeline(spark)
