In [None]:
import os
import re
import imageio
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from skimage.metrics import adapted_rand_error, variation_of_information, structural_similarity, mean_squared_error
from skimage import data, img_as_float
from skimage.filters import threshold_otsu, unsharp_mask, sobel, try_all_threshold, difference_of_gaussians, gaussian, rank
from skimage.segmentation import clear_border, chan_vese, watershed
from skimage.measure import label, regionprops, approximate_polygon
from skimage.morphology import closing, binary_closing, binary_opening, square, binary_dilation, binary_erosion, area_closing, area_opening, reconstruction, ball, disk, convex_hull_image
from skimage.color import label2rgb, rgb2gray
from skimage import exposure
from skimage.util import img_as_ubyte
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
from scipy import signal
import scipy.fft as fft
from scipy.fft import fft2, fftshift

from tqdm import tqdm
from collections.abc import Iterable
import functools



In [None]:
# https://github.com/scikit-image/scikit-image/blob/main/skimage/filters/_fft_based.py

new_float_type = {
    # preserved types
    np.float32().dtype.char: np.float32,
    np.float64().dtype.char: np.float64,
    np.complex64().dtype.char: np.complex64,
    np.complex128().dtype.char: np.complex128,
    # altered types
    np.float16().dtype.char: np.float32,
    'g': np.float64,      # np.float128 ; doesn't exist on windows
    'G': np.complex128,   # np.complex256 ; doesn't exist on windows
}

def _supported_float_type(input_dtype, allow_complex=False):
    """Return an appropriate floating-point dtype for a given dtype.
    float32, float64, complex64, complex128 are preserved.
    float16 is promoted to float32.
    complex256 is demoted to complex128.
    Other types are cast to float64.
    Parameters
    ----------
    input_dtype : np.dtype or Iterable of np.dtype
        The input dtype. If a sequence of multiple dtypes is provided, each
        dtype is first converted to a supported floating point type and the
        final dtype is then determined by applying `np.result_type` on the
        sequence of supported floating point types.
    allow_complex : bool, optional
        If False, raise a ValueError on complex-valued inputs.
    Returns
    -------
    float_type : dtype
        Floating-point dtype for the image.
    """
    if isinstance(input_dtype, Iterable) and not isinstance(input_dtype, str):
        return np.result_type(*(_supported_float_type(d) for d in input_dtype))
    input_dtype = np.dtype(input_dtype)
    if not allow_complex and input_dtype.kind == 'c':
        raise ValueError("complex valued input is not supported")
    return new_float_type.get(input_dtype.char, np.float64)

def _get_nd_butterworth_filter(shape, factor, order, high_pass, real,
                               dtype=np.float64, squared_butterworth=True):
    """Create a N-dimensional Butterworth mask for an FFT
    Parameters
    ----------
    shape : tuple of int
        Shape of the n-dimensional FFT and mask.
    factor : float
        Fraction of mask dimensions where the cutoff should be.
    order : float
        Controls the slope in the cutoff region.
    high_pass : bool
        Whether the filter is high pass (low frequencies attenuated) or
        low pass (high frequencies are attenuated).
    real : bool
        Whether the FFT is of a real (True) or complex (False) image
    squared_butterworth : bool, optional
        When True, the square of the Butterworth filter is used.
    Returns
    -------
    wfilt : ndarray
        The FFT mask.
    """
    ranges = []
    for i, d in enumerate(shape):
        # start and stop ensures center of mask aligns with center of FFT
        axis = np.arange(-(d - 1) // 2, (d - 1) // 2 + 1) / (d * factor)
        ranges.append(fft.ifftshift(axis ** 2))
    # for real image FFT, halve the last axis
    if real:
        limit = d // 2 + 1
        ranges[-1] = ranges[-1][:limit]
    # q2 = squared Euclidian distance grid
    q2 = functools.reduce(
            np.add, np.meshgrid(*ranges, indexing="ij", sparse=True)
            )
    q2 = q2.astype(dtype)
    q2 = np.power(q2, order)
    wfilt = 1 / (1 + q2)
    if high_pass:
        wfilt *= q2
    if not squared_butterworth:
        np.sqrt(wfilt, out=wfilt)
    return wfilt

def butterworth(
    image,
    cutoff_frequency_ratio=0.005,
    high_pass=True,
    order=2.0,
    channel_axis=None,
    *,
    squared_butterworth=True,
    npad=0,
):
    """Apply a Butterworth filter to enhance high or low frequency features.
    This filter is defined in the Fourier domain.
    Parameters
    ----------
    image : (M[, N[, ..., P]][, C]) ndarray
        Input image.
    cutoff_frequency_ratio : float, optional
        Determines the position of the cut-off relative to the shape of the
        FFT. Receives a value between [0, 0.5].
    high_pass : bool, optional
        Whether to perform a high pass filter. If False, a low pass filter is
        performed.
    order : float, optional
        Order of the filter which affects the slope near the cut-off. Higher
        order means steeper slope in frequency space.
    channel_axis : int, optional
        If there is a channel dimension, provide the index here. If None
        (default) then all axes are assumed to be spatial dimensions.
    squared_butterworth : bool, optional
        When True, the square of a Butterworth filter is used. See notes below
        for more details.
    npad : int, optional
        Pad each edge of the image by `npad` pixels using `numpy.pad`'s
        ``mode='edge'`` extension.
    Returns
    -------
    result : ndarray
        The Butterworth-filtered image.
    Notes
    -----
    A band-pass filter can be achieved by combining a high-pass and low-pass
    filter. The user can increase `npad` if boundary artifacts are apparent.
    The "Butterworth filter" used in image processing textbooks (e.g. [1]_,
    [2]_) is often the square of the traditional Butterworth filters as
    described by [3]_, [4]_. The squared version will be used here if
    `squared_butterworth` is set to ``True``. The lowpass, squared Butterworth
    filter is given by the following expression for the lowpass case:
    .. math::
        H_{low}(f) = \\frac{1}{1 + \\left(\\frac{f}{c f_s}\\right)^{2n}}
    with the highpass case given by
    .. math::
        H_{hi}(f) = 1 - H_{low}(f)
    where :math:`f=\\sqrt{\\sum_{d=0}^{\\mathrm{ndim}} f_{d}^{2}}` is the
    absolute value of the spatial frequency, :math:`f_s` is the sampling
    frequency, :math:`c` the ``cutoff_frequency_ratio``, and :math:`n` is the
    filter `order` [1]_. When ``squared_butterworth=False``, the square root of
    the above expressions are used instead.
    Note that ``cutoff_frequency_ratio`` is defined in terms of the sampling
    frequency, :math:`f_s`. The FFT spectrum covers the Nyquist range
    (:math:`[-f_s/2, f_s/2]`) so ``cutoff_frequency_ratio`` should have a value
    between 0 and 0.5. The frequency response (gain) at the cutoff is 0.5 when
    ``squared_butterworth`` is true and :math:`1/\\sqrt{2}` when it is false.
    Examples
    --------
    Apply a high-pass and low-pass Butterworth filter to a grayscale and
    color image respectively:
    >>> from skimage.data import camera, astronaut
    >>> from skimage.filters import butterworth
    >>> high_pass = butterworth(camera(), 0.07, True, 8)
    >>> low_pass = butterworth(astronaut(), 0.01, False, 4, channel_axis=-1)
    References
    ----------
    .. [1] Russ, John C., et al. The Image Processing Handbook, 3rd. Ed.
           1999, CRC Press, LLC.
    .. [2] Birchfield, Stan. Image Processing and Analysis. 2018. Cengage
           Learning.
    .. [3] Butterworth, Stephen. "On the theory of filter amplifiers."
           Wireless Engineer 7.6 (1930): 536-541.
    .. [4] https://en.wikipedia.org/wiki/Butterworth_filter
    """
    if npad < 0:
        raise ValueError("npad must be >= 0")
    elif npad > 0:
        center_slice = tuple(slice(npad, s + npad) for s in image.shape)
        image = np.pad(image, npad, mode='edge')
    fft_shape = (image.shape if channel_axis is None
                 else np.delete(image.shape, channel_axis))
    is_real = np.isrealobj(image)
    float_dtype = _supported_float_type(image.dtype, allow_complex=True)
    if cutoff_frequency_ratio < 0 or cutoff_frequency_ratio > 0.5:
        raise ValueError(
            "cutoff_frequency_ratio should be in the range [0, 0.5]"
        )
    wfilt = _get_nd_butterworth_filter(
        fft_shape, cutoff_frequency_ratio, order, high_pass, is_real,
        float_dtype, squared_butterworth
    )
    axes = np.arange(image.ndim)
    if channel_axis is not None:
        axes = np.delete(axes, channel_axis)
        abs_channel = channel_axis % image.ndim
        post = image.ndim - abs_channel - 1
        sl = ((slice(None),) * abs_channel + (np.newaxis,) +
              (slice(None),) * post)
        wfilt = wfilt[sl]
    if is_real:
        butterfilt = fft.irfftn(wfilt * fft.rfftn(image, axes=axes),
                                s=fft_shape, axes=axes)
    else:
        butterfilt = fft.ifftn(wfilt * fft.fftn(image, axes=axes),
                               s=fft_shape, axes=axes)
    if npad > 0:
        butterfilt = butterfilt[center_slice]
    return butterfilt

In [None]:
from google.colab import drive
drive.mount("/content/drive/")# definindo workdir

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
PATH_EP_ROOT = "/content/drive/MyDrive/EP - MAC0417 5768"
PATH_DATASET_GRAY = os.path.join(PATH_EP_ROOT, "dataset_gray")
PATH_DATASET_AUG = os.path.join(PATH_EP_ROOT, "dataset_augmented")
PATH_DATASET_GROUND = os.path.join(PATH_EP_ROOT, "dataset_ground_truth")
PATH_DATASET_NORMALIZED = os.path.join(PATH_EP_ROOT, "dataset_normalized")
classNames = ["tesoura", "garrafa", "chave", "prato", "livro", "sapato", "chinelo", "celular", "portacopo", "caneca"]

In [None]:
def getImagesDict(path):
  images = {}
  for dirpath, dirname, filename in os.walk(path):
      object_class = dirpath.split("/")[-1]
      for f in tqdm(filename, desc=f"{object_class}  -> "):
              if f.split(".")[-1] in ["jpg", "png", "jpeg"]:
                  path = os.path.join(dirpath, f)
                  img = imageio.imread(path)
                  images[f] = img
  return images

In [None]:
grayImagesDict = getImagesDict(PATH_DATASET_GRAY)

dataset_gray  -> : 0it [00:00, ?it/s]
celular  -> : 100%|██████████| 180/180 [00:02<00:00, 69.84it/s] 
sapato  -> : 100%|██████████| 216/216 [00:02<00:00, 83.82it/s] 
chinelo  -> : 100%|██████████| 144/144 [00:01<00:00, 75.85it/s] 
caneca  -> : 100%|██████████| 144/144 [00:01<00:00, 78.24it/s] 
tesoura  -> : 100%|██████████| 144/144 [00:02<00:00, 62.20it/s] 
livro  -> : 100%|██████████| 144/144 [00:01<00:00, 76.80it/s] 
portacopo  -> : 100%|██████████| 144/144 [00:02<00:00, 54.79it/s]
garrafa  -> : 100%|██████████| 144/144 [00:01<00:00, 73.31it/s] 
prato  -> : 100%|██████████| 144/144 [00:02<00:00, 68.48it/s] 
chave  -> : 100%|██████████| 144/144 [00:01<00:00, 85.58it/s] 


In [None]:
imagesDict = getImagesDict(PATH_DATASET_AUG)

dataset_augmented  -> : 0it [00:00, ?it/s]
celular  -> : 100%|██████████| 675/675 [00:13<00:00, 48.25it/s] 
sapato  -> : 100%|██████████| 720/720 [00:10<00:00, 69.91it/s] 
chinelo  -> : 100%|██████████| 720/720 [00:14<00:00, 48.60it/s] 
caneca  -> : 100%|██████████| 720/720 [00:13<00:00, 53.54it/s] 
tesoura  -> : 100%|██████████| 720/720 [00:12<00:00, 56.03it/s] 
livro  -> : 100%|██████████| 720/720 [00:11<00:00, 61.66it/s] 
portacopo  -> : 100%|██████████| 720/720 [00:12<00:00, 56.93it/s] 
garrafa  -> : 100%|██████████| 720/720 [00:14<00:00, 50.15it/s] 
prato  -> : 100%|██████████| 720/720 [00:11<00:00, 64.29it/s] 
chave  -> : 100%|██████████| 720/720 [00:09<00:00, 74.57it/s] 


In [None]:
groundTruthImagesDict = getImagesDict(PATH_DATASET_GROUND)

dataset_ground_truth  -> : 0it [00:00, ?it/s]
garrafa  -> : 100%|██████████| 22/22 [00:06<00:00,  3.25it/s]
portacopo  -> : 100%|██████████| 22/22 [00:06<00:00,  3.28it/s]
prato  -> : 100%|██████████| 21/21 [00:06<00:00,  3.30it/s]
caneca  -> : 100%|██████████| 27/27 [00:08<00:00,  3.22it/s]
livro  -> : 100%|██████████| 27/27 [00:07<00:00,  3.45it/s]
sapato  -> : 100%|██████████| 26/26 [00:08<00:00,  3.13it/s]
celular  -> : 100%|██████████| 24/24 [00:07<00:00,  3.28it/s]
chinelo  -> : 100%|██████████| 24/24 [00:07<00:00,  3.34it/s]
tesoura  -> : 100%|██████████| 21/21 [00:05<00:00,  3.54it/s]
chave  -> : 100%|██████████| 22/22 [00:06<00:00,  3.21it/s]


In [None]:
normalizedImagesDict = getImagesDict(PATH_DATASET_NORMALIZED)

dataset_normalized  -> : 0it [00:00, ?it/s]
prato  -> : 100%|██████████| 720/720 [00:15<00:00, 45.93it/s] 
celular  -> : 100%|██████████| 675/675 [00:11<00:00, 60.85it/s] 
sapato  -> : 100%|██████████| 720/720 [00:12<00:00, 57.76it/s] 
chinelo  -> : 100%|██████████| 720/720 [00:12<00:00, 56.37it/s] 
caneca  -> : 100%|██████████| 720/720 [00:13<00:00, 55.22it/s] 
tesoura  -> : 100%|██████████| 720/720 [00:10<00:00, 68.50it/s] 
livro  -> : 100%|██████████| 720/720 [00:13<00:00, 52.89it/s] 
portacopo  -> : 100%|██████████| 720/720 [00:11<00:00, 61.08it/s] 
garrafa  -> : 100%|██████████| 720/720 [00:10<00:00, 71.33it/s] 
chave  -> : 100%|██████████| 720/720 [00:14<00:00, 51.37it/s] 


In [None]:
def getOriginalGroundTruthDict(imagesDict, groundTruthImageDict):
  images = {}
  for key in groundTruthImageDict.keys():
    filename = re.sub("_[a-zA-Z]+\.png", ".png", key)
    if filename in imagesDict:
      images[filename] = imagesDict[filename]
  return images

originalGroundTruthImagesDict = getOriginalGroundTruthDict(grayImagesDict, groundTruthImagesDict)
originalGroundTruthImagesList = list(originalGroundTruthImagesDict.values())

In [None]:
def plot(image):
  fig, ax = plt.subplots(figsize=(5, 3))
  ax.imshow(image)
  ax.set_axis_off()
  plt.tight_layout()
  plt.show()

# def applyFourier(image):
#   return np.fft.fftshift(np.fft.fft2(image))

# def undoFourier(image):
#   return abs(np.fft.ifft2(image))

def isImageBorderTrue(image):
  radius = 10
  border = np.concatenate((image[0:radius,:].flat, image[-1,-radius:].flat, image[:,0:radius].flat, image[:,-1:-radius].flat))
  return round(np.count_nonzero(border) / len(border))

def isImageMostlyTrue(image):
  flat = image.flat
  return round(np.count_nonzero(flat) / len(flat))

square3 = square(3)
square5 = square(5)
square10 = square(10)
square50 = square(50)


def plot_label(label_image, image, regions=None):
    # to make the background transparent, pass the value of `bg_label`,
    # and leave `bg_color` as `None` and `kind` as `overlay`
    image_label_overlay = label2rgb(label_image, image=image, bg_label=0)
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.imshow(image_label_overlay)

    if regions==None:
      regions = regionprops(label_image)
    for region in regions:
        # take regions with large enough areas
        if region.area >= 625:
            # draw rectangle around segmented coins
            minr, minc, maxr, maxc = region.bbox
            rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                      fill=False, edgecolor='red', linewidth=2)
            ax.add_patch(rect)
    ax.set_axis_off()
    plt.tight_layout()

def reconstruct(image):
  seed = np.copy(image)
  seed[1:-1, 1:-1] = image.max()
  mask = image
  image = reconstruction(seed, mask, method='erosion')
  return image


In [None]:
def segmentation(original, plotSteps = True, sigma1 = 4, sigma2 = 4.1):
  if plotSteps:
    print("original")
    plot(original)

  def tryToPlot(text):
    if (plotSteps):
      print(text)
      plot(image)
  
  image = original



  
  # remove high frequency
  # sharpMask = unsharp_mask(image, radius=10, amount=50)
  # image = image * sharpMask
  # tryToPlot("Sharpening")

  # sobelMask = sobel(image)
  # sobelMask1 = binary_dilation(sobelMask, square10)
  # sobelMask2 = binary_dilation(sobelMask, square5)
  # image = image + sobelMask1 * 100
  # image = image + sobelMask2 * -200
  # image[image < 0] = 0
  # tryToPlot("Sobel")


  image = difference_of_gaussians(image, sigma1, sigma2)
  tryToPlot("DoG")

  # image = butterworth(image, 0.07, False, 1, channel_axis=-1)
  # tryToPlot("Butter")

  # apply threshold
  thresh = threshold_otsu(image)
  if plotSteps:
    fig, ax = try_all_threshold(image, figsize=(10, 6), verbose=False)
    plt.show()
  # image = image * (image > thresh)
  image = image > thresh
  tryToPlot("Threshold")

  # image = binary_closing(image, square5)
  # tryToPlot("Closing")

  # image = difference_of_gaussians(image, 0.9, 1)
  # tryToPlot("DoG")

  # image = butterworth(image, 0.07, False, 1, channel_axis=-1)
  # tryToPlot("Butter")

  # image = gaussian(image, 1)
  # tryToPlot("Gaussian")

  # thresh = threshold_otsu(image)
  # image = image > thresh
  # tryToPlot("Threshold")

  # remove artifacts connected to image border
  # toBeCleared = image
  # if isImageBorderTrue(toBeCleared):
  #   toBeCleared = toBeCleared == False
  # tryToPlot("toBeCleared")
  # image = clear_border(toBeCleared)
  # tryToPlot("cleared")

  if isImageMostlyTrue(image):
    image = image == False
  tryToPlot("Corrected")

  # image = binary_dilation(image, square5)
  # tryToPlot("dilation")
  
  image = binary_closing(image, square3)
  tryToPlot("Closing")

  # image = binary_opening(image, square3)
  # tryToPlot("Opening")

  image = area_opening(image, area_threshold = 512)
  tryToPlot("Area Opening")

  image = binary_closing(image, square10)
  tryToPlot("Closing")

  image = reconstruct(image)
  tryToPlot("Fill holes")


  # image = area_closing(image, area_threshold = 256)
  # tryToPlot("Area Closing")

  # if isImageBorderTrue(image):
  #   image = image == False
  # tryToPlot("imageToBeLabeled")


  # label image regions
  label_image = label(image)

  # use only biggest image blob?
  # use morphological snakes to approximate the contour?'
  props = regionprops(label_image)
  values = []
  def getValue(region):
    m, n = image.shape
    centerx = m / 2
    centery = n / 2
    cx, cy = region.centroid
    
    dx = cx - centerx
    dy = cy - centery
    distanceToCenter = np.sqrt(dx * dx + dy * dy)
    if distanceToCenter == 0:
      distanceToCenter = 0.1
    perimeter = region.perimeter
    if (perimeter == 0):
      perimeter = 0.1
    return 1/(distanceToCenter ** 2) * region.area
  
  for region in props:
    value = getValue(region)
    values.append(value)
  valueThreshold = np.mean(values)
  chosenRegions = []
  for region in props:
    value = getValue(region)
    if value >= valueThreshold:
      chosenRegions.append(region)

  image[:] = False
  for region in chosenRegions:
    mask = label_image == region.label
    image = image + mask
  tryToPlot("Chosen regions")

  if np.count_nonzero(image) > 0:
    image = convex_hull_image(image)
  tryToPlot("Convex hull")

  if plotSteps:
    plot_label(label_image, original, chosenRegions)

  return image, label_image

dictToSegment = grayImagesDict
for className in classNames:
  filename = f"{className}_obj1_p1_di_f1.png"
  if (filename not in dictToSegment):
    print("Not found: ", filename)
    continue
  image, label_image = segmentation(dictToSegment[filename], plotSteps = True, sigma1 = 4, sigma2 = 4.1)
  # for sigma in [3, 4, 5, 6, 7, 8, 10]:
  #   sigma2 = sigma + 0.1
  #   print(sigma, sigma2)
  #   image, label_image = segmentation(dictToSegment[filename], plotSteps = False, sigma1 = sigma, sigma2 = sigma2})
  #   plot(image)
  #   plot_label(label_image, dictToSegment[filename])
  # for step in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
  #   sigma = 4
  #   sigma2 = sigma + step
  #   print(sigma, sigma2)
  #   image, label_image = segmentation(dictToSegment[filename], plotSteps = False, sigma1 = sigma, sigma2 = sigma2)
  #   plot(image)
  #   plot_label(label_image, dictToSegment[filename])

Output hidden; open in https://colab.research.google.com to view.

In [None]:
def test_chan_vese (original):
  for mu in [0.25, 0.50, 0.75]:
    for lambda1 in [1, 2, 3]:
      for lambda2 in [1, 2, 3]:
        for dt in [0.25, 0.5, 0.75]:
          # for init_level_set in ['checkerboard', 'disk', 'small disk']:
            init_level_set = 'checkerboard'
            print(mu, lambda1, lambda2, dt, init_level_set)
            cv = chan_vese(original, mu=mu, lambda1=lambda1, lambda2=lambda2, tol=1e-3, dt=dt, init_level_set=init_level_set,
                          extended_output=True)

            fig, axes = plt.subplots(2, 2, figsize=(8, 8))
            ax = axes.flatten()

            ax[0].imshow(image, cmap="gray")
            ax[0].set_axis_off()
            ax[0].set_title("Original Image", fontsize=12)

            ax[1].imshow(cv[0], cmap="gray")
            ax[1].set_axis_off()
            title = f'Chan-Vese segmentation - {len(cv[2])} iterations'
            ax[1].set_title(title, fontsize=12)

            ax[2].imshow(cv[1], cmap="gray")
            ax[2].set_axis_off()
            ax[2].set_title("Final Level Set", fontsize=12)

            ax[3].plot(cv[2])
            ax[3].set_title("Evolution of energy over iterations", fontsize=12)

            fig.tight_layout()
            plt.show()

def segmentByChanVese(original, plotSteps = False):
  if plotSteps:
    print("Original")
    plot(original)
  def tryToPlot(text):
    if (plotSteps):
      print(text)
      plot(image)
  mu = 0.2
  lambda1 = 1
  lambda2 = 1
  dt = 0.5
  tol = 1e-3
  init_level_set="checkerboard"

  image = original


  # image = difference_of_gaussians(image, 1, 1.1)
  # tryToPlot("DoG")


  cv = chan_vese(image, mu=mu, lambda1=lambda1, lambda2=lambda2, tol=tol, dt=dt, init_level_set=init_level_set, extended_output=True)
  image, phi, energies = cv
  tryToPlot("ChanVese")
  if isImageMostlyTrue(image):
    image = image == False
  tryToPlot("Corrected")
  # image = binary_closing(image, square5)
  # tryToPlot("Closing")
  # image = reconstruct(image)
  # tryToPlot("Reconstruction")
  # image = binary_opening(image, square10)
  # tryToPlot("Opening")
  image = gaussian(image, 5)
  tryToPlot("Gaussian")
  if plotSteps:
    fig, ax = try_all_threshold(image, figsize=(10, 6), verbose=False)
    plt.show()
  # thresh = threshold_otsu(image)
  thresh = np.max(image) * 0.8
  print(thresh)
  image = image > thresh
  tryToPlot("Threshold")
  # image = binary_opening(image, square10)
  # tryToPlot("Opening")

  label_image = label(image)
  if plotSteps:
    plot_label(label_image, original)

  return image
for className in classNames:
  filename = f"{className}_obj1_p1_di_f1.png"
  if (filename not in dictToSegment):
    print("Not found: ", filename)
    continue
  image = segmentByChanVese(dictToSegment[filename], plotSteps = True)
#test_chan_vese(dictToSegment['sapato_obj1_p1_di_f3.png'])

Output hidden; open in https://colab.research.google.com to view.

In [None]:
from sklearn.metrics import jaccard_score

def jaccard_index(img_a: np.array, img_b: np.array) -> float:
  if img_a.size != img_b.size:
    raise ValueError('Image sizes differ.')
  img_a_flat = img_a.copy().flatten() == 1
  img_b_flat = img_b.copy().flatten() == 1
  return jaccard_score(img_a_flat, img_b_flat)

def getDictOfClass(dictIndexedByFilename, nameOfClass):
  dictOfClass = {}
  for key, value in dictIndexedByFilename.items():
    if nameOfClass in key:
      dictOfClass[key] = value

  return dictOfClass

def getAllDictsOfClass(dictIndexedByFilename):
  allDicts = {}
  for className in classNames:
    allDicts[className] = getDictOfClass(dictIndexedByFilename, className)
  return allDicts

def testAccuracyOfSegmentation(groundTruthImagesDict, originalGroundTruthImagesDict, segmentationFunction, filterFilename=None, plotSteps=False):
  errorDict = {}
  for filename, image in originalGroundTruthImagesDict.items():
    if (filterFilename and filterFilename not in filename):
      continue
    if (filename not in groundTruthImagesDict):
      print("Ground truth image not found:", filename)
      continue
    imageToBeLabeled = segmentationFunction(image)
    if np.count_nonzero(imageToBeLabeled) == 0:
      continue
    error, precision, recall = adapted_rand_error(groundTruthImagesDict[filename], imageToBeLabeled)
    splits, merges = variation_of_information(groundTruthImagesDict[filename], imageToBeLabeled)
    mse = mean_squared_error(groundTruthImagesDict[filename], imageToBeLabeled)
    ssi = structural_similarity(groundTruthImagesDict[filename], imageToBeLabeled, data_range=1)
    currentErrors = {
        "mean_squared_error" : mse,
        "structural_similarity" : ssi,
        "error" : error,
        "precision" : precision,
        "recall" : recall,
        "splits" : splits,
        "merges" : merges
    }
    if plotSteps:
      print(filename)
      plot(imageToBeLabeled)
      plot(groundTruthImagesDict[filename])
      print(currentErrors)
    errorDict[filename] = currentErrors
  errors = list(errorDict.values())
  print("Errors", errorDict)
  def getMean(errors):
    if (not errors or len(errors) == 0):
      return {}
    meanIndexes = {}
    for key in errors[0].keys():
      meanIndexes[key] = np.mean([e[key] for e in errors])
    return meanIndexes
  meanIndexes = getMean(errors)
  print("Mean errors\n", meanIndexes)
  allDictsOfClass = getAllDictsOfClass(errorDict)
  print(allDictsOfClass)
  classMeanDict = {}
  for k, v in allDictsOfClass.items():
    classErrors = list(v.values())
    classMeanDict[k] = getMean(classErrors)
  print("Errors by class\n", classMeanDict)
  return errorDict, allDictsOfClass, meanIndexes, classMeanDict

error = testAccuracyOfSegmentation(groundTruthImagesDict, originalGroundTruthImagesDict, lambda image : segmentation(image, plotSteps = False)[0], filterFilename='f1')
error = testAccuracyOfSegmentation(groundTruthImagesDict, originalGroundTruthImagesDict, lambda image : segmentation(image, plotSteps = False)[0], filterFilename='f2')
error = testAccuracyOfSegmentation(groundTruthImagesDict, originalGroundTruthImagesDict, lambda image : segmentation(image, plotSteps = False)[0], filterFilename='f3')
# error = testAccuracyOfSegmentation(groundTruthImagesDict, originalGroundTruthImagesDict, lambda image : segmentByChanVese(image, plotSteps = False))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Errors {'garrafa_obj1_p2_de_f1.png': {'mean_squared_error': 9562.177154541016, 'structural_similarity': 0.8012295935396861, 'error': 0.10511283723163678, 'precision': 0.8491724548926355, 'recall': 0.9458039724943059, 'splits': 0.09851508394531436, 'merges': 0.2573666061930155}, 'garrafa_obj1_p2_ni_f1.png': {'mean_squared_error': 8016.714004516602, 'structural_similarity': 0.8211312813826878, 'error': 0.05671897461085462, 'precision': 0.9525548169289546, 'recall': 0.9341860665618726, 'splits': 0.10906342936690958, 'merges': 0.21621229970715228}, 'garrafa_obj1_p3_de_f1.png': {'mean_squared_error': 9486.23063659668, 'structural_similarity': 0.8017635532591062, 'error': 0.07004663208482209, 'precision': 0.9101821162280148, 'recall': 0.9506026466095221, 'splits': 0.08288171424172106, 'merges': 0.23488464036504803}, 'garrafa_obj2_p2_de_f1.png': {'mean_squared_error': 7926.325256347656, 'structural_similarity': 0.835861166323417, 'error': 0.13534372861051092, 'precision': 0.8074198377152592, 

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Errors {'garrafa_obj1_p1_ne_f2.png': {'mean_squared_error': 6634.256988525391, 'structural_similarity': 0.8672495857440881, 'error': 0.23885956543122655, 'precision': 0.6708881431131878, 'recall': 0.8794498158155948, 'splits': 0.07383530459465713, 'merges': 0.4826160694827116}, 'garrafa_obj1_p2_di_f2.png': {'mean_squared_error': 6605.22346496582, 'structural_similarity': 0.7419066366350212, 'error': 0.08229778157798817, 'precision': 0.9548108013852732, 'recall': 0.8833701653890444, 'splits': 0.4943446478525184, 'merges': 0.34509975282983446}, 'garrafa_obj1_p3_ni_f2.png': {'mean_squared_error': 7288.2144775390625, 'structural_similarity': 0.8458099599871197, 'error': 0.2376659767842354, 'precision': 0.6992877942738774, 'recall': 0.8378749623443992, 'splits': 0.07511114031846243, 'merges': 0.5619861543153579}, 'garrafa_obj2_p2_di_f2.png': {'mean_squared_error': 6609.116470336914, 'structural_similarity': 0.8634002582655048, 'error': 0.29516567287470796, 'precision': 0.6036646309922288, '