In [1]:
from matplotlib.image import imread
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class DistanceCluster:
    def __init__(self, img) -> None:
        m, n, _ = img.shape
        self.D = np.zeros((m, n))
        self.L = -np.ones((m, n))
        self.D[:, :] = np.inf

class Superpixels:
    def __init__(self, img, n_sp) -> None:
        self.img = img
        n_tp = self.img.shape[0] * self.img.shape[1]  # * img.shape[2]
        self.s = int((n_tp / n_sp) ** 0.5)
        self.DL = DistanceCluster(self.img)
        self.threshold = 10

    # Step 1

    def _calculate_lowest_gradient_position(self, image, x, y):
        x_start = max(x - 1, 0)
        x_end = min(x + 1, image.shape[1] - 1)
        y_start = max(y - 1, 0)
        y_end = min(y + 1, image.shape[0] - 1)

        neighborhood = image[y_start:y_end + 1, x_start:x_end + 1]

        pixel1 = np.expand_dims(image[y, x], axis=0)
        differences = np.sum(np.abs(neighborhood - pixel1), axis=2)
        min_y, min_x = np.unravel_index(np.argmin(differences), differences.shape)
        min_x += x_start
        min_y += y_start

        return min_x, min_y


    def _calculate_gradient(self, image, x1, y1, x2, y2):
        # Get the RGB values of the two pixels
        pixel1 = image[y1, x1]
        pixel2 = image[y2, x2]

        # Calculate the gradient as the sum of absolute differences of RGB values
        gradient = np.sum(np.abs(pixel1 - pixel2))

        return gradient

    def _initialize_algo(self):
        # The function creates the centers list
        m_list = []
        # finding the pixels with s distance
        x = [l % self.img.shape[0] for l in range(self.s, self.img.shape[0], self.s)]
        y = [l % self.img.shape[1] for l in range(self.s, self.img.shape[1], self.s)]
        for i in range(len(x)):
            for j in range(len(y)):
                x1, y1 = self._calculate_lowest_gradient_position(self.img, x[i], y[j])
                r, g, b = self.img[x1, y1, :3]
                m_list.append(np.array([r, g, b, x1, y1]).T)
        return m_list

    # Step 2
    def _calculate_distance(self, v1, v2, c=10):
        dc = ((v1[0] - v2[0]) ** 2 + (v1[1] - v2[1]) ** 2 + (v1[2] - v2[2]) ** 2) ** 0.5
        ds = ((v1[3] - v2[3]) ** 2 + (v1[4] - v2[4]) ** 2) ** 0.5
        return ((dc / c) ** 2 + (ds / self.s) ** 2) ** 0.5
    
    def _set_neigbors_dist(self, mi, cluster):
        x_start = max(0, int(mi[3] - 2 * self.s))
        y_start = max(0, int(mi[4] - 2 * self.s))    
        x_stop = min(int(mi[3] + 2 * self.s), self.img.shape[0])
        y_stop = min(int(mi[4] + 2 * self.s), self.img.shape[1])     
        for i in range(x_start, x_stop):
            for j in range(y_start, y_stop):
                r, g, b = self.img[i, j, :3]
                v = [r,g,b,i,j]
                d = self._calculate_distance(mi, v, self.s)
                if d < self.DL.D[i, j]:
                    self.DL.D[i, j] = d
                    self.DL.L[i, j] = cluster   


    def _assign_clusters(self):
        for i, v in enumerate(self.m_i):
           self._set_neigbors_dist(v, i)
        print("STEP 2 finished")


    # Step 3
    def _find_cluster_coordinates(self, CLUSTER):
        # Find the indices where the matrix equals the CLUSTER value
        indices = np.argwhere(self.DL.L == CLUSTER)
        
        # Extract the x, y coordinates from the indices
        coordinates = [(index[0], index[1]) for index in indices]
        
        return coordinates

    def _update_mi(self):
        test_convergence = []
        for i, v in enumerate(self.m_i): 
            relevent_cords = self._find_cluster_coordinates(i)
            c_i = len(relevent_cords)
            t = np.copy(v)
            t[:] = 0
            for z in relevent_cords:
                t[:3] += self.img[z[0], z[1], :3]
                t[3:] += z
            t /= c_i
            test_convergence.append(np.linalg.norm(v-t))
        print("STEP 3 finished")
        self.test_convergence = np.array(test_convergence) 

    def __call__(self):
        self.m_i = self._initialize_algo()
        # self._assign_clusters()
        # self._update_mi()
        # while np.sum(self.test_convergence) < self.threshold:
        #     self._assign_clusters()
        #     self._update_mi()
        # return self.DL.L
        

In [3]:
img = imread("totem.png")
sup = Superpixels(img, 500)
sup()

In [4]:
sup._assign_clusters()

In [10]:
sup._update_mi()
while np.sum(sup.test_convergence) > sup.threshold:
    sup._assign_clusters()
    sup._update_mi()

STEP 3 finished


KeyboardInterrupt: 

10