In [0]:
%run ./config

In [0]:
from py4j.protocol import Py4JJavaError
import re
from typing import Dict, List, Tuple
from pathlib import Path
from pyspark.sql import DataFrame, Window
from pyspark.sql import functions as f
import pyspark.sql.types as t

# f_oneway() function takes the group data as input and returns F-statistic and P-value
from scipy.stats import f_oneway

In [0]:
# Create a table from the DataFrame.
# This table will persist across cluster restarts as well as allow different notebooks to query this data.
def save_table(*, df: DataFrame, file_path: str) -> None:
    """The function save a delta table in the process file path
    :param df: A pyspark.sql.dataframe.DataFrame object to run the function on
    :param file_path: The file path.
    :return: None - No returned value.
    """
    table_name = Path(file_path).stem
    file_type = Path(file_path).suffix.split(".")[-1]
    # Remove table if it exists
    import boto3

    s3 = boto3.resource(
        "s3", aws_access_key_id=ACCESS_KEY, aws_secret_access_key=SECRET_KEY
    )
    bucket = s3.Bucket(S3_BUCKET)
    prefix = file_path.split(f"/{S3_BUCKET}")[-1].strip("/")
    bucket.objects.filter(Prefix=prefix).delete()
    # Save table
    try:
        df.write.format(file_type).option("mergeSchema", "true").mode("overwrite").save(file_path)
        print(f"The table: {table_name} was saved.   File path: {file_path}")
    except Py4JJavaError as err:
        print(f"Py4JJavaError. Could not save the table: {table_name}\n{err}")
    return None

In [0]:
def count_by_col(*, df: DataFrame, col_name: str) -> DataFrame:
    """The function returns a DataFrame that counts the label_col, and calculate the frequency in percentage and the comulative percentage.
    :param df: A pyspark.sql.dataframe.DataFrame object to run the function on
    :param label_col: The column to run calaculation on
    :return: pyspark.sql.dataframe.DataFrame object of the column and it's calculation.
    """
    total = df.count()
    counted_df = (
        df.groupBy(col_name)
        .count()
        .withColumn("Percentage (%)", (f.col("count") / total * 100))
        .withColumn("Percentage (%)", (f.round(f.col("Percentage (%)"), 2)))
        .withColumn(
            "Comulative pct (%)",
            f.sum("Percentage (%)").over(
                Window.orderBy(f.col("Percentage (%)").desc())
            ),
        )
        .withColumn("Comulative pct (%)", (f.round(f.col("Comulative pct (%)"), 2)))
        .orderBy("count", ascending=False)
    )
    return counted_df

In [0]:
# Performing the ANOVA test
def run_anova_test(
    *, df: DataFrame, col_name: str, target_col: str = "ConvertedCompYearly"
) -> List[int]:
    """The function runs one-way anova tests between a feature column and the target column.
    :param df: A pyspark.sql.dataframe.DataFrame
    :param col_name: The relevant column to run the function on.
    :param target col: The target column.
    :return: None - No returned value.
    """
    # Finds out the target data for each category as a list
    categories_vals = df.toPandas().groupby(col_name)[target_col].apply(list)
    # We accept the Assumption(H0) only when P-Value > 0.05
    anova_results = f_oneway(*categories_vals)
    p_value = anova_results[1]
    print("\x1b[0;31;47m" + f"P-Value for Anova is: {p_value:.4f}")
    if p_value > 0.05:
        print(
            "\x1b[0;31;47m"
            + f"**** We reject the Assumption(H0).  {col_name} and {target_col} are correlated ****\n"
            + "=" * 100
        )
    else:
        print(
            "\x1b[0;32;47m"
            + f"**** We reject the Assumption(H0).  {col_name} and {target_col} are correlated ****\n"
            + "=" * 100
        )
    return None

In [0]:
def run_multiple_anova_tests(
    *, df: DataFrame, cols: List[str] = ["Country", "Currency", "Ethnicity"]
) -> DataFrame:
    """The function gets a list of columns as an input and set categories with low freaquency to 'Other' category.
    :param df: A pyspark.sql.dataframe.DataFrame object.
    :param cols: A list of columns names.
    :return: None - No returned value.
    """
    for col_name in cols:
        run_anova_test(df=df, col_name=col_name)

    return None

In [0]:
def floor_and_cap_outliers(
    *, df: DataFrame, quant: float = 0.1, target_col: str = "ConvertedCompYearly"
):
    """Quantile based flooring and capping the outliers
    :param df: A pyspark.sql.dataframe.DataFrame
    :param quant: The floor quantile.
    :param target_col: The target column.
    :return: None - No returned value.
    """
    floor_quantile = df.approxQuantile(target_col, [quant], 0)[0]
    cap_quantile = df.approxQuantile(target_col, [1.0 - quant], 0)[0]
    df = df.withColumn(
        target_col,
        f.when(f.col(target_col) < floor_quantile, floor_quantile)
        .otherwise(f.col(target_col))
        .cast(t.DoubleType()),
    )
    df = df.withColumn(
        target_col,
        f.when(f.col(target_col) > cap_quantile, cap_quantile)
        .otherwise(f.col(target_col))
        .cast(t.DoubleType()),
    )
    return df

In [0]:
def extract_cols_names_by_regex(
    *, df: DataFrame, pattern: str = "(^Total|_rank$)"
) -> List[str]:
    """The function extract columns name with the relevant substring
    :param df: A pyspark.sql.dataframe.DataFrame object to run the function on
    :param pattern: A regex to extract by.
    :return: A list of origin columns names substrings
    """
    cols_names = []
    for col_name in df.columns:
        match = re.search(pattern, col_name)
        if match:
            cols_names.append(col_name)

    return list(set(cols_names))

In [0]:
def sum_cols_group_by_prefix(*, df: DataFrame, col_prefix: str) -> DataFrame:
    """The function add a sum column of a group of columns by the same prefix
    :param df: A pyspark.sql.dataframe.DataFrame object to run the function on
    :param col_prefix: The prefix to search for
    :return: A pyspark.sql.dataframe.DataFrame with an additional Total column that sum the values of the group columns.
    """
    cols_group = [
        col_name for col_name in df.columns if col_name.startswith(col_prefix)
    ]
    return df.withColumn(f"Total_{col_prefix}", sum(df[c] for c in cols_group))

In [0]:
def replace_na_cols_to_null(*, df: DataFrame) -> DataFrame:
    """The function gets a DataFrame, col name, and a value and returns a modified DataFrame with None values isntead of the input value.
    :param df: A pyspark.sql.dataframe.DataFrame
    :param cols: A list of columns names.
    :return df: A pyspark.sql.dataframe.DataFrame object.
    """
    for col_name in df.columns:
        df = df.withColumn(
            col_name,
            f.when(f.col(col_name) != "NA", f.col(col_name)).otherwise(f.lit(None)),
        )
    return df

In [0]:
def get_cols_by_dtypes(*, df: DataFrame) -> Dict[str, List[str]]:
    """The function gets a DataFrame and returns a dictionary of columns dtypes.
    :param df: A pyspark.sql.dataframe.DataFrame
    :return: A dictionary with the column dtype as the key, and a list of relevant columns names as the value.
    """
    str_cols = [column[0] for column in df.dtypes if column[1] == "string"]
    numeric_cols = [column[0] for column in df.dtypes if column[1] in ("double", "int")]
    # Excluding splitted columns from numeric_cols list because they are binary with values of 0 and 1
    num_cols = extract_cols_names_by_regex(df=df)
    bin_cols = [col_name for col_name in numeric_cols if col_name not in num_cols]
    return {
        "cat_cols": str_cols,
        "bin_cols": bin_cols,
        "num_cols": num_cols,
    }

In [0]:
# def isinf(*, f_col: Column) -> Column:
#   """Check if a column contains infinity values.
#   param f_col (pyspark.sql.column.Column): The column to check for infinity values.
#   Returns: pyspark.sql.column.Column: A boolean column indicating whether each value in `f_col` is infinity.
#   """
#   return (f.col == float("inf")) | (f.col == float("-inf"))

In [0]:
def get_cols_with_null_values(*, df: DataFrame) -> List[str]:
    null_columns = []
    for column in df.columns:
        if df.filter(f.col(column).isNull() | f.isnan(column)).filter(f.col(column).isNull()).count() > 0:
            null_columns.append(column)
    return null_columns

In [0]:
def count_nulls(*, df: DataFrame) -> Dict[str, int]:
    """
    Counts the number of null values in each column of a Spark DataFrame.
    :param dataframe: Spark DataFrame to count null values for
    :return: Dictionary of column names and corresponding null value counts
    """
    null_counts = {}
    for column in df.columns:
        count = df.filter(f.col(column).isNull()).count()
        null_counts[column] = count
    return null_counts