In [1]:
clc; clear; close all;

% -----------------------------------------------------------
% 1. Load MATLAB inbuilt image
% -----------------------------------------------------------
img = imread('peppers.png');
img = im2double(img);

[rows, cols, ch] = size(img);

% Create coordinate grid
[x, y] = meshgrid(1:cols, 1:rows);

% -----------------------------------------------------------
% 2. Create 5D feature vector: [x y R G B]
% -----------------------------------------------------------
features = [x(:), y(:), reshape(img, [], 3)];

% Number of pixels
N = size(features, 1);

% -----------------------------------------------------------
% 3. Mean Shift Parameters
% -----------------------------------------------------------
hs = 10;      % spatial bandwidth (in pixels)
hr = 0.1;     % color bandwidth (RGB distance)
max_iters = 15;
epsilon = 1e-3;   % convergence threshold

% Start modes at original feature positions
modes = features;

fprintf("Running 5D Mean Shift Segmentation...\n");

Running 5D Mean Shift Segmentation...

In [2]:

% -----------------------------------------------------------
% 4. Mean Shift Iterations
% -----------------------------------------------------------
for iter = 1:max_iters
    fprintf("Iteration %d/%d\n", iter, max_iters);

    for i = 1:N

        % Extract current mode (x, y, R, G, B)
        cur_point = modes(i, :);

        % Compute spatial distance
        spatial_dist = sqrt( (features(:,1) - cur_point(1)).^2 + ...
                             (features(:,2) - cur_point(2)).^2 );

        % Compute color distance
        color_dist = sqrt(sum((features(:,3:5) - cur_point(3:5)).^2, 2));

        % Find neighbors within spatial AND color bandwidths
        mask = (spatial_dist < hs) & (color_dist < hr);

        neighbors = features(mask, :);

        % Avoid empty neighborhood
        if ~isempty(neighbors)
            new_point = mean(neighbors, 1);
        else
            new_point = cur_point;
        end

        % Shift magnitude
        shift = norm(new_point - cur_point);

        % Update mode
        modes(i,:) = new_point;

        % Check convergence
        if shift < epsilon
            continue;
        end
    end
end

Iteration 1/15

In [3]:

% -----------------------------------------------------------
% 5. Cluster pixels whose modes converge close together
% -----------------------------------------------------------
fprintf("Clustering modes...\n");

cluster_labels = zeros(N,1);
cluster_count = 0;
modes_final = [];

for i = 1:N
    assigned = false;

    for j = 1:cluster_count
        % If the mode is close to an existing cluster mode
        if norm(modes(i,:) - modes_final(j,:)) < 2
            cluster_labels(i) = j;
            assigned = true;
            break;
        end
    end

    if ~assigned
        cluster_count = cluster_count + 1;
        modes_final(cluster_count,:) = modes(i,:);
        cluster_labels(i) = cluster_count;
    end
end

fprintf("Total clusters found: %d\n", cluster_count);

% -----------------------------------------------------------
% 6. Construct segmented image from cluster modes
% -----------------------------------------------------------
seg_img = zeros(N, 3);

for k = 1:cluster_count
    seg_img(cluster_labels == k, :) = modes_final(k, 3:5); % RGB only
end

seg_img = reshape(seg_img, rows, cols, 3);

% -----------------------------------------------------------
% 7. Display Original and Segmented Images
% -----------------------------------------------------------
figure;
subplot(1,2,1);
imshow(img);
title('Original Image');

subplot(1,2,2);
imshow(seg_img);
title('Mean Shift Segmented Image (Color + Spatial)');