In [1]:
import cv2
from skimage.segmentation import slic, mark_boundaries
from skimage.restoration import denoise_bilateral
from skimage.io import imread
from skimage.transform import rescale, resize

import numpy as np
import matplotlib.pyplot as plt

In [63]:
# image = imread("/home/ubuntu/531763067-0c995e2e-649c-4a87-9b89-31f2eed67255.jpeg")
image = imread("/home/ubuntu/bigstock-Sailboats-At-Sea-2767411-1.png")
# image = imread("/home/ubuntu/animal_facts-e1396431549968.jpg")
# image = imread("/home/ubuntu/534166831-ece1f0ab-13f7-4bcb-b9c6-2697082bb49b.png")
# image = imread("/home/ubuntu/534167110-a77f9c2c-b1fb-4349-8b45-0ffb8beb2898.png")
# image = imread("/home/ubuntu/531764276-e3d4588d-2764-41a2-a4b8-53d476d443bb.jpeg")

In [None]:
image.shape

In [65]:
# RESIZE for "big" image bc of resursion limit... at least in jupyter notebook
# image = rescale(image, 0.125, channel_axis=-1, anti_aliasing=True)

In [None]:
image.shape

In [None]:
plt.imshow(image)
plt.show()

In [68]:
img_filt = denoise_bilateral(
    np.asarray(image),
    sigma_color=0.05,
    sigma_spatial=3.0,
    channel_axis=-1,
    mode='edge',
)

In [69]:
from sklearn.cluster import KMeans

In [None]:
ncolors = 16
kmeans = KMeans(ncolors)
kmeans.fit(img_filt.reshape(-1, 3))
img_kmean_labels = kmeans.transform(img_filt.reshape(-1, 3)).reshape((*img_filt.shape[:2], ncolors))
img_kmean_labels = np.argmin(img_kmean_labels,-1) 

In [None]:
plt.imshow(img_kmean_labels)

In [72]:
import sys
sys.setrecursionlimit(100000)

from collections import deque

# also return actual pixels
def DFS(image_array, label_array, x, y, target_value, label_value, count=0, coords=[]):
    """recursive"""
    if (
        (x < 0)
        or (x >= image_array.shape[1])
        or (y < 0)
        or (y >= image_array.shape[0])
        or (image_array[y, x] != target_value)
        or (label_array[y, x] != -1) # stop if already labeled
    ):
        return count, coords
    # print(y, x)
    label_array[y, x] = label_value
    coords.append((y, x))
    count += 1
    count, coords = DFS(image_array, label_array, x + 1, y, target_value, label_value, count, coords)
    count, coords = DFS(image_array, label_array, x, y + 1, target_value, label_value, count, coords)
    count, coords = DFS(image_array, label_array, x, y - 1, target_value, label_value, count, coords)
    count, coords = DFS(image_array, label_array, x - 1, y, target_value, label_value, count, coords)
    return count, coords

def floodfill(image_array, label_array, x, y, target_value, label_value, color_array):
    """non recursive"""
    queue = deque([(y, x)])

    count = 0
    coords = []
    colors = []

    label_array[y, x] = label_value
    coords.append((y, x))
    colors.append(color_array[y, x])
    count += 1

    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
    while queue:
        y, x = queue.popleft()
        for dx, dy in directions:
            x1 = x + dx
            y1 = y + dy
            if (
                (x1 >= 0)
                and (x1 < image_array.shape[1])
                and (y1 >= 0)
                and (y1 < image_array.shape[0])
                and (image_array[y1, x1] == target_value)
                and (label_array[y1, x1] == -1) # stop if already labeled
            ):
                label_array[y1, x1] = label_value
                coords.append((y1, x1))
                colors.append(color_array[y, x])
                count += 1

                queue.append((y1, x1))

    return count, coords, colors

region_labels = np.ones(img_kmean_labels.shape, dtype=int) * -1

c_lbl = -1 # current
area_counts = []
region_pixels = []
region_colors = []
for i in range(img_kmean_labels.shape[0]):
    for j in range(img_kmean_labels.shape[1]):
        # print(i,j)
        pix = img_kmean_labels[i, j]
        rlab = region_labels[i, j]
        # this pixel hasn't been labeled yet
        if rlab == -1:
            # print("here")
            # print(i, j, pix)
            c_lbl = c_lbl + 1
            # 
            # flood fill all neighboring pixels with values pix to c_lbl.
            # coords = []
            # count, coords = DFS(img_kmean_labels, region_labels, j, i, pix, c_lbl, 0, coords)
            count, coords, colors = floodfill(img_kmean_labels, region_labels, j, i, pix, c_lbl, img_filt)
            # print(count)
            area_counts.append(count)
            region_pixels.append(coords)
            region_colors.append(colors)


In [None]:
colors

In [None]:
plt.imshow(region_labels)
for coords in region_pixels:
    centroid = np.mean(np.array(coords), 0)
    plt.plot(centroid[1], centroid[0], '.r')

plt.title("Region centroids")
plt.show()

In [75]:
from copy import deepcopy

In [None]:
# merge small areas
minArea = int(0.001 * img_kmean_labels.shape[0] * img_kmean_labels.shape[1])
print(minArea)
img_kmean_labels2 = np.copy(img_kmean_labels)
region_labels2 = np.copy(region_labels)


In [77]:
class Node:
    def __init__(self, id, area, coords, colors):
        self.id = id
        self.area = deepcopy(area)
        self.coords = deepcopy(coords)
        self.colors = np.array(deepcopy(colors))
        self.edges = set()

    def __repr__(self):
        return f"Node {self.id} with area {self.area} and color {self.color}"
    
    @property
    def centroid(self):
        return np.mean(np.array(self.coords), 0)
    
    @property
    def color(self):
        return np.mean(self.colors, 0)
    
    def add_colors(self, new_colors):
        self.colors = np.concatenate((self.colors, new_colors), 0)
        
    def add_edge(self, node):
        i = self.find_edge(node)
        if i == -1:
            self.edges.add(node)

    def find_edge(self, node):
        for i,n in enumerate(self.edges):
            if n.id == node.id:
                return i
        return -1
    
    def remove_all_edges(self):
        self.edges = set()

    def remove_edge(self, node):
        # i = self.find_edge(node)
        # 
        # if i > -1:
        #     _ = self.edges.pop(i)
        self.edges.discard(node)

In [78]:
# take every region and encode as Node
nodes = []
for i, (area, coords, colors) in enumerate(zip(area_counts, region_pixels, region_colors)):
    nodes.append(Node(i, area, coords, colors))

In [None]:
len(nodes)

In [None]:
nodes[10].color

In [81]:
class Graph:
    def __init__(self, nodes):
        """initialize from set of nodes"""
        self.nodes = deepcopy(nodes)
        self.node_ids = {n.id: j for j, n in enumerate(self.nodes)}

    def find_node_by_id(self, id):
        for i,n in enumerate(self.nodes):
            if n.id == id:
                return i
        return -1
    
    def find_node(self, node):
        for i,n in enumerate(self.nodes):
            if n.id == node.id:
                return i
        return -1
    
    # def add_node(self, node: Node):
    #    n_id = self.find_node(node)
    #    if n_id == -1:
    #        self.nodes.append(node)
    #    else:
    #        # update parameteres
    #        self.nodes[n_id] = node
    
    # def remove_node(self, id):
    #     n_id = self.find_node_by_id(id)
    #     if n_id > -1:
    #         removed_node = self.nodes.pop(n_id)
    #         # remove edges of all nodes that connect to this node
    #         edges = removed_node.edges
    #         for n in edges:
    #             n.remove_edge(removed_node)

    def merge_nodes(self, node_to_keep, node_to_remove):
        # idx1 = self.find_node(node_to_keep)
        # idx2 = self.find_node(node_to_remove)
        idx1 = self.node_ids.get(node_to_keep.id, -1)
        idx2 = self.node_ids.get(node_to_remove.id, -1)
        if idx1 == -1:
            print(f"{node_to_keep} isn't in graph")
            return
        if idx2 == -1:
            print(f"{node_to_remove} isn't in graph")
            return
        
        # transfer edges from removed node to absorbing node
        # update edges of removed node neighbors to reference absorbing node now

        transfer_edges = node_to_remove.edges
        for n in transfer_edges:
            if n.id != node_to_keep.id:
                n.add_edge(node_to_keep)
                node_to_keep.add_edge(n)

            # n.remove_edge(node_to_remove)
            # node_to_remove.remove_edge(n)
        
        node_to_remove.remove_all_edges()
            
                
        
        # combine areas and coordinates
        node_to_keep.area = node_to_keep.area + deepcopy(node_to_remove.area)
        node_to_keep.coords.extend(deepcopy(node_to_remove.coords))
        # node_to_keep.colors.extend(deepcopy(node_to_remove.colors))
        node_to_keep.add_colors(deepcopy(node_to_remove.colors))

        node_to_remove.area = 0
        node_to_remove.coords = []
        node_to_remove.colors = []

        # _ = self.nodes.pop(idx2)

        # don't remove node from graph but remove all of it's edges


    def add_edge(self, node_id1, node_id2):
        # idx1 = self.find_node_by_id(node_id1)
        # idx2 = self.find_node_by_id(node_id2)
        idx1 = self.node_ids.get(node_id1, -1)
        idx2 = self.node_ids.get(node_id2, -1)
        if idx1 == -1:
            print(f"Node {node_id1} isn't in graph")
            return
        if idx2 == -1:
            print(f"Node {node_id2} isn't in graph")
            return

        self.nodes[idx1].add_edge(self.nodes[idx2])
        self.nodes[idx2].add_edge(self.nodes[idx1])

    def clear_unconnected_nodes(self):
        self.nodes = [n for n in self.nodes if n.area > 0]
        self.node_ids = {n.id: j for j, n in enumerate(self.nodes)}


In [82]:
# iterate over region label image and discover node edges
G = Graph(nodes)

for i in range(region_labels.shape[0]):
    for j in range(region_labels.shape[1]):
        rid = region_labels[i, j]

        # neighbors
        rid_neighbors = [
            region_labels[max(0, i - 1), j],
            region_labels[min(i + 1, region_labels.shape[0] - 1), max(0, j)],
            region_labels[i, max(0, j - 1)],
            region_labels[i, min(j + 1, region_labels.shape[1] - 1)]
        ]
        
        for r in rid_neighbors:
            if r != rid:
                G.add_edge(rid, r)
                


In [83]:
# sorted(G.nodes[1].edges, key=lambda x: x.area)

In [None]:
len(G.nodes)

In [None]:
G.nodes[10].edges

## Verify that the graph captures all pixels in the image
Node area  == node coordinate list == image height x image width

In [None]:
sum([n.area for n in G.nodes])

In [None]:
sum([len(n.coords) for n in G.nodes])

In [None]:
np.prod(img_kmean_labels.shape)

In [None]:
# run merging until no changes

while(any([n.area < minArea for n in G.nodes])):
    for idx in range(len(G.nodes)):
        node = G.nodes[idx]
        node_c = node.color
        if (node.area < minArea) and (node.area > 0):
            neighbors = node.edges
            # smallest = sorted(neighbors, key=lambda x: x.area)

            # some function of area and color similarity
            smallest = sorted(neighbors, key=lambda x: x.area  + 100.0 * np.sqrt(np.sum(np.square(x.color - node_c))))

            areas = [n.area for n in smallest]
            first_non_zero_idx = next((i for i, j in enumerate(areas) if j != 0), None)
            if first_non_zero_idx is None:
                continue
            # get smallest neighbor and merge with it
            if smallest[first_non_zero_idx].area >= node.area:
                G.merge_nodes(smallest[first_non_zero_idx], node)
            else:
                G.merge_nodes(node, smallest[first_non_zero_idx])

    G.clear_unconnected_nodes()


In [None]:
len([n.area for n in G.nodes if n.area > 0])

In [91]:
# G.nodes

In [None]:
plt.imshow(region_labels)

for node1 in G.nodes:
    
    if node1.area == 0:
        continue

    print(node1)
    
    c1 = node1.centroid
    
    plt.plot(c1[1], c1[0], '.r')
    for node2 in node1.edges:
        if node2.area == 0:
            continue
        c2 = node2.centroid
        plt.plot((c1[1], c2[1]), (c1[0], c2[0]), '--r', linewidth=1)

plt.title("Region graph")
plt.show()

In [93]:
# relabel
region_labels2 = np.ones_like(region_labels) * -1
for n in G.nodes:
    if n.area == 0:
        continue
    id = n.id
    coords = n.coords
    for c in coords:
        region_labels2[c[0], c[1]] = id

In [None]:
plt.imshow(region_labels2)

for node1 in G.nodes:
    if node1.area == 0:
        continue
    print(node1)
    c1 = node1.centroid
    
    plt.plot(c1[1], c1[0], '.r')
    # for node2 in node1.edges:
    #     c2 = node2.centroid
    #     plt.plot((c1[1], c2[1]), (c1[0], c2[0]), '--r', linewidth=1)

plt.title("Region graph")
plt.show()

In [None]:
plt.imshow(region_labels2 == -1) # verify no regions left out

# RGB version of merged regions

In [96]:
# unique_labels = np.unique(region_labels2).tolist()
# img_kmean_rgbs = np.zeros_like(img_filt)
# for lab in unique_labels:
#     m = np.expand_dims(region_labels2 == lab, -1)
#     mean_c = np.sum(img_filt * m, (0, 1)) / np.sum(m, (0,1))
#     img_kmean_rgbs += mean_c * m

In [97]:
img_kmean_rgbs = np.zeros_like(img_filt)
for n in G.nodes:
    col = n.color
    for c in n.coords:
        img_kmean_rgbs[c[0], c[1]] = col

In [None]:
plt.imshow(img_kmean_rgbs)

# Extract contours

In [99]:
from collections import defaultdict

In [100]:
unique_labels = np.unique(region_labels2).tolist()
region_labels3 = deepcopy(region_labels2)
# collect contours pixels per region
contours = defaultdict(set)
for i in range(region_labels2.shape[0]):
    for j in range(region_labels2.shape[1]):

        # neighbors in 3x3 grid (8 neighbors)
        neighbors = [
            region_labels3[max(0, i - 1), max(0, j - 1)],
            region_labels3[max(0, i - 1), j],
            region_labels3[max(0, i - 1), min(j + 1, region_labels2.shape[1] - 1)],

            region_labels3[min(i + 1, region_labels2.shape[0] - 1), max(0, j - 1)],
            region_labels3[min(i + 1, region_labels2.shape[0] - 1), j],
            region_labels3[min(i + 1, region_labels2.shape[0] - 1), min(j + 1, region_labels2.shape[1] - 1)],

            region_labels3[i, max(0, j - 1)],
            region_labels3[i, j],
            region_labels3[i, min(j + 1, region_labels2.shape[1] - 1)]
        ]
        # print(neighbors)
        neighbors = set(neighbors)
        neighbors.discard(-1)

        # if all neighbors are equal then this is not a contour pixel
        # otherwise this is a contour point
        if (
            (len(neighbors) > 1) 
            or (i == 0) 
            or (i == region_labels2.shape[0] - 1)
            or (j == 0)
            or (j == region_labels2.shape[1] - 1)
        ):
            for n in neighbors:
                contours[n].add((i,j))

            region_labels3[i, j] = -1 # don't consider this contour point anymore

for id, contour in contours.items():
    contours[id] = list(contour)

In [None]:
# def angle(pt, center):
#     dif = np.array(pt) - center
#     return np.arctan2(dif[0], dif[1])
# 
# # sort contour points to be contiguous
# for id, contour in contours.items():
#     centroid = np.mean(np.array(contour), 0)
#     sorted_c = sorted(contour, key=lambda x: angle(x, centroid))
#     contours[id] = sorted_c

for id, contour in contours.items():
    print(id, len(contour))
    sorted_contour = [contour.pop(0)]
    while len(contour) > 0:
        # print(id, len(contour))
        # find next closest point that hasn't been taken yet
        c = sorted_contour[-1] # coord (y, x)
        #print("distance sorting")
        contour = sorted(contour, key=lambda x: (x[0] - c[0])**2 + (x[1] - c[1])**2 )
        sorted_contour.append(contour.pop(0))
    
    contours[id] = sorted_contour


In [None]:
fig = plt.figure(figsize=(19.20,10.80))
plt.imshow(img_kmean_rgbs)
# trace contours
for id, contour in contours.items():
    c = np.array(contour)
    plt.plot(c[:, 1], c[:, 0], '-', color=tuple(np.random.rand(3).tolist()), linewidth=1)
    # break
plt.show()