In [None]:
from pyspark.sql import DataFrame
from typing import Optional, List, Union


def save_dataframe_to_csv(
    df: DataFrame,
    folder_path: str,
    file_name: str,
    partition_by: Optional[Union[str, List[str]]] = None,
    overwrite: bool = False,
    header: bool = True,
    delimiter: str = ",",
    encoding: str = "utf-8",
) -> None:
    """
    Save a PySpark DataFrame as a single CSV file in the specified path in Databricks.

    This function saves a PySpark DataFrame as a CSV file, with options for
    partitioning, overwriting, and basic CSV customization.

    Args:
        df (DataFrame): The PySpark DataFrame to be saved.
        folder_path (str): The directory path where the CSV file will be saved.
        file_name (str): The name of the CSV file (without the .csv extension).
        partition_by (Optional[Union[str, List[str]]]): Column(s) to partition the output by.
            Note: Using this will result in multiple files. Default is None.
        overwrite (bool): Whether to overwrite the file if it already exists. Default is False.
        header (bool): Whether to include column names as the first row. Default is True.
        delimiter (str): The delimiter to use in the CSV file. Default is ",".
        encoding (str): The character encoding to use. Default is "utf-8".

    Raises:
        ValueError: If the input validation fails.
        Exception: If there's an error during the save operation.

    Example:
        >>> df = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"])
        >>> save_dataframe_to_csv(df, "/dbfs/path/to/folder", "users")
    """
    # Input validation
    if not isinstance(df, DataFrame):
        raise ValueError("❌ Input 'df' must be a PySpark DataFrame")
    if not isinstance(folder_path, str) or not folder_path.strip():
        raise ValueError("❌ 'folder_path' must be a non-empty string")
    if not isinstance(file_name, str) or not file_name.strip():
        raise ValueError("❌ 'file_name' must be a non-empty string")
    if partition_by is not None and not isinstance(partition_by, (str, list)):
        raise ValueError("❌ 'partition_by' must be a string, list of strings, or None")

    try:
        # Construct the full file path
        full_path = f"{folder_path}/{file_name}"

        # Prepare write options
        write_options = {
            "header": str(header).lower(),
            "delimiter": delimiter,
            "encoding": encoding,
        }

        # Prepare the write operation
        write_op = (
            df.coalesce(1)
            .write.options(**write_options)
            .mode("overwrite" if overwrite else "error")
        )

        # Apply partitioning if specified
        if partition_by:
            write_op = write_op.partitionBy(partition_by)
            print("⚠️ Note: Using partition_by will result in multiple output files.")

        # Save the DataFrame as CSV
        write_op.csv(full_path)

        print(f"✅ DataFrame successfully saved as CSV: {full_path}")

        # Get and print some statistics about the saved data
        saved_df = spark.read.csv(full_path, header=header, inferSchema=True)
        row_count = saved_df.count()
        file_count = saved_df.rdd.getNumPartitions()
        print(f"📊 Saved {row_count} rows in {file_count} file(s)")

    except Exception as e:
        error_message = f"❌ Error saving DataFrame to CSV: {str(e)}"
        print(error_message)
        raise Exception(error_message)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType
from typing import List, Optional, Union
import re


def get_dataframe(
    folder_path: str,
    schema: Optional[Union[str, StructType]] = None,
    encoding: str = "utf-8",
    header: bool = True,
    delimiter: str = ",",
    ignore_files: List[str] = [],
    recursive: bool = False,
    infer_schema: bool = False,
    cache: bool = False,
    repartition: Optional[int] = None,
) -> DataFrame:
    """
    Extract all CSV files from a folder and combine them into a single DataFrame in Databricks.

    This function reads all CSV files from the specified folder, combining them
    using a single read operation for better performance. It provides various options
    for customizing the read process and optimizing the resulting DataFrame.

    Args:
        folder_path (str): The directory path containing the CSV file(s).
        schema (Optional[Union[str, StructType]]): The schema for the DataFrame. Can be a string or StructType.
            If None, the schema will be inferred. Default is None.
        encoding (str): The character encoding of the CSV file(s). Default is "utf-8".
        header (bool): Whether the CSV file(s) have a header row. Default is True.
        delimiter (str): The delimiter used in the CSV file(s). Default is ",".
        ignore_files (List[str]): List of filenames to ignore in the folder. Default is empty list.
        recursive (bool): Whether to search for CSV files in subfolders. Default is False.
        infer_schema (bool): Whether to infer the schema from a subset of the data. Default is False.
        cache (bool): Whether to cache the resulting DataFrame. Default is False.
        repartition (Optional[int]): Number of partitions for the final DataFrame. If None, no repartitioning is done.

    Returns:
        DataFrame: A PySpark DataFrame containing the combined data from all CSV files.

    Raises:
        ValueError: If the input validation fails or no CSV files are found.
        Exception: If there's an error during the read operation.

    Example:
        >>> df = get_dataframe("/dbfs/path/to/csv/folder", header=True, recursive=True, cache=True)
        >>> df.show()
    """
    # Input validation
    if not isinstance(folder_path, str) or not folder_path.strip():
        raise ValueError("❌ 'folder_path' must be a non-empty string")

    try:
        dbutils.fs.ls(folder_path)
    except Exception:
        raise ValueError(
            f"❌ The specified folder does not exist or is not accessible: {folder_path}"
        )

    # Validate schema if it's a string
    if isinstance(schema, str):
        try:
            schema = StructType.fromDDL(schema)
        except Exception as e:
            raise ValueError(f"❌ Invalid schema string: {str(e)}")

    try:
        # Prepare file path for reading
        file_path = f"{folder_path}{'/**' if recursive else ''}/*.csv"

        # Filter out ignored files
        if ignore_files:
            ignore_pattern = "|".join(re.escape(f) for f in ignore_files)
            files = [
                f.path
                for f in dbutils.fs.ls(folder_path)
                if f.path.endswith(".csv") and not re.search(ignore_pattern, f.name)
            ]
            if not files:
                raise ValueError(
                    f"❌ No CSV files found in the specified folder after applying ignore list: {folder_path}"
                )
            file_path = files

        # Read options
        read_options = {
            "encoding": encoding,
            "header": str(header).lower(),
            "delimiter": delimiter,
            "ignoreLeadingWhiteSpace": "true",
            "ignoreTrailingWhiteSpace": "true",
        }

        if schema:
            read_options["schema"] = schema
        elif infer_schema:
            read_options["inferSchema"] = "true"

        # Read all CSV files
        df = spark.read.options(**read_options).csv(file_path)

        # Add metadata column
        df = df.withColumn("source_file", F.input_file_name())

        # Repartition if specified
        if repartition:
            df = df.repartition(repartition)

        # Cache if specified
        if cache:
            df = df.cache()

        print(f"✅ Successfully read CSV file(s) from: {folder_path}")
        return df

    except Exception as e:
        error_message = f"❌ Error reading CSV file(s): {str(e)}"
        print(error_message)
        raise Exception(error_message)

In [None]:
from pyspark.sql import DataFrame
from delta.tables import DeltaTable
from pyspark.sql.utils import AnalysisException


def save_dataframe_to_delta(
    df: DataFrame,
    table_name: str,
    overwrite_schema: bool = False,
    merge_schema: bool = False,
    optimize_write: bool = True,
    partition_by: list = None,
) -> None:
    """
    Save a PySpark DataFrame as a Delta table in Databricks.

    This function saves a PySpark DataFrame as a Delta table using saveAsTable or overwrite mode.
    It provides options for overwriting existing tables, merging schemas, and optimizing writes.

    Args:
        df (DataFrame): The PySpark DataFrame to be saved.
        table_name (str): The name of the Delta table to be created or overwritten.
        overwrite_schema (bool): Whether to overwrite the schema if the table already exists.
                                 Default is False.
        merge_schema (bool): Whether to merge the new schema with the existing one.
                             Only applicable when overwrite_schema is False. Default is False.
        optimize_write (bool): Whether to optimize the write operation. Default is True.
        partition_by (list): List of columns to partition the table by. Default is None.

    Raises:
        ValueError: If the input validation fails.
        Exception: If there's an error during the save operation.

    Example:
        >>> df = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"])
        >>> save_dataframe_to_delta(df, "users_table", merge_schema=True, partition_by=["id"])
    """
    # Input validation
    if not isinstance(df, DataFrame):
        raise ValueError("❌ Input 'df' must be a PySpark DataFrame")
    if not isinstance(table_name, str) or not table_name.strip():
        raise ValueError("❌ 'table_name' must be a non-empty string")
    if partition_by and not isinstance(partition_by, list):
        raise ValueError("❌ 'partition_by' must be a list of column names or None")

    try:
        # Check if the table already exists
        table_exists = spark.catalog._jcatalog.tableExists(table_name)

        # Prepare the write operation
        write_op = df.write.format("delta")

        if optimize_write:
            write_op = write_op.option("optimizeWrite", "true")

        if partition_by:
            write_op = write_op.partitionBy(partition_by)

        if table_exists:
            if overwrite_schema:
                write_op = write_op.mode("overwrite")
                print(
                    f"ℹ️ Overwriting existing table '{table_name}' including its schema."
                )
            elif merge_schema:
                write_op = write_op.option("mergeSchema", "true").mode("overwrite")
                print(f"ℹ️ Merging new schema with existing table '{table_name}'.")
            else:
                write_op = write_op.mode("overwrite")
                print(
                    f"ℹ️ Overwriting data in existing table '{table_name}' without changing schema."
                )
        else:
            write_op = write_op.mode("overwrite")
            print(f"ℹ️ Creating new table '{table_name}'.")

        # Save the DataFrame as a Delta table
        write_op.saveAsTable(table_name)

        print(f"✅ DataFrame successfully saved as Delta table: {table_name}")

        # Get and print some statistics about the saved data
        saved_df = spark.table(table_name)
        row_count = saved_df.count()
        column_count = len(saved_df.columns)
        print(
            f"📊 Saved {row_count} rows with {column_count} columns in Delta table '{table_name}'"
        )

        # Print Delta table details
        delta_table = DeltaTable.forName(spark, table_name)
        print(f"ℹ️ Delta table information:")
        print(
            f"   - Location: {delta_table.detail().select('location').collect()[0][0]}"
        )
        print(
            f"   - Version: {delta_table.history(1).select('version').collect()[0][0]}"
        )

    except AnalysisException as ae:
        error_message = f"❌ Error saving DataFrame to Delta table: {str(ae)}"
        print(error_message)
        if "already exists" in str(ae):
            print(
                "ℹ️ Consider using overwrite_schema=True if you want to replace the existing table schema."
            )
        raise Exception(error_message)
    except Exception as e:
        error_message = f"❌ Error saving DataFrame to Delta table: {str(e)}"
        print(error_message)
        raise Exception(error_message)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import expr
import re
from typing import List, Dict, Tuple


def clean_single_name(name: str, to_upper: bool = False) -> str:
    """
    Clean a single column name.

    Args:
        name (str): The column name to clean
        to_upper (bool): If True, converts name to uppercase. Default is False (lowercase).

    Returns:
        str: The cleaned column name
    """
    name = name.upper() if to_upper else name.lower()
    name = re.sub(r"[\s\-\.\[\]\(\)\{\}\,\;\:]+", "_", name)
    name = re.sub(r"[^a-zA-Z0-9_]", "", name)
    return "col_" + name if not name[0].isalpha() else name


def get_unique_names(names: List[str]) -> List[str]:
    """
    Ensure uniqueness of column names.

    Args:
        names (List[str]): List of column names

    Returns:
        List[str]: List of unique column names
    """
    name_counts: Dict[str, int] = {}
    unique_names: List[str] = []
    for name in names:
        if name in name_counts:
            name_counts[name] += 1
            unique_names.append(f"{name}_{name_counts[name]}")
        else:
            name_counts[name] = 0
            unique_names.append(name)
    return unique_names


def generate_rename_expressions(
    old_names: List[str], new_names: List[str]
) -> List[str]:
    """
    Generate expressions for renaming columns.

    Args:
        old_names (List[str]): Original column names
        new_names (List[str]): New column names

    Returns:
        List[str]: List of expressions for renaming columns
    """
    return [
        f"`{old}` AS `{new}`" for old, new in zip(old_names, new_names) if old != new
    ]


def log_column_changes(old_names: List[str], new_names: List[str]) -> None:
    """
    Log the changes made to column names.

    Args:
        old_names (List[str]): Original column names
        new_names (List[str]): New column names
    """
    changes = [(old, new) for old, new in zip(old_names, new_names) if old != new]
    if changes:
        print("🔄 Column renaming summary:")
        for old, new in changes:
            print(f"  '{old}' -> '{new}'")
    else:
        print("✅ No column renaming was necessary.")


def clean_column_names(df: DataFrame, to_upper: bool = False) -> DataFrame:
    """
    Efficiently clean and standardize column names of a PySpark DataFrame.

    This function orchestrates the cleaning process of DataFrame column names,
    ensuring they adhere to naming conventions and are unique.

    Args:
        df (DataFrame): The input PySpark DataFrame
        to_upper (bool): If True, converts names to uppercase. Default is False (lowercase).

    Returns:
        DataFrame: A new DataFrame with cleaned column names

    Raises:
        ValueError: If the input is not a PySpark DataFrame
        RuntimeError: If column cleaning or renaming fails

    Example:
        >>> df = spark.createDataFrame([(1, 'a'), (2, 'b')], ['ID', 'Column With Spaces!'])
        >>> cleaned_df = clean_column_names(df)
        >>> cleaned_df.printSchema()
        root
         |-- id: long (nullable = true)
         |-- column_with_spaces: string (nullable = true)

    Note:
        This function is optimized for large DataFrames and uses PySpark's internal
        functions where possible to maximize performance.
    """
    if not isinstance(df, DataFrame):
        raise ValueError("❌ Input must be a PySpark DataFrame")

    try:
        # Clean and uniquify column names
        new_names = get_unique_names(
            [clean_single_name(col, to_upper) for col in df.columns]
        )

        # Generate renaming expressions
        rename_exprs = generate_rename_expressions(df.columns, new_names)

        # Apply renaming in a single transformation
        if rename_exprs:
            df = df.selectExpr(
                *rename_exprs,
                *[
                    f"`{col}`"
                    for col in df.columns
                    if col not in dict(zip(df.columns, new_names))
                ],
            )

        # Log the changes
        log_column_changes(df.columns, new_names)

        return df

    except Exception as e:
        raise RuntimeError(f"❌ Error during column name cleaning: {str(e)}")

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import count, when, col
from typing import List


def assert_dataframe_equality(df1: DataFrame, df2: DataFrame) -> None:
    """
    Assert that two PySpark DataFrames have equal content, optimized for speed.

    This function efficiently compares two DataFrames and asserts that they have the same content.
    It assumes that both DataFrames have the same column names.

    Args:
        df1 (DataFrame): The first DataFrame to compare
        df2 (DataFrame): The second DataFrame to compare

    Raises:
        AssertionError: If the DataFrames are not equal, with catchy details about the differences

    Example:
        >>> df1 = spark.createDataFrame([(1, 'a'), (2, 'b')], ['id', 'value'])
        >>> df2 = spark.createDataFrame([(1, 'a'), (2, 'b')], ['id', 'value'])
        >>> assert_dataframe_equality(df1, df2)  # This will pass
        >>>
        >>> df3 = spark.createDataFrame([(1, 'a'), (3, 'c')], ['id', 'value'])
        >>> assert_dataframe_equality(df1, df3)  # This will raise an AssertionError
    """

    # Check if both inputs are DataFrames
    if not isinstance(df1, DataFrame) or not isinstance(df2, DataFrame):
        raise AssertionError(
            "🚫 Oops! Both inputs must be PySpark DataFrames. Let's stick to the script!"
        )

    # Check if the column names are the same
    if df1.columns != df2.columns:
        raise AssertionError(
            f"🔤 Column name mismatch! We've got:\n{df1.columns}\nvs\n{df2.columns}"
        )

    try:
        # Efficiently compare DataFrames
        df_combined = df1.selectExpr("*", "1 as df1_marker").unionAll(
            df2.selectExpr("*", "2 as df1_marker")
        )

        diff_stats = (
            df_combined.groupBy(*df1.columns)
            .agg(
                count(when(col("df1_marker") == 1, True)).alias("count_df1"),
                count(when(col("df1_marker") == 2, True)).alias("count_df2"),
            )
            .where("count_df1 != count_df2")
        )

        diff_count = diff_stats.count()

        if diff_count > 0:
            sample_diff = diff_stats.limit(5).collect()

            error_message = f"📊 DataFrames are like apples and oranges! Found {diff_count} mismatched rows.\n"
            error_message += "🔍 Here's a sneak peek at the differences:\n"
            for row in sample_diff:
                error_message += f"   {row.asDict()}\n"
            error_message += "💡 Tip: Check your data sources or transformations!"

            raise AssertionError(error_message)

        # If we've made it this far, the DataFrames are equal
        print("🎉 Jackpot! The DataFrames are identical twins.")

    except Exception as e:
        if isinstance(e, AssertionError):
            raise e
        else:
            raise AssertionError(
                f"❌ An unexpected error occurred while comparing DataFrames: {str(e)}"
            )

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    StringType,
    DoubleType,
)

# Initialize Spark Session
spark = SparkSession.builder.appName("DataFrameEqualityTest").getOrCreate()

# Define schema
schema = StructType(
    [
        StructField("id", IntegerType(), False),
        StructField("name", StringType(), False),
        StructField("score", DoubleType(), True),
    ]
)

# Create two identical DataFrames
data1 = [(1, "Alice", 85.5), (2, "Bob", 92.0), (3, "Charlie", None), (4, "David", 78.5)]

df1 = spark.createDataFrame(data1, schema)
df2 = spark.createDataFrame(data1, schema)

# Create a slightly different DataFrame
data3 = [
    (1, "Alice", 85.5),
    (2, "Bob", 92.0),
    (3, "Charlie", None),
    (4, "David", 79.0),  # Changed score from 78.5 to 79.0
]

df3 = spark.createDataFrame(data3, schema)

# Test the assert_dataframe_equality function
print("Test 1: Comparing identical DataFrames")
try:
    assert_dataframe_equality(df1, df2)
except AssertionError as e:
    print(f"Unexpected Error: {str(e)}")

print("\nTest 2: Comparing DataFrames with a difference")
try:
    assert_dataframe_equality(df1, df3)
except AssertionError as e:
    print(f"Expected Error: {str(e)}")

# Clean up
spark.stop()