In [1]:
from skimage.morphology import thin
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch


cmap = mpl.colors.ListedColormap(torch.rand(256**2, 3).numpy())

def plot_segmentation_boundaries(image_np, output_mask, figsize=(15, 15)):
    """
    Plot the boundaries of the segmentation over the original image.
    Uses pixel comparison to detect boundaries and applies thinning.
    """
    output_mask_np = output_mask[0].cpu().numpy()

    # Find boundaries by comparing neighboring pixels
    boundaries = np.zeros_like(output_mask_np)
    boundaries[1:, :] = np.logical_or(boundaries[1:, :], output_mask_np[1:, :] != output_mask_np[:-1, :])  # Compare vertically
    boundaries[:, 1:] = np.logical_or(boundaries[:, 1:], output_mask_np[:, 1:] != output_mask_np[:, :-1])  # Compare horizontally

    # Apply thinning to ensure boundaries are only 1 pixel wide
    boundaries = thin(boundaries)

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    ax.imshow(image_np)

    ax.contour(boundaries, colors='red', linewidths=0.7)

    ax.set_title('Image with Segmentation Boundaries (Thinned)')
    ax.axis('off')

    plt.show()
    
    
def plot_mask(mask):
    plt.figure(figsize=(10, 10))
    plt.imshow(mask[0].cpu().numpy(), cmap=cmap)
    plt.colorbar()
    plt.title('Superpixel Mask')
    plt.show()
    
def plot_gradient_map(grad_map):
    plt.figure(figsize=(10, 10))
    plt.imshow(grad_map[0, 0].cpu().numpy(), cmap='gray')
    plt.colorbar()
    plt.title('Gradient Map')
    plt.show()



In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

class BSDS500Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.images_dir = os.path.join(root_dir, 'images', split)
        self.ground_truth_dir = os.path.join(root_dir, 'ground_truth', split)
        self.image_files = [f for f in os.listdir(self.images_dir) if f.endswith('.jpg')]
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_name = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        gt_name = os.path.join(self.ground_truth_dir, self.image_files[idx].replace('.jpg', '.mat'))
        gt_data = sio.loadmat(gt_name)
        ground_truth = gt_data['groundTruth'][0][0][0][0][1]

        
        #print(ground_truth)
        # print(ground_truth[0, 0])
        # print(ground_truth[0, 0]['Segmentation'])
        segmentation = ground_truth
        
        if isinstance(segmentation, np.ndarray) and segmentation.shape == (1, 1):
            segmentation = segmentation[0, 0]
        
        segmentation = Image.fromarray(segmentation)
        segmentation = segmentation.resize((224, 224), Image.NEAREST)
        
        segmentation = np.array(segmentation, dtype=np.int64)

        segmentation = torch.tensor(segmentation, dtype=torch.long)
        
        return image, segmentation

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ImageNet mean and std
])

dataset_train = BSDS500Dataset(root_dir=r'D:\Data\BSDS500\data', split='train', transform=transform)

dataloader = DataLoader(dataset_train, batch_size=10, shuffle=True, num_workers=0)

import random
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import segmentation_models_pytorch as smp
import math


class VoronoiPropagation(nn.Module):
    def __init__(self, num_clusters=64, n_channels=3, height=224, width=224, device='cpu'):
        """
        Args:
            num_clusters (int): Number of clusters (centroids) to initialize.
            height (int): Height of the input image.
            width (int): Width of the input image.
            device (str): Device to run the model ('cpu' or 'cuda').
        """
        super(VoronoiPropagation, self).__init__()
        
        self.C = num_clusters
        self.H = height
        self.W = width
        self.device = torch.device(device)
        
        self.unet = smp.Unet(encoder_name="efficientnet-b0",
                             encoder_weights=None,  
                             in_channels=n_channels,               
                             classes=n_channels)   
        
        # Set bandwidth / sigma for kernel
        self.std = self.C / (self.H * self.W)**0.5
        
        self.convert_to_greyscale = torchvision.transforms.Grayscale(num_output_channels=1)

    def compute_gradient_map(self, x):
        # Sobel kernels for single-channel input
        sobel_x = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], device=x.device, dtype=x.dtype)
        sobel_y = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], device=x.device, dtype=x.dtype)
        
        # Apply Sobel filters
        grad_x = F.conv2d(x, sobel_x, padding=1)
        grad_y = F.conv2d(x, sobel_y, padding=1)
        
        # Compute gradient magnitude
        grad_map = torch.sqrt(grad_x.pow(2) + grad_y.pow(2))
        return grad_map

    def place_centroids_on_grid(self, batch_size):
        num_cols = int(math.sqrt(self.C * self.W / self.H))
        num_rows = int(math.ceil(self.C / num_cols))

        grid_spacing_y = self.H / num_rows
        grid_spacing_x = self.W / num_cols

        centroids = []
        for i in range(num_rows):
            for j in range(num_cols):
                if len(centroids) >= self.C:
                    break
                y = int((i + 0.5) * grid_spacing_y)
                x = int((j + 0.5) * grid_spacing_x)
                centroids.append([y, x])
            if len(centroids) >= self.C:
                break

        centroids = torch.tensor(centroids, device=self.device).float()
        return centroids.unsqueeze(0).repeat(batch_size, 1, 1)

    def find_nearest_minima(self, centroids, grad_map, neighborhood_size=10):
        updated_centroids = []
        B, _, _ = centroids.shape
        
        for batch_idx in range(B):
            updated_centroids_batch = []
            occupied_positions = set()
            for centroid in centroids[batch_idx]:
                y, x = centroid
                y_min = max(0, int(y) - neighborhood_size)
                y_max = min(self.H, int(y) + neighborhood_size)
                x_min = max(0, int(x) - neighborhood_size)
                x_max = min(self.W, int(x) + neighborhood_size)
                
                neighborhood = grad_map[batch_idx, 0, y_min:y_max, x_min:x_max]
                min_val = torch.min(neighborhood)
                min_coords = torch.nonzero(neighborhood == min_val, as_tuple=False)
                
                # Iterate over all minima to find an unoccupied one
                found = False
                for coord in min_coords:
                    new_y = y_min + coord[0].item()
                    new_x = x_min + coord[1].item()
                    position = (new_y, new_x)
                    if position not in occupied_positions:
                        occupied_positions.add(position)
                        updated_centroids_batch.append([new_y, new_x])
                        found = True
                        break
                if not found:
                    # If all minima are occupied, keep the original position
                    updated_centroids_batch.append([y.item(), x.item()])
            
            updated_centroids.append(torch.tensor(updated_centroids_batch, device=self.device))
        
        return torch.stack(updated_centroids, dim=0)

    def distance_weighted_propagation(self, centroids, grad_map, color_map, num_iters=50, gradient_weight=10.0, color_weight=10.0, edge_exponent=4.0): # gradient weight, color weight and edge exponent are all tuneable parameters 
        """
        Perform Voronoi-like propagation from centroids, guided by both the gradient map and color similarity.
        
        Args:
            centroids (Tensor): Initial centroid positions.
            grad_map (Tensor): Gradient magnitude map.
            color_map (Tensor): Input image for color similarity.
            num_iters (int): Number of iterations to perform propagation.
            gradient_weight (float): Weight for the gradient penalty.
            color_weight (float): Weight for the color similarity penalty.
            edge_exponent (float): Exponent to amplify edge gradients.
        
        Returns:
            Tensor: Final segmentation mask.
        """
        B, _, H, W = grad_map.shape
        mask = torch.full((B, H, W), fill_value=-1, device=grad_map.device)  # Label mask
        dist_map = torch.full((B, H, W), fill_value=float('inf'), device=grad_map.device)  # Distance map
        
        for batch_idx in range(B):
            for idx, (cy, cx) in enumerate(centroids[batch_idx]):
                mask[batch_idx, int(cy), int(cx)] = idx
                dist_map[batch_idx, int(cy), int(cx)] = 0  # Distance from centroid is 0 initially
        
        # 4-connected neighbors (dy, dx)
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        
        # Amplify the impact of the gradient map by multiplying it with a weight and applying a non-linear transformation
        weighted_grad_map = (grad_map ** edge_exponent) * gradient_weight

        # Perform propagation with both gradient penalties and color similarity
        for _ in range(num_iters):
            for dy, dx in directions:
                # Shift the distance map in each direction
                shifted_dist = torch.roll(dist_map, shifts=(dy, dx), dims=(1, 2))
                shifted_mask = torch.roll(mask, shifts=(dy, dx), dims=(1, 2))
                
                # Calculate color distance between current pixel and centroid it is being propagated from
                color_diff = torch.abs(color_map - torch.roll(color_map, shifts=(dy, dx), dims=(2, 3))).sum(dim=1)  # Sum over color channels

                # Add the gradient map value as a weighted penalty to the distance
                weighted_dist = shifted_dist + weighted_grad_map[:, 0, :, :] + color_diff * color_weight
                
                # Update the mask and distance map where the new combined distance is smaller
                update_mask = weighted_dist < dist_map
                dist_map[update_mask] = weighted_dist[update_mask]
                mask[update_mask] = shifted_mask[update_mask]
        
        return mask
        
    def forward(self, x):
        B, C_in, H, W = x.shape
        
        if C_in == 3:
            grayscale_image = self.convert_to_greyscale(x)
        else:
            grayscale_image = x
        
        # Compute the gradient map from grayscale image
        grad_map = self.compute_gradient_map(grayscale_image)
        
        # Place centroids on a grid
        centroids = self.place_centroids_on_grid(B)
        
        # Move centroids to nearest local minima
        centroids = self.find_nearest_minima(centroids, grad_map)
        
        # Use the color map (the original image) to guide propagation
        spixel_features = self.unet(x)
        
        # Perform distance-weighted propagation with both gradient and color guidance
        mask = self.distance_weighted_propagation(centroids, grad_map, spixel_features)
        
        # return grad_map, centroids, mask, spixel_features
        return grad_map, centroids, mask, spixel_features

In [4]:
class BoundaryPathFinder2(nn.Module):
    def __init__(self, num_segments_row=8, num_segments_col=8, height=224, width=224, device='cpu'):
        super(BoundaryPathFinder2, self).__init__()
        
        self.num_segments_row = num_segments_row
        self.num_segments_col = num_segments_col
        self.H = height
        self.W = width
        self.device = device
        
        self.convert_to_grayscale = torchvision.transforms.Grayscale(num_output_channels=1)
        
        # Sobel kernels
        self.sobel_x = torch.tensor([[[[-1, 0, 1], 
                                  [-2, 0, 2], 
                                  [-1, 0, 1]]]], device=device, dtype=torch.float32)
        self.sobel_y = torch.tensor([[[[-1, -2, -1], 
                                  [0, 0, 0], 
                                  [1, 2, 1]]]], device=device, dtype=torch.float32)
    
    def compute_gradient_map(self, x):
        # x: (B, C, H, W)
        if x.shape[1] == 3:
            x = self.convert_to_grayscale(x)
        
        # Apply Sobel filters
        grad_x = F.conv2d(x, self.sobel_x, padding=1)
        grad_y = F.conv2d(x, self.sobel_y, padding=1)

        # Compute gradient magnitude
        grad_map = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)
        return grad_map  # Shape: (B, 1, H, W)
    
    def initialize_grid(self, batch_size):
        # Create grid labels
        rows = torch.arange(self.H, device=self.device).unsqueeze(1)
        cols = torch.arange(self.W, device=self.device).unsqueeze(0)

        row_labels = rows // (self.H // self.num_segments_row)
        col_labels = cols // (self.W // self.num_segments_col)

        labels = (row_labels * self.num_segments_col + col_labels).to(torch.int32)
        labels = labels.expand(batch_size, -1, -1)  # Shape: (B, H, W)

        return labels
    
    def adjust_boundaries(self, grad_map, segmentation_mask, band_width=5):
        """
        Adjust boundary lines to align with the highest gradients while keeping the number of segments constant.
        """
        B, H, W = segmentation_mask.shape
        device = grad_map.device

        # Prepare indices
        y_indices = torch.arange(H, device=device)
        x_indices = torch.arange(W, device=device)

        # Initialize boundary masks
        boundary_masks = torch.zeros((B, H, W), dtype=torch.bool, device=device)

        # Process each image in the batch
        for b in range(B):
            grad_map_b = grad_map[b, 0]  # Shape: (H, W)

            # Vertical boundaries
            vertical_paths = []
            for i in range(1, self.num_segments_col):
                x_init = i * (W // self.num_segments_col)
                x_init = min(x_init, W - 1)
                path = self.find_optimal_vertical_path(grad_map_b, x_init, band_width)
                vertical_paths.append(path)

            # Mark vertical boundaries
            for path in vertical_paths:
                boundary_masks[b, y_indices, path] = True

            # Horizontal boundaries
            horizontal_paths = []
            for i in range(1, self.num_segments_row):
                y_init = i * (H // self.num_segments_row)
                y_init = min(y_init, H - 1)
                path = self.find_optimal_horizontal_path(grad_map_b, y_init, band_width)
                horizontal_paths.append(path)

            # Mark horizontal boundaries
            for path in horizontal_paths:
                boundary_masks[b, path, x_indices] = True

        # Use connected components labeling and reassign labels based on majority voting
        from skimage.measure import label as skimage_label

        new_segmentation_masks = []
        for b in range(B):
            boundary_mask_np = boundary_masks[b].cpu().numpy()
            regions = ~boundary_mask_np
            labeled_array = skimage_label(regions, connectivity=1)
            initial_labels = segmentation_mask[b].cpu().numpy()
            new_segmentation_mask = np.zeros_like(labeled_array, dtype=initial_labels.dtype)
            for region_label in np.unique(labeled_array):
                if region_label == 0:
                    continue  # Skip background if any
                region_mask = (labeled_array == region_label)
                labels_in_region = initial_labels[region_mask]
                if labels_in_region.size == 0:
                    continue
                # Majority voting to find the most common label
                majority_label = np.bincount(labels_in_region).argmax()
                new_segmentation_mask[region_mask] = majority_label
            new_segmentation_masks.append(torch.from_numpy(new_segmentation_mask).to(device))

        new_segmentation_masks = torch.stack(new_segmentation_masks, dim=0).to(torch.int32)

        return new_segmentation_masks  # Shape: (B, H, W)
    
    def find_optimal_vertical_path(self, grad_map, x_init, band_width):
        """
        Find the optimal vertical path around the initial x position using dynamic programming.
        """
        H, W = grad_map.shape
        device = grad_map.device

        # Define band around x_init
        x_indices = x_init + torch.arange(-band_width, band_width + 1, device=device)
        x_indices = x_indices.clamp(0, W - 1).long()
        num_positions = x_indices.size(0)

        # Initialize cost and path matrices
        cost = torch.full((H, num_positions), float('inf'), device=device)
        path = torch.zeros((H, num_positions), dtype=torch.long, device=device)

        # First row
        cost[0] = -grad_map[0, x_indices]

        # Precompute position indices
        positions = torch.arange(num_positions, device=device)

        # Dynamic programming
        for y in range(1, H):
            # Pad the previous cost for easy indexing
            padded_prev_cost = torch.cat([
                torch.full((1,), float('inf'), device=device),
                cost[y - 1],
                torch.full((1,), float('inf'), device=device)
            ])

            # Indices for possible moves: left (-1), stay (0), right (+1)
            move_offsets = torch.tensor([-1, 0, 1], device=device)
            neighbor_indices = positions.unsqueeze(1) + move_offsets  # Shape: (num_positions, 3)
            neighbor_indices = neighbor_indices.clamp(0, num_positions - 1)

            # Gather costs for possible moves
            prev_costs = padded_prev_cost[neighbor_indices + 1]  # Adjust for padding
            min_prev_costs, min_indices = prev_costs.min(dim=1)

            # Update cost and path
            cost[y] = min_prev_costs - grad_map[y, x_indices]
            path[y] = neighbor_indices[positions, min_indices]

        # Backtracking to find the optimal path
        idx = cost[-1].argmin().item()
        optimal_path = []
        for y in reversed(range(H)):
            optimal_path.append(x_indices[idx])
            idx = path[y, idx].item()
        optimal_path.reverse()
        optimal_path = torch.stack(optimal_path)
        return optimal_path  # Shape: (H,)
        
    def find_optimal_horizontal_path(self, grad_map, y_init, band_width):
        """
        Find the optimal horizontal path around the initial y position using dynamic programming.
        """
        H, W = grad_map.shape
        device = grad_map.device

        # Define band around y_init
        y_indices = y_init + torch.arange(-band_width, band_width + 1, device=device)
        y_indices = y_indices.clamp(0, H - 1).long()
        num_positions = y_indices.size(0)

        # Initialize cost and path matrices
        cost = torch.full((W, num_positions), float('inf'), device=device)
        path = torch.zeros((W, num_positions), dtype=torch.long, device=device)

        # First column
        cost[0] = -grad_map[y_indices, 0]

        # Precompute position indices
        positions = torch.arange(num_positions, device=device)
        
        # Dynamic programming
        for x in range(1, W):
            # Pad the previous cost for easy indexing
            padded_prev_cost = torch.cat([
                torch.full((1,), float('inf'), device=device),
                cost[x - 1],
                torch.full((1,), float('inf'), device=device)
            ])

            # Indices for possible moves: up (-1), stay (0), down (+1)
            move_offsets = torch.tensor([-1, 0, 1], device=device)
            neighbor_indices = positions.unsqueeze(1) + move_offsets  # Shape: (num_positions, 3)
            neighbor_indices = neighbor_indices.clamp(0, num_positions - 1)

            # Gather costs for possible moves
            prev_costs = padded_prev_cost[neighbor_indices + 1]  # Adjust for padding
            min_prev_costs, min_indices = prev_costs.min(dim=1)

            # Update cost and path
            cost[x] = min_prev_costs - grad_map[y_indices, x]
            path[x] = neighbor_indices[positions, min_indices]

        # Backtracking to find the optimal path
        idx = cost[-1].argmin().item()
        optimal_path = []
        for x in reversed(range(W)):
            optimal_path.append(y_indices[idx])
            idx = path[x, idx].item()
        optimal_path.reverse()
        optimal_path = torch.stack(optimal_path)
        return optimal_path  # Shape: (W,)
    
    def forward(self, x):
        B, C, H, W = x.shape
        if H != self.H or W != self.W:
            raise ValueError(f"Input image size must match initialized size: ({self.H}, {self.W})")

        # Compute gradient map
        grad_map = self.compute_gradient_map(x)  # Shape: (B, 1, H, W)

        # Initialize grid segmentation
        segmentation_mask = self.initialize_grid(B)  # Shape: (B, H, W)

        # Adjust boundaries
        new_segmentation_mask = self.adjust_boundaries(grad_map, segmentation_mask)

        return grad_map, segmentation_mask, new_segmentation_mask

In [5]:
voronoi_model = VoronoiPropagation(num_clusters=256)
pathFinder_model = BoundaryPathFinder2(num_segments_col=16,num_segments_row=16)

In [6]:
def explained_variance_batch(image_batch, superpixel_labels_batch):
    batch_size, num_channels, height, width = image_batch.shape
    explained_variance_scores = []

    for i in range(batch_size):
        image = image_batch[i]  # Shape: (C, H, W)
        superpixel_labels = superpixel_labels_batch[i]  # Shape: (H, W)

        # Ensure superpixel_labels is in shape (H, W)
        superpixel_labels = superpixel_labels.squeeze().to(image.device)

        # Flatten image and labels for computation
        image_flat = image.view(num_channels, height * width)
        labels_flat = superpixel_labels.view(height * width)

        # Compute total variance of the image across all channels
        total_variance = image_flat.var(dim=1, unbiased=False).mean().item()

        # Proceed to compute within-superpixel variance (homogeneity_score)
        unique_labels = superpixel_labels.unique()
        num_superpixels = unique_labels.size(0)

        pixel_sums = torch.zeros((num_superpixels, num_channels), device=image.device)
        pixel_squares = torch.zeros((num_superpixels, num_channels), device=image.device)
        pixel_counts = torch.zeros(num_superpixels, device=image.device)

        for j, label in enumerate(unique_labels):
            mask = (labels_flat == label)
            pixel_sums[j] = image_flat[:, mask].sum(dim=1)
            pixel_squares[j] = (image_flat[:, mask] ** 2).sum(dim=1)
            pixel_counts[j] = mask.sum()

        pixel_means = pixel_sums / pixel_counts.unsqueeze(1)
        pixel_variances = (pixel_squares / pixel_counts.unsqueeze(1)) - (pixel_means ** 2)
        within_variance = pixel_variances.mean().item()

        # Compute explained variance
        explained_variance = 1 - (within_variance / total_variance)
        explained_variance_scores.append(explained_variance)

    return explained_variance_scores

In [7]:
import time

explained_variance_scores_voronoi = []
explained_variance_scores_pathfinder = []

for (image, labels) in dataloader:
    start = time.time()
    voronor_grad_map, voronor_centroids, voronor_mask, voronor_spixel_features = voronoi_model(image)
    end = time.time()
    print("Voronoi took", end-start)
    start = time.time()
    pathfinder_grad_map, pathfinder_segmentation_mask, pathfinder_new_segmentation_mask = pathFinder_model(image)
    end = time.time()
    print("Pathfinder took", end-start)
    
    explained_variance_scores_voronoi.append(explained_variance_batch(image, voronor_mask))
    explained_variance_scores_pathfinder.append(explained_variance_batch(image, pathfinder_new_segmentation_mask))
    

print(np.mean(np.array(explained_variance_scores_voronoi)))
print(np.mean(np.array(explained_variance_scores_pathfinder)))

Voronoi took 3.409597873687744
Pathfinder took 13.611455917358398
Voronoi took 3.3326168060302734
Pathfinder took 13.238955020904541
Voronoi took 3.3081629276275635


KeyboardInterrupt: 