In [1]:
import os
os.sys.path.append("../")
from scripts.merchant_fraud import *


In [2]:
spark = (
    SparkSession.builder.appName("Merchant Fraud Model")
    .config("spark.sql.repl.eagerEval.enabled", True)
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "2g")
    .getOrCreate()
)

24/09/30 15:40:54 WARN Utils: Your hostname, DESKTOP-H6V94HM resolves to a loopback address: 127.0.1.1; using 192.168.0.236 instead (on interface wifi0)
24/09/30 15:40:54 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/30 15:40:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


First, let's load the data that we're going to use

In [3]:
path = "../data/curated"

# Read in transactions dataset
transactions = spark.read.parquet(f"{path}/transactions.parquet")

                                                                                

## Feature engineering

We decide to introduce feature that will flag a transaction with a dollar value that significantly deviates from the mean dollar value. We going to keep it simple and assume that the underlying distribution of the dollar value for all merchants' transactions is normal. if the error, difference between dollar value and average dollar value, is greater than 2 standard deviation then the transaction will be flagged as suspicious.

We recognise that this may not be the best practice as the underlying transaction's dollar value distribution is normal for all merchants. Regardless, we believe this simple approach suffices for flagging transactions.

In [4]:
# Flag unusual transactions that deviate greatly from a merchant's usual dollar value
# Calculate average and standard deviation of dollar_value per merchant
transaction_stats = transactions.groupBy("merchant_abn").agg(
    F.avg("dollar_value").alias("avg_dollar_value"),
    F.stddev("dollar_value").alias("std_dollar_value")
)

# Join the stats back to the original dataset
transaction_df = transactions.join(transaction_stats, on="merchant_abn", how="left")

We noticed that there are certain merchant with `NULL` values for dollar value standard deviation. This is because that these merchants only have one transaction across the entirety time range of the data. Though these records seems suspcious as it seems unreasonable for a business to have only 1 transaction from Febuary 2021 to August 2022, we still going to keep it in the data.

In [5]:
# Calculate how many standard deviations away each transaction is, in other words, we're normalising the dollar value
# May need extra caution to interpret this feature as it can be POSITIVE OR NEGATIVE
transaction_df = transaction_df.withColumn(
    "std_diff_dollar_value", 
    F.when(
        F.col("std_dollar_value").isNotNull() & (F.col("std_dollar_value") != 0),  # Are there cases where std of dollar value is 0?
        (F.col("dollar_value") - F.col("avg_dollar_value")) / F.col("std_dollar_value")
    ).otherwise(0) 
)
transaction_df = transaction_df.drop("consumer_id", "consumer_fp", "name", "category", "avg_dollar_value", "std_dollar_value")

In [6]:
# Flag unusual monthly transaction volumes that deviate from a merchant's usual monthly volume
# Extract month and year from order_datetime
transaction_df = transaction_df.withColumn("order_month", F.date_format(F.col("order_datetime"), "yyyy-MM"))

# Calculate number of transactions per merchant per month
transaction_records_monthly = transaction_df.groupBy("merchant_abn", "order_month").agg(
    F.count("order_id").alias("monthly_order_volume")
)

# Calculate the average and standard deviation of monthly transactions per merchant
transaction_stats = transaction_records_monthly.groupBy("merchant_abn").agg(
    F.avg("monthly_order_volume").alias("avg_monthly_order_volume"),
    F.stddev("monthly_order_volume").alias("stddev_monthly_order_volume")
)

# Join the monthly volume feature back with the original dataset
transaction_records_final = transaction_df.join(transaction_records_monthly, on=["merchant_abn", "order_month"], how="left"
)

# Join the transaction statistics back to the original dataset 
transaction_records_final = transaction_records_final.join(transaction_stats, on="merchant_abn", how="left")

# Calculate how many standard deviations away each monthly volume is
transaction_records_final = transaction_records_final.withColumn(
    "std_diff_order_volume", 
    F.when(F.col("stddev_monthly_order_volume").isNotNull() & (F.col("stddev_monthly_order_volume") != 0),
           (F.col("monthly_order_volume") - F.col("avg_monthly_order_volume")) / F.col("stddev_monthly_order_volume"))
    .otherwise(0)
)

transaction_records_final.show(5)

                                                                                

+------------+-----------+--------------+------------------+--------------------+-----------+-------------+---------+---------------------+--------------------+------------------------+---------------------------+---------------------+
|merchant_abn|order_month|order_datetime|      dollar_value|            order_id|merchant_fp|revenue_level|take_rate|std_diff_dollar_value|monthly_order_volume|avg_monthly_order_volume|stddev_monthly_order_volume|std_diff_order_volume|
+------------+-----------+--------------+------------------+--------------------+-----------+-------------+---------+---------------------+--------------------+------------------------+---------------------------+---------------------+
| 96161947306|    2021-08|    2021-08-19| 63.60772275481862|e7da0886-4c01-4f1...|       NULL|            b|     4.52|  -0.5164591992066518|                 668|       655.3333333333334|         121.39483466092345|  0.10434271525676357|
| 92779316513|    2021-08|    2021-08-16|10.081895520137

In [7]:
transaction_records_final = transaction_records_final.drop("avg_monthly_order_volume", "stddev_monthly_order_volume")

We suspect that there is an underlying relationship between month, day of the week, and if the order date is a weekend with the fraud probability of the merchants. Thus, we will create 3 features that capture this temporal effect. We will later encode these feature to feed into our model

In [8]:
# Extract the weekday (1 = Sunday, 7 = Saturday)
transaction_records_final = transaction_records_final.withColumn("weekday", F.dayofweek("order_datetime"))

# Add a column to flag weekends (Saturday = 7, Sunday = 1)
transaction_records_final = transaction_records_final.withColumn(
    "is_weekend", 
    F.when((F.col("weekday") == 7) | (F.col("weekday") == 1), 1).otherwise(0)
)

# Extract year and month from 'order_month' and create new columns
transaction_records_final = transaction_records_final.withColumns(
    {"year":  F.split(F.col("order_month"), "-")[0].cast("integer"),
    "month": F.split(F.col("order_month"), "-")[1].cast("integer")}
)

# Modelling

Splitting the data to train and test set to train and fine-tune our model.

In [9]:
train_data = transaction_records_final.filter(F.col("merchant_fp").isNotNull())
test_data = transaction_records_final.filter(F.col("merchant_fp").isNull())

We will be using Random Forest Regression and Linear Regression to predict the fraud probability. We will aso perform a cross-validated grid search to find the "better" hyperparameters for the model.

For now, we will use the default hyperparameters of the model to see how each model perform.

In [10]:
# Assemble the data
assembled_train, _ = assemble_data(train_data)
train_set, validation_set = assembled_train.randomSplit([0.8,0.2], seed=123)

                                                                                

In [11]:
rfr_model = unoptimal_model(RandomForestRegressor(labelCol='merchant_fp', featuresCol='features'),
                            train_set, validation_set)

                                                                                

Root Mean Squared Error (RMSE) on validation data = 2.696487561469904
R2 (Coefficient of Determination) on validation data: 0.8172803798199015


In [12]:
# lr_model = unoptimal_model(LinearRegression(labelCol="merchant_fp", featuresCol="features"),
#                            train_set, validation_set)

We can see that the LR performance is quite terrible with RMSE that's almost double than that of RFR. It's $R^2$ is only 0.08 which indicates that the model failed to explain a large portion of variation in the data. Thus, we will use RFR as our main model and perform cross-validated grid search.

In [13]:
# Parameter grid
rfr_paramGrid = ParamGridBuilder() \
    .addGrid(RandomForestRegressor(labelCol='merchant_fp', featuresCol='features').numTrees, [10, 20, 40]) \
    .addGrid(RandomForestRegressor(labelCol='merchant_fp', featuresCol='features').maxDepth, [5, 10, 12]) \
    .build()

evaluator = RegressionEvaluator(labelCol="merchant_fp", predictionCol="prediction")

crossval = CrossValidator(estimator=RandomForestRegressor(labelCol='merchant_fp', featuresCol='features'),
                          estimatorParamMaps=rfr_paramGrid,
                          evaluator=evaluator,
                          numFolds=2)

cv_model = crossval.fit(train_set)
cv_predictions = cv_model.transform(validation_set)
cv_rmse = evaluator.evaluate(cv_predictions)

24/09/30 15:42:23 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

In [14]:
best_model = cv_model.bestModel
print(f"Best Model RMSE on test data = {cv_rmse}")
print(f"Best number of trees: {best_model.getNumTrees}") 
print(f"Best max depth: {best_model.getMaxDepth()}") 
print(f"Best max bins: {best_model.getMaxBins()}")

Best Model RMSE on test data = 2.696487561469904
Best number of trees: 20
Best max depth: 5
Best max bins: 32


In [15]:
rfr_better = RandomForestRegressor(labelCol='merchant_fp', featuresCol='features',
                                         numTrees = 20, maxBins = 32, maxDepth=5)

rfr_better_model = rfr_better.fit(train_set)

                                                                                

In [16]:
assembled_test, assembler = assemble_data(test_data)

                                                                                

Making preditions on the test data

In [17]:
predictions = rfr_better_model.transform(assembled_test)

predictions.write.parquet(f"../data/curated/transactions_predicted_merchant_fp.parquet", mode = "overwrite")

                                                                                

## Feature importances

In [18]:
import pandas as pd

rfr_feature_importances = rfr_better_model.featureImportances
feature_names = assembler.getInputCols()

rf_importances_df = pd.DataFrame({
    "Feature": feature_names,
    "Importance": rfr_feature_importances.toArray()
}).sort_values(by="Importance", ascending=False)

print(rf_importances_df)

                      Feature  Importance
7   norm_monthly_order_volume    0.263439
8  norm_std_diff_order_volume    0.181058
0               revenue_index    0.142353
2                 month_index    0.137357
9                   take_rate    0.093114
5           norm_dollar_value    0.089446
3               weekday_index    0.067100
6  norm_std_diff_dollar_value    0.022867
1                  year_index    0.002671
4           is_weekend_vector    0.000595


24/09/30 18:02:49 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 7139033 ms exceeds timeout 120000 ms
24/09/30 18:02:49 WARN SparkContext: Killing executors is not supported by current scheduler.
24/09/30 18:02:49 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$