In [1]:
import os
import io
import logging
from kafka import KafkaProducer
from PIL import Image
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, struct, udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType, StringType
import segmentation_models_pytorch
import torch
from torchvision import transforms
import cv2
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pymongo import MongoClient
import gridfs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Set up logging (Driver-Side)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

python_version = 'D:/Anaconda/envs/cuda/python.exe'
os.environ['PYSPARK_PYTHON'] = python_version
os.environ['PYSPARK_DRIVER_PYTHON'] = python_version

In [3]:
# Kafka and Spark configuration (Driver-Side)
kafka_server = 'localhost:9092'
topic_name = 'RandomImage'
scala_version = '2.12'
spark_version = '3.5.1'
packages = [
    f'org.apache.spark:spark-sql-kafka-0-10_{scala_version}:{spark_version}', 
    'org.apache.kafka:kafka-clients:3.7.0',
    'org.mongodb.spark:mongo-spark-connector_2.12:2.4.1'  # Use an older, stable version
]

# Spark session setup (Driver-Side)
logger.info("Initializing Spark session on driver.")
spark = SparkSession.builder.master("local[*]") \
                    .appName("kafka-example") \
                    .config("spark.executor.memory", "32g") \
                    .config("spark.driver.memory", "32g") \
                    .config("spark.executor.cores", "8") \
                    .config("spark.jars.packages", ",".join(packages)) \
                    .getOrCreate()

INFO:__main__:Initializing Spark session on driver.


In [4]:

# Model and transformation setup (Driver-Side)
MODEL_LOAD_PATH = './models/abhi_sudo_full_2-pretrained_preproc_1-unet+-ep_13-0.01-24.pt'


In [5]:

ENCODER_DEPTH = 5
DECODER_CHANNELS = (256, 128, 64, 32, 16)
BATCH_SIZE = 8
LR = 1e-2
EPOCHS = 50
ENCODER_NAME = 'resnet34'
CLASSES = {'Background': 0, 'Building-flooded': 1, 'Building-non-flooded': 2, 'Road-flooded': 3, 'Road-non-flooded': 4,
           'Water': 5, 'Tree': 6, 'Vehicle': 7, 'Pool': 8, 'Grass': 9}
IMG_DIM = 512

In [6]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = segmentation_models_pytorch.Unet(encoder_name=ENCODER_NAME, encoder_depth=ENCODER_DEPTH,
                                         decoder_channels=DECODER_CHANNELS, classes=len(CLASSES))
model = model.to(device)
model.load_state_dict(torch.load(MODEL_LOAD_PATH, map_location=torch.device('cuda')))
model.eval()

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [7]:

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Broadcast the model and transformation needed on the workers (Driver-Side)
broadcast_model = spark.sparkContext.broadcast(model)
broadcast_val_transform = spark.sparkContext.broadcast(val_transform)

In [8]:

# Functions for image transformation and mask processing (Worker-Side)
def reverse_transform_mask(inp):
    inp = inp.transpose((1, 2, 0))
    t_mask = np.argmax(inp, axis=2).astype('float32')
    t_mask = cv2.resize(t_mask, dsize=(4000, 3000))
    kernel = np.ones((3, 3), np.uint8)
    t_mask = cv2.erode(t_mask, kernel, iterations=1)
    return t_mask

# Pandas UDF to transform images (Worker-Side)
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR)
def transform_image_pandas_udf(image_bytes_series: pd.Series) -> pd.Series:
    logger.info("Starting image transformation with Pandas UDF")
    def transform_image_single(image_bytes):
        img = Image.open(io.BytesIO(image_bytes))
        img = img.resize((IMG_DIM, IMG_DIM))
        img = np.array(img)
        kernel = np.ones((2, 2), np.uint8)
        img = cv2.bilateralFilter(img, 5, 75, 75)
        img = cv2.erode(cv2.dilate(img, kernel, iterations=2), kernel, iterations=1)
        transformed_img = broadcast_val_transform.value(img)
        return transformed_img.numpy().flatten().astype(np.float32)
    
    return image_bytes_series.apply(transform_image_single)

# Predict function for batch processing (Worker-Side)
def predict_batch_fn():
    model = broadcast_model.value
    model.eval()
    def predict(inputs: np.ndarray):
        inputs = inputs.reshape(-1, 3, IMG_DIM, IMG_DIM)
        inputs_tensor = torch.from_numpy(inputs).to(device).float()
        with torch.no_grad():
            preds = model(inputs_tensor)
            preds = torch.sigmoid(preds).cpu().numpy()
        return preds.reshape(inputs.shape[0], -1)
    return predict

image_predict_udf = predict_batch_udf(
    predict_batch_fn,
    return_type=ArrayType(FloatType()),
    batch_size=BATCH_SIZE,
    input_tensor_shapes=[[3 * IMG_DIM * IMG_DIM]]
)

# UDF for transforming the mask and saving to MongoDB (Worker-Side)
def save_to_mongo(key, predictions):
    logger.info(f"Saving to MongoDB for key: {key}")
    predictions = np.array(predictions).reshape(len(CLASSES), IMG_DIM, IMG_DIM)
    f_mask = reverse_transform_mask(predictions)
    
    # Convert mask to bytes and save to GridFS (Worker-Side)
    mask_bytes = io.BytesIO()
    np.save(mask_bytes, f_mask)
    mask_bytes.seek(0)

    # Setup MongoDB connection inside the worker function (Worker-Side)
    client = MongoClient('localhost', 27017)
    db = client['image_db']
    fs = gridfs.GridFS(db)
    collection = db['image_masks']

    mask_id = fs.put(mask_bytes, filename=f"{key}_mask.npy")

    # Save to MongoDB (Worker-Side)
    image_doc = {
        "image_id": key,
        "mask_gridfs_id": mask_id
    }
    collection.insert_one(image_doc)

    logger.info(f"Successfully saved to MongoDB for key: {key}")
    return key



In [9]:

save_to_mongo_udf = udf(save_to_mongo, StringType())

In [10]:

# Batch processing function (Driver-Side)
def foreach_batch_function(df, epoch_id):
    logger.info(f"Foreach Batch Function called on Epoch ID: {epoch_id}.")
    if df.count() > 0:
        logger.info(f"Processing {df.count()} records in the batch.")

        # Transform the image (Worker-Side)
        df = df.withColumn("transformed_image_bytes", transform_image_pandas_udf(col("image_bytes")))

        # Make predictions (Worker-Side)
        df = df.withColumn("predictions", image_predict_udf("transformed_image_bytes"))

        # Save predictions to MongoDB (Worker-Side)
        df = df.withColumn("saved_to_mongo", save_to_mongo_udf(col("key"), col("predictions")))

        df.show()

    else:
        logger.info("Empty DataFrame received in foreachBatch")


In [11]:

# Clean the checkpoint directory (Driver-Side)
import shutil
shutil.rmtree("checkpoint_dir", ignore_errors=True)


In [12]:

# Create DataFrame from Kafka stream (Driver-Side)
logger.info("Setting up Kafka stream on driver.")
streamRawDf = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", kafka_server) \
    .option("subscribe", topic_name) \
    .option("startingOffsets", "latest") \
    .load()

streamDF = streamRawDf.select(col("key").cast("string"), col("value").alias("image_bytes"))

INFO:__main__:Setting up Kafka stream on driver.


In [13]:

# Start the streaming query (Driver-Side)
logger.info("Starting the streaming query on driver.")
query = streamDF.writeStream \
    .foreachBatch(foreach_batch_function) \
    .option("checkpointLocation", "checkpoint_dir") \
    .start()

# Wait for the termination of the query (Driver-Side)
query.awaitTermination()

INFO:__main__:Starting the streaming query on driver.
INFO:py4j.java_gateway:Callback Server Starting
INFO:py4j.java_gateway:Socket listening on ('127.0.0.1', 53097)
INFO:py4j.clientserver:Python Server ready to receive messages
INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 0.
INFO:__main__:Empty DataFrame received in foreachBatch
INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 1.
INFO:__main__:Processing 1 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|7413|[FF D8 FF E0 00 1...|   [-1.2616663, -1.2...|[0.066137396, 0.0...|          7413|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 2.
INFO:__main__:Processing 3 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|7415|[FF D8 FF E0 00 1...|   [-1.1246684, -1.1...|[0.055927064, 0.0...|          7415|
|7420|[FF D8 FF E0 00 1...|   [-0.5253019, -0.5...|[0.067052126, 0.0...|          7420|
|7423|[FF D8 FF E0 00 1...|   [-1.3815396, -1.3...|[0.08092851, 0.05...|          7423|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 3.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|7431|[FF D8 FF E0 00 1...|   [-0.5253019, -0.5...|[0.08983604, 0.05...|          7431|
|7450|[FF D8 FF E0 00 1...|   [0.31381115, 0.31...|[0.12111328, 0.09...|          7450|
|7457|[FF D8 FF E0 00 1...|   [-1.0219197, -1.0...|[0.047142897, 0.0...|          7457|
|7461|[FF D8 FF E0 00 1...|   [0.15968838, 0.15...|[0.08039708, 0.04...|          7461|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 4.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|7464|[FF D8 FF E0 00 1...|   [0.60493195, 0.60...|[0.06619186, 0.03...|          7464|
|7476|[FF D8 FF E0 00 1...|   [-1.6041614, -1.6...|[0.057231933, 0.0...|          7476|
|7486|[FF D8 FF E0 00 1...|   [0.15968838, 0.15...|[0.07117865, 0.04...|          7486|
|7541|[FF D8 FF E0 00 1...|   [-1.0219197, -1.0...|[0.06371592, 0.03...|          7541|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 5.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|7543|[FF D8 FF E0 00 1...|   [-0.5424266, -0.5...|[0.05732034, 0.02...|          7543|
|7577|[FF D8 FF E0 00 1...|   [-0.38830382, -0....|[0.086897716, 0.0...|          7577|
|7581|[FF D8 FF E0 00 1...|   [-0.6622999, -0.6...|[0.04989769, 0.02...|          7581|
|7583|[FF D8 FF E0 00 1...|   [1.0844251, 1.084...|[0.072445035, 0.0...|          7583|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 6.
INFO:__main__:Processing 4 records in the batch.
