In [20]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import time

# Load the correct ResNet18 classifier: 2-class model
model = models.resnet18(pretrained=False)
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 2)  # ✅ match saved model output
)

# Load your trained weights
model.load_state_dict(torch.load("fraud_type_classification_resnet18_10.pth", map_location="cpu"))
model.eval()

# Define transform (no augmentation for inference)
inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [21]:
def infer_from_path(image_path: str) -> int:
    try:
        image = Image.open(image_path).convert("RGB")
        image = inference_transform(image).unsqueeze(0)  # shape: [1, 3, 224, 224]
        with torch.no_grad():
            output = model(image)
            prediction = torch.argmax(output, dim=1).item()
        return prediction
    except Exception as e:
        print(f"Failed on {image_path}: {e}")
        return -1  # error

In [22]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, IntegerType

spark = SparkSession.builder.appName("FraudImagePrediction").getOrCreate()

# Point to your folder of .jpg/.png images
image_folder = "/home/itewari1/DISML/Query Execution/heavy_test_query2/"

# Get list of valid image file paths
image_files = [os.path.join(image_folder, f) for f in os.listdir(image_folder)
               if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

# print(f"Found {len(image_files)} image files:")
# print(image_files[:5])

# Load image paths into Spark DataFrame
df = spark.createDataFrame([(f,) for f in image_files], ["image_path"])

In [23]:
# Count number of input images
num_images = df.count()

# Register your Python function as a Spark UDF
infer_fraud_type_udf = udf(infer_from_path, IntegerType())

# Start timer
start_time = time.time()

# Add prediction column
df_predicted = df.withColumn("predicted_class", infer_fraud_type_udf(df["image_path"]))
df_predicted.groupBy("predicted_class").count().show()

# End timer
end_time = time.time()

# Calculate latency
total_time = end_time - start_time
latency_per_image = total_time / num_images if num_images > 0 else 0

# Print metrics
print(f"\n Total images: {num_images}")
print(f" Total inference time: {total_time:.2f} seconds")
print(f" Avg latency per image: {latency_per_image:.4f} seconds/image")



+---------------+-----+
|predicted_class|count|
+---------------+-----+
|              1| 2656|
|              0| 2344|
+---------------+-----+


 Total images: 5000
 Total inference time: 40.74 seconds
 Avg latency per image: 0.0081 seconds/image


                                                                                