In [42]:
import pyspark
print(pyspark.__version__)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.utils import AnalysisException
from time import sleep
import threading
import os
import io
import torch
import torch.nn.functional as F
from torchvision import transforms
from kafka import KafkaProducer
from PIL import Image
import cv2
import numpy as np
import segmentation_models_pytorch
from pymongo import MongoClient
import gridfs


3.5.1


In [43]:
# MongoDB Configuration
mongo_client = MongoClient('localhost', 27017)
db = mongo_client['bigdata']
collection = db['processed_images']
fs = gridfs.GridFS(db)

In [44]:
from pyspark.sql import SparkSession

scala_version = '2.12'
spark_version = '3.5.1'
packages = [f'org.apache.spark:spark-sql-kafka-0-10_{scala_version}:{spark_version}', 'org.apache.kafka:kafka-clients:3.7.0']

spark = SparkSession.builder.master("local").appName("kafka-example")\
                    .config("spark.driver.memory", "16g") \
                    .config("spark.executor.memory", "16g")\
                    .config("spark.executor.cores", "8")\
                    .config("spark.sql.shuffle.partitions", "1000")\
                    .config("spark.sql.execution.arrow.pyspark.enabled", "true")\
                    .config("spark.jars.packages", ",".join(packages)).getOrCreate()

spark

In [45]:
topic_name = 'RandomImage'
kafka_server = 'localhost:9092'



In [46]:
import os
import io
from time import sleep
from kafka import KafkaProducer
from PIL import Image
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import segmentation_models_pytorch
import torch
import torch.nn.functional as F
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [47]:
MODEL_LOAD_PATH = f'./models/supreme/abhi_sudo_full_2-pretrained_preproc_1-FPN+-ep_9-0.01-24.pt'
SAVE_PATH = f'./predictionsFPN'

os.makedirs(SAVE_PATH, exist_ok=True)

LOAD_SIZE = 8 # pred generating load size

In [48]:
### FPN ####
ENCODER_DEPTH=5
DPC=256
DSC = 128
BATCH_SIZE = [8]
LR = [1e-2]
EPOCHS= 50
ENCODER_NAME= 'resnet34'


CLASSES={'Background':0,'Building-flooded':1,'Building-non-flooded':2,'Road-flooded':3,'Road-non-flooded':4,
         'Water':5,'Tree':6,'Vehicle':7,'Pool':8,'Grass':9}
IMG_DIM= 512

In [49]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [50]:
model = segmentation_models_pytorch.FPN(encoder_name=ENCODER_NAME, encoder_depth=ENCODER_DEPTH,
                                              decoder_pyramid_channels=DPC, decoder_segmentation_channels=DSC, classes=len(CLASSES))


model = model.to(device)
model.load_state_dict(torch.load(MODEL_LOAD_PATH,map_location=torch.device('cuda')))
# model = model.to("cpu")
model.eval()

FPN(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_

In [51]:


val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [52]:
def reverse_transform_mask(inp):
    inp=inp.transpose((1, 2, 0))
    t_mask=np.argmax(inp,axis=2).astype('float32')
    t_mask=cv2.resize(t_mask, dsize=(4000, 3000))
    kernel = np.ones((3,3),np.uint8)
    t_mask = cv2.erode(t_mask, kernel, iterations=1)
    return t_mask

In [53]:
img_dim = 512

In [55]:
# Define the processing function
def process_image(image_bytes, image_count):
    try:
        print(f"Processing image {image_count}")
        img = Image.open(io.BytesIO(image_bytes))
        img = img.resize((IMG_DIM, IMG_DIM))
        img = np.array(img)

        transformed_img = val_transform(img)
        inputs = transformed_img.unsqueeze(0).to(device)

        with torch.no_grad():
            preds = model(inputs)
            preds = torch.sigmoid(preds).cpu().numpy()

        f_mask = reverse_transform_mask(preds[0])

        # Convert mask to bytes
        mask_bytes = io.BytesIO()
        np.save(mask_bytes, f_mask)
        mask_bytes.seek(0)

        # Save the mask bytes to GridFS
        mask_id = fs.put(mask_bytes, filename=f"{image_count}_mask.npy")

        # Save the mask ID and image ID to MongoDB
        image_doc = {
            "image_id": image_count,
            "mask_gridfs_id": mask_id
        }
        collection.insert_one(image_doc)
        print(f"Processed and saved mask for image {image_count} to MongoDB")
    except Exception as e:
        print(f"Error processing image {image_count}: {e}")


In [56]:


view_created = False
view_created_lock = threading.Lock()

In [57]:
def foreach_batch_function(df, epoch_id):
    print(f"Foreach Batch Function called on Epoch ID: {epoch_id}.")
    if df.count() > 0:
        print(f"Processing {df.count()} records in the batch.")
        pandas_df = df.toPandas()
        for index, row in pandas_df.iterrows():
            print(f"Processing row {index}")
            process_image(row['image_bytes'], row['key'])
    else:
        print("Empty DataFrame received in foreachBatch")

In [58]:
# Clean the checkpoint directory
import shutil
shutil.rmtree("checkpoint_dir", ignore_errors=True)

In [59]:
# Create DataFrame from Kafka stream
streamRawDf = spark.readStream.format("kafka").option("kafka.bootstrap.servers", kafka_server).option("subscribe", topic_name).option("startingOffsets", "latest").load()
streamDF = streamRawDf.select(col("key").cast("string"), col("value").alias("image_bytes"))


In [61]:
# Start the streaming query
query = streamDF.writeStream.foreachBatch(foreach_batch_function).option("checkpointLocation", "checkpoint_dir").start()

# Wait for the termination of the query
query.awaitTermination()

Foreach Batch Function called on Epoch ID: 0.
Empty DataFrame received in foreachBatch
Foreach Batch Function called on Epoch ID: 1.
Processing 1 records in the batch.


  PyArrow >= 4.0.0 must be installed; however, it was not found.
Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
  warn(msg)


Processing row 0
Processing image 7333
Processed and saved mask for image 7333 to MongoDB
Foreach Batch Function called on Epoch ID: 2.
Processing 1 records in the batch.
Processing row 0
Processing image 7413
Processed and saved mask for image 7413 to MongoDB
Foreach Batch Function called on Epoch ID: 3.
Processing 1 records in the batch.
Processing row 0
Processing image 7415
Processed and saved mask for image 7415 to MongoDB
Foreach Batch Function called on Epoch ID: 4.
Processing 1 records in the batch.
Processing row 0
Processing image 7420
Processed and saved mask for image 7420 to MongoDB
