In [1]:
!pip install pyspark
!pip install torch torchvision
!pip install pillow

Defaulting to user installation because normal site-packages is not writeable
Collecting pyspark
  Downloading pyspark-3.5.5.tar.gz (317.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.2/317.2 MB[0m [31m66.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting py4j==0.10.9.7 (from pyspark)
  Downloading py4j-0.10.9.7-py2.py3-none-any.whl.metadata (1.5 kB)
Downloading py4j-0.10.9.7-py2.py3-none-any.whl (200 kB)
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25ldone
[?25h  Created wheel for pyspark: filename=pyspark-3.5.5-py2.py3-none-any.whl size=317747923 sha256=9046619a5811f08f8b7397bdf17efcf357576cfa5e7bcdb06294b28097cfb97c
  Stored in directory: /home/itewari1/.cache/pip/wheels/8f/cb/c0/cc57eb1bf0f9dc87cdaf2b0dbac49e58a210ff68d21d6fc709
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.7 p

In [17]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("FraudDetection") \
    .getOrCreate()

# df = spark.read.option("header", True).csv("idimage.csv")

In [25]:
df = spark.read \
    .option("header", True) \
    .option("multiLine", True) \
    .option("quote", "\"") \
    .option("escape", "\"") \
    .option("mode", "PERMISSIVE") \
    .option("columnNameOfCorruptRecord", "_corrupt_record") \
    .csv("heavy_test_query1.csv")

In [27]:
import pandas as pd
df1 = pd.read_csv("heavy_test_query1.csv")
print(len(df1))

6000


In [28]:
print(f"Row count in original DataFrame: {df.count()}")
df.select("imageData").show(5, truncate=True)

[Stage 33:>                                                         (0 + 1) / 1]

Row count in original DataFrame: 6000
+--------------------+
|           imageData|
+--------------------+
|/9j/4AAQSkZJRgABA...|
|/9j/4AAQSkZJRgABA...|
|/9j/4AAQSkZJRgABA...|
|/9j/4AAQSkZJRgABA...|
|/9j/4AAQSkZJRgABA...|
+--------------------+
only showing top 5 rows



                                                                                

In [31]:
import base64
import io
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
import time

# Step 1: Load base MobileNetV3 Small architecture
model = models.mobilenet_v3_small(pretrained=False)

# Step 2: Rebuild the classifier to match training setup
model.classifier[3] = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(in_features=model.classifier[3].in_features, out_features=2)
)

# Step 3: Load the trained state_dict
state_dict = torch.load("mobileNetV3_fraud_model_3.pth", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

# Step 4: Define the inference transform (no augmentation)
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Step 5: Define the fraud prediction function
def infer_fraud(image_base64: str) -> int:
    try:
        image_bytes = base64.b64decode(image_base64)
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        image = inference_transform(image).unsqueeze(0)  # batch size = 1
        with torch.no_grad():
            output = model(image)
            prediction = torch.argmax(output, dim=1).item()
        return prediction  # 0 = Genuine, 1 = Fraud
    except Exception as e:
        print(f"Inference error: {e}")
        return -1

# Step 6: Register UDF in Spark
infer_fraud_udf = udf(infer_fraud, IntegerType())

In [32]:
# Step 1: Count number of input records
num_images = df.count()

# Step 2: Start timer
start_time = time.time()

# Assume 'df' has a column 'Base64' with base64-encoded images
df_with_preds = df.withColumn("prediction", infer_fraud_udf(df["imageData"]))

# Map predictions to labels
from pyspark.sql.functions import when

df_labeled = df_with_preds.withColumn(
    "label",
    when(df_with_preds["prediction"] == 0, "Genuine")
    .when(df_with_preds["prediction"] == 1, "Fraud")
    .otherwise("Error")
)

# Group by and count labels
result = df_labeled.groupBy("label").count()

result.show()

# Step 4: Stop timer
end_time = time.time()

# Step 5: Compute metrics
total_time = end_time - start_time
latency_per_image = total_time / num_images if num_images > 0 else 0

# Step 6: Print results
print(f"\nTotal images: {num_images}")
print(f"Total end-to-end processing time: {total_time:.2f} seconds")
print(f"Avg latency per image: {latency_per_image:.4f} seconds/image")

25/05/03 21:56:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/05/03 21:56:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
[Stage 43:>                                                         (0 + 1) / 1]

+-------+-----+
|  label|count|
+-------+-----+
|Genuine| 4822|
|  Fraud| 1178|
+-------+-----+


Total images: 6000
Total end-to-end processing time: 61.68 seconds
Avg latency per image: 0.0103 seconds/image


                                                                                