In [1]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.3.tar.gz (317.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.3/317.3 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.3-py2.py3-none-any.whl size=317840625 sha256=c688e48f42c718fc3c436b277fa4d1a4dc839d469579e25447a3d2c2b34f5bb5
  Stored in directory: /root/.cache/pip/wheels/1b/3a/92/28b93e2fbfdbb07509ca4d6f50c5e407f48dce4ddbda69a4ab
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.3


In [8]:
from pyspark.ml.classification import LogisticRegression
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import col, avg, stddev, hour
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType, StringType, TimestampType

# Initialize Sparksession
spark = SparkSession.builder.appName("FraudDetectionModelECommerce").getOrCreate()

# Define schema
schema = StructType([
    StructField("transaction_id", IntegerType(), True),
    StructField("user_id", IntegerType(), True),
    StructField("amount", DoubleType(), True),
    StructField("location_diff", DoubleType(), True),
    StructField("fraud_label", IntegerType(), True),
    StructField("transaction_time", TimestampType(), True)
])

# Load historical transaction data
transaction_df = spark.read.csv('/content/sample_data/ecom_transaction_data.csv', header=True, schema=schema)

# Display schema and initial data for verification
transaction_df.printSchema()
transaction_df.show()

# Feature Engineering
# Extract time-based features
transaction_df = transaction_df.withColumn("transaction_hour", hour(col("transaction_time")))

# User-level statistics (Calculate average and stddev of transaction amount per user)
user_stats = transaction_df.groupBy("user_id").agg(
    avg("amount").alias("avg_transaction_amount"),
    stddev("amount").alias("stddev_transaction_amount")
)

# Join user statistics
transaction_df = transaction_df.join(user_stats, on="user_id", how="left")

# Transaction amount deviation from user's average
transaction_df = transaction_df.withColumn(
    "amount_deviation",
    (col("amount") - col("avg_transaction_amount")) / col("stddev_transaction_amount")
)

# Assemble all the features into a single vector
assembler = VectorAssembler(
    inputCols=["amount", "transaction_hour", "location_diff", "amount_deviation"],
    outputCol="features",
    handleInvalid="skip"
)

# Apply assembler and transform the DataFrame
transaction_df_assembled = assembler.transform(transaction_df)

# Check if the DataFrame is empty or contains any null values in the features
if transaction_df_assembled.rdd.isEmpty() or transaction_df_assembled.filter(col("features").isNull()).count() > 0:
    print("Error: No valid data available for model training.")
else:
    # Train a logistic regression model to predict fraudulent transactions
    lr = LogisticRegression(featuresCol="features", labelCol="fraud_label")
    model = lr.fit(transaction_df_assembled)

    # Output predictions
    predictions = model.transform(transaction_df_assembled)

    # Show predictions
    predictions.select("transaction_id", "user_id", "amount", "location_diff", "fraud_label", "prediction").show(truncate=False)

root
 |-- transaction_id: integer (nullable = true)
 |-- user_id: integer (nullable = true)
 |-- amount: double (nullable = true)
 |-- location_diff: double (nullable = true)
 |-- fraud_label: integer (nullable = true)
 |-- transaction_time: timestamp (nullable = true)

+--------------+-------+------+-------------+-----------+-------------------+
|transaction_id|user_id|amount|location_diff|fraud_label|   transaction_time|
+--------------+-------+------+-------------+-----------+-------------------+
|             1|    100| 250.0|          0.5|          0|2024-09-01 10:00:00|
|             2|    101| 500.0|          1.0|          0|2024-09-01 11:30:00|
|             3|    102|1000.0|          0.3|          1|2024-09-01 12:00:00|
|             4|    100| 150.0|          0.6|          0|2024-09-01 13:00:00|
|             5|    101| 750.0|          1.2|          1|2024-09-01 14:00:00|
|             6|    102| 200.0|          0.8|          0|2024-09-01 15:00:00|
|             7|    100| 30