Skip to content

Commit

Permalink
feat: adaptive line kernel based on structure size
Browse files Browse the repository at this point in the history
  • Loading branch information
OBrink committed Sep 18, 2023
1 parent 70a5b2b commit 2f2fe8e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
39 changes: 30 additions & 9 deletions decimer_segmentation/complete_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,10 @@ def get_neighbour_pixels(
return neighbour_pixels


def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray:
def detect_horizontal_and_vertical_lines(
image: np.ndarray,
average_depiction_size: Tuple[int, int]
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
Expand All @@ -382,19 +385,19 @@ def detect_horizontal_and_vertical_lines(image: np.ndarray) -> np.ndarray:
"""
binarised_im = ~image * 255
binarised_im = binarised_im.astype("uint8")

structure_height, structure_width = average_depiction_size

horizontal_kernel_size = int(binarised_im.shape[1] / 7)
horizontal_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT, (horizontal_kernel_size, 1)
cv2.MORPH_RECT, (structure_width, 1)
)
horizontal_mask = cv2.morphologyEx(
binarised_im, cv2.MORPH_OPEN, horizontal_kernel, iterations=2
)
horizontal_mask = horizontal_mask == 255

vertical_kernel_size = int(binarised_im.shape[0] / 7)
vertical_kernel = cv2.getStructuringElement(
cv2.MORPH_RECT, (1, vertical_kernel_size)
cv2.MORPH_RECT, (1, structure_height)
)
vertical_mask = cv2.morphologyEx(
binarised_im, cv2.MORPH_OPEN, vertical_kernel, iterations=2
Expand Down Expand Up @@ -472,13 +475,30 @@ def expansion_coordination(


def complete_structure_mask(
image_array: np.array, mask_array: np.array, debug=False
image_array: np.array,
mask_array: np.array,
average_depiction_size: Tuple[int, int],
debug=False
) -> np.array:
"""
This funtion takes an image (array) and an array containing the masks (shape:
This funtion takes an image (np.array) and an array containing the masks (shape:
x,y,n where n is the amount of masks and x and y are the pixel coordinates).
Additionally, it takes the average depiction size of the structures in the image
which is used to define the kernel size for the vertical and horizontal line
detection for the exclusion masks. The exclusion mask is used to exclude pixels
from the mask expansion to avoid including whole tables.
It detects objects on the contours of the mask and expands it until it frames the
complete object in the image. It returns the expanded mask array"""
complete object in the image. It returns the expanded mask array
Args:
image_array (np.array): input image
mask_array (np.array): shape: y, x, n where n is the amount of masks
average_depiction_size (Tuple[int, int]): height, width
debug (bool, optional): More verbose if True. Defaults to False.
Returns:
np.array: expanded mask array
"""

if mask_array.size != 0:
# Binarization of input image
Expand All @@ -498,7 +518,8 @@ def complete_structure_mask(
split_mask_arrays = np.array(
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
)
exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array)
exclusion_mask = detect_horizontal_and_vertical_lines(blurred_image_array,
average_depiction_size)
# Run expansion the expansion
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
Expand Down
8 changes: 5 additions & 3 deletions decimer_segmentation/decimer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def segment_chemical_structures(
if not expand:
masks, bboxes, _ = get_mrcnn_results(image)
else:
average_height, average_width = determine_average_depiction_size(bboxes)
masks = get_expanded_masks(image)

segments, bboxes = apply_masks(image, masks)
Expand Down Expand Up @@ -227,9 +226,12 @@ def get_expanded_masks(image: np.array) -> np.array:
np.array: expanded masks (shape: (h, w, num_masks))
"""
# Structure detection with MRCNN
masks, _, _ = get_mrcnn_results(image)
masks, bboxes, _ = get_mrcnn_results(image)
size = determine_average_depiction_size(bboxes)
# Mask expansion
expanded_masks = complete_structure_mask(image_array=image, mask_array=masks)
expanded_masks = complete_structure_mask(image_array=image,
mask_array=masks,
average_depiction_size=size,)
return expanded_masks


Expand Down

0 comments on commit 2f2fe8e

Please sign in to comment.