In [None]:
import argparse

import cv2
import numpy as np
import random
import matplotlib.pyplot as plt

# used for scikit comparison
from skimage.segmentation import slic, mark_boundaries
from skimage import io

In [None]:
def parse_args(args=None):
    parser = argparse.ArgumentParser(description="SLIC")

    parser.add_argument('--clusters', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=20)
    
    # USED KAGGLE INPUT DATA
    # image options: chumscastle.jpg, brandeis.jfif, admissionscenter.jpg
    # in order of fastest to compute -> slowest
    parser.add_argument('--img_folder', type=str, default="/images/")
    parser.add_argument('--img_name', type=str, default="chumscastle.jpg")
    
    if args is None:
      args=[]
    args = parser.parse_args(args)
    return args

In [None]:
# parameters
args = parse_args()

# Implementation of Simple Linear Iterative Clustering 

## Initialization
1. Initialize cluster centers $C_k = [l_k, a_k, b_k, x_k, y_k]^T$ by sampling pixels at regular grid steps S.
2. Perturb cluster centers in an n x n neighborhood, to lowest gradient position.

In [None]:
# initialize centroids randomly within region of datapoints
def initialize_centroids(img, k):
    features = []
    for _ in range(k):
        # Sample random coordinates within the image boundaries
        rand_x = random.randint(0, img.shape[0] - 1)  # Random x-coordinate
        rand_y = random.randint(0, img.shape[1] - 1)  # Random y-coordinate
        # Get pixel value at the sampled coordinates
        pixel_value = img[rand_x, rand_y]
        # Append the pixel value along with spatial coordinates to the centroids list
        features.append([pixel_value[0], pixel_value[1], pixel_value[2], rand_y, rand_x])
    return np.array(features)

## Update Centroids
3. **REPEAT**: 
  1. **for** each cluster center $C_k$ do:
  2. assign best matching pixels from 2S x 2S square neighborhood around the cluster center according to distance measure (eq. 1).
  3. **end for**:
4. compute new cluster centers and residual error E {L1 distance between previous centers + recomputed centers}
5. **until** E <= threshold.

In [None]:
# pt1, pt2: each pt has format of [r,g,b,x,y]
# S: sqrt(N/K), where N = # pixels in img, K = desired number of superpixels
# m: steps
def get_distance(pt1, pt2, S):
    M=100
    # -----------------------------
    # CALCULATE COLOR DISTANCE
    # distance of r
    dist_r=(int(pt1[0]) - int(pt2[0])) ** 2
    # distance of g
    dist_g=(int(pt1[1]) - int(pt2[1])) ** 2
    # distance of b
    dist_b=(int(pt1[2]) - int(pt2[2])) ** 2
    distance_color = (dist_r+dist_g+dist_b) ** 0.5
    # -----------------------------
    # CALCULATE SPATIAL DISTANCE
    # distance between x's
    dist_x = (pt1[3]-pt2[3]) ** 2
    # distance between y's
    dist_y = (pt1[4]-pt2[4]) ** 2
    distance_spatial = (dist_x+dist_y) ** 0.5
    return distance_color + M / S * distance_spatial

In [None]:
# get labels for data points
# S: sqrt(N/K), where N = # pixels in img, K = desired number of superpixels
def get_labels(img, centroids, S):
    labels = np.zeros((img.shape[0], img.shape[1]), dtype=int)
    for y in range(img.shape[0]):
        for x in range(img.shape[1]):
            min_dist = float('inf') # minimum distance by default is infinity
            label = None
            for idx, center in enumerate(centroids):
                img_pt=[img[y, x][0], img[y, x][1], img[y, x][2], x, y] # image pixel as 5-dimensional point in format r,g,b,x,y
                # center_pt here
                dist = get_distance(img_pt, center, S) # gets distance from center: position (i, j) color is img[i, j]
                if dist < min_dist:# if distance from center is smaller than minimum distance
                    min_dist = dist# set new minimum distance
                    label = idx# index/label of centroid currently being looked at
            labels[y, x] = label# labeling pixel is a part of which centroid
    return labels

In [None]:
# Function definition to update centroids based on assigned labels and image data
def update_centroids(img, labels, centroids):
    new_centroids = np.zeros_like(centroids, dtype=float)
    counts = np.zeros(len(centroids), dtype=int)
    
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            label = labels[i, j]
            # Adding pixel's color values to new centroids
            new_centroids[label][:3] += img[i, j][:3]  # Color values [r, g, b]
            # Adding pixel's spatial coordinates to new centroids
            new_centroids[label][3:] += [j, i]  # Spatial coordinates [x, y]
            counts[label] += 1
    
    for i in range(len(centroids)):
        if counts[i] != 0:
            # Calculate the average of color values and spatial coordinates
            new_centroids[i][:3] /= counts[i]
            new_centroids[i][3:] /= counts[i]
    
    return new_centroids

In [None]:
# define stopping criteria to exit loop
def converged(old_centroids, new_centroids, threshold=1e-5):
    total_movement = np.sum(np.linalg.norm(old_centroids - new_centroids, axis=1))
    return total_movement < threshold

## SLIC Algorithm
Using Achanta's SLIC Superpixel Algorithm:
1. Initialize cluster centers $C_k = [l_k, a_k, b_k, x_k, y_k]^T$ by sampling pixels at regular grid steps S.
2. Perturb cluster centers in an n x n neighborhood, to lowest gradient position.
3. **REPEAT**: 
  1. **for** each cluster center $C_k$ do:
  2. assign best matching pixels from 2S x 2S square neighborhood around the cluster center according to distance measure (eq. 1).
  3. **end for**:
4. compute new cluster centers and residual error E {L1 distance between previous centers + recomputed centers}
5. **until** E <= threshold.
6. Enforce connectivity.

In [None]:
def SLIC(image, k):
    S = np.sqrt(image.size / k)
    # Initialize cluster centers
    centroids = initialize_centroids(image, k)
    for i in range(args.epochs):
    # while True: # used when using converged method
        old_centers = centroids.copy()
        # Assignment step
        labels = get_labels(image, centroids, S)
        # Update step
        centroids = update_centroids(image, labels, centroids)
        
        # Check convergence - this worsens performance for any images with a higher resolution than chumscastle.jpg
        # included still bc experiment w algorithm
        # if converged(old_centers, centroids):
        #     break
            
    return labels

# Using SLIC on Images

## Visualizing Segments
This is a method that I used to show colored clusters within the photo as opposed to a border.

In [None]:
def visualize_segments(image, labels):
    # Generate random colors for each label
    colors = np.random.randint(0, 255, (np.max(labels) + 1, 3), dtype=np.uint8)
    
    # Create segmented image
    segmented_image = np.zeros_like(image)
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            segmented_image[i, j] = colors[labels[i, j]]

    # Overlay segmented image on the original image
    overlaid_image = cv2.addWeighted(image, 0.5, segmented_image, 0.5, 0)

    return overlaid_image

## Running SLIC + visualizing result

In [None]:
img = args.img_folder + args.img_name
print(img)

In [None]:
# Load image
image_rgb = cv2.imread(img)
# Apply SLIC algorithm
labels = SLIC(image=image_rgb, k=args.clusters)
# Visualize segments - choose one or the other
# result = visualize_segments(image_rgb, labels)
image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
result = mark_boundaries(image=image_rgb, label_img=labels.astype(int))

Displaying images (original and clustered) side by side.

In [None]:
# Defining subplot area
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Original Image
axes[0].imshow(image_rgb)
axes[0].set_title('Original Image')
axes[0].axis('off')

# Clustered Image
axes[1].imshow(result)
axes[1].set_title('SLIC Result')
axes[1].axis('off')

# Display plot
plt.show()

# Comparison with Scikit Image
Generating boundaries using Scikit's SLIC function + same number of clusters.

In [None]:
scikit_img = io.imread(img)
segments=slic(image=scikit_img, n_segments=args.clusters,convert2lab=True,enforce_connectivity=False)
scikit_img = mark_boundaries(image=scikit_img, label_img=segments)

## Displaying Visual Side-by-Side

In [None]:
# Defining subplot area
fig, axes = plt.subplots(1, 2, figsize=(15, 5))


# Original Image
axes[0].imshow(scikit_img)
axes[0].set_title('Scikit Image SLIC')
axes[0].axis('off')

# Clustered Image
axes[1].imshow(result)
axes[1].set_title('My SLIC Result')
axes[1].axis('off')

# Display plot
plt.show()