In [None]:
import os

# Widgets
import ipywidgets as widgets
from IPython.display import display

# Display image
import cv2
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.color import gray2rgb

from typing import List, Tuple

from demo_utils import *

In [None]:
%matplotlib tk
matplotlib.rcParams["figure.figsize"] = (12, 10)

In [None]:
if not os.path.exists(folder_out_path):
    os.makedirs(folder_out_path)

In [None]:
image_original_path = os.path.join(folder_path, image_original_filename)

# Read image with opencv
img_ori = cv2.imread(image_original_path)

In [None]:
image_rgb_path = os.path.join(folder_path, image_rgb_filename)

# Read image with opencv
img_rgb = cv2.imread(image_rgb_path)

In [None]:
# Read superposition

def get_superposition(superposition_filename: str) -> dict:
    """
    Create a dictionary with the superposition points
    :param superposition_filename: The superposition file
    :return: A dictionary with the superposition points
    ""
    """
    superposition = {}
    with open(os.path.join(folder_path, superposition_filename), 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                line = line.split(',')
                line = [int(l) for l in line]
                if not line[-1] in superposition:
                    superposition[line[-1]] = []
                superposition[line[-1]].append(line[:-1])
    return superposition

def get_superposition_map(superposition: dict) -> np.ndarray:
    """
    Compute the superposition img
    :param superposition: The superposition dictionary
    :return: The superposition img
    """
    ret = np.zeros(img_ori.shape[:2])
    for label, points in superposition.items():
        for point in points:
            ret[point[1], point[0]] += 1
    return ret

superposition = get_superposition(superposition_filename)
superposition_map = get_superposition_map(superposition)
# superposition

In [None]:
# Read label dict

def get_label_rgb_dict(label_dict_filename : str) -> dict:
    """
    Create a dictionary with the label and rgb values
    :param label_dict_filename: The label dict file
    :return: A dictionary with the label and rgb values
    """
    label_rgb_dict = {}
    with open(os.path.join(folder_path, label_dict_filename), 'r') as f:
        for line in f:
            line = line.strip()
            if line:
                line = line.split(',')
                line = [int(l) for l in line]
                label_rgb_dict[line[0]] = np.array(line[1:])
    return label_rgb_dict

label_rgb_dict = get_label_rgb_dict(label_dict_filename)
rgb_label_dict = {tuple(v): k for k, v in label_rgb_dict.items()}

In [None]:
def highlight(mask_rgb_value_to_keep : List[int])-> np.ndarray:
    """
    Get the mask of highlighted points in mask_rgb_value_to_keep
    :param mask_rgb_value_to_keep: Value to keep in the mask
    :return: The mask of highlighted points in mask_rgb_value_to_keep
    """
    img_copy = img_rgb.copy()

    mask = ~mask_rgb_value_to_keep & ~np.all(img_copy == [255, 255, 255], axis=2)
    img_copy[mask, :] = [0, 0, 0]
    
    mask = mask_rgb_value_to_keep
    img_copy[mask, :] = [255, 0, 0]

    return img_copy

def build_mask_rgb(label_value_to_keep : List[int], highlight : bool):
    """
    Build the mask of the label_value_to_keep
    :param label_value_to_keep: The label value to keep
    :param highlight: If the superposition should be highlighted
    :return: The mask of the label_value_to_keep
    """
    rgb_value_to_keep = [label_rgb_dict[label_value_to_keep] for label_value_to_keep in label_value_to_keep]

    # Everything that is not rgb_value_to_keep (neither background) is black
    mask_list = [np.all(img_rgb == rgb_value_to_keep, axis=2) for rgb_value_to_keep in rgb_value_to_keep]
    mask_rgb_value_to_keep = mask_list[0]
    for m in mask_list[1:]:
        mask_rgb_value_to_keep |= m
    
    superposition_to_highlight = []
    for label in label_value_to_keep:
        if label in superposition:
            superposition_to_highlight += superposition[label]

    superposition_map_cpy = superposition_map.copy()
    for x, y in superposition_to_highlight:
        superposition_map_cpy[y, x] -= 1
        if highlight or superposition_map_cpy[y, x] == 0:
            mask_rgb_value_to_keep[y, x] = True

    return mask_rgb_value_to_keep

def build_image(label_value_to_keep : List[int], highlight:bool=True, display_mode:str="ori"):
    """
    Build the image according to settings
    :param label_value_to_keep: The label value to keep
    :param highlight: If the superposition should be highlighted, otherwise it will be delete
    :param display_mode: The display mode (ori, label, bin)
    :return: The image
    """
    img_copy = img_rgb.copy() if display_mode != "ori" else img_ori.copy()
    if not label_value_to_keep:
        return img_copy

    mask_rgb_value_to_keep = build_mask_rgb(label_value_to_keep, highlight)
    mask_all_execpt_rgb_value_to_keep = ~mask_rgb_value_to_keep & ~np.all(img_copy == [255, 255, 255], axis=2)

    color = [255, 0, 0] if highlight else [255, 255, 255]

    if display_mode == 'label':
        img_copy[mask_rgb_value_to_keep, :] = color
    if display_mode == 'ori':
        if highlight:
            img_copy[mask_rgb_value_to_keep, :] = color
        else:
            mask_rgb_value_to_keep = mask_rgb_value_to_keep.astype(np.uint8) * 255
            img_copy = cv2.inpaint(img_copy,mask_rgb_value_to_keep,3,cv2.INPAINT_NS)
    if display_mode == 'bin':
        img_copy[mask_all_execpt_rgb_value_to_keep, :] = [0, 0, 0]
        img_copy[mask_rgb_value_to_keep, :] = color

    return img_copy

In [None]:
slider = widgets.IntSlider(min=2, max=len(label_rgb_dict.keys())-1, step=1, value=2)

w_out = widgets.Output()

def on_value_change(change):
    """
    Display the image highliting the particular line according to the slider value
    :param change: The change
    """
    w_out.clear_output()
    with w_out:
        label = change['new']
        img_copy = build_image([label])
        plt.clf()
        plt.imshow(img_copy)
        plt.show()

slider.observe(on_value_change, names='value')

display(slider)
display(w_out)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

fig = plt.figure(figsize=(9,14))

delete_mode = False
display_mode = False
current_label_values = []
img_displayed = None

w_out = widgets.Output()

def search_nearest_active(x : int, y:int) -> Tuple[int, int]:
    """
    Search the nearest active pixel
    :param x: The x coordinate of the pixel
    :param y: The y coordinate of the pixel
    :return: The x and y coordinate of the nearest active pixel
    """
    radius = 1
    while True:
        for i in range(-radius, radius+1, ):
            for j in range(-radius, radius+1):
                if i == -radius or i == radius or j == -radius or j == radius:
                    if y+j >= 0 and y+j < img_rgb.shape[0] and x+i >= 0 and x+i < img_rgb.shape[1] :
                        if np.all(img_rgb[y+j, x+i] != 255):
                            return x+i, y+j
        radius += 1

def display_image() -> None:
    """
    Display the image
    """
    global delete_mode
    global display_mode
    global img_displayed

    global current_label_values

    img_closest = build_image(current_label_values, highlight=not delete_mode, display_mode=display_mode)

    with w_out:
        w_out.clear_output()
        plt.clf()
        plt.imshow(img_closest)
        img_displayed = img_closest
        plt.show()

def onclick(event):
    """
    On click event
    :param event: The event
    """
    global selection_mode
    global delete_mode
    global display_mode

    global current_label_values

    ix, iy = event.xdata, event.ydata
    cx, cy = int(ix), int(iy)
    nx, ny = search_nearest_active(cx, cy)
    rgb_value = img_rgb[ny, nx]
    label_value = rgb_label_dict[tuple(rgb_value)]

    if not label_value in current_label_values:
        current_label_values.append(label_value)
    else:
        current_label_values.remove(label_value)

    display_image()

def on_key(event):
    """
    On key event
    :param event: The event
    """
    global display_mode
    global delete_mode
    global img_displayed

    if event.key == 'd':
        delete_mode = not delete_mode
        display_image()
    if event.key == 'r':
        display_mode = "ori" if display_mode == "label" else ("label" if display_mode == "bin" else ("bin" if display_mode == "ori" else "ori"))
        display_image()
    if event.key == 'v':
        path_out = os.path.join(folder_out_path, f'{image_original_filename}_{display_mode=}_{delete_mode=}.png')
        cv2.imwrite(path_out, img_displayed)

fig.canvas.mpl_connect('button_press_event', onclick)
fig.canvas.mpl_connect('key_press_event', on_key)

display_image()