In [1]:
! pip install pyspark

Collecting pyspark
  Downloading pyspark-3.5.3.tar.gz (317.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.3/317.3 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.3-py2.py3-none-any.whl size=317840625 sha256=df278bb5a977006f14ba34e4c2503d6c3f55305305f09f97e3cf4eef78026af0
  Stored in directory: /root/.cache/pip/wheels/1b/3a/92/28b93e2fbfdbb07509ca4d6f50c5e407f48dce4ddbda69a4ab
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.3


In [8]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType, StringType, TimestampType

# Initialize SparkSession
spark = SparkSession.builder.appName("FraudDetectionModel").getOrCreate()

# Define schema
schema = StructType([
    StructField("transaction_id", IntegerType(), True),
    StructField("user_id", IntegerType(), True),
    StructField("amount", DoubleType(), True),
    StructField("time_diff", DoubleType(), True),  # Time difference between transactions
    StructField("location_diff", DoubleType(), True),
    StructField("fraud_label", IntegerType(), True),
    StructField("transaction_time", TimestampType(), True)
])

# Load historical transaction data
transaction_df = spark.read.csv('/content/sample_data/fin_transaction_data.csv',header=True,schema=schema)

# Display schema and initial data for verification
transaction_df.printSchema()
transaction_df.show()

# Data Cleaning
transaction_df = transaction_df.na.drop(subset=["amount", "time_diff", "location_diff", "fraud_label"])
transaction_df = transaction_df.filter(F.col("amount") >= 0)

# Feature Engineering
# Extract time-based features
transaction_df = transaction_df.withColumn("transaction_hour", F.hour(F.col("transaction_time"))) \
                                 .withColumn("transaction_day", F.dayofweek(F.col("transaction_time"))) \
                                 .withColumn("transaction_month", F.month(F.col("transaction_time")))

# Count the number of transactions for each user
user_transaction_count = transaction_df.groupBy("user_id").count().withColumnRenamed("count", "user_transaction_count")
transaction_df = transaction_df.join(user_transaction_count, on="user_id", how="left")

# Calculate user statistics (Mean and standard deviation of the transaction amounts)
user_stats = transaction_df.groupBy("user_id").agg(
    F.mean("amount").alias("user_mean_amount"),
    F.stddev("amount").alias("user_stddev_amount")
)

# Join user statistics
transaction_df = transaction_df.join(user_stats, on="user_id", how="left")

# Calculate the Z-score
transaction_df = transaction_df.withColumn("amount_z_score",
    (F.col("amount") - F.col("user_mean_amount")) / F.col("user_stddev_amount"))

# Assemble features into a single vector
assembler = VectorAssembler(
    inputCols=["amount", "location_diff", "transaction_hour", "amount_z_score"],
    outputCol="features",
    handleInvalid="skip"  # Skip rows with invalid data
)
transaction_df = assembler.transform(transaction_df)

# Check for valid data before training
if transaction_df.rdd.isEmpty() or transaction_df.filter(F.col("features").isNull()).count() > 0:
    print("Error: No valid data available for model training.")
else:
    # Train a logistic regression model for fraud detection
    lr = LogisticRegression(featuresCol="features", labelCol="fraud_label")
    model = lr.fit(transaction_df)

    # Make predictions on the same historical transaction data
    predictions = model.transform(transaction_df)

    # Show the predictions with transaction details
    predictions.select("transaction_id", "user_id", "amount", "location_diff", "fraud_label", "prediction").show(truncate=False)


root
 |-- transaction_id: integer (nullable = true)
 |-- user_id: integer (nullable = true)
 |-- amount: double (nullable = true)
 |-- time_diff: double (nullable = true)
 |-- location_diff: double (nullable = true)
 |-- fraud_label: integer (nullable = true)
 |-- transaction_time: timestamp (nullable = true)

+--------------+-------+------------------+------------------+------------------+-----------+--------------------+
|transaction_id|user_id|            amount|         time_diff|     location_diff|fraud_label|    transaction_time|
+--------------+-------+------------------+------------------+------------------+-----------+--------------------+
|             1|      7| 278.3481840251138|1860.3725726843804|  98.1840888310531|          1|2023-11-03 05:34:...|
|             2|     15| 563.4372087184627|  939.724200214642| 83.89335020693633|          1|2024-02-17 05:49:...|
|             3|     11| 386.0122403800209|3586.5170654287695| 86.04046183116752|          1|2023-12-31 05:05:...