In [33]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import pandas_udf

from models.cnn import CNN

import base64
import io
from PIL import Image
import numpy as np
import pandas as pd
import torch

[Stage 12:>                                                         (0 + 4) / 4]

In [1]:
MODEL_PATH_PYTORCH = "models/fraud_cnn.pt"
MODEL_PATH_TENSORFLOW = "models/efficientNet.h5"

## Create Spark session and load data

In [15]:
spark = SparkSession.builder.appName("FraudModelInference_Tables").getOrCreate()

In [16]:
idimage_df = spark.read.option("header", True).csv("data/idimage_fixed.csv")
idlabel_df = spark.read.option("header", True).csv("data/idlabel.csv")
idmeta_df = spark.read.option("header", True).csv("data/idmeta.csv")  
#spark action

In [17]:
#change isfraud column datatype to Bool
idlabel_df = idlabel_df.withColumn("isfraud", col("isfraud").cast(BooleanType()))   #Transformation


# Register all tables

In [18]:
idimage_df.createOrReplaceTempView("idimage")
idlabel_df.createOrReplaceTempView("idlabel")
idmeta_df.createOrReplaceTempView("idmeta")

# Load model from saved weights (CNN(100), EfficientNet (81%))

In [19]:
print("Select which model you want to load:")
print("1. PyTorch CNN (.pt file)")
print("2. Keras Model (.h5 file)")

model_choice = input("Enter 1 or 2: ").strip()

model = None  

if model_choice == '1':
    try:
        model = CNN()
        model.load_state_dict(torch.load(MODEL_PATH_PYTORCH, map_location=torch.device('cpu')))
        model.eval()
        print("PyTorch CNN model loaded successfully.")
    except Exception as e:
        print(f"Error loading PyTorch model: {e}")
        model = None

elif model_choice == '2':
    try:
        model = load_model(MODEL_PATH_KERAS)
        print("Keras model loaded successfully.")
    except Exception as e:
        print(f"Error loading Keras model: {e}")
        model = None

else:
    print("Invalid input. Please enter 1 or 2.")
    model = None

Select which model you want to load:
1. PyTorch CNN (.pt file)
2. Keras Model (.h5 file)


Enter 1 or 2:  1


PyTorch CNN model loaded successfully.


Function to preprocess imagepath before sending it to model for input

In [21]:
def preprocess_image(base64_str):
    try:
        image_data = base64.b64decode(base64_str)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        image = image.resize((128, 128))  
        image_array = np.array(image) / 255.0  # Normalize
        image_array = np.expand_dims(image_array, axis=0)  # Add batch dim
        image_array = image_array.transpose((0, 3, 1, 2)) 
        image_tensor = torch.from_numpy(image_array).float()
        
        return image_tensor

    except Exception as e:
        print(f"Preprocessing failed: {e}")
        return None


DEFINING an REGISTERING UDF in SPARK

In [24]:
def fraud_detector(base64_str):
    if model is None:
        print("Model not loaded, cannot perform prediction.")
        return False 

    image_tensor = preprocess_image(base64_str)
    if image_tensor is None:
        return False
    try:
        with torch.no_grad():
            prediction = model(image_tensor)
        
        return bool(prediction[0][0] > 0.5)  
    except Exception as e:
        print(f"Prediction failed: {e}")
        return False




In [27]:
fraud_udf = udf(fraud_detector, BooleanType())
spark.udf.register("cnn_fraud_udf", fraud_detector, BooleanType())

<function __main__.fraud_detector(base64_str)>

# Schema of tables

idimage

In [28]:
idimage_schema = spark.sql("""
    SELECT * FROM idimage LIMIT 0;
""")

idimage_schema.printSchema()


root
 |-- name: string (nullable = true)
 |-- imageData: string (nullable = true)



idlabel

In [29]:
idlabel_schema = spark.sql("""
    SELECT * FROM idlabel LIMIT 0;
""")
idlabel_schema.printSchema()

root
 |-- id: string (nullable = true)
 |-- isfraud: boolean (nullable = true)
 |-- fraudpattern: string (nullable = true)
 |-- srcvalue: string (nullable = true)
 |-- srcfontstyle: string (nullable = true)
 |-- srcfontsize: string (nullable = true)
 |-- srcfontcolor: string (nullable = true)
 |-- srcbbox: string (nullable = true)
 |-- desvalue: string (nullable = true)
 |-- desfontstyle: string (nullable = true)
 |-- desfontsize: string (nullable = true)
 |-- desfontcolor: string (nullable = true)
 |-- desbbox: string (nullable = true)
 |-- srcname: string (nullable = true)
 |-- srcregionvalue: string (nullable = true)
 |-- srcregionfontstyle: string (nullable = true)
 |-- srcregionfontsize: string (nullable = true)
 |-- srcregionfontcolor: string (nullable = true)
 |-- srcregionbbox: string (nullable = true)
 |-- srcshift: string (nullable = true)
 |-- desname: string (nullable = true)
 |-- desregionvalue: string (nullable = true)
 |-- desregionfontstyle: string (nullable = true)
 |-

idmeta

In [30]:
idmeta_schema = spark.sql("""
    SELECT * FROM idmeta LIMIT 0;
""")
idmeta_schema.printSchema()

root
 |-- id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- address: string (nullable = true)
 |-- birthday: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- ethnicity: string (nullable = true)
 |-- class: string (nullable = true)
 |-- issue_date: string (nullable = true)
 |-- expire_date: string (nullable = true)
 |-- height: string (nullable = true)
 |-- weight: string (nullable = true)
 |-- eye_color: string (nullable = true)
 |-- hair_color: string (nullable = true)
 |-- is_donor: string (nullable = true)
 |-- is_veteran: string (nullable = true)
 |-- license_number: string (nullable = true)



# SQL QUEREIS TO GET INSIGHTS FROM DATASETS

Total IDs and Predicted Fraud Percentage


In [32]:
spark.sql("""
SELECT 
    COUNT(*) AS total_ids,
    SUM(CASE WHEN fraud_udf(imageData) THEN 1 ELSE 0 END) AS fraud_predicted,
    (SUM(CASE WHEN fraud_udf(imageData) THEN 1 ELSE 0 END) * 100.0) / COUNT(*) AS fraud_rate_percentage
FROM idimage
""").show()

AnalysisException: [UNRESOLVED_ROUTINE] Cannot resolve function `fraud_udf` on search path [`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`].; line 4 pos 18