In [52]:
# Standard Library Imports
import os
import glob
import logging
import time
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal
from functools import reduce

# PySpark Core Imports
from pyspark.sql import DataFrame, functions as F
from pyspark.sql.column import Column
from pyspark.sql.types import (
    IntegerType,
    StringType,
    StructField,
    StructType,
    DataType,
)
from pyspark.sql.window import Window


# PySpark MLlib Imports
from pyspark.ml import Pipeline
from pyspark.ml.feature import Imputer, OneHotEncoder, StringIndexer

# Additional Third-Party Imports
import pandas as pd
import numpy as np

# Constants
LOGGING_LEVEL = logging.INFO

# Logging Configuration
logging.basicConfig(level=LOGGING_LEVEL)
logger = logging.getLogger(__name__)

# Commonly Used Functions and Aliases
col = F.col
lit = F.lit
count = F.count
mean = F.mean
stddev = F.stddev
min = F.min
max = F.max
when = F.when
regexp_replace = F.regexp_replace
trim = F.trim
lower = F.lower
row_number = F.row_number
broadcast = F.broadcast

# Custom Type Aliases
NumericCol = Union[int, float]
StringCol = str
ColumnList = List[Column]

In [15]:
def change_column_types(
    df: DataFrame, type_mapping: Dict[str, Union[str, DataType]]
) -> DataFrame:
    """
    Change the data types of specified columns in a PySpark DataFrame.

    This function efficiently modifies the data types of multiple columns in a single pass,
    utilizing PySpark's internal functions for optimal performance on large datasets.

    Args:
        df (DataFrame): The input PySpark DataFrame.
        type_mapping (Dict[str, Union[str, DataType]]): A dictionary mapping column names
            to their desired data types. The data types can be specified as strings
            (e.g., 'int', 'float', 'timestamp') or PySpark DataType objects.

    Returns:
        DataFrame: A new DataFrame with updated column data types.

    Raises:
        ValueError: If an invalid column name is provided.

    Example:
        >>> data = [("1", "2.0", "true", "2023-01-01")]
        >>> df = spark.createDataFrame(data, ["A", "B", "C", "D"])
        >>> type_mapping = {"A": "int", "B": "float", "C": "boolean", "D": "timestamp"}
        >>> result_df = change_column_types(df, type_mapping)
        >>> result_df.printSchema()
        root
         |-- A: integer (nullable = true)
         |-- B: float (nullable = true)
         |-- C: boolean (nullable = true)
         |-- D: timestamp (nullable = true)
    """
    if not isinstance(df, DataFrame):
        raise ValueError("Input must be a PySpark DataFrame")

    if not isinstance(type_mapping, dict):
        raise ValueError("Type mapping must be a dictionary")

    # Validate column names
    invalid_columns = set(type_mapping.keys()) - set(df.columns)
    if invalid_columns:
        raise ValueError(
            f"Columns not found in the DataFrame: {', '.join(invalid_columns)}"
        )

    # Apply type changes in a single pass
    return df.select(
        *[
            col(c).cast(type_mapping.get(c, df.schema[c].dataType)).alias(c)
            for c in df.columns
        ]
    )

In [None]:
def rename_columns(df: DataFrame, columns_map: Dict[str, str]) -> DataFrame:
    """
    Rename columns of a PySpark DataFrame based on a provided mapping.

    This function efficiently renames multiple columns in a single operation,
    optimizing performance for large DataFrames.

    Args:
        df (DataFrame): The input PySpark DataFrame.
        columns_map (Dict[str, str]): A dictionary mapping old column names to new column names.

    Returns:
        DataFrame: A new DataFrame with renamed columns.

    Raises:
        ValueError: If the input is not a DataFrame or if the columns_map is not a dictionary.
        KeyError: If any of the old column names in columns_map are not present in the DataFrame.

    Example:
        >>> data = [("John", "Doe", 30)]
        >>> df = spark.createDataFrame(data, ["first_name", "last_name", "age"])
        >>> columns_map = {"first_name": "name", "last_name": "surname"}
        >>> result_df = rename_columns(df, columns_map)
        >>> result_df.show()
        +----+-------+---+
        |name|surname|age|
        +----+-------+---+
        |John|    Doe| 30|
        +----+-------+---+
    """
    # Input validation
    if not isinstance(df, DataFrame):
        raise ValueError("Input must be a PySpark DataFrame")
    if not isinstance(columns_map, dict):
        raise ValueError("columns_map must be a dictionary")

    # Check if all old column names exist in the DataFrame
    missing_columns = set(columns_map.keys()) - set(df.columns)
    if missing_columns:
        raise KeyError(
            f"Columns not found in the DataFrame: {', '.join(missing_columns)}"
        )

    # Prepare the list of new column names
    new_columns = [columns_map.get(c, c) for c in df.columns]

    # Rename columns in a single operation
    return df.toDF(*new_columns)

In [44]:
def clean_column_names(df: DataFrame, max_length: int = 64) -> DataFrame:
    """
    Thoroughly cleans and standardizes column names of a PySpark DataFrame:
    - Converts all characters to uppercase.
    - Trims leading and trailing spaces.
    - Replaces spaces and invalid characters with underscores.
    - Ensures the column name does not start with a digit by prefixing with 'COL_'.
    - Replaces multiple consecutive underscores with a single underscore.
    - Truncates column names longer than max_length characters.
    - Ensures uniqueness by appending a number to duplicate names.
    - Removes any leading or trailing underscores.

    Args:
        df (DataFrame): The input PySpark DataFrame.
        max_length (int): Maximum allowed length for column names. Defaults to 64.

    Returns:
        DataFrame: A new DataFrame with cleaned and standardized column names.

    Raises:
        ValueError: If the input is not a PySpark DataFrame or if max_length is not positive.
        RuntimeError: If cleaning results in empty column names or non-unique names.

    Example:
        >>> data = [("Alice", 34, "NY"), ("Bob", 45, "CA")]
        >>> df = spark.createDataFrame(data, ["Name ", "1age", "St@te!"])
        >>> df_clean = clean_column_names(df)
        >>> df_clean.printSchema()
        root
         |-- NAME: string (nullable = true)
         |-- COL_1AGE: long (nullable = true)
         |-- STATE: string (nullable = true)
    """
    if not isinstance(df, DataFrame):
        raise ValueError("Input must be a PySpark DataFrame")
    if max_length <= 0:
        raise ValueError("max_length must be a positive integer")

    def clean_name(name: str) -> str:
        # Convert to uppercase and trim
        name = name.upper().strip()
        # Replace invalid characters with underscores
        name = re.sub(r"[^\w\s]", "_", name)
        # Replace spaces with underscores
        name = name.replace(" ", "_")
        # Replace multiple consecutive underscores with a single underscore
        name = re.sub(r"_+", "_", name)
        # Ensure the column does not start with a digit
        if name[0].isdigit():
            name = f"COL_{name}"
        # Remove leading and trailing underscores
        name = name.strip("_")
        # Truncate to max_length characters if longer
        name = name[:max_length]
        return name

    # Clean all column names
    new_column_names = [clean_name(col) for col in df.columns]

    # Check for empty column names after cleaning
    if "" in new_column_names:
        raise RuntimeError("Cleaning resulted in empty column name(s)")

    # Ensure uniqueness
    name_count: dict = {}
    unique_names: List[str] = []
    for name in new_column_names:
        if name in name_count:
            name_count[name] += 1
            unique_name = f"{name}_{name_count[name]}"
            # Truncate again if necessary after adding the suffix
            unique_name = unique_name[:max_length]
            unique_names.append(unique_name)
        else:
            name_count[name] = 0
            unique_names.append(name)

    # Check if we still have unique names after truncation
    if len(set(unique_names)) != len(df.columns):
        raise RuntimeError(
            "Unable to generate unique column names within the specified length"
        )

    # Create a list of (old_name, new_name) tuples
    column_mapping: List[Tuple[str, str]] = list(zip(df.columns, unique_names))

    # Rename columns using selectExpr for better performance with many columns
    expr = [f"`{old}` as `{new}`" for old, new in column_mapping]
    return df.selectExpr(*expr)

                                                                                

+-------+---+----+
|   name|age|stte|
+-------+---+----+
|  Alice| 34|  NY|
|    Bob| 45|  CA|
|Charlie| 29|  TX|
+-------+---+----+



In [54]:
def convert_column_names(
    df: DataFrame, case: Literal["upper", "lower", "title"] = "lower"
) -> DataFrame:
    """
    Converts the column names of a PySpark DataFrame to the specified case.

    Args:
        df (DataFrame): The input PySpark DataFrame.
        case (Literal["upper", "lower", "title"]): The case to convert the column names to.
            'upper': Converts all column names to uppercase.
            'lower': Converts all column names to lowercase.
            'title': Converts the first letter of each word in column names to uppercase.
            Defaults to "lower".

    Returns:
        DataFrame: A new DataFrame with converted column names.

    Raises:
        ValueError: If the input is not a PySpark DataFrame or if the case is invalid.

    Example:
        >>> data = [("Alice", 34, "NY"), ("Bob", 45, "CA"), ("Charlie", 29, "TX")]
        >>> df = spark.createDataFrame(data, ["Name", "Age", "State"])
        >>> df_upper = convert_column_names(df, 'upper')
        >>> df_upper.show()
        +-------+---+-----+
        |   NAME|AGE|STATE|
        +-------+---+-----+
        |  Alice| 34|   NY|
        |    Bob| 45|   CA|
        |Charlie| 29|   TX|
        +-------+---+-----+
        >>> df_title = convert_column_names(df, 'title')
        >>> df_title.show()
        +-------+---+-----+
        |   Name|Age|State|
        +-------+---+-----+
        |  Alice| 34|   NY|
        |    Bob| 45|   CA|
        |Charlie| 29|   TX|
        +-------+---+-----+
    """
    if not isinstance(df, DataFrame):
        raise ValueError("Input must be a PySpark DataFrame")

    if case not in ["upper", "lower", "title"]:
        raise ValueError(
            "The 'case' parameter must be one of 'upper', 'lower', or 'title'."
        )

    case_functions = {"upper": str.upper, "lower": str.lower, "title": str.title}

In [None]:
def trim_string_columns(df: DataFrame) -> DataFrame:
    """
    Trims leading and trailing whitespace from all string columns in a PySpark DataFrame.

    This function identifies all columns of StringType in the input DataFrame and applies
    the trim operation to remove leading and trailing whitespace. Non-string columns are
    left unchanged.

    Args:
        df (DataFrame): Input PySpark DataFrame.

    Returns:
        DataFrame: A new DataFrame with all string columns trimmed.

    Raises:
        ValueError: If the input is not a PySpark DataFrame.

    Example:
        >>> data = [(" John ", "  Doe  ", 30), ("Alice  ", " Smith ", 25)]
        >>> df = spark.createDataFrame(data, ["first_name", "last_name", "age"])
        >>> trimmed_df = trim_string_columns(df)
        >>> trimmed_df.show()
        +----------+---------+---+
        |first_name|last_name|age|
        +----------+---------+---+
        |      John|      Doe| 30|
        |     Alice|    Smith| 25|
        +----------+---------+---+
    """
    if not isinstance(df, DataFrame):
        raise ValueError("Input must be a PySpark DataFrame")

    # Create a list of column expressions
    trim_exprs = [
        (
            F.trim(F.col(c)).alias(c)
            if isinstance(df.schema[c].dataType, StringType)
            else F.col(c)
        )
        for c in df.columns
    ]

    # Apply the trim expressions to the DataFrame
    return df.select(*trim_exprs)

In [28]:
def add_mapped_column(
    df: DataFrame,
    target_column: str,
    mapping_list: List[Tuple[Column, Union[str, Callable[[DataFrame], Column]]]],
    default_value: str = "Unknown",
) -> DataFrame:
    """
    Add a new column to the DataFrame based on mappings defined in a list of tuples.

    Args:
        df (DataFrame): Input DataFrame.
        target_column (str): Name of the new column to create.
        mapping_list (List[Tuple[Column, Union[str, Callable[[DataFrame], Column]]]]): List with conditions as elements and mapped values or functions as values.
        default_value (str): Default value for rows not matching any condition.

    Returns:
        DataFrame: DataFrame with the new mapped column.

    Example:
        >>> from pyspark.sql import SparkSession, functions as F
        >>> spark = SparkSession.builder.appName("example").getOrCreate()
        >>> data = [
        >>>     (15, "John"),
        >>>     (25, "Jane"),
        >>>     (35, "Doe"),
        >>>     (55, "Alice"),
        >>>     (75, "Bob"),
        >>>     (None, "Eve"),
        >>>     (30, None),
        >>> ]
        >>> columns = ["Age", "Name"]
        >>> df = spark.createDataFrame(data, columns)
        >>> age_group_mapping = [
        >>>     (F.col("Age") < 18, "Minor"),
        >>>     (F.col("Age").isNull(), "No Age"),
        >>>     ((F.col("Age") >= 30) & (F.col("Age") < 50), "Adult"),
        >>>     (F.col("Name").startswith("A"), "Name with A"),
        >>>     (F.col("Name").rlike("^[Jj]"), "Starts with J or j"),
        >>>     (F.col("Name").contains("a"), "Contains 'a'"),
        >>>     ((F.col("Age") == 25) | (F.col("Age") == 35), "Specific Age 25 or 35"),
        >>>     (~F.col("Age").between(18, 65), "Not between 18 and 65"),
        >>>     (F.col("Name").isNotNull(), "Has Name"),
        >>>     (F.col("Age").isin(25, 35, 55), "Age is 25, 35, or 55"),
        >>>     (F.col("Name").like("%o%"), "Name contains 'o'"),
        >>>     (F.col("Age") > 10, lambda df: F.concat_ws(" - ", df["Age"], df["Name"])),
        >>> ]
        >>> df_with_age_group = add_mapped_column(df, "Group", age_group_mapping)
        >>> df_with_age_group.show()
    """
    # Start with the default value
    when_expr: Column = F.lit(default_value)

    # Iterate over the mapping list and build the condition expression
    for condition, value in mapping_list:
        if callable(value):
            when_expr = F.when(condition, value(df)).otherwise(when_expr)
        else:
            when_expr = F.when(condition, value).otherwise(when_expr)

    # Add the new column to the DataFrame
    df_with_new_column = df.withColumn(target_column, when_expr)

    return df_with_new_column

+----+-----+--------------------+
| Age| Name|               Group|
+----+-----+--------------------+
|  15| John|   Name contains 'o'|
|  25| Jane|Age is 25, 35, or 55|
|  35|  Doe|            35 - Doe|
|  55|Alice|          55 - Alice|
|  75|  Bob|            75 - Bob|
|null|  Eve|            Has Name|
|  30| null|               Adult|
+----+-----+--------------------+



In [39]:
def assert_unique_combination(df: DataFrame, columns: List[str]) -> None:
    """
    Assert that specified columns or their combination have unique values in the DataFrame.

    This function checks for uniqueness in the specified columns. If a single column
    is provided, it checks that column for uniqueness. If multiple columns are provided,
    it checks the combination of these columns for uniqueness.

    Args:
        df (DataFrame): Input DataFrame to check.
        columns (List[str]): List of column names to check for uniqueness.

    Raises:
        ValueError: If duplicate values are found in the specified column(s).

    Example:
        >>> df = spark.createDataFrame([(1, 'A'), (2, 'B'), (1, 'C')], ['id', 'value'])
        >>> assert_unique_combination(df, ['id'])  # Raises ValueError
        >>> assert_unique_combination(df, ['id', 'value'])  # No error
    """
    if not columns:
        raise ValueError("At least one column must be specified.")

    if len(columns) == 1:
        check_column = columns[0]
    else:
        check_column = "unique_combo_" + "_".join(columns)
        concat_expr = F.concat_ws("||", *columns)
        df = df.withColumn(check_column, concat_expr)

    duplicates = df.groupBy(check_column).count().filter(F.col("count") > 1)

    if duplicates.count() > 0:
        dup_values = duplicates.select(check_column).limit(3).collect()
        dup_list = [row[check_column] for row in dup_values]

        if len(columns) == 1:
            error_message = f"🚨 Column {columns[0]} has duplicate values."
        else:
            error_message = f"🚨 Combination of {', '.join(columns)} is not unique."

        error_message += f"\n💡 Examples: {', '.join(map(str, dup_list))}"
        raise ValueError(error_message)

    if len(columns) > 1:
        df = df.drop(check_column)

    print(f"✅ Uniqueness check passed for: {', '.join(columns)}")

                                                                                

🚨 Column id has duplicate values.
💡 Examples: 1
✅ Uniqueness check passed for: id, value
✅ Uniqueness check passed for: id, value, category


In [None]:
def replace_missing_values(
    df: DataFrame,
    columns: Optional[Union[str, List[str]]] = None,
    missing_patterns: Optional[Dict[str, List[str]]] = None,
) -> DataFrame:
    """
    Replace values indicating missing data with None in specified columns of a PySpark DataFrame.

    This function identifies and replaces various forms of missing values, including:
    - Cells containing only special characters
    - Empty strings
    - Variations of the word "null" (including when surrounded by other characters)
    - Custom patterns provided by the user

    Args:
        df (DataFrame): Input PySpark DataFrame.
        columns (Optional[Union[str, List[str]]]): Column(s) to process. If None, all columns are processed.
        missing_patterns (Optional[Dict[str, List[str]]]): Custom patterns to identify missing values.
            Keys are column names, values are lists of regex patterns.

    Returns:
        DataFrame: DataFrame with missing values replaced by None.

    Raises:
        ValueError: If the input DataFrame is empty or if specified columns don't exist.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("MissingValueExample").getOrCreate()
        >>> data = [("John", " "), ("Jane", "(null)"), ("Bob", "@#$"), ("Alice", "N/A")]
        >>> df = spark.createDataFrame(data, ["name", "value"])
        >>> custom_patterns = {"value": [r"N/A"]}
        >>> result_df = replace_missing_values(df, "value", custom_patterns)
        >>> result_df.show()
        +-----+-----+
        | name|value|
        +-----+-----+
        | John| NULL|
        | Jane| NULL|
        |  Bob| NULL|
        |Alice| NULL|
        +-----+-----+
    """
    if df.rdd.isEmpty():
        raise ValueError("Input DataFrame is empty.")

    if columns is None:
        columns = df.columns
    elif isinstance(columns, str):
        columns = [columns]

    non_existent_cols = set(columns) - set(df.columns)
    if non_existent_cols:
        raise ValueError(f"Columns not found in DataFrame: {non_existent_cols}")

    missing_patterns = missing_patterns or {}

    logger.info(f"Processing {len(columns)} columns")

    def get_missing_value_expression(column: str) -> Column:
        base_patterns = [
            r"^[\s\W]*$",  # Only special characters or whitespace
            r"^\s*$",  # Empty string
            r".*null.*",  # "null" surrounded by any characters
        ]
        custom_patterns = missing_patterns.get(column, [])
        all_patterns = base_patterns + custom_patterns

        return when(
            trim(lower(col(column))).rlike(
                "|".join(f"({pattern})" for pattern in all_patterns)
            )
            | col(column).isNull(),
            lit(None),
        ).otherwise(col(column))

    # Preserve original column types
    original_types = {
        field.name: field.dataType
        for field in df.schema.fields
        if field.name in columns
    }

    # Process all columns in a single select operation
    expressions = [
        get_missing_value_expression(c).alias(c) if c in columns else col(c)
        for c in df.columns
    ]

    result_df = df.select(*expressions)

    # Restore original column types
    for column, data_type in original_types.items():
        if not isinstance(data_type, StringType):
            result_df = result_df.withColumn(column, col(column).cast(data_type))

    logger.info("Missing value replacement completed")
    return result_df

In [None]:
def get_csv_data(
    base_path: str,
    folder_name: Optional[str] = None,
    file_pattern: Optional[str] = None,
    encoding: str = "UTF-8",
    sep: str = ";",
    header: bool = True,
    include_subdirectories: bool = True,
    cache_result: bool = False,
) -> DataFrame:
    """
    Retrieve and combine CSV data from a specified path and optionally from a specific folder.

    Args:
        base_path (str): Base path where the CSV data is located.
        folder_name (Optional[str]): Specific folder name within the base path. If None, base_path is used.
        file_pattern (Optional[str]): String pattern to match in file names. If None, all CSV files are considered.
        encoding (str): Encoding of the CSV files. Defaults to "UTF-8".
        sep (str): Separator used in the CSV files. Defaults to ";".
        header (bool): Whether the CSV files have a header row. Defaults to True.
        include_subdirectories (bool): Whether to include subdirectories in the search. Defaults to True.
        cache_result (bool): Whether to cache the resulting DataFrame. Defaults to False.

    Returns:
        DataFrame: Combined DataFrame of all relevant CSV files.

    Raises:
        FileNotFoundError: If the specified path or folder does not exist.
        ValueError: If no CSV files are found in the specified location.
    """
    start_time = time.time()

    # Construct the full path
    full_path = f"{base_path}/{folder_name}" if folder_name else base_path

    # List all CSV files
    if include_subdirectories:
        files = dbutils.fs.ls(full_path)
        csv_files = [
            f.path
            for f in files
            if f.path.endswith(".csv")
            and (file_pattern is None or file_pattern in f.name)
        ]
    else:
        files = dbutils.fs.ls(full_path)
        csv_files = [
            f.path
            for f in files
            if f.path.endswith(".csv")
            and (file_pattern is None or file_pattern in f.name)
        ]

    if not csv_files:
        raise ValueError(
            f"🔍 No CSV files found matching the pattern: {file_pattern} in {full_path} 🔍"
        )

    print(f"📊 Found {len(csv_files)} CSV files matching the criteria")

    # Infer schema from all files
    inferred_schema = infer_schema(csv_files[0], sep, header)

    # Read and combine CSV files
    def read_csv(file_path: str) -> DataFrame:
        return spark.read.csv(
            file_path,
            schema=inferred_schema,
            header=header,
            sep=sep,
            encoding=encoding,
            ignoreLeadingWhiteSpace=True,
            ignoreTrailingWhiteSpace=True,
        ).withColumn("source_file", lit(file_path.split("/")[-1]))

    dfs = spark.sparkContext.parallelize(csv_files).map(read_csv).collect()

    if not dfs:
        raise ValueError(f"📭 No valid CSV data could be read from {full_path}")

    # Combine all DataFrames using unionByName with checkpointing
    result_df = reduce(
        lambda df1, df2: df1.unionByName(df2, allowMissingColumns=True), dfs
    )
    result_df = result_df.checkpoint()

    if cache_result:
        result_df = result_df.cache()

    # Calculate and log statistics
    total_rows = result_df.count()
    column_stats = calculate_column_stats(result_df)

    end_time = time.time()
    duration = end_time - start_time

    print(
        f"🎉 Successfully combined {len(dfs)} CSV files into a DataFrame with {total_rows} rows in {duration:.2f} seconds! 🎉"
    )
    print("Column Statistics:")
    for col, stats in column_stats.items():
        print(f"  {col}: {stats}")

    return result_df


def infer_schema(file_path: str, sep: str, header: bool) -> StructType:
    """Infer the schema from the first file"""
    return spark.read.csv(file_path, sep=sep, header=header, inferSchema=True).schema


def calculate_column_stats(df: DataFrame) -> dict:
    """Calculate basic statistics for each column"""
    return {
        col: df.agg(
            count(col).alias("count"),
            mean(col).alias("mean"),
            stddev(col).alias("stddev"),
            min(col).alias("min"),
            max(col).alias("max"),
        )
        .collect()[0]
        .asDict()
        for col in df.columns
        if df.schema[col].dataType.typeName() in ["double", "float", "int", "long"]
    }

In [None]:
def advanced_join(
    left_df: DataFrame,
    right_df: DataFrame,
    on: Union[str, List[str]],
    how: str = "inner",
    suffix_left: str = "_left",
    suffix_right: str = "_right",
) -> DataFrame:
    """
    Join two DataFrames with advanced features like automatic column renaming for duplicates.

    This function performs a join operation on two DataFrames and automatically renames
    columns with the same name in both DataFrames to avoid confusion.

    Args:
        left_df (DataFrame): The left DataFrame for the join operation.
        right_df (DataFrame): The right DataFrame for the join operation.
        on (Union[str, List[str]]): Column(s) to join on. Can be a string for single column
                                    or a list of strings for multiple columns.
        how (str, optional): Type of join to perform. Defaults to "inner".
                             Options: "inner", "outer", "left", "right", "leftsemi", "leftanti".
        suffix_left (str, optional): Suffix to add to duplicate column names from left DataFrame.
                                     Defaults to "_left".
        suffix_right (str, optional): Suffix to add to duplicate column names from right DataFrame.
                                      Defaults to "_right".

    Returns:
        DataFrame: A new DataFrame resulting from the join operation with renamed columns.

    Raises:
        ValueError: If an invalid join type is specified.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("AdvancedJoinExample").getOrCreate()
        >>>
        >>> # Create sample DataFrames
        >>> df1 = spark.createDataFrame([(1, "A", 100), (2, "B", 200), (3, "C", 300)], ["id", "name", "value"])
        >>> df2 = spark.createDataFrame([(1, "X", 1000), (2, "Y", 2000), (4, "Z", 4000)], ["id", "category", "value"])
        >>>
        >>> # Perform advanced join
        >>> result = advanced_join(df1, df2, on="id", how="left")
        >>> result.show()
        +---+----+----------+--------+-----------+
        | id|name|value_left|category|value_right|
        +---+----+----------+--------+-----------+
        |  1|   A|       100|       X|       1000|
        |  2|   B|       200|       Y|       2000|
        |  3|   C|       300|    null|       null|
        +---+----+----------+--------+-----------+
    """
    # Input validation
    if how not in ["inner", "outer", "left", "right", "leftsemi", "leftanti"]:
        raise ValueError(
            f"Invalid join type: {how}. Supported types are: inner, outer, left, right, leftsemi, leftanti"
        )

    # Get the list of common column names (excluding join columns)
    if isinstance(on, str):
        on = [on]
    left_columns = set(left_df.columns) - set(on)
    right_columns = set(right_df.columns) - set(on)
    common_columns = list(left_columns.intersection(right_columns))

    # Rename common columns in both DataFrames
    left_renamed = left_df.select(
        *[
            col(c).alias(f"{c}{suffix_left}") if c in common_columns else col(c)
            for c in left_df.columns
        ]
    )
    right_renamed = right_df.select(
        *[
            col(c).alias(f"{c}{suffix_right}") if c in common_columns else col(c)
            for c in right_df.columns
        ]
    )

    # Perform the join operation
    joined_df = left_renamed.join(right_renamed, on=on, how=how)

    return joined_df

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import to_timestamp
from typing import Union, List


def convert_to_timestamp(
    df: DataFrame,
    columns: Union[str, List[str]],
    format: str = "yyyy-MM-dd'T'HH:mm:ss.SSSZ",
) -> DataFrame:
    """
    Convert string column(s) to timestamp format in a PySpark DataFrame.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        columns (Union[str, List[str]]): Column name(s) to convert.
        format (str, optional): Timestamp format string. Defaults to "yyyy-MM-dd'T'HH:mm:ss.SSSZ".

    Returns:
        DataFrame: DataFrame with converted timestamp column(s).

    Raises:
        ValueError: If the input column(s) are not present in the DataFrame.

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [("2023-07-27T13:18:12.039+0000",), ("2023-07-28T13:20:30.039+0000",)]
        >>> df = spark.createDataFrame(data, ["ZEITPUNKT"])
        >>> result_df = convert_to_timestamp(df, "ZEITPUNKT")
        >>> result_df.printSchema()
        root
         |-- ZEITPUNKT: timestamp (nullable = true)

        >>> result_df.show(truncate=False)
        +----------------------------+
        |ZEITPUNKT                   |
        +----------------------------+
        |2023-07-27 13:18:12.039     |
        |2023-07-28 13:20:30.039     |
        +----------------------------+
    """
    if isinstance(columns, str):
        columns = [columns]

    # Validate input columns
    invalid_columns = set(columns) - set(df.columns)
    if invalid_columns:
        raise ValueError(
            f"The following columns are not present in the DataFrame: {invalid_columns}"
        )

    # Convert string columns to timestamp
    for col in columns:
        df = df.withColumn(col, to_timestamp(df[col], format))

    return df

In [None]:
def get_most_current_data(
    df: DataFrame,
    partition_columns: Union[str, List[str]],
    timestamp_column: str,
    ascending: bool = False,
) -> DataFrame:
    """
    Returns the most current dataset from a DataFrame containing historical data.

    This function partitions the data based on specified column(s) and selects the
    most recent record for each partition using a timestamp column.

    Args:
        df (DataFrame): Input DataFrame containing historical data.
        partition_columns (Union[str, List[str]]): Column(s) to partition the data.
            Can be a string for a single column or a list of strings for multiple columns.
        timestamp_column (str): Name of the column containing timestamp information.
        ascending (bool, optional): If True, selects the earliest record.
            If False (default), selects the latest record.

    Returns:
        DataFrame: A new DataFrame containing only the most current records.

    Raises:
        ValueError: If the specified columns are not present in the DataFrame.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> from pyspark.sql.functions import to_timestamp
        >>>
        >>> spark = SparkSession.builder.appName("MostCurrentDataExample").getOrCreate()
        >>>
        >>> # Create a sample DataFrame with historical data
        >>> data = [
        ...     (1, "A", "2023-01-01 10:00:00", 100),
        ...     (1, "A", "2023-01-02 11:00:00", 150),
        ...     (2, "B", "2023-01-01 09:00:00", 200),
        ...     (2, "B", "2023-01-03 12:00:00", 250),
        ...     (3, "C", "2023-01-02 08:00:00", 300),
        ... ]
        >>> df = spark.createDataFrame(data, ["id", "category", "timestamp", "value"])
        >>> df = df.withColumn("timestamp", to_timestamp(col("timestamp")))
        >>>
        >>> # Get the most current data
        >>> result = get_most_current_data(df, ["id", "category"], "timestamp")
        >>> result.show()
        +---+--------+-------------------+-----+
        | id|category|          timestamp|value|
        +---+--------+-------------------+-----+
        |  1|       A|2023-01-02 11:00:00|  150|
        |  2|       B|2023-01-03 12:00:00|  250|
        |  3|       C|2023-01-02 08:00:00|  300|
        +---+--------+-------------------+-----+
    """
    # Input validation
    all_columns = df.columns
    if isinstance(partition_columns, str):
        partition_columns = [partition_columns]

    invalid_columns = set(partition_columns + [timestamp_column]) - set(all_columns)
    if invalid_columns:
        raise ValueError(
            f"The following columns are not present in the DataFrame: {invalid_columns}"
        )

    # Create a window specification
    window_spec = Window.partitionBy(partition_columns).orderBy(
        col(timestamp_column).desc()
    )

    if ascending:
        window_spec = Window.partitionBy(partition_columns).orderBy(
            col(timestamp_column).asc()
        )

    # Add row number column
    df_with_row_number = df.withColumn("row_number", row_number().over(window_spec))

    # Filter to keep only the most current record for each partition
    most_current_df = df_with_row_number.filter(col("row_number") == 1).drop(
        "row_number"
    )

    return most_current_df

In [41]:
def apply_groupby_aggregations(
    df: DataFrame,
    group_cols: List[str],
    agg_expressions: Dict[str, Union[str, F.Column]],
) -> DataFrame:
    """
    Apply groupBy and aggregations on a PySpark DataFrame.

    This function groups the input DataFrame by the specified columns and applies
    the given aggregation expressions. It supports both string-based and Column-based
    aggregation expressions.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        group_cols (List[str]): List of column names to group by.
        agg_expressions (Dict[str, Union[str, F.Column]]): Dictionary of aggregation expressions.
            Keys are the resulting column names, and values are either string expressions
            or PySpark Column objects representing the aggregation.

    Returns:
        DataFrame: A new DataFrame with the grouping and aggregations applied.

    Raises:
        ValueError: If the input is not a PySpark DataFrame, if group_cols is empty,
                    or if agg_expressions is empty.

    Example:
        >>> data = [("A", 1, 100), ("B", 2, 200), ("A", 3, 300), ("B", 4, 400)]
        >>> df = spark.createDataFrame(data, ["group", "value1", "value2"])
        >>> group_cols = ["group"]
        >>> agg_expressions = {
        ...     "sum_value1": "sum(value1)",
        ...     "avg_value2": F.avg("value2"),
        ...     "count": F.count("*")
        ... }
        >>> result_df = apply_groupby_aggregations(df, group_cols, agg_expressions)
        >>> result_df.show()
        +-----+---------+---------+-----+
        |group|sum_value1|avg_value2|count|
        +-----+---------+---------+-----+
        |    B|        6|    300.0|    2|
        |    A|        4|    200.0|    2|
        +-----+---------+---------+-----+
    """
    # Input validation
    if not isinstance(df, DataFrame):
        raise ValueError("Input 'df' must be a PySpark DataFrame")
    if not group_cols:
        raise ValueError("'group_cols' must not be empty")
    if not agg_expressions:
        raise ValueError("'agg_expressions' must not be empty")

    # Ensure all group columns exist in the DataFrame
    missing_cols = set(group_cols) - set(df.columns)
    if missing_cols:
        raise ValueError(
            f"Group columns not found in DataFrame: {', '.join(missing_cols)}"
        )

    # Process aggregation expressions
    agg_cols = []
    for col_name, expr in agg_expressions.items():
        if isinstance(expr, str):
            agg_cols.append(F.expr(expr).alias(col_name))
        elif isinstance(expr, F.Column):
            agg_cols.append(expr.alias(col_name))
        else:
            raise ValueError(
                f"Invalid aggregation expression for '{col_name}'. "
                "Must be a string or Column object."
            )

    # Apply groupBy and aggregations
    return df.groupBy(*group_cols).agg(*agg_cols)

                                                                                

+--------+-----------+---------+-----+----+-------------+
|category|total_value|avg_value|count|rank|running_total|
+--------+-----------+---------+-----+----+-------------+
|       A|        250|    125.0|    2|   1|          250|
|       B|        450|    225.0|    2|   1|          450|
|       C|        650|    325.0|    2|   1|          650|
+--------+-----------+---------+-----+----+-------------+



In [None]:
def add_window_calculations(
    df: DataFrame,
    partition_cols: List[str],
    order_cols: List[str],
    window_expressions: Dict[str, Union[str, F.Column, Callable]],
    ascending: Union[bool, List[bool]] = True,
) -> DataFrame:
    """
    Add window function calculations to a PySpark DataFrame.

    This function applies window calculations to the input DataFrame based on the specified
    partition and order columns. It supports various types of window expressions and allows
    for flexible ordering.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        partition_cols (List[str]): List of column names to partition by.
        order_cols (List[str]): List of column names to order by.
        window_expressions (Dict[str, Union[str, F.Column, Callable]]): Dictionary of window expressions.
            Keys are the resulting column names, and values can be:
            - String: A SQL-like expression (e.g., "rank()")
            - Column: A PySpark Column object (e.g., F.rank())
            - Callable: A function that takes a Window spec and returns a Column
        ascending (Union[bool, List[bool]]): Specifies the ordering direction.
            If a single boolean, it applies to all order columns.
            If a list, it should match the length of order_cols. Defaults to True.

    Returns:
        DataFrame: A new DataFrame with additional window calculations.

    Raises:
        ValueError: If inputs are invalid or if columns are not found in the DataFrame.

    Example:
        >>> data = [("A", 1, 100), ("A", 2, 150), ("B", 1, 200), ("B", 2, 250)]
        >>> df = spark.createDataFrame(data, ["group", "subgroup", "value"])
        >>> window_expressions = {
        ...     "row_number": F.row_number(),
        ...     "rank": "rank()",
        ...     "running_total": F.sum("value").over,
        ...     "lag_value": F.lag("value", 1).over
        ... }
        >>> result_df = add_window_calculations(
        ...     df,
        ...     partition_cols=["group"],
        ...     order_cols=["subgroup", "value"],
        ...     window_expressions=window_expressions,
        ...     ascending=[True, False]
        ... )
        >>> result_df.show()
        +-----+--------+-----+----------+----+-------------+---------+
        |group|subgroup|value|row_number|rank|running_total|lag_value|
        +-----+--------+-----+----------+----+-------------+---------+
        |    A|       2|  150|         1|   1|          150|     null|
        |    A|       1|  100|         2|   2|          250|      150|
        |    B|       2|  250|         1|   1|          250|     null|
        |    B|       1|  200|         2|   2|          450|      250|
        +-----+--------+-----+----------+----+-------------+---------+
    """
    # Input validation
    if not isinstance(df, DataFrame):
        raise ValueError("Input 'df' must be a PySpark DataFrame")
    if not partition_cols:
        raise ValueError("'partition_cols' must not be empty")
    if not order_cols:
        raise ValueError("'order_cols' must not be empty")
    if not window_expressions:
        raise ValueError("'window_expressions' must not be empty")

    # Validate columns exist in DataFrame
    all_cols = set(partition_cols + order_cols)
    missing_cols = all_cols - set(df.columns)
    if missing_cols:
        raise ValueError(f"Columns not found in DataFrame: {', '.join(missing_cols)}")

    # Handle ascending parameter
    if isinstance(ascending, bool):
        ascending = [ascending] * len(order_cols)
    elif isinstance(ascending, list):
        if len(ascending) != len(order_cols):
            raise ValueError("Length of 'ascending' list must match 'order_cols'")
    else:
        raise ValueError("'ascending' must be a boolean or a list of booleans")

    # Create window specification
    window_spec = Window.partitionBy(*partition_cols).orderBy(
        *[col if asc else col.desc() for col, asc in zip(order_cols, ascending)]
    )

    # Apply window expressions
    for column, expr in window_expressions.items():
        if isinstance(expr, str):
            df = df.withColumn(column, F.expr(expr).over(window_spec))
        elif isinstance(expr, F.Column):
            df = df.withColumn(column, expr.over(window_spec))
        elif callable(expr):
            df = df.withColumn(column, expr(window_spec))
        else:
            raise ValueError(
                f"Invalid window expression for '{column}'. "
                "Must be a string, Column object, or callable."
            )

    return df

In [None]:
def one_hot_encode(df: DataFrame, categorical_columns: list) -> DataFrame:
    """
    Perform one-hot encoding on specified categorical columns.

    Args:
    df (DataFrame): Input DataFrame
    categorical_columns (list): List of categorical column names to encode

    Returns:
    DataFrame: DataFrame with one-hot encoded columns
    """
    stages = []
    for cat_col in categorical_columns:
        string_indexer = StringIndexer(inputCol=cat_col, outputCol=f"{cat_col}_index")
        encoder = OneHotEncoder(
            inputCols=[f"{cat_col}_index"], outputCols=[f"{cat_col}_encoded"]
        )
        stages += [string_indexer, encoder]

    pipeline = Pipeline(stages=stages)
    model = pipeline.fit(df)
    encoded_df = model.transform(df)

    return encoded_df

In [None]:
def remove_duplicates(df: DataFrame, subset: list = None) -> DataFrame:
    """
    Remove duplicate rows from the DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    subset (list): List of columns to consider for duplicates (default: None, uses all columns)

    Returns:
    DataFrame: DataFrame with duplicates removed
    """
    if subset:
        return df.dropDuplicates(subset=subset)
    else:
        return df.dropDuplicates()

In [None]:
def apply_sql_query(
    df: DataFrame, query: str, view_name: str = "temp_view"
) -> DataFrame:
    """
    Apply a custom SQL query to the DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    query (str): SQL query to apply
    view_name (str): Name for the temporary view (default: "temp_view")

    Returns:
    DataFrame: Result of the SQL query
    """
    df.createOrReplaceTempView(view_name)
    result = df.sparkSession.sql(query)
    df.sparkSession.catalog.dropTempView(view_name)
    return result

In [None]:

def summarize_numeric_columns(df: DataFrame) -> DataFrame:
    """
    Compute summary statistics for all numeric columns in a DataFrame.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    DataFrame: Summary statistics including count, mean, stddev, min, and max for each numeric column
    """
    numeric_columns = [f.name for f in df.schema.fields if f.dataType.simpleString() in ('double', 'float', 'int', 'long')]
    
    summary = df.select([
        F.count(F.col(c)).alias(f"{c}_count"),
        F.mean(F.col(c)).alias(f"{c}_mean"),
        F.stddev(F.col(c)).alias(f"{c}_stddev"),
        F.min(F.col(c)).alias(f"{c}_min"),
        F.max(F.col(c)).alias(f"{c}_max")
    for c in numeric_columns])
    
    return summary

def identify_outliers(df: DataFrame, column: str, lower_quantile: float = 0.25, upper_quantile: float = 0.75, iqr_multiplier: float = 1.5) -> DataFrame:
    """
    Identify outliers in a specified column using the Interquartile Range (IQR) method.

    Args:
    df (DataFrame): Input DataFrame
    column (str): Name of the column to check for outliers
    lower_quantile (float): Lower quantile for IQR calculation (default: 0.25)
    upper_quantile (float): Upper quantile for IQR calculation (default: 0.75)
    iqr_multiplier (float): Multiplier for IQR to determine outlier boundaries (default: 1.5)

    Returns:
    DataFrame: Original DataFrame with an additional boolean column indicating outliers
    """
    quantiles = df.approxQuantile(column, [lower_quantile, upper_quantile], 0.01)
    q1, q3 = quantiles[0], quantiles[1]
    iqr = q3 - q1
    lower_bound = q1 - iqr_multiplier * iqr
    upper_bound = q3 + iqr_multiplier * iqr
    
    return df.withColumn(f"{column}_is_outlier", 
                         ~F.col(column).between(lower_bound, upper_bound))

def pivot_and_unpivot(df: DataFrame, id_cols: List[str], pivot_col: str, value_col: str) -> Dict[str, DataFrame]:
    """
    Perform both pivot and unpivot operations on a DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    id_cols (List[str]): List of columns to use as identifiers
    pivot_col (str): Column to use for pivoting
    value_col (str): Column containing the values to be pivoted

    Returns:
    Dict[str, DataFrame]: Dictionary containing both pivoted and unpivoted DataFrames
    """
    # Pivot operation
    pivot_df = df.groupBy(id_cols).pivot(pivot_col).agg(F.first(value_col))
    
    # Unpivot operation
    value_columns = [c for c in pivot_df.columns if c not in id_cols]
    unpivot_expr = [F.expr(f"stack({len(value_columns)}, {', '.join([f''''{c}', {c}''' for c in value_columns])}) as ({pivot_col}, {value_col})")]
    unpivot_df = pivot_df.select(*id_cols, *unpivot_expr)
    
    return {"pivoted": pivot_df, "unpivoted": unpivot_df}


def compare_dataframes(df1: DataFrame, df2: DataFrame, join_columns: List[str]) -> Dict[str, Union[DataFrame, int]]:
    """
    Compare two DataFrames and identify differences.

    Args:
    df1 (DataFrame): First DataFrame
    df2 (DataFrame): Second DataFrame
    join_columns (List[str]): Columns to use for joining the DataFrames

    Returns:
    Dict[str, Union[DataFrame, int]]: Dictionary containing DataFrames with differences and counts
    """
    # Ensure both DataFrames have the same columns
    columns = sorted(set(df1.columns + df2.columns))
    df1 = df1.select(columns)
    df2 = df2.select(columns)
    
    # Perform full outer join
    joined = df1.join(df2, join_columns, "full_outer")
    
    # Identify rows present in df1 but not in df2
    in_df1_not_df2 = joined.filter(' AND '.join([f"(df2.`{col}` IS NULL OR df1.`{col}` != df2.`{col}`)" for col in columns if col not in join_columns]))
    
    # Identify rows present in df2 but not in df1
    in_df2_not_df1 = joined.filter(' AND '.join([f"(df1.`{col}` IS NULL OR df1.`{col}` != df2.`{col}`)" for col in columns if col not in join_columns]))
    
    return {
        "in_df1_not_df2": in_df1_not_df1,
        "in_df2_not_df1": in_df2_not_df1,
        "total_differences": in_df1_not_df2.count() + in_df2_not_df1.count()
    }



In [None]:
def add_date_features(df: DataFrame, date_column: str) -> DataFrame:
    """
    Add various date-related features to a DataFrame based on a specified date column.

    Args:
    df (DataFrame): Input DataFrame
    date_column (str): Name of the column containing date information

    Returns:
    DataFrame: DataFrame with additional date-related columns
    """
    return (
        df.withColumn("year", F.year(F.col(date_column)))
        .withColumn("month", F.month(F.col(date_column)))
        .withColumn("day", F.dayofmonth(F.col(date_column)))
        .withColumn("day_of_week", F.dayofweek(F.col(date_column)))
        .withColumn("day_of_year", F.dayofyear(F.col(date_column)))
        .withColumn("week_of_year", F.weekofyear(F.col(date_column)))
        .withColumn("quarter", F.quarter(F.col(date_column)))
    )

In [None]:

def profile_data_quality(df: DataFrame) -> DataFrame:
    """
    Profile the data quality of a DataFrame, including null counts, distinct counts, and data types.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    DataFrame: Data quality profile including column name, data type, null count, distinct count, and sample values
    """
    def sample_values(col_name):
        return F.concat_ws(", ", F.collect_set(F.col(col_name)).over(Window.partitionBy(F.lit(1)).rowsBetween(-1000, 1000)))

    profile = df.select([
        F.lit(c).alias("column_name"),
        F.lit(str(df.schema[c].dataType)).alias("data_type"),
        F.count(F.when(F.col(c).isNull(), c)).alias("null_count"),
        F.count(F.when(F.col(c).isNotNull(), c)).alias("non_null_count"),
        F.countDistinct(c).alias("distinct_count"),
        sample_values(c).alias("sample_values")
    for c in df.columns])

    return profile.select("column_name", "data_type", "null_count", "non_null_count", "distinct_count", 
                          F.expr("substring(sample_values, 1, 100)").alias("sample_values"))


def binning(df: DataFrame, column: str, num_bins: int, strategy: str = "equal_range") -> DataFrame:
    """
    Perform binning on a numeric column.

    Args:
    df (DataFrame): Input DataFrame
    column (str): Column to bin
    num_bins (int): Number of bins
    strategy (str): Binning strategy ('equal_range' or 'equal_frequency')

    Returns:
    DataFrame: DataFrame with binned column added
    """
    if strategy == "equal_range":
        return df.withColumn(f"{column}_binned", F.ntile(num_bins).over(Window.orderBy(column)))
    elif strategy == "equal_frequency":
        quantiles = df.approxQuantile(column, [i/num_bins for i in range(1, num_bins)], 0.01)
        return df.withColumn(f"{column}_binned", F.bucketizer(F.col(column), [-float("inf")] + quantiles + [float("inf")]))
    else:
        raise ValueError("Strategy must be either 'equal_range' or 'equal_frequency'")

def detect_data_drift(df1: DataFrame, df2: DataFrame, columns: List[str]) -> DataFrame:
    """
    Detect data drift between two DataFrames for specified columns.

    Args:
    df1 (DataFrame): First DataFrame (e.g., training data)
    df2 (DataFrame): Second DataFrame (e.g., new data)
    columns (List[str]): List of columns to check for drift

    Returns:
    DataFrame: Data drift statistics including column name, drift metric, and p-value
    """
    from pyspark.sql.functions import kurtosis, skewness, mean, stddev
    from scipy import stats

    drift_results = []

    for column in columns:
        stats1 = df1.select(mean(column).alias("mean"), 
                            stddev(column).alias("std"),
                            kurtosis(column).alias("kurtosis"),
                            skewness(column).alias("skewness")).collect()[0]
        
        stats2 = df2.select(mean(column).alias("mean"), 
                            stddev(column).alias("std"),
                            kurtosis(column).alias("kurtosis"),
                            skewness(column).alias("skewness")).collect()[0]

        # Perform Kolmogorov-Smirnov test
        ks_statistic, p_value = stats.ks_2samp(df1.select(column).rdd.flatMap(lambda x: x).collect(),
                                               df2.select(column).rdd.flatMap(lambda x: x).collect())

        drift_results.append((column, ks_statistic, p_value, 
                              stats1.mean, stats2.mean, 
                              stats1.std, stats2.std,
                              stats1.kurtosis, stats2.kurtosis,
                              stats1.skewness, stats2.skewness))

    drift_df = spark.createDataFrame(drift_results, 
                                     ["column", "ks_statistic", "p_value", 
                                      "mean1", "mean2", "std1", "std2", 
                                      "kurtosis1", "kurtosis2", "skewness1", "skewness2"])
    
    return drift_df

def generate_summary_report(df: DataFrame) -> str:
    """
    Generate a summary report of the DataFrame in Markdown format.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    str: Markdown-formatted summary report
    """
    num_rows = df.count()
    num_columns = len(df.columns)
    
    numeric_columns = [c for c, dtype in df.dtypes if dtype in ('int', 'double', 'float')]
    categorical_columns = [c for c, dtype in df.dtypes if dtype not in ('int', 'double', 'float')]
    
    summary_stats = df.summary()
    
    report = f"# DataFrame Summary Report\n\n"
    report += f"## Basic Information\n"
    report += f"- Number of rows: {num_rows}\n"
    report += f"- Number of columns: {num_columns}\n"
    report += f"- Numeric columns: {', '.join(numeric_columns)}\n"
    report += f"- Categorical columns: {', '.join(categorical_columns)}\n\n"
    
    report += f"## Numeric Column Statistics\n"
    for col in numeric_columns:
        stats = summary_stats.filter(F.col("summary").isin("min", "max", "mean", "stddev")).select(col).collect()
        report += f"### {col}\n"
        report += f"- Min: {stats[0][0]}\n"
        report += f"- Max: {stats[1][0]}\n"
        report += f"- Mean: {stats[2][0]}\n"
        report += f"- StdDev: {stats[3][0]}\n\n"
    
    report += f"## Categorical Column Statistics\n"
    for col in categorical_columns:
        top_values = df.groupBy(col).count().orderBy(F.desc("count")).limit(5)
        report += f"### {col}\n"
        report += f"Top 5 values:\n"
        for row in top_values.collect():
            report += f"- {row[0]}: {row[1]}\n"
        report += "\n"
    
    return report





In [None]:
from typing import List, Dict, Any
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, lit, count, abs, concat_ws
from pyspark.sql.types import StructType, DataType


def identify_schema_differences(df1: DataFrame, df2: DataFrame) -> Dict[str, Any]:
    """
    Identify differences in schema between two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.

    Returns:
        Dict[str, Any]: A dictionary containing schema differences.
    """

    def schema_to_dict(schema: StructType) -> Dict[str, str]:
        return {field.name: field.dataType for field in schema.fields}

    schema1, schema2 = schema_to_dict(df1.schema), schema_to_dict(df2.schema)

    return {
        "columns_only_in_df1": list(set(schema1.keys()) - set(schema2.keys())),
        "columns_only_in_df2": list(set(schema2.keys()) - set(schema1.keys())),
        "type_mismatches": [
            (col, schema1[col].simpleString(), schema2[col].simpleString())
            for col in set(schema1.keys()) & set(schema2.keys())
            if schema1[col] != schema2[col]
        ],
    }


def compare_row_counts(df1: DataFrame, df2: DataFrame) -> Dict[str, int]:
    """
    Compare the number of rows in two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.

    Returns:
        Dict[str, int]: A dictionary containing row count information.
    """
    count1, count2 = df1.count(), df2.count()
    return {
        "rows_in_df1": count1,
        "rows_in_df2": count2,
        "row_difference": count2 - count1,
    }


def find_key_column_differences(
    df1: DataFrame, df2: DataFrame, key_columns: List[str]
) -> Dict[str, int]:
    """
    Find differences in key columns between two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.
        key_columns (List[str]): Columns that uniquely identify rows.

    Returns:
        Dict[str, int]: A dictionary containing key column difference information.
    """
    df1_keys = df1.select(key_columns).distinct()
    df2_keys = df2.select(key_columns).distinct()

    return {
        "keys_only_in_df1": df1_keys.subtract(df2_keys).count(),
        "keys_only_in_df2": df2_keys.subtract(df1_keys).count(),
    }


def analyze_timestamp_differences(
    df1: DataFrame, df2: DataFrame, key_columns: List[str], timestamp_column: str
) -> DataFrame:
    """
    Analyze differences in timestamp column between two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.
        key_columns (List[str]): Columns that uniquely identify rows.
        timestamp_column (str): Name of the timestamp column.

    Returns:
        DataFrame: A DataFrame containing timestamp difference analysis.
    """
    join_condition = [df1[col] == df2[col] for col in key_columns]

    return (
        df1.alias("df1")
        .join(df2.alias("df2"), join_condition, "full_outer")
        .select(
            *[col(f"df1.{c}").alias(f"{c}_df1") for c in key_columns],
            col(f"df1.{timestamp_column}").alias(f"{timestamp_column}_df1"),
            col(f"df2.{timestamp_column}").alias(f"{timestamp_column}_df2"),
            when(
                col(f"df1.{timestamp_column}") > col(f"df2.{timestamp_column}"),
                "df1 more recent",
            )
            .when(
                col(f"df1.{timestamp_column}") < col(f"df2.{timestamp_column}"),
                "df2 more recent",
            )
            .otherwise("Same timestamp")
            .alias("timestamp_comparison"),
        )
    )


def compare_column_values(
    df1: DataFrame,
    df2: DataFrame,
    key_columns: List[str],
    exclude_columns: List[str] = None,
    numeric_tolerance: float = 1e-6,
) -> DataFrame:
    """
    Compare values in corresponding columns between two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.
        key_columns (List[str]): Columns that uniquely identify rows.
        exclude_columns (List[str], optional): Columns to exclude from comparison.
        numeric_tolerance (float, optional): Tolerance for numeric comparisons.

    Returns:
        DataFrame: A DataFrame containing value differences.
    """
    if exclude_columns is None:
        exclude_columns = []

    join_condition = [df1[col] == df2[col] for col in key_columns]
    comparison_columns = [
        col for col in df1.columns if col not in key_columns + exclude_columns
    ]

    diff_df = df1.alias("df1").join(df2.alias("df2"), join_condition, "full_outer")

    for col_name in comparison_columns:
        data_type = df1.schema[col_name].dataType

        if isinstance(
            data_type, (IntegerType, LongType, FloatType, DoubleType, DecimalType)
        ):
            diff_df = diff_df.withColumn(
                f"{col_name}_diff",
                when(
                    abs(col(f"df1.{col_name}") - col(f"df2.{col_name}"))
                    > numeric_tolerance,
                    concat_ws(" -> ", col(f"df1.{col_name}"), col(f"df2.{col_name}")),
                ),
            )
        else:
            diff_df = diff_df.withColumn(
                f"{col_name}_diff",
                when(
                    col(f"df1.{col_name}") != col(f"df2.{col_name}"),
                    concat_ws(" -> ", col(f"df1.{col_name}"), col(f"df2.{col_name}")),
                ),
            )

    select_columns = key_columns + [f"{col}_diff" for col in comparison_columns]
    return diff_df.select(*select_columns).where(
        " OR ".join([f"`{col}_diff` IS NOT NULL" for col in comparison_columns])
    )


def analyze_dataframe_differences(
    df1: DataFrame,
    df2: DataFrame,
    key_columns: List[str],
    timestamp_column: str = None,
    exclude_columns: List[str] = None,
    numeric_tolerance: float = 1e-6,
) -> Dict[str, Any]:
    """
    Comprehensive analysis of differences between two DataFrames.

    Args:
        df1 (DataFrame): First DataFrame.
        df2 (DataFrame): Second DataFrame.
        key_columns (List[str]): Columns that uniquely identify rows.
        timestamp_column (str, optional): Name of the timestamp column.
        exclude_columns (List[str], optional): Columns to exclude from comparison.
        numeric_tolerance (float, optional): Tolerance for numeric comparisons.

    Returns:
        Dict[str, Any]: A dictionary containing comprehensive difference analysis.

    Raises:
        ValueError: If key columns are not present in both DataFrames.
    """
    if not set(key_columns).issubset(df1.columns) or not set(key_columns).issubset(
        df2.columns
    ):
        raise ValueError("Key columns must be present in both DataFrames.")

    analysis = {
        "schema_differences": identify_schema_differences(df1, df2),
        "row_count_comparison": compare_row_counts(df1, df2),
        "key_column_differences": find_key_column_differences(df1, df2, key_columns),
        "value_differences": compare_column_values(
            df1, df2, key_columns, exclude_columns, numeric_tolerance
        ),
    }

    if timestamp_column:
        if timestamp_column not in df1.columns or timestamp_column not in df2.columns:
            raise ValueError(
                f"Timestamp column '{timestamp_column}' not found in one or both DataFrames."
            )
        analysis["timestamp_analysis"] = analyze_timestamp_differences(
            df1, df2, key_columns, timestamp_column
        )

    return analysis