In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import to_timestamp
from typing import Union, List


def convert_to_timestamp(
    df: DataFrame,
    columns: Union[str, List[str]],
    format: str = "yyyy-MM-dd'T'HH:mm:ss.SSSZ",
) -> DataFrame:
    """
    Convert string column(s) to timestamp format in a PySpark DataFrame.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        columns (Union[str, List[str]]): Column name(s) to convert.
        format (str, optional): Timestamp format string. Defaults to "yyyy-MM-dd'T'HH:mm:ss.SSSZ".

    Returns:
        DataFrame: DataFrame with converted timestamp column(s).

    Raises:
        ValueError: If the input column(s) are not present in the DataFrame.

    Examples:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [("2023-07-27T13:18:12.039+0000",), ("2023-07-28T13:20:30.039+0000",)]
        >>> df = spark.createDataFrame(data, ["ZEITPUNKT"])
        >>> result_df = convert_to_timestamp(df, "ZEITPUNKT")
        >>> result_df.printSchema()
        root
         |-- ZEITPUNKT: timestamp (nullable = true)

        >>> result_df.show(truncate=False)
        +----------------------------+
        |ZEITPUNKT                   |
        +----------------------------+
        |2023-07-27 13:18:12.039     |
        |2023-07-28 13:20:30.039     |
        +----------------------------+
    """
    if isinstance(columns, str):
        columns = [columns]

    # Validate input columns
    invalid_columns = set(columns) - set(df.columns)
    if invalid_columns:
        raise ValueError(
            f"The following columns are not present in the DataFrame: {invalid_columns}"
        )

    # Convert string columns to timestamp
    for col in columns:
        df = df.withColumn(col, to_timestamp(df[col], format))

    return df

In [None]:
from pyspark.sql import DataFrame
from typing import Optional


def standardize_process_mining_column_names(
    df: DataFrame,
    case_column: str,
    activity_column: str,
    timestamp_column: str,
    standardized_case_name: str = "_CASE_KEY",
    standardized_activity_name: str = "ACTIVITY",
    standardized_timestamp_name: str = "EVENTTIME",
) -> DataFrame:
    """
    Standardize column names for process mining by renaming case, activity, and timestamp columns.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        case_column (str): Name of the column containing case IDs.
        activity_column (str): Name of the column containing activities.
        timestamp_column (str): Name of the column containing timestamps.
        standardized_case_name (str, optional): Standardized name for the case column. Defaults to "_CASE_KEY".
        standardized_activity_name (str, optional): Standardized name for the activity column. Defaults to "ACTIVITY".
        standardized_timestamp_name (str, optional): Standardized name for the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: DataFrame with standardized column names for process mining.

    Raises:
        ValueError: If any of the specified columns are not present in the input DataFrame.
        ValueError: If any of the standardized names are already present in the DataFrame but don't match the columns to be renamed.

    Example:
        >>> data = [("A1", "Start", "2023-01-01"), ("A2", "End", "2023-01-02")]
        >>> df = spark.createDataFrame(data, ["ID", "Action", "Date"])
        >>> standardized_df = standardize_process_mining_column_names(df, "ID", "Action", "Date")
        >>> standardized_df.show()
        +-----------+--------+----------+
        |_CASE_KEY  |ACTIVITY|EVENTTIME |
        +-----------+--------+----------+
        |A1         |Start   |2023-01-01|
        |A2         |End     |2023-01-02|
        +-----------+--------+----------+
    """
    # Check if all specified columns are present in the DataFrame
    required_columns = {case_column, activity_column, timestamp_column}
    missing_columns = required_columns - set(df.columns)
    if missing_columns:
        raise ValueError(
            f"The following columns are missing from the DataFrame: {missing_columns}"
        )

    # Create a mapping of original column names to standardized names
    column_mapping = {
        case_column: standardized_case_name,
        activity_column: standardized_activity_name,
        timestamp_column: standardized_timestamp_name,
    }

    # Check if any of the standardized names already exist in the DataFrame
    existing_standard_names = set(column_mapping.values()) & set(df.columns)
    conflicting_names = existing_standard_names - set(column_mapping.keys())
    if conflicting_names:
        raise ValueError(
            f"The following standardized names already exist in the DataFrame and don't match the columns to be renamed: {conflicting_names}"
        )

    # Rename the columns
    for original, standardized in column_mapping.items():
        df = df.withColumnRenamed(original, standardized)

    return df

In [None]:
from pyspark.sql import DataFrame
from typing import List, Callable, Dict, Any


def process_mining_preprocessing_pipeline(
    df: DataFrame,
    pipeline_steps: List[Callable[[DataFrame, Dict[str, Any]], DataFrame]],
    config: Dict[str, Any],
) -> DataFrame:
    """
    Apply a flexible series of preprocessing steps for process mining on a DataFrame.

    Args:
        df (DataFrame): Input DataFrame.
        pipeline_steps (List[Callable]): List of functions to apply to the DataFrame.
        config (Dict[str, Any]): Configuration dictionary for the pipeline steps.

    Returns:
        DataFrame: Processed DataFrame ready for process mining.

    Raises:
        ValueError: If any of the preprocessing or validation steps fail.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [("A1", "Start", "2023-01-01 10:00:00"),
        ...         ("A2", "End", "2023-01-02 11:00:00")]
        >>> df = spark.createDataFrame(data, ["ID", "Action", "Date"])
        >>> config = {
        ...     "case_column": "ID",
        ...     "activity_column": "Action",
        ...     "timestamp_column": "Date",
        ...     "additional_string_columns": []
        ... }
        >>> pipeline_steps = [
        ...     trim_all_strings,
        ...     convert_empty_string_to_null,
        ...     standardize_process_mining_names,
        ...     convert_to_timestamp,
        ...     check_process_mining_conditions
        ... ]
        >>> processed_df = process_mining_preprocessing_pipeline(df, pipeline_steps, config)
        >>> processed_df.show()
    """
    for step in pipeline_steps:
        df = df.transform(lambda df: step(df, config))

    return df


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("A1", "Start", "2023-01-01 10:00:00"),
        ("A2", "End  ", "2023-01-02 11:00:00"),
        ("A3", "", "2023-01-03 12:00:00"),
    ]
    df = spark.createDataFrame(data, ["ID", "Action", "Date"])

    # Configuration
    config = {
        "case_column": "ID",
        "activity_column": "Action",
        "timestamp_column": "Date",
        "additional_string_columns": [],
        "show_examples": True,  # for check_process_mining_conditions
    }

    # Define pipeline steps
    pipeline_steps = [
        trim_all_strings,
        convert_empty_string_to_null,
        standardize_process_mining_names,
        convert_to_timestamp,
        check_process_mining_conditions,
    ]

    try:
        processed_df = process_mining_preprocessing_pipeline(df, pipeline_steps, config)
        print("Preprocessing successful!")
        processed_df.show()
    except ValueError as e:
        print(f"Preprocessing failed: {e}")

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, count, when, year, expr, sum as spark_sum
from pyspark.sql.types import StringType, TimestampType
from typing import Dict, Optional


def check_process_mining_conditions(
    df: DataFrame,
    case_column: str = "CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "TIMESTAMP",
) -> Optional[str]:
    """
    Efficiently check if all necessary conditions for the process mining process are met.

    This function performs all checks in a single pass over the data, minimizing Spark actions
    and optimizing for large datasets.

    Args:
        df (DataFrame): Input PySpark DataFrame.
        case_column (str, optional): Name of the case column. Defaults to "CASE_KEY".
        activity_column (str, optional): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str, optional): Name of the timestamp column. Defaults to "TIMESTAMP".

    Returns:
        Optional[str]: A catchy success message if no errors are found, None otherwise.

    Raises:
        ValueError: If any of the conditions are not met, with a summary of all errors found.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> from pyspark.sql.functions import to_timestamp
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [("A1", "Start", "2023-01-01 10:00:00"),
        ...         ("A1", "Process", "2023-01-01 11:00:00"),
        ...         ("A1", "End", "2023-01-01 12:00:00"),
        ...         ("A2", "Start", "2023-01-02 09:00:00"),
        ...         ("A2", "End", "2023-01-02 10:00:00")]
        >>> df = spark.createDataFrame(data, ["CASE_KEY", "ACTIVITY", "TIMESTAMP"])
        >>> df = df.withColumn("TIMESTAMP", to_timestamp("TIMESTAMP"))
        >>> result = check_process_mining_conditions(df)
        >>> print(result)
        🎉 Process mining data perfection achieved! Your data is ready to uncover insights.
    """
    # Check column types
    for column, expected_type in [
        (case_column, StringType),
        (activity_column, StringType),
        (timestamp_column, TimestampType),
    ]:
        if not isinstance(df.schema[column].dataType, expected_type):
            raise ValueError(
                f"📊 Column '{column}' should be a {expected_type.__name__}, but it's a {df.schema[column].dataType}."
            )

    # Prepare all checks in a single DataFrame operation
    checks_df = df.select(
        when(col(case_column).isNull(), 1).otherwise(0).alias("null_cases"),
        when(col(activity_column).isNull(), 1).otherwise(0).alias("null_activities"),
        when(col(timestamp_column).isNull(), 1).otherwise(0).alias("null_timestamps"),
        when(year(col(timestamp_column)) < 1970, 1)
        .otherwise(0)
        .alias("early_timestamps"),
        when(year(col(timestamp_column)) > 2100, 1)
        .otherwise(0)
        .alias("future_timestamps"),
        expr(
            f"count(*) over (partition by {case_column}, {activity_column}, {timestamp_column})"
        ).alias("event_count"),
        expr(f"count(*) over (partition by {case_column})").alias("case_event_count"),
    )

    # Collect all error counts in a single action
    error_counts = checks_df.agg(
        spark_sum("null_cases").alias("null_cases"),
        spark_sum("null_activities").alias("null_activities"),
        spark_sum("null_timestamps").alias("null_timestamps"),
        spark_sum("early_timestamps").alias("early_timestamps"),
        spark_sum("future_timestamps").alias("future_timestamps"),
        spark_sum(when(col("event_count") > 1, 1).otherwise(0)).alias(
            "duplicate_events"
        ),
        spark_sum(when(col("case_event_count") == 0, 1).otherwise(0)).alias(
            "cases_without_activities"
        ),
    ).collect()[0]

    # Convert to dictionary for easier handling
    error_dict: Dict[str, int] = error_counts.asDict()

    # Prepare error messages
    error_messages = []
    if error_dict["null_cases"] > 0:
        error_messages.append(
            f"🚫 Found {error_dict['null_cases']} null values in '{case_column}' column."
        )
    if error_dict["null_activities"] > 0:
        error_messages.append(
            f"🚫 Found {error_dict['null_activities']} null values in '{activity_column}' column."
        )
    if error_dict["null_timestamps"] > 0:
        error_messages.append(
            f"🚫 Found {error_dict['null_timestamps']} null values in '{timestamp_column}' column."
        )
    if error_dict["early_timestamps"] > 0:
        error_messages.append(
            f"⏳ Found {error_dict['early_timestamps']} timestamps before 1970."
        )
    if error_dict["future_timestamps"] > 0:
        error_messages.append(
            f"🔮 Found {error_dict['future_timestamps']} timestamps after 2100."
        )
    if error_dict["duplicate_events"] > 0:
        error_messages.append(
            f"👯 Found {error_dict['duplicate_events']} duplicate events. Each combination of {case_column}, {activity_column}, and {timestamp_column} should be unique!"
        )
    if error_dict["cases_without_activities"] > 0:
        error_messages.append(
            f"🔍 Found {error_dict['cases_without_activities']} cases with no activities."
        )

    if error_messages:
        raise ValueError("\n".join(error_messages))

    return "🎉 Process mining data perfection achieved! Your data is ready to uncover insights."

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, count
from pyspark.sql.window import Window
from pyspark.sql.types import StringType, TimestampType


def find_duplicate_case_activity_eventtime(
    df: DataFrame,
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    eventtime_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Identifies and returns all duplicate case-activity-eventtime combinations in a process mining DataFrame.

    Duplicate combinations are defined as rows with the same values for _CASE_KEY, ACTIVITY, and EVENTTIME.

    Args:
        df (DataFrame): Input PySpark DataFrame containing process mining data.
        case_column (str, optional): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str, optional): Name of the activity column. Defaults to "ACTIVITY".
        eventtime_column (str, optional): Name of the eventtime column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: A DataFrame containing all rows that are part of duplicate combinations,
                   along with a count of how many times each combination appears.

    Raises:
        ValueError: If the specified columns are not present in the DataFrame or are of incorrect type.

    Example:
        >>> from pyspark.sql import SparkSession
        >>> from pyspark.sql.functions import to_timestamp
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [
        ...     ("A1", "Start", "2023-01-01 10:00:00"),
        ...     ("A1", "Start", "2023-01-01 10:00:00"),  # Duplicate
        ...     ("A1", "Process", "2023-01-01 11:00:00"),
        ...     ("A2", "Start", "2023-01-02 09:00:00"),
        ...     ("A2", "Start", "2023-01-02 09:00:00"),  # Duplicate
        ...     ("A2", "End", "2023-01-02 10:00:00")
        ... ]
        >>> df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
        >>> df = df.withColumn("EVENTTIME", to_timestamp("EVENTTIME"))
        >>> duplicates = find_duplicate_case_activity_eventtime(df)
        >>> duplicates.show()
        +---------+--------+-------------------+-----+
        |_CASE_KEY|ACTIVITY|          EVENTTIME|count|
        +---------+--------+-------------------+-----+
        |       A1|   Start|2023-01-01 10:00:00|    2|
        |       A2|   Start|2023-01-02 09:00:00|    2|
        +---------+--------+-------------------+-----+
    """
    # Validate column presence and types
    for column, expected_type in [
        (case_column, StringType),
        (activity_column, StringType),
        (eventtime_column, TimestampType),
    ]:
        if column not in df.columns:
            raise ValueError(f"Column '{column}' not found in the DataFrame.")
        if not isinstance(df.schema[column].dataType, expected_type):
            raise ValueError(
                f"Column '{column}' should be of type {expected_type.__name__}, but is {df.schema[column].dataType}"
            )

    # Identify duplicates
    window_spec = Window.partitionBy(case_column, activity_column, eventtime_column)

    duplicates_df = (
        df.withColumn("count", count("*").over(window_spec))
        .filter(col("count") > 1)
        .select(case_column, activity_column, eventtime_column, "count")
        .distinct()
        .orderBy(case_column, activity_column, eventtime_column)
    )

    return duplicates_df

In [None]:
from pyspark.sql import DataFrame, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col, when, lit, round
from typing import Union, List


def calculate_duration_from_start_to_target(
    df: DataFrame,
    target_activities: Union[str, List[str]],
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Calculate the time difference between the start activity and the first occurrence of any specified target activity for each case.
    Only returns cases where a target activity is found.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        target_activities (Union[str, List[str]]): The target activity or list of target activities to calculate duration to.
        case_column (str, optional): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str, optional): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str, optional): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: A DataFrame with the following columns:
            - case_column: The case identifier
            - start_activity: The name of the start activity
            - start_timestamp: The timestamp of the start activity
            - target_activity: The name of the first occurring target activity
            - target_timestamp: The timestamp of the first occurring target activity
            - duration_minutes: Duration in minutes (rounded to 2 decimal places)
            - duration_hours: Duration in hours (rounded to 2 decimal places)
            - duration_days: Duration in days (rounded to 2 decimal places)

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [
        ...     ("case1", "Start", "2023-01-01 10:00:00"),
        ...     ("case1", "Middle", "2023-01-02 11:00:00"),
        ...     ("case1", "End", "2023-01-03 12:00:00"),
        ...     ("case2", "Start", "2023-01-01 09:00:00"),
        ...     ("case2", "End", "2023-01-01 17:00:00"),
        ...     ("case3", "Start", "2023-01-01 08:00:00"),
        ...     ("case3", "Other", "2023-01-02 08:00:00")
        ... ]
        >>> df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
        >>> df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))
        >>> result = calculate_duration_from_start_to_target(df, ["Middle", "End"])
        >>> result.show(truncate=False)
    """
    # Convert target_activities to a list if it's a single string
    if isinstance(target_activities, str):
        target_activities = [target_activities]

    # Window specification for operations within each case
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Find the start activity and its timestamp for each case
    start_activity_df = (
        df.withColumn("row_number", F.row_number().over(case_window))
        .filter(col("row_number") == 1)
        .select(
            col(case_column),
            col(activity_column).alias("start_activity"),
            col(timestamp_column).alias("start_timestamp"),
        )
    )

    # Find the first occurrence of any target activity and its timestamp for each case
    target_activity_df = (
        df.filter(col(activity_column).isin(target_activities))
        .withColumn("row_number", F.row_number().over(case_window))
        .filter(col("row_number") == 1)
        .select(
            col(case_column),
            col(activity_column).alias("target_activity"),
            col(timestamp_column).alias("target_timestamp"),
        )
    )

    # Join the start activity and target activity dataframes
    result_df = start_activity_df.join(target_activity_df, case_column, "inner")

    # Calculate duration in seconds, then convert to minutes, hours, and days
    result_df = result_df.withColumn(
        "duration_seconds",
        F.unix_timestamp("target_timestamp") - F.unix_timestamp("start_timestamp"),
    )

    result_df = (
        result_df.withColumn("duration_minutes", round(col("duration_seconds") / 60, 2))
        .withColumn("duration_hours", round(col("duration_seconds") / 3600, 2))
        .withColumn("duration_days", round(col("duration_seconds") / 86400, 2))
        .drop("duration_seconds")
    )

    # Select and order the final columns
    final_columns = [
        case_column,
        "start_activity",
        "start_timestamp",
        "target_activity",
        "target_timestamp",
        "duration_minutes",
        "duration_hours",
        "duration_days",
    ]

    return result_df.select(final_columns)

In [None]:
from pyspark.sql import DataFrame, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, round
from typing import Union, List


def calculate_duration_between_activities(
    df: DataFrame,
    start_activities: Union[str, List[str]],
    target_activities: Union[str, List[str]],
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Calculate the time difference between the first occurrence of any specified start activity
    and the first occurrence of any specified target activity for each case.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        start_activities (Union[str, List[str]]): The start activity or list of start activities.
        target_activities (Union[str, List[str]]): The target activity or list of target activities.
        case_column (str, optional): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str, optional): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str, optional): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: A DataFrame with the following columns:
            - case_column: The case identifier
            - start_activity: The name of the first occurring start activity
            - start_timestamp: The timestamp of the first occurring start activity
            - target_activity: The name of the first occurring target activity
            - target_timestamp: The timestamp of the first occurring target activity
            - duration_minutes: Duration in minutes (rounded to 2 decimal places)
            - duration_hours: Duration in hours (rounded to 2 decimal places)
            - duration_days: Duration in days (rounded to 2 decimal places)

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [
        ...     ("case1", "Start", "2023-01-01 10:00:00"),
        ...     ("case1", "Middle", "2023-01-02 11:00:00"),
        ...     ("case1", "End", "2023-01-03 12:00:00"),
        ...     ("case2", "Begin", "2023-01-01 09:00:00"),
        ...     ("case2", "End", "2023-01-01 17:00:00"),
        ...     ("case3", "Start", "2023-01-01 08:00:00"),
        ...     ("case3", "Other", "2023-01-02 08:00:00")
        ... ]
        >>> df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
        >>> df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))
        >>> result = calculate_duration_between_activities(df, ["Start", "Begin"], ["Middle", "End"])
        >>> result.show(truncate=False)
    """
    # Convert start_activities and target_activities to lists if they're single strings
    if isinstance(start_activities, str):
        start_activities = [start_activities]
    if isinstance(target_activities, str):
        target_activities = [target_activities]

    # Window specification for operations within each case
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Find the first occurrence of any start activity and its timestamp for each case
    start_activity_df = (
        df.filter(col(activity_column).isin(start_activities))
        .withColumn("row_number", F.row_number().over(case_window))
        .filter(col("row_number") == 1)
        .select(
            col(case_column),
            col(activity_column).alias("start_activity"),
            col(timestamp_column).alias("start_timestamp"),
        )
    )

    # Find the first occurrence of any target activity and its timestamp for each case
    target_activity_df = (
        df.filter(col(activity_column).isin(target_activities))
        .withColumn("row_number", F.row_number().over(case_window))
        .filter(col("row_number") == 1)
        .select(
            col(case_column),
            col(activity_column).alias("target_activity"),
            col(timestamp_column).alias("target_timestamp"),
        )
    )

    # Join the start activity and target activity dataframes
    result_df = start_activity_df.join(target_activity_df, case_column, "inner")

    # Calculate duration in seconds, then convert to minutes, hours, and days
    result_df = result_df.withColumn(
        "duration_seconds",
        F.unix_timestamp("target_timestamp") - F.unix_timestamp("start_timestamp"),
    )

    result_df = (
        result_df.withColumn("duration_minutes", round(col("duration_seconds") / 60, 2))
        .withColumn("duration_hours", round(col("duration_seconds") / 3600, 2))
        .withColumn("duration_days", round(col("duration_seconds") / 86400, 2))
        .drop("duration_seconds")
    )

    # Select and order the final columns
    final_columns = [
        case_column,
        "start_activity",
        "start_timestamp",
        "target_activity",
        "target_timestamp",
        "duration_minutes",
        "duration_hours",
        "duration_days",
    ]

    return result_df.select(final_columns)


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00"),
        ("case1", "Middle", "2023-01-02 11:00:00"),
        ("case1", "End", "2023-01-03 12:00:00"),
        ("case2", "Begin", "2023-01-01 09:00:00"),
        ("case2", "End", "2023-01-01 17:00:00"),
        ("case3", "Start", "2023-01-01 08:00:00"),
        ("case3", "Other", "2023-01-02 08:00:00"),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    # Calculate duration from the first occurrence of either "Start" or "Begin" to the first occurrence of either "Middle" or "End"
    result = calculate_duration_between_activities(
        df, ["Start", "Begin"], ["Middle", "End"]
    )
    result.show(truncate=False)

    # Calculate duration from "Start" to "End" only
    result_start_to_end = calculate_duration_between_activities(df, "Start", "End")
    result_start_to_end.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame, Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, round, min as spark_min, max as spark_max
from typing import Union, List


def calculate_throughput_time(
    df: DataFrame,
    start_activities: Union[str, List[str]],
    end_activities: Union[str, List[str]],
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Calculate the throughput time between specified sets of start and end activities for each case.
    The throughput time includes all intermediate activities between the first occurrence of any start activity
    and the last occurrence of any end activity.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        start_activities (Union[str, List[str]]): The start activity or list of start activities.
        end_activities (Union[str, List[str]]): The end activity or list of end activities.
        case_column (str, optional): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str, optional): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str, optional): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: A DataFrame with the following columns:
            - case_column: The case identifier
            - first_start_activity: The name of the first occurring start activity
            - first_start_timestamp: The timestamp of the first occurring start activity
            - last_end_activity: The name of the last occurring end activity
            - last_end_timestamp: The timestamp of the last occurring end activity
            - throughput_time_minutes: Throughput time in minutes (rounded to 2 decimal places)
            - throughput_time_hours: Throughput time in hours (rounded to 2 decimal places)
            - throughput_time_days: Throughput time in days (rounded to 2 decimal places)
            - activity_count: The number of activities between start and end (inclusive)

    Example:
        >>> from pyspark.sql import SparkSession
        >>> spark = SparkSession.builder.getOrCreate()
        >>> data = [
        ...     ("case1", "Start", "2023-01-01 10:00:00"),
        ...     ("case1", "Middle1", "2023-01-02 11:00:00"),
        ...     ("case1", "Middle2", "2023-01-03 09:00:00"),
        ...     ("case1", "End", "2023-01-03 12:00:00"),
        ...     ("case2", "Begin", "2023-01-01 09:00:00"),
        ...     ("case2", "Process", "2023-01-01 14:00:00"),
        ...     ("case2", "Finish", "2023-01-01 17:00:00"),
        ...     ("case3", "Start", "2023-01-01 08:00:00"),
        ...     ("case3", "Other", "2023-01-02 08:00:00")
        ... ]
        >>> df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
        >>> df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))
        >>> result = calculate_throughput_time(df, ["Start", "Begin"], ["End", "Finish"])
        >>> result.show(truncate=False)
    """
    # Convert start_activities and end_activities to lists if they're single strings
    if isinstance(start_activities, str):
        start_activities = [start_activities]
    if isinstance(end_activities, str):
        end_activities = [end_activities]

    # Window specification for operations within each case
    case_window = Window.partitionBy(case_column)

    # Find the first occurrence of any start activity and its timestamp for each case
    start_df = df.filter(col(activity_column).isin(start_activities))
    first_start = start_df.groupBy(case_column).agg(
        spark_min(timestamp_column).alias("first_start_timestamp"),
        F.first(activity_column).alias("first_start_activity"),
    )

    # Find the last occurrence of any end activity and its timestamp for each case
    end_df = df.filter(col(activity_column).isin(end_activities))
    last_end = end_df.groupBy(case_column).agg(
        spark_max(timestamp_column).alias("last_end_timestamp"),
        F.last(activity_column).alias("last_end_activity"),
    )

    # Join the start and end dataframes
    result_df = first_start.join(last_end, case_column, "inner")

    # Calculate throughput time in seconds, then convert to minutes, hours, and days
    result_df = result_df.withColumn(
        "throughput_time_seconds",
        F.unix_timestamp("last_end_timestamp")
        - F.unix_timestamp("first_start_timestamp"),
    )

    result_df = (
        result_df.withColumn(
            "throughput_time_minutes", round(col("throughput_time_seconds") / 60, 2)
        )
        .withColumn(
            "throughput_time_hours", round(col("throughput_time_seconds") / 3600, 2)
        )
        .withColumn(
            "throughput_time_days", round(col("throughput_time_seconds") / 86400, 2)
        )
        .drop("throughput_time_seconds")
    )

    # Calculate the number of activities between start and end (inclusive)
    activity_count = (
        df.filter(
            (col(timestamp_column) >= col("first_start_timestamp"))
            & (col(timestamp_column) <= col("last_end_timestamp"))
        )
        .groupBy(case_column)
        .agg(F.count("*").alias("activity_count"))
    )

    # Join the activity count to the result dataframe
    result_df = result_df.join(activity_count, case_column, "left_outer")

    # Select and order the final columns
    final_columns = [
        case_column,
        "first_start_activity",
        "first_start_timestamp",
        "last_end_activity",
        "last_end_timestamp",
        "throughput_time_minutes",
        "throughput_time_hours",
        "throughput_time_days",
        "activity_count",
    ]

    return result_df.select(final_columns)


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00"),
        ("case1", "Middle1", "2023-01-02 11:00:00"),
        ("case1", "Middle2", "2023-01-03 09:00:00"),
        ("case1", "End", "2023-01-03 12:00:00"),
        ("case2", "Begin", "2023-01-01 09:00:00"),
        ("case2", "Process", "2023-01-01 14:00:00"),
        ("case2", "Finish", "2023-01-01 17:00:00"),
        ("case3", "Start", "2023-01-01 08:00:00"),
        ("case3", "Other", "2023-01-02 08:00:00"),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    # Calculate throughput time from the first occurrence of either "Start" or "Begin" to the last occurrence of either "End" or "Finish"
    result = calculate_throughput_time(df, ["Start", "Begin"], ["End", "Finish"])
    result.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import (
    col,
    lower,
    first,
    last,
    collect_set,
    array_contains,
    lit,
)
from pyspark.sql.window import Window
from typing import Union, List, Optional


def filter_process_cases(
    df: DataFrame,
    case_flows_through: Optional[Union[str, List[str]]] = None,
    case_does_not_flow_through: Optional[Union[str, List[str]]] = None,
    case_starts_with: Optional[Union[str, List[str]]] = None,
    case_ends_with: Optional[Union[str, List[str]]] = None,
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Highly optimized function to filter process cases based on specified flow conditions for large datasets (case-insensitive).

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        case_flows_through (Optional[Union[str, List[str]]]): Activity or list of activities that cases must flow through.
        case_does_not_flow_through (Optional[Union[str, List[str]]]): Activity or list of activities that cases must not flow through.
        case_starts_with (Optional[Union[str, List[str]]]): Activity or list of activities that cases must start with.
        case_ends_with (Optional[Union[str, List[str]]]): Activity or list of activities that cases must end with.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: Filtered DataFrame containing only the cases that meet all specified conditions.

    Example:
        >>> filtered_df = filter_process_cases(
        ...     df,
        ...     case_flows_through=["Middle", "Review"],
        ...     case_does_not_flow_through="Reject",
        ...     case_starts_with="Start",
        ...     case_ends_with=["End", "Complete"]
        ... )
    """

    # Helper function to convert input to lowercase list
    def to_lower_list(x):
        if x is None:
            return []
        return [s.lower() for s in (x if isinstance(x, list) else [x])]

    # Convert inputs to lowercase lists
    flows_through = to_lower_list(case_flows_through)
    not_flows_through = to_lower_list(case_does_not_flow_through)
    starts_with = to_lower_list(case_starts_with)
    ends_with = to_lower_list(case_ends_with)

    # If no filtering conditions are provided, return the original DataFrame
    if not any([flows_through, not_flows_through, starts_with, ends_with]):
        return df

    # Create a window spec for each case
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Precompute case summary
    case_summary = df.groupBy(case_column).agg(
        collect_set(lower(col(activity_column))).alias("activities"),
        first(lower(col(activity_column))).alias("first_activity"),
        last(lower(col(activity_column))).alias("last_activity"),
    )

    # Apply filters
    if flows_through:
        case_summary = case_summary.filter(
            all(
                array_contains(col("activities"), lit(activity))
                for activity in flows_through
            )
        )

    if not_flows_through:
        case_summary = case_summary.filter(
            ~array_contains(col("activities"), lit(activity))
            for activity in not_flows_through
        )

    if starts_with:
        case_summary = case_summary.filter(col("first_activity").isin(starts_with))

    if ends_with:
        case_summary = case_summary.filter(col("last_activity").isin(ends_with))

    # Get the list of cases that meet all conditions
    valid_cases = case_summary.select(case_column)

    # Filter the original DataFrame
    return df.join(valid_cases, case_column, "inner")


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data (you would replace this with your large dataset)
    data = [
        ("case1", "Start", "2023-01-01 10:00:00"),
        ("case1", "Middle", "2023-01-02 11:00:00"),
        ("case1", "End", "2023-01-03 12:00:00"),
        ("case2", "start", "2023-01-01 09:00:00"),
        ("case2", "Other", "2023-01-01 10:00:00"),
        ("case2", "END", "2023-01-01 17:00:00"),
        ("case3", "Begin", "2023-01-01 08:00:00"),
        ("case3", "middle", "2023-01-02 08:00:00"),
        ("case3", "Finish", "2023-01-03 08:00:00"),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    filtered_df = filter_process_cases(
        df,
        case_flows_through="Middle",
        case_does_not_flow_through="Other",
        case_starts_with=["Start", "Begin"],
        case_ends_with=["End", "Finish"],
    )

    print("Original DataFrame:")
    df.show(truncate=False)
    print("\nFiltered DataFrame:")
    filtered_df.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lower, lead, when, lit, collect_list
from pyspark.sql.window import Window
from typing import Literal


def filter_process_flow(
    df: DataFrame,
    first_activity: str,
    second_activity: str,
    relationship: Literal[
        "directly_followed",
        "followed_anytime_by",
        "not_directly_followed",
        "never_followed_by",
    ],
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Filter the process flow based on the relationship between two specified activities.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        first_activity (str): The first activity in the relationship.
        second_activity (str): The second activity in the relationship.
        relationship (Literal["directly_followed", "followed_anytime_by", "not_directly_followed", "never_followed_by"]):
            The type of relationship between the two activities.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: Filtered DataFrame containing only the cases that meet the specified relationship.

    Raises:
        ValueError: If an invalid relationship is provided.

    Example:
        >>> filtered_df = filter_process_flow(
        ...     df,
        ...     first_activity="Start",
        ...     second_activity="Review",
        ...     relationship="directly_followed"
        ... )
    """
    # Validate the relationship parameter
    valid_relationships = [
        "directly_followed",
        "followed_anytime_by",
        "not_directly_followed",
        "never_followed_by",
    ]
    if relationship not in valid_relationships:
        raise ValueError(
            f"Invalid relationship. Must be one of: {', '.join(valid_relationships)}"
        )

    # Create a window spec for each case, ordered by timestamp
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Convert activities to lowercase for case-insensitive matching
    df = df.withColumn(activity_column, lower(col(activity_column)))
    first_activity = first_activity.lower()
    second_activity = second_activity.lower()

    if relationship == "directly_followed":
        # Add a column with the next activity
        df_with_next = df.withColumn(
            "next_activity", lead(activity_column).over(case_window)
        )

        # Filter cases where first_activity is directly followed by second_activity
        filtered_cases = (
            df_with_next.filter(
                (col(activity_column) == first_activity)
                & (col("next_activity") == second_activity)
            )
            .select(case_column)
            .distinct()
        )

    elif relationship == "followed_anytime_by":
        # Collect all activities for each case
        case_activities = df.groupBy(case_column).agg(
            collect_list(activity_column).alias("activities")
        )

        # Filter cases where first_activity appears before second_activity
        filtered_cases = case_activities.filter(
            (array_contains(col("activities"), first_activity))
            & (array_contains(col("activities"), second_activity))
            & (
                col("activities").getItem(
                    array_position(col("activities"), first_activity)
                )
                < col("activities").getItem(
                    array_position(col("activities"), second_activity)
                )
            )
        ).select(case_column)

    elif relationship == "not_directly_followed":
        # Add a column with the next activity
        df_with_next = df.withColumn(
            "next_activity", lead(activity_column).over(case_window)
        )

        # Filter cases where first_activity is not directly followed by second_activity
        filtered_cases = (
            df_with_next.filter(
                (col(activity_column) == first_activity)
                & (col("next_activity") != second_activity)
            )
            .select(case_column)
            .distinct()
        )

    else:  # never_followed_by
        # Collect all activities for each case
        case_activities = df.groupBy(case_column).agg(
            collect_list(activity_column).alias("activities")
        )

        # Filter cases where first_activity appears but is never followed by second_activity
        filtered_cases = case_activities.filter(
            (array_contains(col("activities"), first_activity))
            & (
                ~array_contains(col("activities"), second_activity)
                | (
                    col("activities").getItem(
                        array_position(col("activities"), first_activity)
                    )
                    > col("activities").getItem(
                        array_position(col("activities"), second_activity)
                    )
                )
            )
        ).select(case_column)

    # Join the filtered cases back to the original DataFrame
    return df.join(filtered_cases, case_column, "inner")


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00"),
        ("case1", "Middle", "2023-01-02 11:00:00"),
        ("case1", "End", "2023-01-03 12:00:00"),
        ("case2", "Start", "2023-01-01 09:00:00"),
        ("case2", "Review", "2023-01-01 10:00:00"),
        ("case2", "End", "2023-01-01 17:00:00"),
        ("case3", "Start", "2023-01-01 08:00:00"),
        ("case3", "Middle", "2023-01-02 08:00:00"),
        ("case3", "Review", "2023-01-03 08:00:00"),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    relationships = [
        "directly_followed",
        "followed_anytime_by",
        "not_directly_followed",
        "never_followed_by",
    ]

    for rel in relationships:
        filtered_df = filter_process_flow(df, "Start", "Review", rel)
        print(f"\nFiltered DataFrame ({rel}):")
        filtered_df.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lower, count, when
from pyspark.sql.window import Window
from typing import Tuple, Optional


def filter_rework_cases(
    df: DataFrame,
    activity: str,
    occurs_between: Optional[Tuple[int, int]] = None,
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Filter cases based on the number of occurrences of a specific activity within a given range.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        activity (str): The activity to check for rework.
        occurs_between (Optional[Tuple[int, int]]): The range of occurrences to filter for.
                                                    Defaults to None, which is treated as (0, float('inf')).
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: Filtered DataFrame containing only the cases that meet the specified rework criteria.

    Example:
        >>> filtered_df = filter_rework_cases(
        ...     df,
        ...     activity="Review",
        ...     occurs_between=(2, 4)
        ... )
    """
    # Set default range if not provided
    if occurs_between is None:
        occurs_between = (0, float("inf"))

    min_occurrences, max_occurrences = occurs_between

    # Validate input
    if min_occurrences < 0 or max_occurrences < min_occurrences:
        raise ValueError("Invalid occurrence range. Ensure min >= 0 and max >= min.")

    # Convert activity to lowercase for case-insensitive matching
    activity = activity.lower()

    # Count occurrences of the activity for each case
    activity_counts = df.groupBy(case_column).agg(
        count(when(lower(col(activity_column)) == activity, True)).alias(
            "activity_count"
        )
    )

    # Filter based on the occurrence range
    if max_occurrences == float("inf"):
        filtered_cases = activity_counts.filter(
            col("activity_count") >= min_occurrences
        )
    else:
        filtered_cases = activity_counts.filter(
            (col("activity_count") >= min_occurrences)
            & (col("activity_count") <= max_occurrences)
        )

    # Handle the special case of (0,1)
    if min_occurrences == 0 and max_occurrences == 1:
        filtered_cases = activity_counts.filter(col("activity_count") <= 1)

    # Join the filtered cases back to the original DataFrame
    return df.join(filtered_cases.select(case_column), case_column, "inner")


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00"),
        ("case1", "Review", "2023-01-02 11:00:00"),
        ("case1", "Review", "2023-01-03 12:00:00"),
        ("case2", "Start", "2023-01-01 09:00:00"),
        ("case2", "Review", "2023-01-01 10:00:00"),
        ("case2", "End", "2023-01-01 17:00:00"),
        ("case3", "Start", "2023-01-01 08:00:00"),
        ("case3", "Review", "2023-01-02 08:00:00"),
        ("case3", "Review", "2023-01-03 08:00:00"),
        ("case3", "Review", "2023-01-04 08:00:00"),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    # Test different occurrence ranges
    ranges = [(None), (0, 1), (1, 2), (2, float("inf"))]

    for range_ in ranges:
        filtered_df = filter_rework_cases(df, "Review", range_)
        range_str = f"({range_[0]}, {range_[1]})" if range_ else "Default (0, inf)"
        print(f"\nFiltered DataFrame (Review occurs between {range_str}):")
        filtered_df.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, first, last, when, lit, sum, isnull
from pyspark.sql.window import Window
from typing import Union


def filter_attribute_cases(
    df: DataFrame,
    attribute_column: str,
    attribute_value: Union[str, int, float],
    once_occurred: bool = False,
    first_occurrence: bool = False,
    last_occurrence: bool = False,
    case_column: str = "_CASE_KEY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Filter cases based on the occurrence of a specific attribute value in the process,
    ignoring null values when considering first or last occurrences.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        attribute_column (str): Name of the column containing the attribute to filter on.
        attribute_value (Union[str, int, float]): The value of the attribute to filter for.
        once_occurred (bool): If True, select cases where the attribute occurred at least once. Defaults to False.
        first_occurrence (bool): If True, select cases where the attribute occurred first (ignoring nulls). Defaults to False.
        last_occurrence (bool): If True, select cases where the attribute occurred last (ignoring nulls). Defaults to False.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: Filtered DataFrame containing only the cases that meet the specified attribute criteria.

    Raises:
        ValueError: If no occurrence type is specified or if the attribute column doesn't exist.

    Example:
        >>> filtered_df = filter_attribute_cases(
        ...     df,
        ...     attribute_column="PRIORITY",
        ...     attribute_value="High",
        ...     first_occurrence=True
        ... )
    """
    # Validate input
    if not any([once_occurred, first_occurrence, last_occurrence]):
        raise ValueError(
            "At least one occurrence type (once_occurred, first_occurrence, or last_occurrence) must be True."
        )

    if attribute_column not in df.columns:
        raise ValueError(
            f"Column '{attribute_column}' does not exist in the DataFrame."
        )

    # Create a window spec for each case, ordered by timestamp
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Create a flag column for rows where the attribute matches the specified value
    df_with_flag = df.withColumn(
        "attribute_flag",
        when(
            (col(attribute_column) == attribute_value)
            & (~isnull(col(attribute_column))),
            1,
        ).otherwise(0),
    )

    # Perform filtering based on occurrence type
    filtered_cases = None

    if once_occurred:
        # Select cases where the attribute occurred at least once
        filtered_cases = (
            df_with_flag.groupBy(case_column)
            .agg((sum("attribute_flag") > 0).alias("occurred"))
            .filter(col("occurred"))
            .select(case_column)
        )

    if first_occurrence:
        # Find the first non-null occurrence of the attribute for each case
        first_occurrences = (
            df_with_flag.withColumn(
                "first_non_null",
                first("attribute_flag", ignorenulls=True).over(case_window),
            )
            .filter(col("first_non_null") == 1)
            .select(case_column)
            .distinct()
        )

        filtered_cases = (
            first_occurrences
            if filtered_cases is None
            else filtered_cases.union(first_occurrences)
        )

    if last_occurrence:
        # Find the last non-null occurrence of the attribute for each case
        last_occurrences = (
            df_with_flag.withColumn(
                "last_non_null",
                last("attribute_flag", ignorenulls=True).over(case_window),
            )
            .filter(col("last_non_null") == 1)
            .select(case_column)
            .distinct()
        )

        filtered_cases = (
            last_occurrences
            if filtered_cases is None
            else filtered_cases.union(last_occurrences)
        )

    # Remove duplicates in case multiple conditions were applied
    filtered_cases = filtered_cases.distinct()

    # Join the filtered cases back to the original DataFrame
    return df.join(filtered_cases, case_column, "inner")


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data with null values
    data = [
        ("case1", "Start", "2023-01-01 10:00:00", None),
        ("case1", "Middle", "2023-01-02 11:00:00", "High"),
        ("case1", "End", "2023-01-03 12:00:00", "Low"),
        ("case2", "Start", "2023-01-01 09:00:00", "Medium"),
        ("case2", "Middle", "2023-01-01 10:00:00", "High"),
        ("case2", "End", "2023-01-01 17:00:00", "High"),
        ("case3", "Start", "2023-01-01 08:00:00", None),
        ("case3", "Middle", "2023-01-02 08:00:00", None),
        ("case3", "End", "2023-01-03 08:00:00", "High"),
        ("case4", "Start", "2023-01-01 08:00:00", None),
        ("case4", "Middle", "2023-01-02 08:00:00", "Medium"),
        ("case4", "End", "2023-01-03 08:00:00", None),
    ]
    df = spark.createDataFrame(data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME", "PRIORITY"])
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    # Test different occurrence types
    occurrence_types = [
        {"once_occurred": True},
        {"first_occurrence": True},
        {"last_occurrence": True},
        {"once_occurred": True, "first_occurrence": True, "last_occurrence": True},
    ]

    for occurrence in occurrence_types:
        filtered_df = filter_attribute_cases(
            df, attribute_column="PRIORITY", attribute_value="High", **occurrence
        )
        print(f"\nFiltered DataFrame (PRIORITY = 'High', {occurrence}):")
        filtered_df.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, first, last, when
from pyspark.sql.window import Window
from typing import Union, List


def standardize_attribute_values(
    df: DataFrame,
    attribute_columns: Union[str, List[str]],
    use_first_occurrence: bool = True,
    case_column: str = "_CASE_KEY",
    timestamp_column: str = "EVENTTIME",
) -> DataFrame:
    """
    Standardize attribute values for each case based on the first or last non-null occurrence.
    Can handle either a single attribute or multiple attributes.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        attribute_columns (Union[str, List[str]]): Name of the column(s) containing the attribute(s) to standardize.
        use_first_occurrence (bool): If True, use the first non-null occurrence; if False, use the last. Defaults to True.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".

    Returns:
        DataFrame: DataFrame with standardized attribute values for each case.

    Raises:
        ValueError: If any of the attribute columns don't exist in the DataFrame.

    Example:
        >>> # For a single attribute
        >>> standardized_df = standardize_attribute_values(
        ...     df,
        ...     attribute_columns="PRIORITY",
        ...     use_first_occurrence=True
        ... )
        >>> # For multiple attributes
        >>> standardized_df = standardize_attribute_values(
        ...     df,
        ...     attribute_columns=["PRIORITY", "CATEGORY", "DEPARTMENT"],
        ...     use_first_occurrence=False
        ... )
    """
    # Convert single attribute to list for uniform processing
    if isinstance(attribute_columns, str):
        attribute_columns = [attribute_columns]

    # Validate input
    missing_columns = [col for col in attribute_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(
            f"The following columns do not exist in the DataFrame: {', '.join(missing_columns)}"
        )

    # Create a window spec for each case, ordered by timestamp
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Function to get standardized value
    def get_standardized_value(column):
        if use_first_occurrence:
            return first(col(column), ignorenulls=True).over(case_window)
        else:
            return last(col(column), ignorenulls=True).over(case_window)

    # Create standardized columns
    for attribute_column in attribute_columns:
        standardized_value = get_standardized_value(attribute_column)
        df = df.withColumn(f"standardized_{attribute_column}", standardized_value)

    # Replace original attribute columns with standardized values
    for attribute_column in attribute_columns:
        df = df.withColumn(
            attribute_column,
            when(
                col(f"standardized_{attribute_column}").isNotNull(),
                col(f"standardized_{attribute_column}"),
            ).otherwise(col(attribute_column)),
        ).drop(f"standardized_{attribute_column}")

    return df


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data with null values and multiple attributes
    data = [
        ("case1", "Start", "2023-01-01 10:00:00", None, "Dept1", None),
        ("case1", "Middle", "2023-01-02 11:00:00", "High", "Dept2", "Cat1"),
        ("case1", "End", "2023-01-03 12:00:00", "Low", "Dept1", "Cat2"),
        ("case2", "Start", "2023-01-01 09:00:00", "Medium", "Dept3", "Cat1"),
        ("case2", "Middle", "2023-01-01 10:00:00", "High", "Dept3", "Cat2"),
        ("case2", "End", "2023-01-01 17:00:00", "High", "Dept2", "Cat2"),
        ("case3", "Start", "2023-01-01 08:00:00", None, None, "Cat3"),
        ("case3", "Middle", "2023-01-02 08:00:00", None, "Dept1", None),
        ("case3", "End", "2023-01-03 08:00:00", "High", "Dept2", "Cat1"),
    ]
    df = spark.createDataFrame(
        data,
        ["_CASE_KEY", "ACTIVITY", "EVENTTIME", "PRIORITY", "DEPARTMENT", "CATEGORY"],
    )
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    # Standardize a single attribute based on first occurrence
    df_single = standardize_attribute_values(df, "PRIORITY", use_first_occurrence=True)
    print("\nStandardized DataFrame (Single Attribute, First Occurrence):")
    df_single.show(truncate=False)

    # Standardize multiple attributes based on last occurrence
    df_multiple = standardize_attribute_values(
        df, ["PRIORITY", "DEPARTMENT", "CATEGORY"], use_first_occurrence=False
    )
    print("\nStandardized DataFrame (Multiple Attributes, Last Occurrence):")
    df_multiple.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import (
    col,
    min,
    max,
    avg,
    percentile_approx,
    collect_set,
    round as spark_round,
)
from pyspark.sql.types import DoubleType
from typing import Union, List, Optional, Literal


def calculate_case_duration_stats(
    df: DataFrame,
    case_column: str = "_CASE_KEY",
    timestamp_column: str = "EVENTTIME",
    attribute_columns: Optional[Union[str, List[str]]] = None,
    percentiles: List[float] = [0.25, 0.5, 0.75, 0.9],
    time_unit: Literal["seconds", "minutes", "hours", "days"] = "days",
) -> DataFrame:
    """
    Calculate case duration statistics, optionally aggregated by specified attributes, in the specified time unit.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".
        attribute_columns (Optional[Union[str, List[str]]]): Column(s) to aggregate by. Can be None, a single column name, or a list of column names.
        percentiles (List[float]): List of percentiles to calculate. Defaults to [0.25, 0.5, 0.75, 0.9].
        time_unit (Literal["seconds", "minutes", "hours", "days"]): The time unit for duration calculations. Defaults to "days".

    Returns:
        DataFrame: DataFrame with case duration statistics, optionally aggregated by attributes, in the specified time unit.

    Raises:
        ValueError: If any specified attribute column doesn't exist in the DataFrame or if an invalid time unit is provided.

    Example:
        >>> stats_df = calculate_case_duration_stats(
        ...     df,
        ...     attribute_columns=["PRIORITY", "DEPARTMENT"],
        ...     percentiles=[0.5, 0.75, 0.9],
        ...     time_unit="hours"
        ... )
    """
    # Validate input
    if attribute_columns:
        if isinstance(attribute_columns, str):
            attribute_columns = [attribute_columns]
        missing_columns = [col for col in attribute_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"The following columns do not exist in the DataFrame: {', '.join(missing_columns)}"
            )

    # Validate and set time unit conversion factor
    time_unit_factors = {
        "seconds": 1,
        "minutes": 1 / 60,
        "hours": 1 / 3600,
        "days": 1 / 86400,
    }
    if time_unit not in time_unit_factors:
        raise ValueError(
            f"Invalid time unit. Must be one of: {', '.join(time_unit_factors.keys())}"
        )
    time_factor = time_unit_factors[time_unit]

    # Calculate case durations
    case_durations = df.groupBy(case_column).agg(
        (
            (
                max(col(timestamp_column)).cast("long")
                - min(col(timestamp_column)).cast("long")
            )
            * time_factor
        ).alias(f"duration_{time_unit}")
    )

    # Prepare aggregation columns
    agg_columns = []
    if attribute_columns:
        for attr in attribute_columns:
            agg_columns.append(collect_set(attr).alias(attr))

    duration_col = f"duration_{time_unit}"
    agg_columns.extend(
        [
            avg(duration_col).alias(f"avg_duration_{time_unit}"),
            min(duration_col).alias(f"min_duration_{time_unit}"),
            max(duration_col).alias(f"max_duration_{time_unit}"),
        ]
    )

    for p in percentiles:
        agg_columns.append(
            percentile_approx(duration_col, p).alias(
                f"p{int(p*100)}_duration_{time_unit}"
            )
        )

    # Perform aggregation
    if attribute_columns:
        result = case_durations.groupBy(*attribute_columns).agg(*agg_columns)
    else:
        result = case_durations.agg(*agg_columns)

    # Round duration values to 2 decimal places
    duration_columns = [
        f"avg_duration_{time_unit}",
        f"min_duration_{time_unit}",
        f"max_duration_{time_unit}",
    ] + [f"p{int(p*100)}_duration_{time_unit}" for p in percentiles]
    for col_name in duration_columns:
        result = result.withColumn(col_name, spark_round(col(col_name), 2))

    return result


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00", "High", "Dept1"),
        ("case1", "Middle", "2023-01-02 11:00:00", "High", "Dept2"),
        ("case1", "End", "2023-01-03 12:00:00", "High", "Dept1"),
        ("case2", "Start", "2023-01-01 09:00:00", "Medium", "Dept3"),
        ("case2", "Middle", "2023-01-01 10:00:00", "High", "Dept3"),
        ("case2", "End", "2023-01-01 17:00:00", "High", "Dept2"),
        ("case3", "Start", "2023-01-01 08:00:00", "Low", "Dept1"),
        ("case3", "Middle", "2023-01-02 08:00:00", "Medium", "Dept1"),
        ("case3", "End", "2023-01-05 08:00:00", "High", "Dept2"),
    ]
    df = spark.createDataFrame(
        data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME", "PRIORITY", "DEPARTMENT"]
    )
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    # Calculate overall statistics in different time units
    for unit in ["seconds", "minutes", "hours", "days"]:
        stats = calculate_case_duration_stats(df, time_unit=unit)
        print(f"\nOverall Case Duration Statistics (in {unit}):")
        stats.show(truncate=False)

    # Calculate statistics aggregated by PRIORITY in hours
    priority_stats = calculate_case_duration_stats(
        df, attribute_columns="PRIORITY", time_unit="hours"
    )
    print("\nCase Duration Statistics by PRIORITY (in hours):")
    priority_stats.show(truncate=False)

    # Calculate statistics aggregated by multiple attributes in days
    multi_attr_stats = calculate_case_duration_stats(
        df, attribute_columns=["PRIORITY", "DEPARTMENT"], time_unit="days"
    )
    print("\nCase Duration Statistics by PRIORITY and DEPARTMENT (in days):")
    multi_attr_stats.show(truncate=False)

In [None]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import (
    col,
    concat_ws,
    count,
    desc,
    percent_rank,
    collect_list,
    struct,
    lit,
)
from pyspark.sql.window import Window
from typing import List, Optional


def process_variant_analysis(
    df: DataFrame,
    case_column: str = "_CASE_KEY",
    activity_column: str = "ACTIVITY",
    timestamp_column: str = "EVENTTIME",
    attribute_columns: Optional[List[str]] = None,
    top_n_variants: int = 10,
) -> DataFrame:
    """
    Perform process variant analysis, including unique variants count, frequency, and "happy path" identification.

    Args:
        df (DataFrame): Input DataFrame with process mining data.
        case_column (str): Name of the case column. Defaults to "_CASE_KEY".
        activity_column (str): Name of the activity column. Defaults to "ACTIVITY".
        timestamp_column (str): Name of the timestamp column. Defaults to "EVENTTIME".
        attribute_columns (Optional[List[str]]): List of attribute columns to distinguish variants. Defaults to None.
        top_n_variants (int): Number of top variants to include in the detailed output. Defaults to 10.

    Returns:
        DataFrame: DataFrame with process variant analysis results.

    Example:
        >>> variant_analysis = process_variant_analysis(
        ...     df,
        ...     attribute_columns=["PRIORITY", "DEPARTMENT"],
        ...     top_n_variants=5
        ... )
    """
    # Validate input
    if attribute_columns:
        missing_columns = [col for col in attribute_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(
                f"The following columns do not exist in the DataFrame: {', '.join(missing_columns)}"
            )

    # Create a window spec for ordering activities within each case
    case_window = Window.partitionBy(case_column).orderBy(timestamp_column)

    # Collect activities (and attributes if specified) for each case
    if attribute_columns:
        collect_columns = [activity_column] + attribute_columns
        df_variants = df.withColumn(
            "variant",
            concat_ws(
                "->",
                collect_list(concat_ws(":", *[col(c) for c in collect_columns])).over(
                    case_window
                ),
            ),
        )
    else:
        df_variants = df.withColumn(
            "variant", concat_ws("->", collect_list(activity_column).over(case_window))
        )

    # Get unique variants per case
    df_unique_variants = (
        df_variants.groupBy(case_column)
        .agg(collect_list("variant").alias("variant_list"))
        .select(case_column, col("variant_list").getItem(0).alias("variant"))
    )

    # Calculate variant frequencies
    variant_freq = (
        df_unique_variants.groupBy("variant")
        .agg(count("*").alias("frequency"))
        .orderBy(desc("frequency"))
    )

    # Calculate total number of cases and identify the "happy path"
    total_cases = df_unique_variants.count()
    happy_path = variant_freq.first()["variant"]
    happy_path_frequency = variant_freq.first()["frequency"]
    happy_path_percentage = (happy_path_frequency / total_cases) * 100

    # Calculate percentage for each variant
    variant_freq = variant_freq.withColumn(
        "percentage", (col("frequency") / total_cases) * 100
    )

    # Add rank to variants
    variant_freq = variant_freq.withColumn(
        "rank", percent_rank().over(Window.orderBy(desc("frequency")))
    )

    # Prepare the summary DataFrame
    summary = spark.createDataFrame(
        [
            ("Total Variants", variant_freq.count()),
            ("Total Cases", total_cases),
            ("Happy Path Frequency", happy_path_frequency),
            ("Happy Path Percentage", happy_path_percentage),
        ],
        ["Metric", "Value"],
    )

    # Prepare the top N variants DataFrame
    top_variants = variant_freq.orderBy(desc("frequency")).limit(top_n_variants)

    # Combine summary and top variants into a single DataFrame
    result = summary.select(
        lit("Summary").alias("category"),
        col("Metric").alias("key"),
        col("Value").alias("value"),
    ).union(
        top_variants.select(
            lit("Top Variants").alias("category"),
            col("variant").alias("key"),
            struct(col("frequency"), col("percentage")).alias("value"),
        )
    )

    return result


# Usage example
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F

    spark = SparkSession.builder.getOrCreate()

    # Sample data
    data = [
        ("case1", "Start", "2023-01-01 10:00:00", "High", "Dept1"),
        ("case1", "Middle", "2023-01-02 11:00:00", "High", "Dept2"),
        ("case1", "End", "2023-01-03 12:00:00", "High", "Dept1"),
        ("case2", "Start", "2023-01-01 09:00:00", "Medium", "Dept3"),
        ("case2", "Middle", "2023-01-01 10:00:00", "High", "Dept3"),
        ("case2", "End", "2023-01-01 17:00:00", "High", "Dept2"),
        ("case3", "Start", "2023-01-01 08:00:00", "Low", "Dept1"),
        ("case3", "Middle", "2023-01-02 08:00:00", "Medium", "Dept1"),
        ("case3", "End", "2023-01-05 08:00:00", "High", "Dept2"),
        ("case4", "Start", "2023-01-01 10:00:00", "High", "Dept1"),
        ("case4", "Middle", "2023-01-02 11:00:00", "High", "Dept2"),
        ("case4", "End", "2023-01-03 12:00:00", "High", "Dept1"),
    ]
    df = spark.createDataFrame(
        data, ["_CASE_KEY", "ACTIVITY", "EVENTTIME", "PRIORITY", "DEPARTMENT"]
    )
    df = df.withColumn("EVENTTIME", F.to_timestamp("EVENTTIME"))

    print("Original DataFrame:")
    df.show(truncate=False)

    # Perform process variant analysis without attributes
    variant_analysis = process_variant_analysis(df)
    print("\nProcess Variant Analysis (without attributes):")
    variant_analysis.show(truncate=False)

    # Perform process variant analysis with attributes
    variant_analysis_with_attrs = process_variant_analysis(
        df, attribute_columns=["PRIORITY", "DEPARTMENT"]
    )
    print("\nProcess Variant Analysis (with PRIORITY and DEPARTMENT attributes):")
    variant_analysis_with_attrs.show(truncate=False)