In [1]:
#Load Libraries and initialize session
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.ml.feature import StringIndexer
from pyspark.sql.functions import year, month, dayofmonth
from pyspark.sql.functions import broadcast

# Initialize Spark session with updated configurations
spark = SparkSession.builder \
    .appName("Fraud Detection Feature Engineering") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryoserializer.buffer.max", "2047m") \
    .config("spark.executor.memory", "80G") \
    .config("spark.shuffle.compress", "true") \
    .config("spark.shuffle.spill.compress", "true") \
    .config("spark.sql.shuffle.partitions", "20000") \
    .getOrCreate()

# Retrieve the value of spark.config
network_timeout = spark.conf.get("spark.network.timeout", "Not Set")
max_result_size = spark.conf.get("spark.driver.maxResultSize", "Not Set")
heartbeat_interval = spark.conf.get("spark.executor.heartbeatInterval", "Not Set")
buffer_max = spark.conf.get("spark.kryoserializer.buffer.max")
print(f"The value of spark.kryoserializer.buffer.max is: {buffer_max}")
executor_memory = spark.conf.get("spark.executor.memory")
print(f"Current spark.executor.memory: {executor_memory}")
print(f"spark.network.timeout: {network_timeout}")
print(f"spark.driver.maxResultSize: {max_result_size}")
print(f"spark.executor.heartbeatInterval: {heartbeat_interval}")

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
0,application_1729426168720_0001,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

The value of spark.kryoserializer.buffer.max is: 2047m
Current spark.executor.memory: 80G
spark.network.timeout: 1600s
spark.driver.maxResultSize: Not Set
spark.executor.heartbeatInterval: Not Set

In [2]:
# Load datasets and infer schema
customers_path = "s3://nvidia-aws-fraud-detection-demo-training-data/customers_parquet/"
terminals_path = "s3://nvidia-aws-fraud-detection-demo-training-data/terminals_parquet/"
transactions_path = "s3://nvidia-aws-fraud-detection-demo-training-data/transactions_parquet/"

customers_df = spark.read.parquet(customers_path).repartition(300)
terminals_df = spark.read.parquet(terminals_path)
transactions_df = spark.read.parquet(transactions_path).repartition(1000)
_
# Show schema of each dataset to understand their structure
print("Customers Schema:")
customers_df.printSchema()

print("Terminals Schema:")
terminals_df.printSchema()

print("Transactions Schema:")
transactions_df.printSchema()

# Count the rows in each dataset to understand the size
#print(f"Number of rows in customers: {customers_df.count()}")
#print(f"Number of rows in terminals: {terminals_df.count()}")
#print(f"Number of rows in transactions: {transactions_df.count()}")

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Customers Schema:
root
 |-- CUSTOMER_ID: string (nullable = true)
 |-- customer_name: string (nullable = true)
 |-- billing_street: string (nullable = true)
 |-- billing_city: string (nullable = true)
 |-- billing_state: string (nullable = true)
 |-- billing_zip: string (nullable = true)
 |-- customer_job: string (nullable = true)
 |-- customer_email: string (nullable = true)
 |-- phone: string (nullable = true)
 |-- x_customer_id: double (nullable = true)
 |-- y_customer_id: double (nullable = true)
 |-- mean_amount: double (nullable = true)
 |-- std_amount: double (nullable = true)
 |-- mean_nb_tx_per_day: double (nullable = true)
 |-- std_dev_nb_tx_per_day: double (nullable = true)
 |-- available_terminals: array (nullable = true)
 |    |-- element: string (containsNull = true)

Terminals Schema:
root
 |-- TERMINAL_ID: string (nullable = true)
 |-- x_terminal_id: double (nullable = true)
 |-- y_terminal_id: double (nullable = true)
 |-- merchant: string (nullable = true)

Transactio

In [3]:
# Broadcast smaller tables for efficient joins
terminals_df = broadcast(terminals_df)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
# Convert the TX_DATETIME column to timestamp
transactions_df = transactions_df.withColumn(
    "TX_DATETIME",
    F.col("TX_DATETIME").cast("timestamp"))

# Split TX_DATETIME into yyyy, mm, and dd columns
transactions_df = transactions_df.withColumn("yyyy", year(F.col("TX_DATETIME"))) \
                                 .withColumn("mm", month(F.col("TX_DATETIME"))) \
                                 .withColumn("dd", dayofmonth(F.col("TX_DATETIME")))

# Define time windows in seconds for feature extraction
time_windows = {
    "15min": 15 * 60,
    "30min": 30 * 60,
    "60min": 60 * 60,
    "1day": 24 * 60 * 60,
    "7day": 7 * 24 * 60 * 60,
    "15day": 15 * 24 * 60 * 60,
    "30day": 30 * 24 * 60 * 60
}

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
# Define a function to add window features efficiently
def add_window_features(transactions_df, time_windows, entity_id_col, prefix):
    for window_name, window_duration in time_windows.items():
        window_spec = Window.partitionBy(entity_id_col).orderBy(
            F.col("TX_DATETIME").cast("long")).rangeBetween(
                -window_duration, 0)

        # Number of transactions in the time window
        transactions_df = transactions_df.withColumn(
            f"{prefix}_nb_txns_{window_name}_window",
            F.count("*").over(window_spec))

        # Average transaction amount in the time window
        transactions_df = transactions_df.withColumn(
            f"{prefix}_avg_amt_{window_name}_window",
            F.avg("TX_AMOUNT").over(window_spec))

    return transactions_df


# Add customer-related features
transactions_df = add_window_features(transactions_df, time_windows,
                                      "CUSTOMER_ID", "customer_id")

# Add terminal-related features
transactions_df = add_window_features(transactions_df, time_windows,
                                      "TERMINAL_ID", "terminal_id")

# Ordinal Encoding using StringIndexer for CUSTOMER_ID and TERMINAL_ID
customer_indexer = StringIndexer(inputCol="CUSTOMER_ID",
                                 outputCol="CUSTOMER_ID_index",
                                 handleInvalid="keep").fit(transactions_df)
transactions_df = customer_indexer.transform(transactions_df)

# Apply the same StringIndexer to customers_df to create the CUSTOMER_ID_index column
customers_df = customer_indexer.transform(customers_df)

# Ordinal encoding for other columns in customers_df
columns_to_encode_customers = ['customer_name', 'customer_email', 'phone']
for column in columns_to_encode_customers:
    if column in customers_df.columns:
        indexer = StringIndexer(inputCol=column,
                                outputCol=f"{column}_index",
                                handleInvalid="keep").fit(customers_df)
        customers_df = indexer.transform(customers_df)

# Ordinal encoding for TERMINAL_ID in transactions_df
terminal_indexer = StringIndexer(inputCol="TERMINAL_ID",
                                 outputCol="TERMINAL_ID_index",
                                 handleInvalid="keep").fit(transactions_df)
transactions_df = terminal_indexer.transform(transactions_df)

# Apply the same StringIndexer to terminals_df to create the TERMINAL_ID_index column
terminals_df = terminal_indexer.transform(terminals_df)

# Ordinal encoding for merchant in both transactions_df and terminals_df
if 'merchant' in transactions_df.columns:
    merchant_indexer = StringIndexer(inputCol='merchant',
                                     outputCol='merchant_index',
                                     handleInvalid="keep").fit(transactions_df)
    transactions_df = merchant_indexer.transform(transactions_df)

if 'merchant' in terminals_df.columns:
    merchant_indexer_terminals = StringIndexer(
        inputCol='merchant', outputCol='merchant_index',
        handleInvalid="keep").fit(terminals_df)
    terminals_df = merchant_indexer_terminals.transform(terminals_df)

# Apply StringIndexer to additional categorical columns in transactions_df
columns_to_encode_transactions = ['merchant']  # Already handled 'merchant'
for column in columns_to_encode_transactions:
    if column in transactions_df.columns:
        indexer = StringIndexer(inputCol=column,
                                outputCol=f"{column}_index",
                                handleInvalid="keep").fit(transactions_df)
        transactions_df = indexer.transform(transactions_df)
        transactions_df = transactions_df.drop(column)

# One-hot encoding for TX_FRAUD
transactions_df = transactions_df.withColumn(
    "TX_FRAUD_0", (F.col("TX_FRAUD") == 0).cast("int"))
transactions_df = transactions_df.withColumn(
    "TX_FRAUD_1", (F.col("TX_FRAUD") == 1).cast("int"))

# Drop TX_FRAUD and TX_DATETIME column after encoding
transactions_df = transactions_df.drop("TX_FRAUD")

transactions_df = transactions_df.drop("TX_DATETIME")

# Apply StringIndexer for billing_city and billing_state in customers_df
billing_city_indexer = StringIndexer(inputCol="billing_city", outputCol="billing_city_index").fit(customers_df)
customers_df = billing_city_indexer.transform(customers_df)

billing_state_indexer = StringIndexer(inputCol="billing_state", outputCol="billing_state_index").fit(customers_df)
customers_df = billing_state_indexer.transform(customers_df)

# Drop the original columns after encoding
customers_df = customers_df.drop("billing_city", "billing_state")

# Join the enriched transactions data with customer and terminal details
#intermediate_df = transactions_df.join(customers_df,
#                                on="CUSTOMER_ID_index",
#                                how="right").join(terminals_df,
#                                                 on="TERMINAL_ID_index",
#                                                 how="right")
#print(f"Total number of rows for right join: {intermediate_df.count()}")

final_df = transactions_df.join(customers_df,
                                on="CUSTOMER_ID_index",
                                how="left").join(terminals_df,
                                                 on="TERMINAL_ID_index",
                                                 how="left")
#print(f"Total number of rows for left join: {final_df.count()}")


# Select the final features and customer/terminal details
final_columns = [
    "CUSTOMER_ID_index",
    "customer_name_index",
    "customer_email_index",
    "phone_index",
    "billing_zip",
    "billing_city_index",  # Ordinal encoded billing_city
    "billing_state_index", # Ordinal encoded billing_state
    "x_customer_id",  # Added column
    "y_customer_id",  # Added column
    "TX_AMOUNT",
    "TX_FRAUD_0",  # One-hot encoded column
    "TX_FRAUD_1",  # One-hot encoded column   
    "TERMINAL_ID_index",
    "merchant_index",  # Ensure 'merchant_index' is present
    "yyyy",
    "mm",
    "dd",
    # Customer-related features
    "customer_id_nb_txns_15min_window",
    "customer_id_nb_txns_30min_window",
    "customer_id_nb_txns_60min_window",
    "customer_id_nb_txns_1day_window",
    "customer_id_nb_txns_7day_window",
    "customer_id_nb_txns_15day_window",
    "customer_id_nb_txns_30day_window",
    "customer_id_avg_amt_15min_window",
    "customer_id_avg_amt_30min_window",
    "customer_id_avg_amt_60min_window",
    "customer_id_avg_amt_1day_window",
    "customer_id_avg_amt_7day_window",
    "customer_id_avg_amt_15day_window",
    "customer_id_avg_amt_30day_window",
    # Terminal-related features
    "terminal_id_nb_txns_15min_window",
    "terminal_id_nb_txns_30min_window",
    "terminal_id_nb_txns_60min_window",
    "terminal_id_nb_txns_1day_window",
    "terminal_id_nb_txns_7day_window",
    "terminal_id_nb_txns_15day_window",
    "terminal_id_nb_txns_30day_window",
    "terminal_id_avg_amt_15min_window",
    "terminal_id_avg_amt_30min_window",
    "terminal_id_avg_amt_60min_window",
    "terminal_id_avg_amt_1day_window",
    "terminal_id_avg_amt_7day_window",
    "terminal_id_avg_amt_15day_window",
    "terminal_id_avg_amt_30day_window"
]

final_df = final_df.select(final_columns).repartition(10000)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [None]:
# Save the result to S3 as Parquet
spark.conf.set("spark.sql.files.maxPartitionBytes", "128M")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "500M")
final_output_path = "s3://nvidia-aws-fraud-detection-demo/output121/"
final_df.write.mode("overwrite").parquet(final_output_path)

print(f"Data successfully written to {final_output_path}")
# Stop the Spark session
spark.stop()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…