In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import pandas_udf

from models.cnn import CNN

import base64
import io
from PIL import Image
import numpy as np
import pandas as pd
import torch

In [3]:
MODEL_PATH_PYTORCH = "models/fraud_cnn.pt"
MODEL_PATH_TENSORFLOW = "models/efficientNet.h5"

## Create Spark session and load data

In [4]:
spark = SparkSession.builder.appName("FraudModelInference_Tables").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/29 00:10:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
idimage_df = spark.read.option("header", True).csv("data/idimage_fixed.csv")
idlabel_df = spark.read.option("header", True).csv("data/idlabel.csv")
idmeta_df = spark.read.option("header", True).csv("data/idmeta.csv")  
#spark action

In [6]:
#change isfraud column datatype to Bool
idlabel_df = idlabel_df.withColumn("isfraud", col("isfraud").cast(BooleanType()))   #Transformation


# Register all tables

In [7]:
idimage_df.createOrReplaceTempView("idimage")
idlabel_df.createOrReplaceTempView("idlabel")
idmeta_df.createOrReplaceTempView("idmeta")

25/04/29 00:10:49 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


# Load model from saved weights (CNN(100), EfficientNet (81%))

In [30]:
print("Select which model you want to load:")
print("1. PyTorch CNN (.pt file)")
print("2. Keras Model (.h5 file)")

model_choice = input("Enter 1 or 2: ").strip()

model = None  

if model_choice == '1':
    try:
        model = CNN()
        model.load_state_dict(torch.load(MODEL_PATH_PYTORCH, map_location=torch.device('cpu')))
        model.eval()
        broadcast_model = spark.sparkContext.broadcast(model)
        print("PyTorch CNN model loaded successfully.")
    except Exception as e:
        print(f"Error loading PyTorch model: {e}")
        model = None

elif model_choice == '2':
    try:
        model = load_model(MODEL_PATH_KERAS)
        print("Keras model loaded successfully.")
    except Exception as e:
        print(f"Error loading Keras model: {e}")
        model = None

else:
    print("Invalid input. Please enter 1 or 2.")
    model = None

Select which model you want to load:
1. PyTorch CNN (.pt file)
2. Keras Model (.h5 file)


Enter 1 or 2:  1


PyTorch CNN model loaded successfully.


Function to preprocess imagepath before sending it to model for input

In [31]:
def preprocess_image(base64_str):
    try:
        image_data = base64.b64decode(base64_str)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        image = image.resize((128, 128))  
        image_array = np.array(image) / 255.0  # Normalize
        image_array = np.expand_dims(image_array, axis=0)  # Add batch dim
        image_array = image_array.transpose((0, 3, 1, 2)) 
        image_tensor = torch.from_numpy(image_array).float()
        
        return image_tensor

    except Exception as e:
        print(f"Preprocessing failed: {e}")
        return None


DEFINING an REGISTERING UDF in SPARK

In [32]:
# def fraud_detector(base64_str):
#     if model is None:
#         print("Model not loaded, cannot perform prediction.")
#         return False 

#     image_tensor = preprocess_image(base64_str)
#     if image_tensor is None:
#         return False
#     try:
#         with torch.no_grad():
#             prediction = model(image_tensor)
        
#         return bool(prediction[0][0] > 0.5)  
#     except Exception as e:
#         print(f"Prediction failed: {e}")
#         return False


@pandas_udf(BooleanType())
def cnn_fraud_detector(image_col: pd.Series) -> pd.Series:
    mdl = broadcast_model.value
    results = []
    for base64_str in image_col:
        if mdl is None:
            results.append(False)
            continue

        try:
            image_data = base64.b64decode(base64_str)
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
            image = image.resize((128, 128))
            image_array = np.array(image) / 255.0
            image_array = np.expand_dims(image_array, axis=0)
            image_array = image_array.transpose((0, 3, 1, 2))
            image_tensor = torch.from_numpy(image_array).float()

            with torch.no_grad():
                prediction = mdl(image_tensor)
            results.append(bool(prediction[0][0] > 0.5))
        except Exception as e:
            print(f"Prediction failed: {e}")
            results.append(False)

    return pd.Series(results)


In [33]:
fraud_udf = udf(cnn_fraud_detector)

spark.udf.register("cnn_fraud_udf", cnn_fraud_detector)

25/04/29 00:18:28 WARN SimpleFunctionRegistry: The function cnn_fraud_udf replaced a previously registered function.


<pyspark.sql.udf.UserDefinedFunction at 0x155458899fd0>

# Schema of tables

idimage

In [34]:
idimage_schema = spark.sql("""
    SELECT * FROM idimage LIMIT 0;
""")

idimage_schema.printSchema()


root
 |-- name: string (nullable = true)
 |-- imageData: string (nullable = true)



idlabel

In [35]:
idlabel_schema = spark.sql("""
    SELECT * FROM idlabel LIMIT 0;
""")
idlabel_schema.printSchema()

root
 |-- id: string (nullable = true)
 |-- isfraud: boolean (nullable = true)
 |-- fraudpattern: string (nullable = true)
 |-- srcvalue: string (nullable = true)
 |-- srcfontstyle: string (nullable = true)
 |-- srcfontsize: string (nullable = true)
 |-- srcfontcolor: string (nullable = true)
 |-- srcbbox: string (nullable = true)
 |-- desvalue: string (nullable = true)
 |-- desfontstyle: string (nullable = true)
 |-- desfontsize: string (nullable = true)
 |-- desfontcolor: string (nullable = true)
 |-- desbbox: string (nullable = true)
 |-- srcname: string (nullable = true)
 |-- srcregionvalue: string (nullable = true)
 |-- srcregionfontstyle: string (nullable = true)
 |-- srcregionfontsize: string (nullable = true)
 |-- srcregionfontcolor: string (nullable = true)
 |-- srcregionbbox: string (nullable = true)
 |-- srcshift: string (nullable = true)
 |-- desname: string (nullable = true)
 |-- desregionvalue: string (nullable = true)
 |-- desregionfontstyle: string (nullable = true)
 |-

idmeta

In [36]:
idmeta_schema = spark.sql("""
    SELECT * FROM idmeta LIMIT 0;
""")
idmeta_schema.printSchema()

root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- address: string (nullable = true)
 |-- birthday: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- ethnicity: string (nullable = true)
 |-- class: string (nullable = true)
 |-- issue_date: string (nullable = true)
 |-- expire_date: string (nullable = true)
 |-- height: string (nullable = true)
 |-- weight: string (nullable = true)
 |-- eye_color: string (nullable = true)
 |-- hair_color: string (nullable = true)
 |-- is_donor: string (nullable = true)
 |-- is_veteran: string (nullable = true)
 |-- license_number: string (nullable = true)



# SQL QUEREIS TO GET INSIGHTS FROM DATASETS

Total IDs and Predicted Fraud Percentage

In [37]:
# THIS TAKES A LOT OF TIME TO RUN AND MEMORY AS WELL BECAUSE FOR EACH SUM IT RUNS MODEL INFERENCE
# spark.sql("""
# SELECT 
#     COUNT(*) AS total_ids,
#     SUM(CASE WHEN fraud_udf(imageData) THEN 1 ELSE 0 END) AS fraud_predicted,
#     (SUM(CASE WHEN fraud_udf(imageData) THEN 1 ELSE 0 END) * 100.0) / COUNT(*) AS fraud_rate_percentage
# FROM idimage
# """).show()

ADDING PREDICTION COLUMN TO THE TABLE FOR FASTER ANSWER RETRiEVAL using UDF

In [38]:
idimage_with_pred = idimage_df.withColumn(
    "predicted_fraud",
    cnn_fraud_detector(col("imageData"))
)

In [39]:
# REGISTERING THE NEW TEMPORARY VIEW
idimage_with_pred.createOrReplaceTempView("idimage_pred")

In [40]:
print("""================ Total IDs and Predicted Fraud Percentage ================""")
spark.sql("""
SELECT 
    COUNT(*) AS total_ids,
    SUM(CASE WHEN predicted_fraud THEN 1 ELSE 0 END) AS fraud_predicted,
    (SUM(CASE WHEN predicted_fraud THEN 1 ELSE 0 END) * 100.0) / COUNT(*) AS fraud_rate_percentage
FROM idimage_pred
""").show()



[Stage 16:>                                                         (0 + 4) / 4]

+---------+---------------+---------------------+
|total_ids|fraud_predicted|fraud_rate_percentage|
+---------+---------------+---------------------+
|      998|            199|    19.93987975951904|
+---------+---------------+---------------------+



                                                                                

In [41]:

print("""================ Total Fraudulent vs Non-Fraudulent IDs (Ground Truth) ================""")
spark.sql("""
    SELECT 
        COUNT(*) AS total_ids,
        SUM(CASE WHEN isfraud THEN 1 ELSE 0 END) AS total_fraud,
        SUM(CASE WHEN NOT isfraud THEN 1 ELSE 0 END) AS total_nonfraud,
        (SUM(CASE WHEN isfraud THEN 1 ELSE 0 END) * 100.0) / COUNT(*) AS fraud_percentage
    FROM idlabel
""").show()

+---------+-----------+--------------+------------------+
|total_ids|total_fraud|total_nonfraud|  fraud_percentage|
+---------+-----------+--------------+------------------+
|      199|        199|             0|100.00000000000000|
+---------+-----------+--------------+------------------+



In [42]:

print("""================ Most Common Fraud Patterns ================""")
spark.sql("""
    SELECT 
        fraudpattern, 
        COUNT(*) AS pattern_count
    FROM idlabel
    WHERE isfraud = TRUE
    GROUP BY fraudpattern
    ORDER BY pattern_count DESC""").show()

+--------------------+-------------+
|        fraudpattern|pattern_count|
+--------------------+-------------+
|Fraud6_crop_and_r...|          100|
|Fraud5_inpaint_an...|           99|
+--------------------+-------------+



In [48]:
# giving empty table
# print("""================  Fraud Rate Gender Wise ================""")
# spark.sql("""
#     SELECT 
#         m.gender,
#         COUNT(*) AS total,
#         SUM(CASE WHEN l.isfraud THEN 1 ELSE 0 END) AS fraud_count,
#     FROM idmeta m
#     JOIN idlabel l ON m.id = l.id
#     GROUP BY m.gender
#     ORDER BY fraud_rate DESC
# """).show()

# +------+-----+-----------+----------+
# |gender|total|fraud_count|fraud_rate|
# +------+-----+-----------+----------+
# +------+-----+-----------+----------+


In [49]:
# print("""================  Ground Truth v/s Prediction ================""")
# spark.sql("""
#     SELECT 
#         m.id, m.name, l.isfraud, p.predicted_fraud
#     FROM idmeta m
#     JOIN idlabel l ON m.id = l.id
#     JOIN idimage_pred p ON m.name = p.name
#     WHERE l.isfraud <> p.predicted_fraud
# """).show()

# getting empty output
# """
# +---+----+-------+---------------+
# | id|name|isfraud|predicted_fraud|
# +---+----+-------+---------------+
# +---+----+-------+---------------+
# """