Skip to content

Commit

Permalink
Merge pull request #29 from MortenTabaka/Add_possibility_to_display_S…
Browse files Browse the repository at this point in the history
…P_borders_on_predicion_mask

Save raw image to .cache with SuperPixels marked borders
  • Loading branch information
MortenTabaka committed May 6, 2023
2 parents 2b9d0bd + 6cd5e2e commit ff2f8c4
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 50 deletions.
4 changes: 3 additions & 1 deletion models/scripts/run_prediction_on_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def main(
sp_count: int = typer.Option(200, min=0),
sp_compactness: float = typer.Option(10, min=0),
sp_thresh: float = typer.Option(0.7, min=0, max=1),
sp_class_balance: bool = typer.Option(False),
border_sp: bool = typer.Option(
True, help="If should post-process tile boundaries with SuperPixels algorithm"
),
border_sp_count: int = typer.Option(
50, min=0, help="Will be multiplied by number of borders in single strip"
75, min=0, help="Will be multiplied by number of borders in single strip"
),
border_compactness: float = typer.Option(10, min=0),
border_sp_thresh: float = typer.Option(0.3, min=0, max=1),
Expand Down Expand Up @@ -73,6 +74,7 @@ def main(
number_of_superpixels=sp_count,
compactness=sp_compactness,
superpixel_threshold=sp_thresh,
sp_class_balance=sp_class_balance,
border_sp=border_sp,
border_sp_count=border_sp_count,
border_compactness=border_compactness,
Expand Down
30 changes: 17 additions & 13 deletions notebooks/testing/Test_superpixels.ipynb

Large diffs are not rendered by default.

35 changes: 19 additions & 16 deletions src/data/image_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import os.path
from enum import Enum
from pathlib import Path
from typing import List, Union
from typing import Any, List, Union, Tuple

import cv2
import numpy as np
import tensorflow as tf
from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries, slic
from tqdm import tqdm

from src.data.image_preprocessing import ImagePreprocessor
Expand Down Expand Up @@ -150,9 +150,9 @@ def get_concatenated_filename_and_image(
] = single_tile
elif self.data_mode == DataMode.NUMPY_TENSOR:
full_sized_tensor[
:,
v * tile_shape[1]: (v + 1) * tile_shape[1],
h * tile_shape[2]: (h + 1) * tile_shape[2],
:,
v * tile_shape[1] : (v + 1) * tile_shape[1],
h * tile_shape[2] : (h + 1) * tile_shape[2],
] = single_tile

k += 1
Expand Down Expand Up @@ -243,7 +243,7 @@ def get_updated_prediction_with_postprocessor_superpixels(
not_decoded_prediction: tf.Tensor,
threshold: float,
should_class_balance: bool = False,
):
) -> Tuple[tf.Tensor, Any]:
"""
Update the prediction for each superpixel segment in a not-decoded predicted tile.
Expand All @@ -264,7 +264,7 @@ def get_updated_prediction_with_postprocessor_superpixels(
contains integer values representing the updated predicted classes for each pixel
in the tile, after considering the most frequent class within each superpixel segment.
"""
superpixel_segments = self.get_superpixel_segments()
superpixel_segments, raw_segments = self.get_superpixel_segments()
num_of_segments = self.get_number_of_segments(superpixel_segments)
class_balance = get_normalized_class_balance_of_the_landcover_dataset()

Expand All @@ -280,7 +280,7 @@ def get_updated_prediction_with_postprocessor_superpixels(
if should_class_balance:
counts = tf.cast(counts, dtype=tf.float32)
num_classes = len(counts)
counts = counts / class_balance[num_classes - 1]
counts = [counts[i] / class_balance[i] for i in range(num_classes)]

# Find the index of the most often repeated value
most_frequent_value_index = tf.math.argmax(counts)
Expand All @@ -290,22 +290,25 @@ def get_updated_prediction_with_postprocessor_superpixels(
most_frequent_count = counts[most_frequent_value_index].numpy()
ratio = most_frequent_count / number_of_all_pixels_in_segment

road_class_pixel_count = counts[-1]

if ratio >= threshold:
# Get the most often repeated value
most_frequent_class_in_tile_segment = most_frequent_value_index.numpy()
# Create a tensor of ones with the shape of indices
ones = tf.ones((tf.shape(indices)[0],), dtype=tf.uint8)
ones = tf.ones((tf.shape(indices)[0],), dtype=tf.int64)
# Multiply the ones tensor by max_value
updates = ones * most_frequent_class_in_tile_segment
# Update the not_decoded_prediction tensor
not_decoded_prediction = tf.tensor_scatter_nd_update(
not_decoded_prediction, indices, updates
)
return not_decoded_prediction

def get_superpixel_segments(self) -> tf.Tensor:
raw_image_with_marked_boundaries = mark_boundaries(
self.raw_image, raw_segments, color=(0, 1, 1)
)

return not_decoded_prediction, raw_image_with_marked_boundaries

def get_superpixel_segments(self) -> Tuple:
"""
Generates superpixel segments for an input image using the Simple Linear
Iterative Clustering (SLIC) algorithm.
Expand All @@ -316,16 +319,16 @@ def get_superpixel_segments(self) -> tf.Tensor:
tensor contains integer values representing the segment labels (superpixel
indices) assigned to each pixel in the image.
"""
segments = slic(
raw_segments = slic(
self.raw_image,
**self.params_of_superpixels_postprocessing,
)
segments = tf.convert_to_tensor(segments)
segments = tf.convert_to_tensor(raw_segments)
segments = tf.reshape(
segments,
(1, segments.shape[0], segments.shape[1]),
)
return segments
return segments, raw_segments

@staticmethod
def get_number_of_segments(superpixel_segments: tf.Tensor) -> int:
Expand Down
70 changes: 50 additions & 20 deletions src/pipelines/prediction_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from shutil import rmtree
from typing import List, Union

import PIL.Image
from PIL import Image
import numpy as np
import tensorflow as tf
from cv2 import imread, imwrite
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
number_of_superpixels: int = None,
compactness: float = None,
superpixel_threshold: float = None,
sp_class_balance: bool = False,
border_sp: bool = True,
border_sp_count: int = None,
border_compactness: float = None,
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
self.number_of_superpixels = number_of_superpixels
self.compactness = compactness
self.superpixel_threshold = superpixel_threshold
self.sp_class_balance = sp_class_balance

self.border_sp = border_sp
self.border_compactness = border_compactness
Expand Down Expand Up @@ -161,10 +163,10 @@ def __get_superpixel_post_processed_tile_prediction(
self, tile: str, prediction: tf.Tensor
) -> tf.Tensor:
image = imread(tile)
prediction = SuperpixelsProcessor(
prediction, raw_image_with_marked_superpixels = SuperpixelsProcessor(
image, self.get_slic_parameters
).get_updated_prediction_with_postprocessor_superpixels(
prediction, self.superpixel_threshold
prediction, self.superpixel_threshold, self.sp_class_balance
)
return prediction

Expand Down Expand Up @@ -205,15 +207,15 @@ def __postprocess_tiles_borders_in_concatenated_prediction(
num_vertical_borders = int(width / self.tile_height) - 1
num_horizontal_borders = int(height / self.tile_width) - 1

raw_mask = self.__process_single_oriented_borders(
raw_mask, raw_image_with_boundaries = self.__process_single_oriented_borders(
raw_image,
raw_mask,
"vertical",
num_vertical_borders,
self.border_sp_pixel_range,
)

raw_mask = self.__process_single_oriented_borders(
raw_mask, raw_image_with_boundaries = self.__process_single_oriented_borders(
raw_image,
raw_mask,
"horizontal",
Expand All @@ -227,7 +229,20 @@ def __postprocess_tiles_borders_in_concatenated_prediction(

filename = self.__generate_filename_with_sp_params(base_name)
filepath = os.path.join(self.output_folder, filename)
decoded_prediction.save(filepath)
decoded_prediction.save(filepath, quality=100)

cached_marked_borders = os.path.join(
self.output_folder, f".cache/full_raw_images_with_boundaries"
)

os.makedirs(cached_marked_borders, exist_ok=False)
imwrite(
os.path.join(
cached_marked_borders,
f"{base_name}.tiff",
),
raw_image_with_boundaries,
)

def __process_single_oriented_borders(
self,
Expand All @@ -245,24 +260,38 @@ def __process_single_oriented_borders(
raw_border_area, raw_border_prediction = self.__get_border_area(
raw_image, raw_prediction, orientation, top_or_right, bottom_or_left
)
post_processed_border = SuperpixelsProcessor(
(
post_processed_border,
raw_image_with_marked_superpixels,
) = SuperpixelsProcessor(
raw_border_area, slic_params_for_border
).get_updated_prediction_with_postprocessor_superpixels(
raw_border_prediction,
threshold=self.border_sp_thresh,
should_class_balance=self.border_sp_class_balance,
)
if orientation == "horizontal":
if orientation == "vertical":
raw_prediction[
:, bottom_or_left:top_or_right, :
:, :, bottom_or_left:top_or_right
] = post_processed_border
elif orientation == "vertical":

raw_image[:, bottom_or_left:top_or_right, :] = (
raw_image_with_marked_superpixels * 255
)

elif orientation == "horizontal":
raw_prediction[
:, :, bottom_or_left:top_or_right
:, bottom_or_left:top_or_right, :
] = post_processed_border

raw_image[bottom_or_left:top_or_right, :, :] = (
raw_image_with_marked_superpixels * 255
)

else:
raise ValueError("Pick correct which border to process.")
return raw_prediction

return raw_prediction, raw_image

def __clear_cache(self, paths=None):
if paths is None:
Expand All @@ -277,22 +306,23 @@ def __generate_filename_with_sp_params(self, base_name: str):
filename = f"{base_name}".replace(".jpg", "")
if self.tiles_superpixel_postprocessing:
filename += (
f"__SpTiles_spCount{self.number_of_superpixels}_spThresh{self.superpixel_threshold}"
f"__spCompactness{self.compactness}"
f"--SpTiles_spCount{self.number_of_superpixels}-spThresh{self.superpixel_threshold}"
f"--spCompactness{self.compactness}--spCB{self.sp_class_balance}"
)
else:
filename = f"{filename}_NoTilesSuperPixelsProcessing"
filename = f"{filename}-NoTilesSuperPixelsProcessing"

if self.border_sp:
filename += (
f"__SpBorders_spBorderCount{self.border_sp_count}_spBorderThresh{self.border_sp_thresh}"
f"_spBorderCompactness{self.border_compactness}_spBorderCB{self.border_sp_class_balance}"
f"_spBorderRange{self.border_sp_pixel_range}"
f"--SpBorders_spBorderCount{self.border_sp_count}-spBorderThresh{self.border_sp_thresh}"
f"-spBorderCompactness{self.border_compactness}-spBorderCB{self.border_sp_class_balance}"
f"-spBorderRange{self.border_sp_pixel_range}"
)
else:
filename = f"{filename}__NoBordersSuperPixelsProcessing"
filename = f"{filename}--NoBordersSuperPixelsProcessing"

return f"{filename}.jpg"
filename = f"{filename}".replace(".", "_")
return f"{filename}.tiff"

@property
def get_slic_parameters(self):
Expand Down

0 comments on commit ff2f8c4

Please sign in to comment.