Skip to content
211 changes: 211 additions & 0 deletions experiments/visualizing_segments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from loguru import logger
from PIL import Image
import numpy as np
from openadapt import vision, adapters
import cv2
from skimage.metrics import structural_similarity as ssim


def extract_difference_image(
new_image: Image.Image,
old_image: Image.Image,
tolerance: float = 0.05,
) -> Image.Image:
"""Extract the portion of the new image that is different from the old image.

Args:
new_image: The new image as a PIL Image object.
old_image: The old image as a PIL Image object.
tolerance: Tolerance level to consider a pixel as different (default is 0.05).

Returns:
A PIL Image object representing the difference image.
"""
new_image_np = np.array(new_image)
old_image_np = np.array(old_image)

# Compute the absolute difference between the two images in each color channel
diff = np.abs(new_image_np - old_image_np)

# Create a mask for the regions where the difference is above the tolerance
mask = np.any(diff > (255 * tolerance), axis=-1)

# Initialize an array for the segmented image
segmented_image_np = np.zeros_like(new_image_np)

# Set the pixels that are different in the new image
segmented_image_np[mask] = new_image_np[mask]

# Convert the numpy array back to an image
return Image.fromarray(segmented_image_np)


def combine_images_with_masks(
image_1: Image.Image,
difference_image: Image.Image,
old_masks: list[np.ndarray],
new_masks: list[np.ndarray],
) -> Image.Image:
"""Combine image_1 and difference_image using the masks.

Args:
image_1: The original image as a PIL Image object.
difference_image: The difference image as a PIL Image object.
old_masks: List of numpy arrays representing the masks from the original image.
new_masks: List of numpy arrays representing the masks from the difference image.

Returns:
A PIL Image object representing the combined image.
"""

image_1_np = np.array(image_1)
difference_image_np = np.array(difference_image)

# Create an empty canvas with the same dimensions and mode as image_1
combined_image_np = np.zeros_like(image_1_np)

def masks_overlap(mask1, mask2):
"""Check if two masks overlap."""
return np.any(np.logical_and(mask1, mask2))

# Apply old masks to the combined image where there is no overlap with new masks
for old_mask in old_masks:
if not any(masks_overlap(old_mask, new_mask) for new_mask in new_masks):
combined_image_np[old_mask] = image_1_np[old_mask]

# Apply new masks to the combined image
for new_mask in new_masks:
combined_image_np[new_mask] = difference_image_np[new_mask]

# Fill in remaining pixels from image_1 where there are no masks
combined_image_np[(combined_image_np == 0).all(axis=-1)] = image_1_np[
(combined_image_np == 0).all(axis=-1)
]

# Convert the numpy array back to an image
return Image.fromarray(combined_image_np)


def find_matching_sections_ssim(
image_1: Image.Image,
image_2: Image.Image,
block_size: int = 50,
threshold: float = 0.9,
):
"""Find and visualize matching sections between two images using SSIM.

Args:
image_1: The first image as a PIL Image object.
image_2: The second image as a PIL Image object.
block_size: The size of the blocks to compare in the SSIM calculation. Default is 50.
threshold: The SSIM score threshold to consider blocks as matching. Default is 0.9.

Returns:
A PIL Image object with matching sections highlighted.
"""

# Convert images to grayscale
image_1_gray = np.array(image_1.convert("L"))
image_2_gray = np.array(image_2.convert("L"))

# Dimensions of the images
height, width = image_1_gray.shape

# Create an empty image to visualize matches
matching_image = np.zeros_like(image_1_gray)

# Iterate over the image in blocks
for y in range(0, height, block_size):
for x in range(0, width, block_size):
# Define the block region
block_1 = image_1_gray[y : y + block_size, x : x + block_size]
block_2 = image_2_gray[y : y + block_size, x : x + block_size]

# Check if blocks have the same shape
if block_1.shape == block_2.shape:
# Compute SSIM for the current block
score, _ = ssim(block_1, block_2, full=True)

# Highlight matching sections
if score >= threshold:
matching_image[y : y + block_size, x : x + block_size] = 255

# Create an overlay to highlight matching regions on the original image
overlay = np.zeros_like(np.array(image_1), dtype=np.uint8)

# Apply the overlay to the matching regions
for c in range(0, 3): # For each color channel
overlay[:, :, c] = np.where(
matching_image == 255, np.array(image_1)[:, :, c], 0
)

# For RGBA images, set the alpha channel to 255 (fully opaque) for matching sections
if image_1.mode == "RGBA":
overlay[:, :, 3] = np.where(matching_image == 255, 255, 0)

# Convert back to PIL Image
matching_image_pil = Image.fromarray(overlay)

return matching_image_pil


def visualize(image_1: Image, image_2: Image):
"""Visualize matching sections, difference sections, and combined images between two images.

Args:
image_1: The first image as a PIL Image object.
image_2: The second image as a PIL Image object.

Returns:
None
"""

try:
images = []

matching_image = find_matching_sections_ssim(image_1, image_2)

difference_image = extract_difference_image(image_2, image_1, tolerance=0.05)

old_masks = vision.get_masks_from_segmented_image(image_1)
new_masks = vision.get_masks_from_segmented_image(difference_image)

combined_image = combine_images_with_masks(
image_1, difference_image, old_masks, new_masks
)

segmentation_adapter = adapters.get_default_segmentation_adapter()
ref_segmented_image = segmentation_adapter.fetch_segmented_image(image_1)
new_segmented_image = segmentation_adapter.fetch_segmented_image(image_2)
matching_image_segment = segmentation_adapter.fetch_segmented_image(
matching_image
)
non_matching_image_Segment = segmentation_adapter.fetch_segmented_image(
difference_image
)
combined_image_segment = segmentation_adapter.fetch_segmented_image(
combined_image
)

images.append(image_1)
images.append(ref_segmented_image)
images.append(image_2)
images.append(new_segmented_image)
images.append(matching_image)
images.append(matching_image_segment)
images.append(difference_image)
images.append(non_matching_image_Segment)
images.append(combined_image)
images.append(combined_image_segment)

for image in images:
image.show()

except Exception as e:
logger.error(f"An error occurred: {e}")


# Example usage
img_2 = Image.open("../experiments/winCalNew.png")
img_1 = Image.open("../experiments/winCalOld.png")
visualize(img_1, img_2)
Binary file added experiments/winCalNew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added experiments/winCalOld.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 120 additions & 1 deletion openadapt/strategies/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,102 @@ def find_similar_image_segmentation(
return similar_segmentation, similar_segmentation_diff


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove extra newline

def combine_segmentations(
difference_image: Image.Image,
previous_segmentation: Segmentation,
new_descriptions: list[str],
new_masked_images: list[Image.Image],
new_masks: list[np.ndarray],
) -> Segmentation:
"""Combine the previous segmentation with the new segmentation of the differences.
Args:
difference_image: The difference image found in similar segmentation.
previous_segmentation: The previous segmentation containing unchanged segments.
new_descriptions: Descriptions of the new segments from the difference image.
new_masked_images: Masked images of the new segments from the difference image.
new_masks: masks of the new segments.
Returns:
Segmentation: A new segmentation combining both previous and new segments.
"""

image_1_np = np.array(previous_segmentation.image)
difference_image_np = np.array(difference_image)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add function docstring including args and return values

# Create an empty canvas with the same dimensions and mode as image_1
combined_image_np = np.zeros_like(image_1_np)

# Ensure difference_image_np is 3 channels
if difference_image_np.ndim == 2: # Grayscale image
difference_image_np = np.stack((difference_image_np,) * 3, axis=-1)

def masks_overlap(mask1, mask2):
"""Check if two masks overlap."""
return np.any(np.logical_and(mask1, mask2))

# Calculate the bounding boxes and centroids for the new segments
new_bounding_boxes, new_centroids = vision.calculate_bounding_boxes(new_masks)

segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_prev_image = segmentation_adapter.fetch_segmented_image(
previous_segmentation.image
)
previous_masks = vision.get_masks_from_segmented_image(segmented_prev_image)

# Filter out overlapping previous segments
filtered_previous_masked_images = []
# filtered_previous_descriptions = []
filtered_previous_bounding_boxes = []
filtered_previous_centroids = []
for idx, prev_mask in enumerate(previous_masks):
if not any(masks_overlap(prev_mask, new_mask) for new_mask in new_masks):
combined_image_np[prev_mask] = image_1_np[
prev_mask
] # Apply previous masks to the combined image where there is no overlap with new masks
filtered_previous_masked_images.append(
previous_segmentation.masked_images[idx]
)
# filtered_previous_descriptions.append(
# previous_segmentation.descriptions[idx]
# )
filtered_previous_bounding_boxes.append(
previous_segmentation.bounding_boxes[idx]
)
filtered_previous_centroids.append(previous_segmentation.centroids[idx])

# Apply new masks to the combined image
for new_mask in new_masks:
combined_image_np[new_mask] = difference_image_np[new_mask]

# Fill in remaining pixels from image_1 where there are no masks
combined_image_np[(combined_image_np == 0).all(axis=-1)] = image_1_np[
(combined_image_np == 0).all(axis=-1)
]

# Combine filtered previous segments with new segments
combined_masked_images = filtered_previous_masked_images + new_masked_images
# combined_descriptions = filtered_previous_descriptions + new_descriptions
combined_bounding_boxes = filtered_previous_bounding_boxes + new_bounding_boxes
combined_centroids = filtered_previous_centroids + new_centroids

# Convert the numpy array back to an image
new_image = Image.fromarray(combined_image_np)

marked_image = plotting.get_marked_image(
new_image,
new_masks, # masks,
)
# new_image.show()

return Segmentation(
new_image,
marked_image,
combined_masked_images,
new_descriptions,
combined_bounding_boxes,
combined_centroids,
)


def get_window_segmentation(
action_event: models.ActionEvent,
exceptions: list[Exception] | None = None,
Expand Down Expand Up @@ -402,7 +498,30 @@ def get_window_segmentation(
# TODO XXX: create copy of similar_segmentation, but overwrite with segments of
# regions of new image where segments of similar_segmentation overlap non-zero
# regions of similar_segmentation_diff
return similar_segmentation
logger.info(f"Found similar_segmentation")
similar_segmentation_diff_image = Image.fromarray(similar_segmentation_diff)
segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_diff_image = segmentation_adapter.fetch_segmented_image(
similar_segmentation_diff_image
)
new_masks = vision.get_masks_from_segmented_image(segmented_diff_image)
new_masked_images = vision.extract_masked_images(
similar_segmentation_diff_image, new_masks
)
new_descriptions = prompt_for_descriptions(
similar_segmentation_diff_image,
new_masked_images,
action_event.active_segment_description,
exceptions,
)
updated_segmentation = combine_segmentations(
similar_segmentation_diff_image,
similar_segmentation,
new_descriptions,
new_masked_images,
new_masks,
)
return updated_segmentation

segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_image = segmentation_adapter.fetch_segmented_image(original_image)
Expand Down
15 changes: 11 additions & 4 deletions openadapt/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,18 @@ def extract_masked_images(
cropped_mask = mask[rmin : rmax + 1, cmin : cmax + 1]
cropped_image = original_image_np[rmin : rmax + 1, cmin : cmax + 1]

# Ensure the cropped image has the correct shape
if cropped_image.ndim == 2: # Grayscale image
cropped_image = cropped_image[:, :, None]
elif cropped_image.shape[2] != 1: # Color image
cropped_image = cropped_image[:, :, :3] # Keep RGB channels only

# Ensure the mask has the correct shape
reshaped_mask = cropped_mask[:, :, None]

# Apply the mask
masked_image = np.where(cropped_mask[:, :, None], cropped_image, 0).astype(
np.uint8
)
masked_images.append(Image.fromarray(masked_image))
masked_image = np.where(reshaped_mask, cropped_image, 0).astype(np.uint8)
masked_images.append(Image.fromarray(masked_image.squeeze()))

logger.info(f"{len(masked_images)=}")
return masked_images
Expand Down
2 changes: 1 addition & 1 deletion openadapt/window/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_active_window_data(
"""
state = get_active_window_state(include_window_data)
if not state:
return None
return {}
title = state["title"]
left = state["left"]
top = state["top"]
Expand Down