In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType, StringType
from pyspark.sql.functions import col

In [1]:
spark = SparkSession.builder.appName("Change Data Type").getOrCreate()

# Sample DataFrame
df = spark.createDataFrame(
    [(1, "100", "Tom", "Brady"), (2, "200", "Tomson", "Lary")],
    ["id", "value", "Name", "SURNAME"],
)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/22 19:16:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [15]:
def change_dtypes(df, dict_mapping):
    """
    This function changes the data types of specified columns in a DataFrame.

    It iterates over a dictionary where the keys are column names and the values are the desired data types.
    For each key-value pair, it changes the data type of the corresponding column in the DataFrame.

    Parameters:
    df (pyspark.sql.DataFrame): The DataFrame whose column data types are to be changed.
    dict_mapping (dict): A dictionary mapping column names (keys) to their desired data types (values).

    Returns:
    df (pyspark.sql.DataFrame): The DataFrame with the changed data types.

    Example:
    Suppose we have a DataFrame 'df' with columns 'A', 'B', and 'C' of type string, integer, and string respectively.
    If we want to change 'A' to integer and 'C' to float, we would use the function as follows:

    dict_mapping_dtypes = {'A': 'int', 'C': 'float'}
    df = change_dtypes(df, dict_mapping_dtypes)
    """
    for col_name, col_dtype in dict_mapping.items():
        df = df.withColumn(col_name, df[col_name].cast(col_dtype))
    return df

In [None]:
def rename_columns(df, columns_map):
    """
    This function renames the columns of a DataFrame based on a provided mapping.

    Parameters:
    df (DataFrame): The DataFrame whose columns are to be renamed.
    columns_map (dict): A dictionary mapping old column names (keys) to new column names (values).

    Returns:
    DataFrame: The DataFrame with renamed columns.

    Example:
    Suppose we have a DataFrame 'df' with columns 'first_name' and 'last_name'.
    If we want to rename 'first_name' to 'name' and 'last_name' to 'surname', we would use the function as follows:

    columns_map = {"first_name": "name", "last_name": "surname"}
    df = rename_columns(df, columns_map)
    """
    for old_name, new_name in columns_map.items():
        df = df.withColumnRenamed(old_name, new_name)
    # Return the DataFrame with renamed columns
    return df


columns_map = {"first_name": "name", "last_name": "surname"}
# Use the rename_columns function to rename the columns in the DataFrame
df = rename_columns(df, columns_map)

In [16]:
# Dictionary for replacement
replacement_dict = {
    "Tom": "Benn",
}

df_replaced = df.replace(replacement_dict, subset=["Name"])
df_replaced.show()

+---+-----+------+-------+
| id|value|  Name|SURNAME|
+---+-----+------+-------+
|  1|  100|  Benn|  Brady|
|  2|  200|Tomson|   Lary|
+---+-----+------+-------+



In [5]:
df.show()

                                                                                

+---+---+--------+----------+
| ID|Age|  Salary|Experience|
+---+---+--------+----------+
|  1| 25| 30000.0|         2|
|  2| 35| 50000.0|         7|
|  3| 45| 75000.0|        15|
|  4| 55|100000.0|        25|
|  5| 65| 80000.0|        35|
|  6| 18| 20000.0|         0|
|  7| 75| 60000.0|        40|
|  8| 30| 45000.0|         5|
|  9| 40| 70000.0|        12|
| 10| 50| 90000.0|        20|
+---+---+--------+----------+



In [44]:
import re
from pyspark.sql import DataFrame


def clean_column_names(df: DataFrame) -> DataFrame:
    """
    Cleans the column names of a PySpark DataFrame to ensure consistency:
    - Trims leading and trailing spaces.
    - Replaces spaces with underscores.
    - Removes any non-alphanumeric characters except underscores.
    - Ensures the column name does not start with a digit by prefixing with 'col_' if necessary.
    - Ensures the column name does not end with an underscore.

    Args:
        df (DataFrame): The input PySpark DataFrame.

    Returns:
        DataFrame: A new DataFrame with cleaned column names.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("CleanColumnNames").getOrCreate()
        >>> data = [("Alice", 34, "NY"), ("Bob", 45, "CA"), ("Charlie", 29, "TX")]
        >>> df = spark.createDataFrame(data, [" name ", " age ", " st@te "])
        >>> df_clean = clean_column_names(df)
        >>> df_clean.show()
        +-------+---+-----+
        |   name|age|state|
        +-------+---+-----+
        |  Alice| 34|   NY|
        |    Bob| 45|   CA|
        |Charlie| 29|   TX|
        +-------+---+-----+
        >>> spark.stop()
    """
    new_column_names = []
    for column in df.columns:
        # Trim leading and trailing spaces
        new_col = column.strip()
        # Replace spaces with underscores
        new_col = new_col.replace(" ", "_")
        # Remove any non-alphanumeric characters except underscores
        new_col = re.sub(r"[^\w]", "", new_col)
        # Ensure the column does not start with a digit
        if new_col[0].isdigit():
            new_col = f"col_{new_col}"
        # Ensure the column does not end with an underscore
        if new_col.endswith("_"):
            new_col = new_col.rstrip("_")
        new_column_names.append(new_col)

    for old_col, new_col in zip(df.columns, new_column_names):
        df = df.withColumnRenamed(old_col, new_col)

    return df


# Example usage
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("CleanColumnNames").getOrCreate()

data = [("Alice", 34, "NY"), ("Bob", 45, "CA"), ("Charlie", 29, "TX")]
df = spark.createDataFrame(data, [" name ", " age ", " st@te "])

df_clean = clean_column_names(df)
df_clean.show()

# Clean up
spark.stop()

                                                                                

+-------+---+----+
|   name|age|stte|
+-------+---+----+
|  Alice| 34|  NY|
|    Bob| 45|  CA|
|Charlie| 29|  TX|
+-------+---+----+



In [45]:
from pyspark.sql import DataFrame


def convert_column_names(df: DataFrame, case: str = "lower") -> DataFrame:
    """
    Converts the column names of a PySpark DataFrame to the specified case:
    - 'upper': Converts all column names to uppercase.
    - 'lower': Converts all column names to lowercase.
    - 'title': Converts the first letter of each word in column names to uppercase.

    Args:
        df (DataFrame): The input PySpark DataFrame.
        case (str): The case to convert the column names to ('upper', 'lower', 'title').

    Returns:
        DataFrame: A new DataFrame with converted column names.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.appName("ConvertColumnNames").getOrCreate()
        >>> data = [("Alice", 34, "NY"), ("Bob", 45, "CA"), ("Charlie", 29, "TX")]
        >>> df = spark.createDataFrame(data, ["name", "age", "state"])
        >>> df_upper = convert_column_names(df, 'upper')
        >>> df_upper.show()
        +-------+---+-----+
        |   NAME|AGE|STATE|
        +-------+---+-----+
        |  Alice| 34|   NY|
        |    Bob| 45|   CA|
        |Charlie| 29|   TX|
        +-------+---+-----+
        >>> df_title = convert_column_names(df, 'title')
        >>> df_title.show()
        +-------+---+-----+
        |   Name|Age|State|
        +-------+---+-----+
        |  Alice| 34|   NY|
        |    Bob| 45|   CA|
        |Charlie| 29|   TX|
        +-------+---+-----+
        >>> spark.stop()
    """
    if case not in ["upper", "lower", "title"]:
        raise ValueError(
            "The 'case' parameter must be one of 'upper', 'lower', or 'title'."
        )

    new_column_names = []
    for column in df.columns:
        if case == "upper":
            new_col = column.upper()
        elif case == "lower":
            new_col = column.lower()
        elif case == "title":
            new_col = column.title()
        new_column_names.append(new_col)

    for old_col, new_col in zip(df.columns, new_column_names):
        df = df.withColumnRenamed(old_col, new_col)

    return df


# Example usage
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("ConvertColumnNames").getOrCreate()

data = [("Alice", 34, "NY"), ("Bob", 45, "CA"), ("Charlie", 29, "TX")]
df = spark.createDataFrame(data, ["name", "age", "state"])

df_upper = convert_column_names(df, "upper")
df_upper.show()

df_lower = convert_column_names(df, "lower")
df_lower.show()

df_title = convert_column_names(df, "title")
df_title.show()

# Clean up
spark.stop()

                                                                                

+-------+---+-----+
|   NAME|AGE|STATE|
+-------+---+-----+
|  Alice| 34|   NY|
|    Bob| 45|   CA|
|Charlie| 29|   TX|
+-------+---+-----+

+-------+---+-----+
|   name|age|state|
+-------+---+-----+
|  Alice| 34|   NY|
|    Bob| 45|   CA|
|Charlie| 29|   TX|
+-------+---+-----+

+-------+---+-----+
|   Name|Age|State|
+-------+---+-----+
|  Alice| 34|   NY|
|    Bob| 45|   CA|
|Charlie| 29|   TX|
+-------+---+-----+



# Complex Filtering

In [28]:
from pyspark.sql import DataFrame, SparkSession, functions as F
from pyspark.sql.column import Column
from typing import List, Tuple, Union, Callable

# Initialize Spark session
spark = SparkSession.builder.appName("example").getOrCreate()


def add_mapped_column(
    df: DataFrame,
    target_column: str,
    mapping_list: List[Tuple[Column, Union[str, Callable[[DataFrame], Column]]]],
    default_value: str = "Unknown",
) -> DataFrame:
    """
    Add a new column to the DataFrame based on mappings defined in a list of tuples.

    Args:
        df (DataFrame): Input DataFrame.
        target_column (str): Name of the new column to create.
        mapping_list (List[Tuple[Column, Union[str, Callable[[DataFrame], Column]]]]): List with conditions as elements and mapped values or functions as values.
        default_value (str): Default value for rows not matching any condition.

    Returns:
        DataFrame: DataFrame with the new mapped column.

    Example:
        >>> from pyspark.sql import SparkSession, functions as F
        >>> spark = SparkSession.builder.appName("example").getOrCreate()
        >>> data = [
        >>>     (15, "John"),
        >>>     (25, "Jane"),
        >>>     (35, "Doe"),
        >>>     (55, "Alice"),
        >>>     (75, "Bob"),
        >>>     (None, "Eve"),
        >>>     (30, None),
        >>> ]
        >>> columns = ["Age", "Name"]
        >>> df = spark.createDataFrame(data, columns)
        >>> age_group_mapping = [
        >>>     (F.col("Age") < 18, "Minor"),
        >>>     (F.col("Age").isNull(), "No Age"),
        >>>     ((F.col("Age") >= 30) & (F.col("Age") < 50), "Adult"),
        >>>     (F.col("Name").startswith("A"), "Name with A"),
        >>>     (F.col("Name").rlike("^[Jj]"), "Starts with J or j"),
        >>>     (F.col("Name").contains("a"), "Contains 'a'"),
        >>>     ((F.col("Age") == 25) | (F.col("Age") == 35), "Specific Age 25 or 35"),
        >>>     (~F.col("Age").between(18, 65), "Not between 18 and 65"),
        >>>     (F.col("Name").isNotNull(), "Has Name"),
        >>>     (F.col("Age").isin(25, 35, 55), "Age is 25, 35, or 55"),
        >>>     (F.col("Name").like("%o%"), "Name contains 'o'"),
        >>>     (F.col("Age") > 10, lambda df: F.concat_ws(" - ", df["Age"], df["Name"])),
        >>> ]
        >>> df_with_age_group = add_mapped_column(df, "Group", age_group_mapping)
        >>> df_with_age_group.show()
    """
    # Start with the default value
    when_expr: Column = F.lit(default_value)

    # Iterate over the mapping list and build the condition expression
    for condition, value in mapping_list:
        if callable(value):
            when_expr = F.when(condition, value(df)).otherwise(when_expr)
        else:
            when_expr = F.when(condition, value).otherwise(when_expr)

    # Add the new column to the DataFrame
    df_with_new_column = df.withColumn(target_column, when_expr)

    return df_with_new_column


data = [
    (15, "John"),
    (25, "Jane"),
    (35, "Doe"),
    (55, "Alice"),
    (75, "Bob"),
    (None, "Eve"),
    (30, None),
]
columns = ["Age", "Name"]
df = spark.createDataFrame(data, columns)

age_group_mapping = [
    (F.col("Age") < 18, "Minor"),
    (F.col("Age").isNull(), "No Age"),
    ((F.col("Age") >= 30) & (F.col("Age") < 50), "Adult"),
    (F.col("Name").startswith("A"), "Name with A"),
    (F.col("Name").rlike("^[Jj]"), "Starts with J or j"),
    (F.col("Name").contains("a"), "Contains 'a'"),
    ((F.col("Age") == 25) | (F.col("Age") == 35), "Specific Age 25 or 35"),
    (~F.col("Age").between(18, 65), "Not between 18 and 65"),
    (F.col("Name").isNotNull(), "Has Name"),
    (F.col("Age").isin(25, 35, 55), "Age is 25, 35, or 55"),
    (F.col("Name").like("%o%"), "Name contains 'o'"),
    (F.col("Age") > 30, F.concat_ws(" - ", df["Age"], df["Name"])),
]

df_with_age_group = add_mapped_column(df, "Group", age_group_mapping)
df_with_age_group.show()

+----+-----+--------------------+
| Age| Name|               Group|
+----+-----+--------------------+
|  15| John|   Name contains 'o'|
|  25| Jane|Age is 25, 35, or 55|
|  35|  Doe|            35 - Doe|
|  55|Alice|          55 - Alice|
|  75|  Bob|            75 - Bob|
|null|  Eve|            Has Name|
|  30| null|               Adult|
+----+-----+--------------------+



In [39]:
from pyspark.sql import DataFrame, functions as F
from typing import List


def assert_unique_combination(df: DataFrame, columns: List[str]) -> None:
    """
    Assert that specified columns or their combination have unique values in the DataFrame.

    This function checks for uniqueness in the specified columns. If a single column
    is provided, it checks that column for uniqueness. If multiple columns are provided,
    it checks the combination of these columns for uniqueness.

    Args:
        df (DataFrame): Input DataFrame to check.
        columns (List[str]): List of column names to check for uniqueness.

    Raises:
        ValueError: If duplicate values are found in the specified column(s).

    Example:
        >>> df = spark.createDataFrame([(1, 'A'), (2, 'B'), (1, 'C')], ['id', 'value'])
        >>> assert_unique_combination(df, ['id'])  # Raises ValueError
        >>> assert_unique_combination(df, ['id', 'value'])  # No error
    """
    if not columns:
        raise ValueError("At least one column must be specified.")

    if len(columns) == 1:
        check_column = columns[0]
    else:
        check_column = "unique_combo_" + "_".join(columns)
        concat_expr = F.concat_ws("||", *columns)
        df = df.withColumn(check_column, concat_expr)

    duplicates = df.groupBy(check_column).count().filter(F.col("count") > 1)

    if duplicates.count() > 0:
        dup_values = duplicates.select(check_column).limit(3).collect()
        dup_list = [row[check_column] for row in dup_values]

        if len(columns) == 1:
            error_message = f"🚨 Column {columns[0]} has duplicate values."
        else:
            error_message = f"🚨 Combination of {', '.join(columns)} is not unique."

        error_message += f"\n💡 Examples: {', '.join(map(str, dup_list))}"
        raise ValueError(error_message)

    if len(columns) > 1:
        df = df.drop(check_column)

    print(f"✅ Uniqueness check passed for: {', '.join(columns)}")


# Initialize Spark session
spark = SparkSession.builder.appName("UniqueComboCheck").getOrCreate()

# Create a sample DataFrame
data = [
    (1, "A", "X"),
    (2, "B", "Y"),
    (3, "C", "Z"),
    (1, "D", "W"),  # Duplicate in first column
    (4, "B", "V"),  # Duplicate in second column
]
df = spark.createDataFrame(data, ["id", "value", "category"])

# Check single column
try:
    assert_unique_combination(df, ["id"])
except ValueError as e:
    print(e)

# Check multiple columns
try:
    assert_unique_combination(df, ["id", "value"])
except ValueError as e:
    print(e)

# Check columns with no duplicates
assert_unique_combination(df, ["id", "value", "category"])

# Stop Spark session
spark.stop()

                                                                                

🚨 Column id has duplicate values.
💡 Examples: 1
✅ Uniqueness check passed for: id, value
✅ Uniqueness check passed for: id, value, category


Key points:

It returns a DataFrame representing the table.
The table must exist in the Hive metastore or Spark catalog.
It's database-agnostic, meaning it can read from any supported database type that's configured in your Spark environment.

df = spark.table("database_name.table_name")

try:
    df = spark.table("database_name.table_name")
except Exception as e:
    print(f"Error reading table: {e}")


df = spark.table("database_name.table_name").select("column1", "column2")

sales_db = spark.conf.get("my_app.sales_database")
customer_data = spark.table(f"{sales_db}.customer_info")

df = spark.table("database_name.table_name").filter(F.col("date") > "2023-01-01")

In [41]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from typing import List, Dict, Any

# Initialize Spark Session
spark = SparkSession.builder.appName("ModularGroupBy").getOrCreate()

# Sample data
data = [
    ("A", 1, 100),
    ("A", 2, 150),
    ("B", 1, 200),
    ("B", 2, 250),
    ("C", 1, 300),
    ("C", 2, 350),
]
df = spark.createDataFrame(data, ["category", "subcategory", "value"])


def apply_groupby_aggregations(
    df: DataFrame, group_cols: List[str], agg_expressions: Dict[str, Any]
) -> DataFrame:
    """
    Apply groupBy and aggregations on a DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    group_cols (List[str]): Columns to group by
    agg_expressions (Dict[str, Any]): Aggregation expressions

    Returns:
    DataFrame: Grouped and aggregated DataFrame
    """
    return df.groupBy(*group_cols).agg(
        *[F.expr(expr).alias(column) for column, expr in agg_expressions.items()]
    )


def add_window_calculations(
    df: DataFrame,
    partition_cols: List[str],
    order_cols: List[str],
    window_expressions: Dict[str, Any],
) -> DataFrame:
    """
    Add window function calculations to a DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    partition_cols (List[str]): Columns to partition by
    order_cols (List[str]): Columns to order by
    window_expressions (Dict[str, Any]): Window function expressions

    Returns:
    DataFrame: DataFrame with additional window calculations
    """
    window_spec = Window.partitionBy(*partition_cols).orderBy(*order_cols)
    for column, expr in window_expressions.items():
        df = df.withColumn(column, expr.over(window_spec))
    return df


# Usage
def main():
    # GroupBy aggregations
    agg_result = apply_groupby_aggregations(
        df,
        group_cols=["category"],
        agg_expressions={
            "total_value": "sum(value)",
            "avg_value": "avg(value)",
            "count": "count(*)",
        },
    )

    # Window calculations
    window_result = add_window_calculations(
        agg_result,
        partition_cols=["category"],
        order_cols=["total_value"],
        window_expressions={"rank": F.rank(), "running_total": F.sum("total_value")},
    )

    window_result.show()


if __name__ == "__main__":
    main()

# Clean up
spark.stop()

                                                                                

+--------+-----------+---------+-----+----+-------------+
|category|total_value|avg_value|count|rank|running_total|
+--------+-----------+---------+-----+----+-------------+
|       A|        250|    125.0|    2|   1|          250|
|       B|        450|    225.0|    2|   1|          450|
|       C|        650|    325.0|    2|   1|          650|
+--------+-----------+---------+-----+----+-------------+



In [42]:
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StringType


def trim_all_columns(df: DataFrame) -> DataFrame:
    """
    Trim leading and trailing whitespace from all string columns in a DataFrame.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    DataFrame: DataFrame with all string columns trimmed
    """
    # Get the schema of the DataFrame
    schema = df.schema

    # Create a list of column expressions
    trim_exprs = [
        (
            F.trim(F.col(field.name)).alias(field.name)
            if isinstance(field.dataType, StringType)
            else F.col(field.name)
        )
        for field in schema.fields
    ]

    # Apply the trim expressions to the DataFrame
    trimmed_df = df.select(*trim_exprs)

    return trimmed_df


# Example usage
from pyspark.sql import SparkSession

# Initialize Spark Session
spark = SparkSession.builder.appName("TrimColumns").getOrCreate()

# Create a sample DataFrame
data = [("  John  ", "  Doe  ", 30), ("Jane ", " Smith", 25), (" Bob", "Johnson ", 35)]
df = spark.createDataFrame(data, ["first_name", "last_name", "age"])

# Apply the trim function
trimmed_df = trim_all_columns(df)

# Show the results
print("Original DataFrame:")
df.show(truncate=False)
print("\nTrimmed DataFrame:")
trimmed_df.show(truncate=False)

# Clean up
spark.stop()

Original DataFrame:


                                                                                

+----------+---------+---+
|first_name|last_name|age|
+----------+---------+---+
|  John    |  Doe    |30 |
|Jane      | Smith   |25 |
| Bob      |Johnson  |35 |
+----------+---------+---+


Trimmed DataFrame:
+----------+---------+---+
|first_name|last_name|age|
+----------+---------+---+
|John      |Doe      |30 |
|Jane      |Smith    |25 |
|Bob       |Johnson  |35 |
+----------+---------+---+



In [None]:
# List all tables in a database
spark.catalog.listTables("default").show()

# Get details of a specific table
spark.catalog.getTable("default", "my_table").show()

catalog = spark.catalog

databases = catalog.listDatabases()
databases.show()

tables = catalog.listTables("default")  # "default" is the database name
tables.show()

columns = catalog.listColumns("my_table", "default")
columns.show()

functions = catalog.listFunctions()
functions.show()

table_details = catalog.getTable("default", "my_table")
print(table_details)

current_db = catalog.currentDatabase()
print(current_db)

table = catalog.getTable("default", "my_table")
properties = table.properties
print(properties)

In [None]:
from pyspark.sql import DataFrame
from pyspark.ml.feature import OneHotEncoder, StringIndexer
from pyspark.ml import Pipeline


def one_hot_encode(df: DataFrame, categorical_columns: list) -> DataFrame:
    """
    Perform one-hot encoding on specified categorical columns.

    Args:
    df (DataFrame): Input DataFrame
    categorical_columns (list): List of categorical column names to encode

    Returns:
    DataFrame: DataFrame with one-hot encoded columns
    """
    stages = []
    for cat_col in categorical_columns:
        string_indexer = StringIndexer(inputCol=cat_col, outputCol=f"{cat_col}_index")
        encoder = OneHotEncoder(
            inputCols=[f"{cat_col}_index"], outputCols=[f"{cat_col}_encoded"]
        )
        stages += [string_indexer, encoder]

    pipeline = Pipeline(stages=stages)
    model = pipeline.fit(df)
    encoded_df = model.transform(df)

    return encoded_df


# Usage
encoded_df = one_hot_encode(df, ["category", "color"])

In [None]:
from pyspark.sql import DataFrame


def remove_duplicates(df: DataFrame, subset: list = None) -> DataFrame:
    """
    Remove duplicate rows from the DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    subset (list): List of columns to consider for duplicates (default: None, uses all columns)

    Returns:
    DataFrame: DataFrame with duplicates removed
    """
    if subset:
        return df.dropDuplicates(subset=subset)
    else:
        return df.dropDuplicates()


# Usage
deduped_df = remove_duplicates(df, ["id", "timestamp"])

In [None]:
from pyspark.sql import DataFrame


def apply_sql_query(
    df: DataFrame, query: str, view_name: str = "temp_view"
) -> DataFrame:
    """
    Apply a custom SQL query to the DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    query (str): SQL query to apply
    view_name (str): Name for the temporary view (default: "temp_view")

    Returns:
    DataFrame: Result of the SQL query
    """
    df.createOrReplaceTempView(view_name)
    result = df.sparkSession.sql(query)
    df.sparkSession.catalog.dropTempView(view_name)
    return result


# Usage
query = """
    SELECT category, AVG(price) as avg_price
    FROM temp_view
    GROUP BY category
    HAVING AVG(price) > 100
"""
result_df = apply_sql_query(df, query)

In [None]:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from typing import List, Dict, Any, Union
from pyspark.sql.types import StructType

def summarize_numeric_columns(df: DataFrame) -> DataFrame:
    """
    Compute summary statistics for all numeric columns in a DataFrame.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    DataFrame: Summary statistics including count, mean, stddev, min, and max for each numeric column
    """
    numeric_columns = [f.name for f in df.schema.fields if f.dataType.simpleString() in ('double', 'float', 'int', 'long')]
    
    summary = df.select([
        F.count(F.col(c)).alias(f"{c}_count"),
        F.mean(F.col(c)).alias(f"{c}_mean"),
        F.stddev(F.col(c)).alias(f"{c}_stddev"),
        F.min(F.col(c)).alias(f"{c}_min"),
        F.max(F.col(c)).alias(f"{c}_max")
    for c in numeric_columns])
    
    return summary

def identify_outliers(df: DataFrame, column: str, lower_quantile: float = 0.25, upper_quantile: float = 0.75, iqr_multiplier: float = 1.5) -> DataFrame:
    """
    Identify outliers in a specified column using the Interquartile Range (IQR) method.

    Args:
    df (DataFrame): Input DataFrame
    column (str): Name of the column to check for outliers
    lower_quantile (float): Lower quantile for IQR calculation (default: 0.25)
    upper_quantile (float): Upper quantile for IQR calculation (default: 0.75)
    iqr_multiplier (float): Multiplier for IQR to determine outlier boundaries (default: 1.5)

    Returns:
    DataFrame: Original DataFrame with an additional boolean column indicating outliers
    """
    quantiles = df.approxQuantile(column, [lower_quantile, upper_quantile], 0.01)
    q1, q3 = quantiles[0], quantiles[1]
    iqr = q3 - q1
    lower_bound = q1 - iqr_multiplier * iqr
    upper_bound = q3 + iqr_multiplier * iqr
    
    return df.withColumn(f"{column}_is_outlier", 
                         ~F.col(column).between(lower_bound, upper_bound))

def pivot_and_unpivot(df: DataFrame, id_cols: List[str], pivot_col: str, value_col: str) -> Dict[str, DataFrame]:
    """
    Perform both pivot and unpivot operations on a DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    id_cols (List[str]): List of columns to use as identifiers
    pivot_col (str): Column to use for pivoting
    value_col (str): Column containing the values to be pivoted

    Returns:
    Dict[str, DataFrame]: Dictionary containing both pivoted and unpivoted DataFrames
    """
    # Pivot operation
    pivot_df = df.groupBy(id_cols).pivot(pivot_col).agg(F.first(value_col))
    
    # Unpivot operation
    value_columns = [c for c in pivot_df.columns if c not in id_cols]
    unpivot_expr = [F.expr(f"stack({len(value_columns)}, {', '.join([f''''{c}', {c}''' for c in value_columns])}) as ({pivot_col}, {value_col})")]
    unpivot_df = pivot_df.select(*id_cols, *unpivot_expr)
    
    return {"pivoted": pivot_df, "unpivoted": unpivot_df}

def add_date_features(df: DataFrame, date_column: str) -> DataFrame:
    """
    Add various date-related features to a DataFrame based on a specified date column.

    Args:
    df (DataFrame): Input DataFrame
    date_column (str): Name of the column containing date information

    Returns:
    DataFrame: DataFrame with additional date-related columns
    """
    return df.withColumn("year", F.year(F.col(date_column))) \
             .withColumn("month", F.month(F.col(date_column))) \
             .withColumn("day", F.dayofmonth(F.col(date_column))) \
             .withColumn("day_of_week", F.dayofweek(F.col(date_column))) \
             .withColumn("day_of_year", F.dayofyear(F.col(date_column))) \
             .withColumn("week_of_year", F.weekofyear(F.col(date_column))) \
             .withColumn("quarter", F.quarter(F.col(date_column)))

def compare_dataframes(df1: DataFrame, df2: DataFrame, join_columns: List[str]) -> Dict[str, Union[DataFrame, int]]:
    """
    Compare two DataFrames and identify differences.

    Args:
    df1 (DataFrame): First DataFrame
    df2 (DataFrame): Second DataFrame
    join_columns (List[str]): Columns to use for joining the DataFrames

    Returns:
    Dict[str, Union[DataFrame, int]]: Dictionary containing DataFrames with differences and counts
    """
    # Ensure both DataFrames have the same columns
    columns = sorted(set(df1.columns + df2.columns))
    df1 = df1.select(columns)
    df2 = df2.select(columns)
    
    # Perform full outer join
    joined = df1.join(df2, join_columns, "full_outer")
    
    # Identify rows present in df1 but not in df2
    in_df1_not_df2 = joined.filter(' AND '.join([f"(df2.`{col}` IS NULL OR df1.`{col}` != df2.`{col}`)" for col in columns if col not in join_columns]))
    
    # Identify rows present in df2 but not in df1
    in_df2_not_df1 = joined.filter(' AND '.join([f"(df1.`{col}` IS NULL OR df1.`{col}` != df2.`{col}`)" for col in columns if col not in join_columns]))
    
    return {
        "in_df1_not_df2": in_df1_not_df1,
        "in_df2_not_df1": in_df2_not_df1,
        "total_differences": in_df1_not_df2.count() + in_df2_not_df1.count()
    }

def validate_schema(df: DataFrame, expected_schema: StructType) -> List[str]:
    """
    Validate the schema of a DataFrame against an expected schema.

    Args:
    df (DataFrame): DataFrame to validate
    expected_schema (StructType): Expected schema

    Returns:
    List[str]: List of discrepancies between actual and expected schema
    """
    actual_fields = df.schema.fields
    expected_fields = expected_schema.fields
    discrepancies = []

    for actual, expected in zip(actual_fields, expected_fields):
        if actual.name != expected.name:
            discrepancies.append(f"Column name mismatch: {actual.name} != {expected.name}")
        elif actual.dataType != expected.dataType:
            discrepancies.append(f"Data type mismatch for column {actual.name}: {actual.dataType} != {expected.dataType}")
        elif actual.nullable != expected.nullable:
            discrepancies.append(f"Nullability mismatch for column {actual.name}: {actual.nullable} != {expected.nullable}")

    if len(actual_fields) != len(expected_fields):
        discrepancies.append(f"Number of columns mismatch: {len(actual_fields)} != {len(expected_fields)}")

    return discrepancies

# Usage examples:
spark = SparkSession.builder.appName("DataAnalysisFunctions").getOrCreate()

# Create a sample DataFrame
data = [
    (1, "2023-06-01", 100, "A"),
    (2, "2023-06-02", 200, "B"),
    (3, "2023-06-03", 300, "A"),
    (4, "2023-06-04", 400, "B"),
    (5, "2023-06-05", 500, "C")
]
df = spark.createDataFrame(data, ["id", "date", "value", "category"])

# Example usage of summarize_numeric_columns
summary = summarize_numeric_columns(df)
summary.show()

# Example usage of identify_outliers
df_with_outliers = identify_outliers(df, "value")
df_with_outliers.show()

# Example usage of pivot_and_unpivot
pivot_results = pivot_and_unpivot(df, ["id"], "category", "value")
pivot_results["pivoted"].show()
pivot_results["unpivoted"].show()

# Example usage of add_date_features
df_with_date_features = add_date_features(df, "date")
df_with_date_features.show()

# Clean up
spark.stop()

In [None]:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from typing import List, Dict, Any, Union
from pyspark.ml.feature import Imputer
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

def profile_data_quality(df: DataFrame) -> DataFrame:
    """
    Profile the data quality of a DataFrame, including null counts, distinct counts, and data types.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    DataFrame: Data quality profile including column name, data type, null count, distinct count, and sample values
    """
    def sample_values(col_name):
        return F.concat_ws(", ", F.collect_set(F.col(col_name)).over(Window.partitionBy(F.lit(1)).rowsBetween(-1000, 1000)))

    profile = df.select([
        F.lit(c).alias("column_name"),
        F.lit(str(df.schema[c].dataType)).alias("data_type"),
        F.count(F.when(F.col(c).isNull(), c)).alias("null_count"),
        F.count(F.when(F.col(c).isNotNull(), c)).alias("non_null_count"),
        F.countDistinct(c).alias("distinct_count"),
        sample_values(c).alias("sample_values")
    for c in df.columns])

    return profile.select("column_name", "data_type", "null_count", "non_null_count", "distinct_count", 
                          F.expr("substring(sample_values, 1, 100)").alias("sample_values"))

def handle_missing_values(df: DataFrame, strategy: str = "mean", columns: List[str] = None) -> DataFrame:
    """
    Handle missing values in specified columns using various imputation strategies.

    Args:
    df (DataFrame): Input DataFrame
    strategy (str): Imputation strategy ('mean', 'median', 'mode', or 'constant')
    columns (List[str]): List of columns to impute (default: None, which means all numeric columns)

    Returns:
    DataFrame: DataFrame with imputed values
    """
    if columns is None:
        columns = [c for c, dtype in df.dtypes if dtype in ('int', 'double', 'float')]

    imputer = Imputer(inputCols=columns, outputCols=columns)
    imputer.setStrategy(strategy)

    if strategy == "constant":
        imputer.setMissingValue(0)  # You can change this to any constant value

    imputed_df = imputer.fit(df).transform(df)
    return imputed_df

def calculate_correlation_matrix(df: DataFrame, method: str = "pearson") -> DataFrame:
    """
    Calculate the correlation matrix for numeric columns in a DataFrame.

    Args:
    df (DataFrame): Input DataFrame
    method (str): Correlation method ('pearson' or 'spearman')

    Returns:
    DataFrame: Correlation matrix
    """
    numeric_columns = [c for c, dtype in df.dtypes if dtype in ('int', 'double', 'float')]
    
    if method == "pearson":
        correlation = df.select(numeric_columns).corr()
    elif method == "spearman":
        correlation = df.select(numeric_columns).corr(method="spearman")
    else:
        raise ValueError("Method must be either 'pearson' or 'spearman'")
    
    return correlation

def binning(df: DataFrame, column: str, num_bins: int, strategy: str = "equal_range") -> DataFrame:
    """
    Perform binning on a numeric column.

    Args:
    df (DataFrame): Input DataFrame
    column (str): Column to bin
    num_bins (int): Number of bins
    strategy (str): Binning strategy ('equal_range' or 'equal_frequency')

    Returns:
    DataFrame: DataFrame with binned column added
    """
    if strategy == "equal_range":
        return df.withColumn(f"{column}_binned", F.ntile(num_bins).over(Window.orderBy(column)))
    elif strategy == "equal_frequency":
        quantiles = df.approxQuantile(column, [i/num_bins for i in range(1, num_bins)], 0.01)
        return df.withColumn(f"{column}_binned", F.bucketizer(F.col(column), [-float("inf")] + quantiles + [float("inf")]))
    else:
        raise ValueError("Strategy must be either 'equal_range' or 'equal_frequency'")

def detect_data_drift(df1: DataFrame, df2: DataFrame, columns: List[str]) -> DataFrame:
    """
    Detect data drift between two DataFrames for specified columns.

    Args:
    df1 (DataFrame): First DataFrame (e.g., training data)
    df2 (DataFrame): Second DataFrame (e.g., new data)
    columns (List[str]): List of columns to check for drift

    Returns:
    DataFrame: Data drift statistics including column name, drift metric, and p-value
    """
    from pyspark.sql.functions import kurtosis, skewness, mean, stddev
    from scipy import stats

    drift_results = []

    for column in columns:
        stats1 = df1.select(mean(column).alias("mean"), 
                            stddev(column).alias("std"),
                            kurtosis(column).alias("kurtosis"),
                            skewness(column).alias("skewness")).collect()[0]
        
        stats2 = df2.select(mean(column).alias("mean"), 
                            stddev(column).alias("std"),
                            kurtosis(column).alias("kurtosis"),
                            skewness(column).alias("skewness")).collect()[0]

        # Perform Kolmogorov-Smirnov test
        ks_statistic, p_value = stats.ks_2samp(df1.select(column).rdd.flatMap(lambda x: x).collect(),
                                               df2.select(column).rdd.flatMap(lambda x: x).collect())

        drift_results.append((column, ks_statistic, p_value, 
                              stats1.mean, stats2.mean, 
                              stats1.std, stats2.std,
                              stats1.kurtosis, stats2.kurtosis,
                              stats1.skewness, stats2.skewness))

    drift_df = spark.createDataFrame(drift_results, 
                                     ["column", "ks_statistic", "p_value", 
                                      "mean1", "mean2", "std1", "std2", 
                                      "kurtosis1", "kurtosis2", "skewness1", "skewness2"])
    
    return drift_df

def generate_summary_report(df: DataFrame) -> str:
    """
    Generate a summary report of the DataFrame in Markdown format.

    Args:
    df (DataFrame): Input DataFrame

    Returns:
    str: Markdown-formatted summary report
    """
    num_rows = df.count()
    num_columns = len(df.columns)
    
    numeric_columns = [c for c, dtype in df.dtypes if dtype in ('int', 'double', 'float')]
    categorical_columns = [c for c, dtype in df.dtypes if dtype not in ('int', 'double', 'float')]
    
    summary_stats = df.summary()
    
    report = f"# DataFrame Summary Report\n\n"
    report += f"## Basic Information\n"
    report += f"- Number of rows: {num_rows}\n"
    report += f"- Number of columns: {num_columns}\n"
    report += f"- Numeric columns: {', '.join(numeric_columns)}\n"
    report += f"- Categorical columns: {', '.join(categorical_columns)}\n\n"
    
    report += f"## Numeric Column Statistics\n"
    for col in numeric_columns:
        stats = summary_stats.filter(F.col("summary").isin("min", "max", "mean", "stddev")).select(col).collect()
        report += f"### {col}\n"
        report += f"- Min: {stats[0][0]}\n"
        report += f"- Max: {stats[1][0]}\n"
        report += f"- Mean: {stats[2][0]}\n"
        report += f"- StdDev: {stats[3][0]}\n\n"
    
    report += f"## Categorical Column Statistics\n"
    for col in categorical_columns:
        top_values = df.groupBy(col).count().orderBy(F.desc("count")).limit(5)
        report += f"### {col}\n"
        report += f"Top 5 values:\n"
        for row in top_values.collect():
            report += f"- {row[0]}: {row[1]}\n"
        report += "\n"
    
    return report

# Example usage
spark = SparkSession.builder.appName("AdvancedDataAnalysisFunctions").getOrCreate()

# Create a sample DataFrame
data = [
    (1, "2023-06-01", 100, "A", 0.5),
    (2, "2023-06-02", 200, "B", 1.5),
    (3, "2023-06-03", None, "A", 2.5),
    (4, "2023-06-04", 400, "B", 3.5),
    (5, "2023-06-05", 500, "C", None)
]
df = spark.createDataFrame(data, ["id", "date", "value", "category", "score"])

# Profile data quality
quality_profile = profile_data_quality(df)
quality_profile.show(truncate=False)

# Handle missing values
df_imputed = handle_missing_values(df, strategy="mean")
df_imputed.show()

# Calculate correlation matrix
corr_matrix = calculate_correlation_matrix(df_imputed)
corr_matrix.show()

# Perform binning
df_binned = binning(df_imputed, "value", 3, strategy="equal_range")
df_binned.show()

# Detect data drift (using the same DataFrame for demonstration)
drift_stats = detect_data_drift(df_imputed, df_imputed, ["value", "score"])
drift_stats.show()

# Generate summary report
report = generate_summary_report(df_imputed)
print(report)

# Clean up
spark.stop()

# Process Mining

In [None]:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    TimestampType,
    IntegerType,
)
from typing import List, Dict, Any, Union


def create_event_log_schema() -> StructType:
    """
    Create a standard schema for event logs in process mining.

    Returns:
    StructType: Schema for event logs
    """
    return StructType(
        [
            StructField("case_id", StringType(), False),
            StructField("activity", StringType(), False),
            StructField("timestamp", TimestampType(), False),
            StructField("resource", StringType(), True),
            StructField("cost", IntegerType(), True),
            # Add more fields as needed for your specific process mining project
        ]
    )


def apply_schema_to_df(df: DataFrame, schema: StructType) -> DataFrame:
    """
    Apply a given schema to a DataFrame, casting columns to the specified types.

    Args:
    df (DataFrame): Input DataFrame
    schema (StructType): Schema to apply

    Returns:
    DataFrame: DataFrame with the applied schema
    """
    for field in schema.fields:
        df = df.withColumn(field.name, F.col(field.name).cast(field.dataType))
    return df


def calculate_case_duration(
    df: DataFrame, case_id_col: str, timestamp_col: str
) -> DataFrame:
    """
    Calculate the duration of each case in the event log.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    timestamp_col (str): Name of the timestamp column

    Returns:
    DataFrame: DataFrame with case duration added
    """
    window_spec = Window.partitionBy(case_id_col).orderBy(timestamp_col)

    df_with_duration = (
        df.withColumn("case_start", F.first(timestamp_col).over(window_spec))
        .withColumn("case_end", F.last(timestamp_col).over(window_spec))
        .withColumn(
            "case_duration_seconds",
            F.unix_timestamp("case_end") - F.unix_timestamp("case_start"),
        )
    )

    return df_with_duration.select(case_id_col, "case_duration_seconds").distinct()


def identify_happy_path(
    df: DataFrame, case_id_col: str, activity_col: str
) -> List[str]:
    """
    Identify the most common sequence of activities (happy path) in the event log.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    activity_col (str): Name of the activity column

    Returns:
    List[str]: List of activities in the happy path
    """
    activity_sequences = df.groupBy(case_id_col).agg(
        F.collect_list(activity_col).alias("activities")
    )

    most_common_sequence = (
        activity_sequences.groupBy("activities")
        .count()
        .orderBy(F.desc("count"))
        .first()["activities"]
    )

    return most_common_sequence


def calculate_activity_frequency(df: DataFrame, activity_col: str) -> DataFrame:
    """
    Calculate the frequency of each activity in the event log.

    Args:
    df (DataFrame): Event log DataFrame
    activity_col (str): Name of the activity column

    Returns:
    DataFrame: DataFrame with activity frequencies
    """
    return (
        df.groupBy(activity_col)
        .count()
        .withColumnRenamed("count", "frequency")
        .orderBy(F.desc("frequency"))
    )


def detect_bottlenecks(
    df: DataFrame, case_id_col: str, activity_col: str, timestamp_col: str
) -> DataFrame:
    """
    Detect potential bottlenecks by calculating the average duration between activities.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    activity_col (str): Name of the activity column
    timestamp_col (str): Name of the timestamp column

    Returns:
    DataFrame: DataFrame with average durations between activities
    """
    window_spec = Window.partitionBy(case_id_col).orderBy(timestamp_col)

    df_with_next = df.withColumn(
        "next_activity", F.lead(activity_col).over(window_spec)
    ).withColumn("next_timestamp", F.lead(timestamp_col).over(window_spec))

    df_durations = df_with_next.withColumn(
        "duration_seconds",
        F.unix_timestamp("next_timestamp") - F.unix_timestamp(timestamp_col),
    )

    return (
        df_durations.groupBy(activity_col, "next_activity")
        .agg(F.avg("duration_seconds").alias("avg_duration_seconds"))
        .orderBy(F.desc("avg_duration_seconds"))
    )


def create_process_variants(
    df: DataFrame, case_id_col: str, activity_col: str
) -> DataFrame:
    """
    Create process variants by grouping similar sequences of activities.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    activity_col (str): Name of the activity column

    Returns:
    DataFrame: DataFrame with process variants
    """
    variants = df.groupBy(case_id_col).agg(
        F.concat_ws("->", F.collect_list(activity_col)).alias("variant")
    )

    return (
        variants.groupBy("variant")
        .count()
        .withColumnRenamed("count", "variant_frequency")
        .orderBy(F.desc("variant_frequency"))
    )


# Example usage
spark = SparkSession.builder.appName("ProcessMiningFunctions").getOrCreate()

# Create sample data
data = [
    ("case1", "Start", "2023-06-01 09:00:00", "John", 10),
    ("case1", "Activity A", "2023-06-01 10:00:00", "Alice", 20),
    ("case1", "Activity B", "2023-06-01 11:00:00", "Bob", 30),
    ("case1", "End", "2023-06-01 12:00:00", "John", 10),
    ("case2", "Start", "2023-06-01 09:30:00", "Alice", 10),
    ("case2", "Activity A", "2023-06-01 10:30:00", "Bob", 20),
    ("case2", "Activity C", "2023-06-01 11:30:00", "John", 40),
    ("case2", "End", "2023-06-01 12:30:00", "Alice", 10),
]

# Create DataFrame and apply schema
schema = create_event_log_schema()
df = spark.createDataFrame(
    data, ["case_id", "activity", "timestamp", "resource", "cost"]
)
df = apply_schema_to_df(df, schema)

# Calculate case duration
case_durations = calculate_case_duration(df, "case_id", "timestamp")
case_durations.show()

# Identify happy path
happy_path = identify_happy_path(df, "case_id", "activity")
print("Happy Path:", happy_path)

# Calculate activity frequency
activity_freq = calculate_activity_frequency(df, "activity")
activity_freq.show()

# Detect bottlenecks
bottlenecks = detect_bottlenecks(df, "case_id", "activity", "timestamp")
bottlenecks.show()

# Create process variants
variants = create_process_variants(df, "case_id", "activity")
variants.show(truncate=False)

# Clean up
spark.stop()

In [None]:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    TimestampType,
    IntegerType,
)
from typing import List, Dict, Any, Union


def remove_incomplete_cases(
    df: DataFrame,
    case_id_col: str,
    activity_col: str,
    start_activity: str,
    end_activity: str,
) -> DataFrame:
    """
    Remove cases that don't start with the specified start activity or don't end with the specified end activity.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    activity_col (str): Name of the activity column
    start_activity (str): Expected start activity
    end_activity (str): Expected end activity

    Returns:
    DataFrame: DataFrame with incomplete cases removed
    """
    window_spec = Window.partitionBy(case_id_col).orderBy(F.col("timestamp"))

    df_with_start_end = df.withColumn(
        "first_activity", F.first(activity_col).over(window_spec)
    ).withColumn("last_activity", F.last(activity_col).over(window_spec))

    return df_with_start_end.filter(
        (F.col("first_activity") == start_activity)
        & (F.col("last_activity") == end_activity)
    ).drop("first_activity", "last_activity")


def remove_outlier_cases(
    df: DataFrame,
    case_id_col: str,
    timestamp_col: str,
    lower_quantile: float = 0.05,
    upper_quantile: float = 0.95,
) -> DataFrame:
    """
    Remove outlier cases based on case duration.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    timestamp_col (str): Name of the timestamp column
    lower_quantile (float): Lower quantile for outlier removal (default: 0.05)
    upper_quantile (float): Upper quantile for outlier removal (default: 0.95)

    Returns:
    DataFrame: DataFrame with outlier cases removed
    """
    case_durations = calculate_case_duration(df, case_id_col, timestamp_col)

    quantiles = case_durations.approxQuantile(
        "case_duration_seconds", [lower_quantile, upper_quantile], 0.01
    )
    lower_bound, upper_bound = quantiles[0], quantiles[1]

    valid_cases = case_durations.filter(
        (F.col("case_duration_seconds") >= lower_bound)
        & (F.col("case_duration_seconds") <= upper_bound)
    ).select(case_id_col)

    return df.join(valid_cases, on=case_id_col, how="inner")


def standardize_activity_names(
    df: DataFrame, activity_col: str, mapping: Dict[str, str]
) -> DataFrame:
    """
    Standardize activity names based on a provided mapping.

    Args:
    df (DataFrame): Event log DataFrame
    activity_col (str): Name of the activity column
    mapping (Dict[str, str]): Dictionary mapping original activity names to standardized names

    Returns:
    DataFrame: DataFrame with standardized activity names
    """
    return df.replace(mapping, subset=[activity_col])


def add_time_attributes(df: DataFrame, timestamp_col: str) -> DataFrame:
    """
    Add time-related attributes to the event log.

    Args:
    df (DataFrame): Event log DataFrame
    timestamp_col (str): Name of the timestamp column

    Returns:
    DataFrame: DataFrame with additional time attributes
    """
    return (
        df.withColumn("hour", F.hour(F.col(timestamp_col)))
        .withColumn("day_of_week", F.dayofweek(F.col(timestamp_col)))
        .withColumn(
            "is_weekend",
            F.when(F.dayofweek(F.col(timestamp_col)).isin([1, 7]), 1).otherwise(0),
        )
        .withColumn("month", F.month(F.col(timestamp_col)))
        .withColumn("quarter", F.quarter(F.col(timestamp_col)))
    )


def add_event_index(df: DataFrame, case_id_col: str, timestamp_col: str) -> DataFrame:
    """
    Add an event index to each event within a case.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    timestamp_col (str): Name of the timestamp column

    Returns:
    DataFrame: DataFrame with event index added
    """
    window_spec = Window.partitionBy(case_id_col).orderBy(timestamp_col)
    return df.withColumn("event_index", F.row_number().over(window_spec))


def add_next_activity(
    df: DataFrame, case_id_col: str, activity_col: str, timestamp_col: str
) -> DataFrame:
    """
    Add the next activity for each event within a case.

    Args:
    df (DataFrame): Event log DataFrame
    case_id_col (str): Name of the case ID column
    activity_col (str): Name of the activity column
    timestamp_col (str): Name of the timestamp column

    Returns:
    DataFrame: DataFrame with next activity added
    """
    window_spec = Window.partitionBy(case_id_col).orderBy(timestamp_col)
    return df.withColumn("next_activity", F.lead(activity_col).over(window_spec))


def filter_time_range(
    df: DataFrame, timestamp_col: str, start_date: str, end_date: str
) -> DataFrame:
    """
    Filter events within a specific time range.

    Args:
    df (DataFrame): Event log DataFrame
    timestamp_col (str): Name of the timestamp column
    start_date (str): Start date in format 'YYYY-MM-DD'
    end_date (str): End date in format 'YYYY-MM-DD'

    Returns:
    DataFrame: DataFrame with events filtered to the specified time range
    """
    return df.filter(
        (F.col(timestamp_col) >= start_date) & (F.col(timestamp_col) <= end_date)
    )


# Example usage
spark = SparkSession.builder.appName("ProcessMiningPreprocessing").getOrCreate()

# Create sample data
data = [
    ("case1", "Start Process", "2023-06-01 09:00:00", "John"),
    ("case1", "Activity A", "2023-06-01 10:00:00", "Alice"),
    ("case1", "Activity B", "2023-06-01 11:00:00", "Bob"),
    ("case1", "End Process", "2023-06-01 12:00:00", "John"),
    ("case2", "Start Process", "2023-06-01 09:30:00", "Alice"),
    ("case2", "Activity A", "2023-06-01 10:30:00", "Bob"),
    ("case2", "Activity C", "2023-06-01 11:30:00", "John"),
    ("case3", "Start Process", "2023-06-02 09:00:00", "Bob"),
    ("case3", "Activity B", "2023-06-02 10:00:00", "Alice"),
]

df = spark.createDataFrame(data, ["case_id", "activity", "timestamp", "resource"])

# Remove incomplete cases
df_complete = remove_incomplete_cases(
    df, "case_id", "activity", "Start Process", "End Process"
)
print("Complete cases:")
df_complete.show()

# Standardize activity names
activity_mapping = {"Start Process": "Start", "End Process": "End"}
df_standardized = standardize_activity_names(df, "activity", activity_mapping)
print("Standardized activity names:")
df_standardized.show()

# Add time attributes
df_with_time = add_time_attributes(df, "timestamp")
print("With time attributes:")
df_with_time.show()

# Add event index
df_with_index = add_event_index(df, "case_id", "timestamp")
print("With event index:")
df_with_index.show()

# Add next activity
df_with_next = add_next_activity(df, "case_id", "activity", "timestamp")
print("With next activity:")
df_with_next.show()

# Filter time range
df_filtered = filter_time_range(df, "timestamp", "2023-06-01", "2023-06-01")
print("Filtered by time range:")
df_filtered.show()

# Clean up
spark.stop()