Skip to content

Commit

Permalink
feat: improved line detection
Browse files Browse the repository at this point in the history
  • Loading branch information
OBrink committed Dec 7, 2023
1 parent 3ec8239 commit 8143111
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 22 deletions.
126 changes: 106 additions & 20 deletions decimer_segmentation/complete_structure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import cv2
import math
import numpy as np
import matplotlib.pyplot as plt
import itertools
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.morphology import binary_erosion
from skimage.morphology import binary_erosion, binary_dilation
from typing import List, Tuple
from scipy.ndimage import label

Expand Down Expand Up @@ -98,8 +99,7 @@ def detect_horizontal_and_vertical_lines(
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
width/height of the image].
are part of long horizontal or vertical lines.
Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
Expand Down Expand Up @@ -130,11 +130,86 @@ def detect_horizontal_and_vertical_lines(
return horizontal_mask + vertical_mask


def find_equidistant_points(
x1: int,
y1: int,
x2: int,
y2: int,
num_points: int = 5
) -> List[Tuple[int, int]]:
"""
Finds equidistant points between two points.
Args:
x1 (int): x coordinate of first point
y1 (int): y coordinate of first point
x2 (int): x coordinate of second point
y2 (int): y coordinate of second point
num_points (int, optional): Number of points to return. Defaults to 5.
Returns:
List[Tuple[int, int]]: Equidistant points on the given line
"""
points = []
for i in range(num_points + 1):
t = i / num_points
x = x1 * (1 - t) + x2 * t
y = y1 * (1 - t) + y2 * t
points.append((x, y))
return points


def detect_lines(
image: np.ndarray,
max_depiction_size: Tuple[int, int],
segmentation_mask: np.ndarray
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of lines that are not part of chemical structures (like arrays, tables).
Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
binary_erosion() in complete_structure_mask()
max_depiction_size (Tuple[int, int]): height, width; used for thresholds
segmentation_mask (np.ndarray): Indicates whether or not a pixel is part of a
chemical structure depiction (shape: (height, width))
Returns:
np.ndarray: Exclusion mask that contains indices of pixels that are part of
horizontal or vertical lines
"""
image = ~image * 255
image = image.astype("uint8")
# Detect lines using the Hough Transform
lines = cv2.HoughLinesP(image,
1,
np.pi / 180,
threshold=5,
minLineLength=int(min(max_depiction_size)/4),
maxLineGap=10)
# Generate exclusion mask based on detected lines
exclusion_mask = np.zeros_like(image)
if lines is None:
return exclusion_mask
for line in lines:
x1, y1, x2, y2 = line[0]
# Check if any of the lines is in a chemical structure depiction
points = find_equidistant_points(x1, y1, x2, y2, num_points=7)
points_in_structure = False
for x, y in points[1:-1]:
if segmentation_mask[int(y), int(x)]:
points_in_structure = True
break
if points_in_structure:
continue
cv2.line(exclusion_mask, (x1, y1), (x2, y2), 255, 2)
return exclusion_mask


def expand_masks(
image_array: np.array,
seed_pixels: List[Tuple[int, int]],
mask_array: np.array,
exclusion_mask: np.array,
) -> np.array:
"""
This function generates a mask array where the given masks have been
Expand All @@ -144,20 +219,15 @@ def expand_masks(
image_array (np.array): array that represents an image (float values)
seed_pixels (List[Tuple[int, int]]): [(x, y), ...]
mask_array (np.array): MRCNN output; shape: (y, x, mask_index)
exclusion_mask (np.array]: indicates whether or not a pixel is excluded from
expansion
contour_expansion (bool, optional): Indicates whether or not to expand
from contours. Defaults to False.
Returns:
np.array: Expanded masks
"""
image_with_exclusion = np.invert(image_array) * np.invert(exclusion_mask)
labeled_array, _ = label(image_with_exclusion)
labeled_array, _ = label(image_array)
mask_array = np.zeros_like(image_array)
for seed_pixel in seed_pixels:
x, y = seed_pixel
if mask_array[y, x] or exclusion_mask[y, x]:
if mask_array[y, x]:
continue
label_value = labeled_array[y, x]
if label_value > 0:
Expand All @@ -176,7 +246,7 @@ def expansion_coordination(
seed_pixels = get_seeds(image_array,
mask_array,
exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array, exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array)
return mask_array


Expand Down Expand Up @@ -213,29 +283,45 @@ def complete_structure_mask(
blur_factor = (
int(image_array.shape[1] / 185) if image_array.shape[1] / 185 >= 2 else 2
)
if debug:
plot_it(binarized_image_array)
# Define kernel and apply
kernel = np.ones((blur_factor, blur_factor))
blurred_image_array = binary_erosion(binarized_image_array, footprint=kernel)
if debug:
plot_it(blurred_image_array)
debug = True
if debug:
plot_it(binarized_image_array)
# Slice mask array along third dimension into single masks
split_mask_arrays = np.array(
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
)
exclusion_mask = detect_horizontal_and_vertical_lines(
# Detect horizontal and vertical lines
horizontal_vertical_lines = detect_horizontal_and_vertical_lines(
blurred_image_array, max_depiction_size
)
# Run expansion the expansion
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])

hough_lines = detect_lines(
binarized_image_array,
max_depiction_size,
segmentation_mask=np.any(mask_array, axis=2).astype(np.bool)
)
hough_lines = binary_dilation(hough_lines, footprint=kernel)
exclusion_mask = horizontal_vertical_lines + hough_lines
if debug:
plot_it(horizontal_vertical_lines)
plot_it(hough_lines)
plot_it(exclusion_mask)
plot_it(np.invert(binarized_image_array) * np.invert(exclusion_mask))
image_with_exclusion = np.invert(blurred_image_array) * np.invert(exclusion_mask)

# Run expansion
image_repeat = itertools.repeat(image_with_exclusion, mask_array.shape[2])
exclusion_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
# Faster with map function
expanded_split_mask_arrays = map(
expansion_coordination,
split_mask_arrays,
image_repeat,
exclusion_mask_repeat,
exclusion_repeat,
)
# Stack mask arrays to give the desired output format
mask_array = np.stack(expanded_split_mask_arrays, -1)
Expand Down
5 changes: 3 additions & 2 deletions decimer_segmentation/decimer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InferenceConfig(moldetect.MolDetectConfig):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.7


def segment_chemical_structures_from_file(
Expand Down Expand Up @@ -142,8 +143,8 @@ def determine_depiction_size_with_buffer(
width = bbox[3] - bbox[1]
heights.append(height)
widths.append(width)
height = int(1.1 * np.max(heights))
width = int(1.1 * np.max(widths))
height = int(1.05 * np.max(heights))
width = int(1.05 * np.max(widths))
return height, width


Expand Down

0 comments on commit 8143111

Please sign in to comment.