In [None]:
import os
import torch
import torch.nn as nn
import pydicom as dicom # to read DCM images
from PIL import Image  # to read PNGs
import matplotlib.pylab as plt
import numpy as np
from collections import deque  # Import deque for efficient queue operations

class imageProcessor():
    def __init__(self, image, x_pix, y_pix):
        # Constructor to initialize the image processor object.
        # Reads the image, sets up the seed pixel and threshold values.
        
        self.ds = image
        self.row, self.col = self.ds.shape  # Get the dimensions of the image (rows, columns)
        
        # Normalize the image pixel values to range [0, 1] by dividing by the maximum value
        self.img_intent = self.ds / self.ds.max()  
        
        # Initialize an empty mask with the same shape as the image
        self.mask = np.zeros((self.row, self.col))
        
        # Initialize an array to store gradient information from the Sobel Filter
        self.gradient = np.zeros((self.row, self.col))
        
        # Set the seed pixel (starting point) coordinates
        self.seed_row = y_pix
        self.seed_col = x_pix

    def showMask(self):
        # Display the mask image using matplotlib
        plt.imshow(self.mask, 'gray')

    def showImage(self):
        # Display the image pixel intensity array using matplotlib
        plt.imshow(self.img_intent, 'gray')

    def sobelFilter(self):
        # Define Sobel kernels
        x_kernel = np.array([[-1, 0, 1],
                             [-2, 0, 2],
                             [-1, 0, 1]])

        y_kernel = np.array([[-1, -2, -1],
                             [ 0,  0,  0],
                             [ 1,  2,  1]])

        # Compute gradient magnitude using Sobel operator
        for y in range(1, self.row - 1):
            for x in range(1, self.col - 1):
                region = self.img_intent[y-1:y+2, x-1:x+2]
                gx = np.sum(region * x_kernel)
                gy = np.sum(region * y_kernel)
                self.gradient[y, x] = np.sqrt(gx**2 + gy**2)

    def equalizeImageHistogram(self):
        flat = self.img_intent.flatten()
        hist, bins = np.histogram(flat, bins=256, range=[0, 1], density=True)
        cdf = hist.cumsum()
        cdf_normalized = cdf / cdf[-1]
        img_eq = np.interp(flat, bins[:-1], cdf_normalized)
        self.img_intent = img_eq.reshape(self.img_intent.shape)
    
    def equalizeGradientHistogram(self):
        flat = self.gradient.flatten()
        hist, bins = np.histogram(flat, bins=256, range=[0, 1], density=True)
        cdf = hist.cumsum()
        cdf_normalized = cdf / cdf[-1]
        img_eq = np.interp(flat, bins[:-1], cdf_normalized)
        self.gradient = img_eq.reshape(self.gradient.shape)

    def preprocess(self):
        # Compute gradients
        self.equalizeImageHistogram()
        self.sobelFilter()
        self.gradient = self.gradient / self.gradient.max()
        self.equalizeGradientHistogram()
        
    def edgeDetection(self, thresh_coeff):
        # Initialize the BFS queue with the seed pixel
        queue = deque([(self.seed_row, self.seed_col)])
        thresh = self.gradient[self.seed_row-2:self.seed_row+3, self.seed_col-2:self.seed_col+3].mean() + thresh_coeff*self.gradient[self.seed_row-4:self.seed_row+5, self.seed_col-4:self.seed_col+5].std() 

        # Mark the seed pixel in the mask as visited (part of the region)
        self.mask[self.seed_row, self.seed_col] = 1

        # Perform BFS to grow the region
        while queue:
            current_row, current_col = queue.popleft()

            # Explore 4-connected neighbors (up, down, left, right)
            for delta_row, delta_col in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                neighbor_row = current_row + delta_row
                neighbor_col = current_col + delta_col

                # Ensure the neighbor is within image bounds and not visited
                if 0 <= neighbor_row < self.row and 0 <= neighbor_col < self.col and self.mask[neighbor_row, neighbor_col] == 0:
                        # Add the neighbor to the queue and mark it as part of the region (visited)
                    if(self.gradient[neighbor_row, neighbor_col] <= thresh):
                        queue.append((neighbor_row, neighbor_col))
                        self.mask[neighbor_row, neighbor_col] = 1  # Mark as part of the region
    
    def intensityThreshold(self):
        # Calculate the average intensity of the masked region
        average_intensity = np.mean(self.img_intent[self.mask == 1])
        intensity_thresh = 1.5*np.std(self.img_intent[self.mask == 1])
        # Initialize a queue for processing pixels and a set to track visited pixels to avoid revisiting
        intensity_queue = deque([(self.seed_row, self.seed_col)])  # Add the seed pixel to the queue
        intensity_visited = set([(self.seed_row, self.seed_col)])  # Mark the seed pixel as visited

        # Directions for exploring neighboring pixels (right, down, left, up)
        directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]  # Right, Down, Left, Up

        # Start the region-growing process for intensity-based filling
        while intensity_queue:
            x, y = intensity_queue.popleft()  # Get the current pixel from the front of the queue

            # Explore all neighboring pixels (right, down, left, up)
            for dx, dy in directions:
                nx, ny = x + dx, y + dy  # Calculate the coordinates of the neighboring pixel

                # Check if the neighboring pixel is within the bounds of the image and not already visited
                if 0 <= nx < self.row and 0 <= ny < self.col and (nx, ny) not in intensity_visited:
                    # Check if the neighboring pixel is within the intensity threshold bounds
                    if (average_intensity-intensity_thresh < self.img_intent[nx, ny] < average_intensity+intensity_thresh):
                        # Mark the neighboring pixel as visited and add it to the queue for further processing
                        intensity_visited.add((nx, ny))
                        intensity_queue.append((nx, ny))
                        self.mask[nx, ny] = 1  # Mark the pixel as part of the re
    
    def process(self):
        self.preprocess()
        thresh_coeff = 0.001
        while(self.mask.sum() < 50):
            self.edgeDetection(thresh_coeff)
            thresh_coeff += 0.001
        self.intensityThreshold()

def process(input_img,x,y):
    img = imageProcessor(input_img,x,y)
    img.process()
    return img.mask
    