In [1]:
import os
import io
import logging
from kafka import KafkaProducer
from PIL import Image
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, struct, udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType, StringType
import segmentation_models_pytorch
import torch
from torchvision import transforms
import cv2
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pymongo import MongoClient
import gridfs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Set up logging (Driver-Side)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

python_version = 'D:/Anaconda/envs/cuda/python.exe'
os.environ['PYSPARK_PYTHON'] = python_version
os.environ['PYSPARK_DRIVER_PYTHON'] = python_version

In [3]:
# Kafka and Spark configuration (Driver-Side)
kafka_server = 'localhost:9092'
topic_name = 'RandomImage'
scala_version = '2.12'
spark_version = '3.5.1'
packages = [
    f'org.apache.spark:spark-sql-kafka-0-10_{scala_version}:{spark_version}', 
    'org.apache.kafka:kafka-clients:3.7.0',
    'org.mongodb.spark:mongo-spark-connector_2.12:2.4.1'  # Use an older, stable version
]

# Spark session setup (Driver-Side)
logger.info("Initializing Spark session on driver.")
spark = SparkSession.builder.master("local[*]") \
                    .appName("kafka-example") \
                    .config("spark.executor.memory", "32g") \
                    .config("spark.driver.memory", "32g") \
                    .config("spark.executor.cores", "8") \
                    .config("spark.jars.packages", ",".join(packages)) \
                    .getOrCreate()

INFO:__main__:Initializing Spark session on driver.


In [4]:

# Model and transformation setup (Driver-Side)

MODEL_LOAD_PATH = f'./models/supreme/abhi_sudo_full_2-pretrained_preproc_1-deeplabv3+-ep_8-0.001-30.pt'


In [5]:
# ####### DEEPLAB V3+ CONFIG #######
ENCODER_DEPTH=5
DECODER_CHANNELS=256
BATCH_SIZE= 16
EPOCHS= 100
LR = [1e-3]
ENCODER_NAME= 'resnet50'

CLASSES={'Background':0,'Building-flooded':1,'Building-non-flooded':2,'Road-flooded':3,'Road-non-flooded':4,
         'Water':5,'Tree':6,'Vehicle':7,'Pool':8,'Grass':9}
IMG_DIM= 512

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
model = segmentation_models_pytorch.DeepLabV3Plus(encoder_name=ENCODER_NAME, encoder_depth=ENCODER_DEPTH,
                                              decoder_channels=DECODER_CHANNELS, classes=len(CLASSES))


model = model.to(device)
model.load_state_dict(torch.load(MODEL_LOAD_PATH,map_location=torch.device('cuda')))
# model = model.to("cpu")
model.eval()

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequentia

In [8]:

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Broadcast the model and transformation needed on the workers (Driver-Side)
broadcast_model = spark.sparkContext.broadcast(model)
broadcast_val_transform = spark.sparkContext.broadcast(val_transform)

In [9]:

# Functions for image transformation and mask processing (Worker-Side)
def reverse_transform_mask(inp):
    inp = inp.transpose((1, 2, 0))
    t_mask = np.argmax(inp, axis=2).astype('float32')
    t_mask = cv2.resize(t_mask, dsize=(4000, 3000))
    kernel = np.ones((3, 3), np.uint8)
    t_mask = cv2.erode(t_mask, kernel, iterations=1)
    return t_mask

# Pandas UDF to transform images (Worker-Side)
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR)
def transform_image_pandas_udf(image_bytes_series: pd.Series) -> pd.Series:
    logger.info("Starting image transformation with Pandas UDF")
    def transform_image_single(image_bytes):
        img = Image.open(io.BytesIO(image_bytes))
        img = img.resize((IMG_DIM, IMG_DIM))
        img = np.array(img)
        kernel = np.ones((2, 2), np.uint8)
        img = cv2.bilateralFilter(img, 5, 75, 75)
        img = cv2.erode(cv2.dilate(img, kernel, iterations=2), kernel, iterations=1)
        transformed_img = broadcast_val_transform.value(img)
        return transformed_img.numpy().flatten().astype(np.float32)
    
    return image_bytes_series.apply(transform_image_single)

# Predict function for batch processing (Worker-Side)
def predict_batch_fn():
    model = broadcast_model.value
    model.eval()
    def predict(inputs: np.ndarray):
        inputs = inputs.reshape(-1, 3, IMG_DIM, IMG_DIM)
        inputs_tensor = torch.from_numpy(inputs).to(device).float()
        with torch.no_grad():
            preds = model(inputs_tensor)
            preds = torch.sigmoid(preds).cpu().numpy()
        return preds.reshape(inputs.shape[0], -1)
    return predict

image_predict_udf = predict_batch_udf(
    predict_batch_fn,
    return_type=ArrayType(FloatType()),
    batch_size=BATCH_SIZE,
    input_tensor_shapes=[[3 * IMG_DIM * IMG_DIM]]
)

# UDF for transforming the mask and saving to MongoDB (Worker-Side)
def save_to_mongo(key, predictions):
    logger.info(f"Saving to MongoDB for key: {key}")
    predictions = np.array(predictions).reshape(len(CLASSES), IMG_DIM, IMG_DIM)
    f_mask = reverse_transform_mask(predictions)
    
    # Convert mask to bytes and save to GridFS (Worker-Side)
    mask_bytes = io.BytesIO()
    np.save(mask_bytes, f_mask)
    mask_bytes.seek(0)

    # Setup MongoDB connection inside the worker function (Worker-Side)
    client = MongoClient('localhost', 27017)
    db = client['image_db']
    fs = gridfs.GridFS(db)
    collection = db['image_masks']

    mask_id = fs.put(mask_bytes, filename=f"{key}_mask.npy")

    # Save to MongoDB (Worker-Side)
    image_doc = {
        "image_id": key,
        "mask_gridfs_id": mask_id
    }
    collection.insert_one(image_doc)

    logger.info(f"Successfully saved to MongoDB for key: {key}")
    return key



In [10]:

save_to_mongo_udf = udf(save_to_mongo, StringType())

In [11]:

# Batch processing function (Driver-Side)
def foreach_batch_function(df, epoch_id):
    logger.info(f"Foreach Batch Function called on Epoch ID: {epoch_id}.")
    if df.count() > 0:
        logger.info(f"Processing {df.count()} records in the batch.")

        # Transform the image (Worker-Side)
        df = df.withColumn("transformed_image_bytes", transform_image_pandas_udf(col("image_bytes")))

        # Make predictions (Worker-Side)
        df = df.withColumn("predictions", image_predict_udf("transformed_image_bytes"))

        # Save predictions to MongoDB (Worker-Side)
        df = df.withColumn("saved_to_mongo", save_to_mongo_udf(col("key"), col("predictions")))

        df.show()

    else:
        logger.info("Empty DataFrame received in foreachBatch")


In [12]:

# Clean the checkpoint directory (Driver-Side)
import shutil
shutil.rmtree("checkpoint_dir", ignore_errors=True)


In [13]:

# Create DataFrame from Kafka stream (Driver-Side)
logger.info("Setting up Kafka stream on driver.")
streamRawDf = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", kafka_server) \
    .option("subscribe", topic_name) \
    .option("startingOffsets", "latest") \
    .load()

streamDF = streamRawDf.select(col("key").cast("string"), col("value").alias("image_bytes"))

INFO:__main__:Setting up Kafka stream on driver.


In [14]:

# Start the streaming query (Driver-Side)
logger.info("Starting the streaming query on driver.")
query = streamDF.writeStream \
    .foreachBatch(foreach_batch_function) \
    .option("checkpointLocation", "checkpoint_dir") \
    .start()

# Wait for the termination of the query (Driver-Side)
query.awaitTermination()

INFO:__main__:Starting the streaming query on driver.
INFO:py4j.java_gateway:Callback Server Starting
INFO:py4j.java_gateway:Socket listening on ('127.0.0.1', 57105)
INFO:py4j.clientserver:Python Server ready to receive messages
INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 0.
INFO:__main__:Empty DataFrame received in foreachBatch
INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 1.
INFO:__main__:Processing 1 records in the batch.


+-----+--------------------+-----------------------+--------------------+--------------+
|  key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+-----+--------------------+-----------------------+--------------------+--------------+
|10163|[FF D8 FF E0 00 1...|   [-0.2170563, -0.2...|[0.050401937, 0.0...|         10163|
+-----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 2.
INFO:__main__:Processing 2 records in the batch.


+-----+--------------------+-----------------------+--------------------+--------------+
|  key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+-----+--------------------+-----------------------+--------------------+--------------+
|10164|[FF D8 FF E0 00 1...|   [-1.6384109, -1.6...|[0.056595802, 0.0...|         10164|
|10167|[FF D8 FF E0 00 1...|   [-0.6280504, -0.6...|[0.041314233, 0.0...|         10167|
+-----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 3.
INFO:__main__:Processing 4 records in the batch.


+-----+--------------------+-----------------------+--------------------+--------------+
|  key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+-----+--------------------+-----------------------+--------------------+--------------+
|10174|[FF D8 FF E0 00 1...|   [0.3651854, 0.365...|[0.068720914, 0.0...|         10174|
|10183|[FF D8 FF E0 00 1...|   [-1.004795, -1.00...|[0.041440956, 0.0...|         10183|
|10808|[FF D8 FF E0 00 1...|   [0.056939743, 0.0...|[0.056107644, 0.0...|         10808|
|10812|[FF D8 FF E0 00 1...|   [-0.06293353, -0....|[0.048254177, 0.0...|         10812|
+-----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 4.
INFO:__main__:Processing 4 records in the batch.


+-----+--------------------+-----------------------+--------------------+--------------+
|  key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+-----+--------------------+-----------------------+--------------------+--------------+
|10813|[FF D8 FF E0 00 1...|   [-0.2170563, -0.2...|[0.040808864, 0.0...|         10813|
|10814|[FF D8 FF E0 00 1...|   [0.005565486, 0.0...|[0.06838085, 0.06...|         10814|
|10823|[FF D8 FF E0 00 1...|   [0.70768046, 0.70...|[0.060613256, 0.0...|         10823|
|10829|[FF D8 FF E0 00 1...|   [-0.02868402, -0....|[0.04760438, 0.05...|         10829|
+-----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 5.
INFO:__main__:Processing 4 records in the batch.


+-----+--------------------+-----------------------+--------------------+--------------+
|  key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+-----+--------------------+-----------------------+--------------------+--------------+
|10838|[FF D8 FF E0 00 1...|   [-0.7307989, -0.7...|[0.03899097, 0.03...|         10838|
|10839|[FF D8 FF E0 00 1...|   [1.1357993, 1.135...|[0.06983887, 0.06...|         10839|
|10843|[FF D8 FF E0 00 1...|   [-1.4157891, -1.4...|[0.03605754, 0.03...|         10843|
|11483|[FF D8 FF E0 00 1...|   [-0.7650484, -0.7...|[0.05621586, 0.04...|         11483|
+-----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 6.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6336|[FF D8 FF E0 00 1...|   [-0.6280504, -0.6...|[0.032775726, 0.0...|          6336|
|6342|[FF D8 FF E0 00 1...|   [-0.49105233, -0....|[0.03148911, 0.02...|          6342|
|6353|[FF D8 FF E0 00 1...|   [-0.06293353, -0....|[0.03539434, 0.03...|          6353|
|6362|[FF D8 FF E0 00 1...|   [-0.11430778, -0....|[0.05369853, 0.04...|          6362|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 7.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6371|[FF D8 FF E0 00 1...|   [-0.6451751, -0.6...|[0.019666988, 0.0...|          6371|
|6377|[FF D8 FF E0 00 1...|   [-1.4671633, -1.4...|[0.027309624, 0.0...|          6377|
|6383|[FF D8 FF E0 00 1...|   [-0.19993155, -0....|[0.06380591, 0.06...|          6383|
|6389|[FF D8 FF E0 00 1...|   [1.8721637, 1.872...|[0.059904538, 0.0...|          6389|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 8.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6391|[FF D8 FF E0 00 1...|   [0.24531215, 0.24...|[0.05138604, 0.04...|          6391|
|6394|[FF D8 FF E0 00 1...|   [0.810429, 0.8104...|[0.044096038, 0.0...|          6394|
|6405|[FF D8 FF E0 00 1...|   [0.34806067, 0.34...|[0.054290712, 0.0...|          6405|
|6412|[FF D8 FF E0 00 1...|   [0.41655967, 0.41...|[0.07816719, 0.07...|          6412|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 9.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6417|[FF D8 FF E0 00 1...|   [0.6220567, 0.622...|[0.05837118, 0.06...|          6417|
|6419|[FF D8 FF E0 00 1...|   [0.63918144, 0.63...|[0.07819656, 0.07...|          6419|
|6420|[FF D8 FF E0 00 1...|   [0.6220567, 0.622...|[0.08028237, 0.07...|          6420|
|6445|[FF D8 FF E0 00 1...|   [-0.91917115, -0....|[0.04051236, 0.03...|          6445|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 10.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6449|[FF D8 FF E0 00 1...|   [0.5535577, 0.553...|[0.06774624, 0.06...|          6449|
|6452|[FF D8 FF E0 00 1...|   [-0.06293353, -0....|[0.043316275, 0.0...|          6452|
|6467|[FF D8 FF E0 00 1...|   [-0.14855729, -0....|[0.082743004, 0.0...|          6467|
|6468|[FF D8 FF E0 00 1...|   [0.005565486, 0.0...|[0.08045129, 0.07...|          6468|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 11.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6476|[FF D8 FF E0 00 1...|   [0.46793392, 0.46...|[0.07418583, 0.06...|          6476|
|6488|[FF D8 FF E0 00 1...|   [-1.0219197, -1.0...|[0.046415146, 0.0...|          6488|
|6514|[FF D8 FF E0 00 1...|   [-1.3986644, -1.3...|[0.06097493, 0.05...|          6514|
|6517|[FF D8 FF E0 00 1...|   [-1.0904187, -1.0...|[0.061262812, 0.0...|          6517|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 12.
INFO:__main__:Processing 4 records in the batch.


+----+--------------------+-----------------------+--------------------+--------------+
| key|         image_bytes|transformed_image_bytes|         predictions|saved_to_mongo|
+----+--------------------+-----------------------+--------------------+--------------+
|6536|[FF D8 FF E0 00 1...|   [-1.1075436, -1.1...|[0.028384877, 0.0...|          6536|
|6545|[FF D8 FF E0 00 1...|   [-1.2102921, -1.2...|[0.045309566, 0.0...|          6545|
|6550|[FF D8 FF E0 00 1...|   [0.24531215, 0.24...|[0.066584244, 0.0...|          6550|
|6553|[FF D8 FF E0 00 1...|   [0.74193, 0.74193...|[0.06966449, 0.06...|          6553|
+----+--------------------+-----------------------+--------------------+--------------+



INFO:py4j.clientserver:Received command c on object id p0
INFO:__main__:Foreach Batch Function called on Epoch ID: 13.
INFO:__main__:Processing 4 records in the batch.
