In [93]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import TimestampType

from datetime import datetime

In [94]:
# Start PySpark session
spark = (
    SparkSession.builder.config("spark.driver.host", "localhost")
    .config("spark.driver.memory", "8g")
    .getOrCreate()
)

In [95]:
def partitionBy_tx_n_days(
    df: DataFrame,
    col_name: str,
    partition_by: str = "CUSTOMER_ID",
    n_days: int = 1,
    func=F.count,
    shift: int = 0,
    analysis_field: str = "TX_AMOUNT",
) -> DataFrame:
    """Takes a spark.DataFrame
        1. Define a windows of analysis
            1. Group by "partition_by" / For each element in partition_by
            2. Order by transaction date (TX_DATETIME)
            3. Define a time window from record's time window to n days before (in seconds)
        2. Apply a function (e.g. sum) to the window of analysis

    Parameters
    ----------
    df : spark.DataFrame
    col_name : str
        where to store the output
    partition_by : str, optional
        Fields to do the partition by / for each element of, by default "CUSTOMER_ID"
    n_days : int, optional
        number of days for the window of analysis, by default 1
    func : _type_, optional
        spark.sql.function, by default F.count
    shift: int, optional
        number of days to shift the time window to the past, by default 0
    analysis_field: str, optional
        field where to apply the function, default "TX_AMOUNT"

    Returns
    -------
    spark.DataFrame
        same df use as input plus a new column with the output of the calculation.
    """
    # Define window to be interrogated
    window_spec = (
        Window.partitionBy(partition_by)
        .orderBy(
            F.col("TX_DATETIME").cast("timestamp").cast("long")
        )  # conver datetime to seconds
        .rangeBetween(-(n_days + shift) * 86400, -shift * 86400)
    )

    return df.withColumn(col_name, func(analysis_field).over(window_spec))

In [96]:
def fraud_risk(
    df: DataFrame, col_name: str, delay: int = 7, n_days: int = 1
) -> DataFrame:
    """Compute Fraud Risk for a time window

    Parameters
    ----------
    df : spark.DataFrame
        input DataFrame
    col_name : str
        Where to save the outcome
    delay : int, optional
        time required to detect a fraudulent transaction based on business logic, by default 7
    n_days : int, optional
        number of days for the time window, by default 1

    Returns
    -------
    spark.DataFrame
        same as the input df + column with the output
    """
    df1 = df.select("*")  # Copy of the input df

    df1 = partitionBy_tx_n_days(
        df1, "fraud_in_window", "TERMINAL_ID", n_days, F.sum, delay, "TX_FRAUD"
    )
    df1 = partitionBy_tx_n_days(
        df1, "tx_in_window", "TERMINAL_ID", n_days, F.count, delay, "TX_FRAUD"
    )

    df1 = df1.withColumn(col_name, F.col("fraud_in_window") / F.col("tx_in_window"))
    df1 = df1.fillna(0, subset=[col_name])

    df = df.join(
        df1.select("TRANSACTION_ID", col_name), on="TRANSACTION_ID", how="left"
    )

    return df

In [97]:
def get_train_test_set(
    df, start_date_training, delta_train=7, delta_delay=1, delta_test=7
):
    # Convert start_date_training to a timestamp
    start_date_training = F.to_timestamp(F.lit(start_date_training))

    # Get the training set data
    train_df = df.filter(
        (df["TX_DATETIME"] >= start_date_training)
        & (df["TX_DATETIME"] < F.date_add(start_date_training, delta_train))
    )

    # Get the test set data
    test_dfs = []

    # Get known defrauded customers from the training set
    known_defrauded_customers = (
        train_df.filter(df["TX_FRAUD"] == 1).select("CUSTOMER_ID").distinct()
    )

    # Get the relative starting day of training set
    start_tx_time_days_training = train_df.select(F.min("TX_TIME_DAYS")).first()[0]

    # Then, for each day of the test set
    for day in range(delta_test):
        # Get test data for that day
        test_day = start_tx_time_days_training + delta_train + delta_delay + day
        test_df_day = df.filter(df["TX_TIME_DAYS"] == test_day)

        # Compromised cards from that test day, minus the delay period, are added to the pool of known defrauded customers
        test_day_delay_period = start_tx_time_days_training + delta_train + day - 1
        test_df_day_delay_period = df.filter(
            df["TX_TIME_DAYS"] == test_day_delay_period
        )

        new_defrauded_customers = (
            test_df_day_delay_period.filter(df["TX_FRAUD"] == 1)
            .select("CUSTOMER_ID")
            .distinct()
        )
        known_defrauded_customers = known_defrauded_customers.union(
            new_defrauded_customers
        ).distinct()

        test_df_day = test_df_day.join(
            known_defrauded_customers, on="CUSTOMER_ID", how="left_anti"
        )

        test_dfs.append(test_df_day)

    # Concatenate the list of test data DataFrames into a single DataFrame
    test_df = test_dfs[0]
    for test_df_day in test_dfs[1:]:
        test_df = test_df.union(test_df_day)

    # Sort data sets by ascending order of transaction ID
    train_df = train_df.orderBy("TRANSACTION_ID")
    test_df = test_df.orderBy("TRANSACTION_ID")

    return (train_df, test_df)

In [98]:
raw_path = "../data/raw/"
preprocessed_path = "../data/preprocessed/"

In [99]:
raw = spark.read.parquet(raw_path + "card_fraud.parquet.gzip")
raw = raw.withColumn("TX_DATETIME", raw["TX_DATETIME"].cast(TimestampType()))


raw.show()

+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+
|TRANSACTION_ID|        TX_DATETIME|CUSTOMER_ID|TERMINAL_ID|TX_AMOUNT|TX_TIME_SECONDS|TX_TIME_DAYS|TX_FRAUD|TX_FRAUD_SCENARIO|
+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+
|             0|2018-04-01 00:00:31|        596|       3156|    57.16|             31|           0|       0|                0|
|             1|2018-04-01 00:02:10|       4961|       3412|    81.51|            130|           0|       0|                0|
|             2|2018-04-01 00:07:56|          2|       1365|    146.0|            476|           0|       0|                0|
|             3|2018-04-01 00:09:29|       4128|       8737|    64.49|            569|           0|       0|                0|
|             4|2018-04-01 00:10:34|        927|       9906|    50.99|            634|           0|       0|   

In [100]:
preprocessed = raw.select("*")
# Create 'is_weekend' column
preprocessed = preprocessed.withColumn("IS_WKD", (F.dayofweek("TX_DATETIME") >= 6))
# Create 'is_night' column
preprocessed = preprocessed.withColumn("IS_NIGHT", (F.hour("TX_DATETIME") <= 6))

n_days = [1, 7, 30]

for f, fname in zip([F.count, F.mean], ["COUNT", "AVG"]):
    for d in n_days:
        field_name = f"CUSTOMER_TX_{fname}_{d}_DAY"
        preprocessed = partitionBy_tx_n_days(
            preprocessed, field_name, "CUSTOMER_ID", d, f
        )

for d in n_days:
    field_name = f"TERNINAL_TX_COUNT_{d}_DAY"
    preprocessed = partitionBy_tx_n_days(
        preprocessed, field_name, "TERMINAL_ID", d, F.count
    )

for d in n_days:
    field_name = f"TERMINAL_FRAUD_RISK_{d}_DAY"
    preprocessed = fraud_risk(preprocessed, field_name, delay=7, n_days=d)

In [101]:
preprocessed.show()

+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+------+--------+-----------------------+-----------------------+------------------------+---------------------+---------------------+----------------------+-----------------------+-----------------------+------------------------+-------------------------+-------------------------+--------------------------+
|TRANSACTION_ID|        TX_DATETIME|CUSTOMER_ID|TERMINAL_ID|TX_AMOUNT|TX_TIME_SECONDS|TX_TIME_DAYS|TX_FRAUD|TX_FRAUD_SCENARIO|IS_WKD|IS_NIGHT|CUSTOMER_TX_COUNT_1_DAY|CUSTOMER_TX_COUNT_7_DAY|CUSTOMER_TX_COUNT_30_DAY|CUSTOMER_TX_AVG_1_DAY|CUSTOMER_TX_AVG_7_DAY|CUSTOMER_TX_AVG_30_DAY|TERNINAL_TX_COUNT_1_DAY|TERNINAL_TX_COUNT_7_DAY|TERNINAL_TX_COUNT_30_DAY|TERMINAL_FRAUD_RISK_1_DAY|TERMINAL_FRAUD_RISK_7_DAY|TERMINAL_FRAUD_RISK_30_DAY|
+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+----

In [102]:
preprocessed.filter(preprocessed.TX_FRAUD == 1).show()

+--------------+-------------------+-----------+-----------+------------------+---------------+------------+--------+-----------------+------+--------+-----------------------+-----------------------+------------------------+---------------------+---------------------+----------------------+-----------------------+-----------------------+------------------------+-------------------------+-------------------------+--------------------------+
|TRANSACTION_ID|        TX_DATETIME|CUSTOMER_ID|TERMINAL_ID|         TX_AMOUNT|TX_TIME_SECONDS|TX_TIME_DAYS|TX_FRAUD|TX_FRAUD_SCENARIO|IS_WKD|IS_NIGHT|CUSTOMER_TX_COUNT_1_DAY|CUSTOMER_TX_COUNT_7_DAY|CUSTOMER_TX_COUNT_30_DAY|CUSTOMER_TX_AVG_1_DAY|CUSTOMER_TX_AVG_7_DAY|CUSTOMER_TX_AVG_30_DAY|TERNINAL_TX_COUNT_1_DAY|TERNINAL_TX_COUNT_7_DAY|TERNINAL_TX_COUNT_30_DAY|TERMINAL_FRAUD_RISK_1_DAY|TERMINAL_FRAUD_RISK_7_DAY|TERMINAL_FRAUD_RISK_30_DAY|
+--------------+-------------------+-----------+-----------+------------------+---------------+------------+----

In [112]:
preprocessed.sort("TERMINAL_FRAUD_RISK_30_DAY", ascending=False).show()

+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+------+--------+-----------------------+-----------------------+------------------------+---------------------+---------------------+----------------------+-----------------------+-----------------------+------------------------+-------------------------+-------------------------+--------------------------+
|TRANSACTION_ID|        TX_DATETIME|CUSTOMER_ID|TERMINAL_ID|TX_AMOUNT|TX_TIME_SECONDS|TX_TIME_DAYS|TX_FRAUD|TX_FRAUD_SCENARIO|IS_WKD|IS_NIGHT|CUSTOMER_TX_COUNT_1_DAY|CUSTOMER_TX_COUNT_7_DAY|CUSTOMER_TX_COUNT_30_DAY|CUSTOMER_TX_AVG_1_DAY|CUSTOMER_TX_AVG_7_DAY|CUSTOMER_TX_AVG_30_DAY|TERNINAL_TX_COUNT_1_DAY|TERNINAL_TX_COUNT_7_DAY|TERNINAL_TX_COUNT_30_DAY|TERMINAL_FRAUD_RISK_1_DAY|TERMINAL_FRAUD_RISK_7_DAY|TERMINAL_FRAUD_RISK_30_DAY|
+--------------+-------------------+-----------+-----------+---------+---------------+------------+--------+-----------------+----

In [None]:
spark.stop()  # Kill spark session