# Notebook for K-Means Clustering

### By Austin Houston

For AReMS, Spring 2025

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AustinHouston/AReMS_2025/blob/main/1_KMeans.ipynb)

Paper published in 1982

Internally circulated in Bell Labs in 1957

https://doi.org/10.1109/TIT.1982.1056489

### Load in the necessary libraries

In [None]:
!pip install ipympl

In [None]:
import os
import sys
import cv2
import numpy as np

# plotting
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from mpl_toolkits.mplot3d import Axes3D
if 'google.colab' in sys.modules:
    %matplotlib widget
    from google.colab import output
    output.enable_custom_widget_manager()

# the star of the show:
from sklearn.cluster import KMeans

# clone the github repository
repo_url = 'https://github.com/AustinHouston/AReMS_2025.git'
repo_name = 'AReMS_2025'
if not os.path.exists(repo_name):
    !git clone {repo_url}


### Load in data

In [None]:
path = './AReMS_2025/images_for_KMeans/'
files = os.listdir(path)
print(files)

You can put any .jpg file in the folder below and run the same example we are about to do together

### Visualize the images, and pick your favorite one!

In [None]:
selected_file = files[0]

print(selected_file)

# Load the image
image = cv2.imread(os.path.join(path, selected_file))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Plot the image
fig, ax = plt.subplots(1,1)
ax.imshow(image)
ax.axis('off')
plt.show()

Each pixel has 3 values (RGB)

In [None]:
# plot the values for each pixels
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(image[:, :, 0], cmap='Reds', vmin=0, vmax=255)
axs[0].set_title('Red Channel')

axs[1].imshow(image[:, :, 1], cmap='Greens', vmin=0, vmax=255)
axs[1].set_title('Green Channel')

axs[2].imshow(image[:, :, 2], cmap='Blues', vmin=0, vmax=255)
axs[2].set_title('Blue Channel')

for a in axs:
    a.axis('off')
fig.tight_layout()


### Scatter the RGB values in 3D space

In [None]:
# for computation, let's only plot 0.1 %
sparsity_factor = 0.001
num_points = int(image.shape[0] * image.shape[1] * sparsity_factor)
indices = np.random.choice(image.shape[0] * image.shape[1], num_points, replace=False)

# Get the RGB values for the selected pixels
pixels = image.reshape(-1, 3)

# Create a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Scatter plot for the selected pixels
ax.scatter(pixels[indices, 0], pixels[indices, 1], pixels[indices, 2], c=pixels[indices] / 255.0, alpha=0.8)
ax.set_xlabel('Red')
ax.set_ylabel('Green')
ax.set_zlabel('Blue')

### Now, let's apply KMeans clustering

In [None]:
# some preprocessing
pixels_ravel = np.ravel(pixels)
pixels = np.reshape(pixels_ravel, [-1, 3])
print(pixels.shape)

In [None]:
n_clusters = 3

# do the fit
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(pixels)

# Extract RGB values of cluster centers
r, g, b = kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], kmeans.cluster_centers_[:, 2]
color_vals = list(zip(r / 255, g / 255, b / 255))

# Label for the closest cluster center to each pixel
labels = kmeans.labels_

In [None]:
# 3D Scatter plot of the clusters and cluster centers
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot of the sparse data points 
ax.scatter(pixels[indices, 0], pixels[indices, 1], pixels[indices, 2], c=pixels[indices] / 255.0, alpha=0.8)

# Scatter plot of cluster centers
ax.scatter(r, b, g, color=color_vals, s=300, edgecolor='black')

# Set labels
ax.set_xlabel('Red')
ax.set_ylabel('Green')
ax.set_zlabel('Blue')
plt.title(f'3D scatter plot of clusters for {n_clusters} clusters (Sparse Plot)')
plt.legend()


### Histogram visualization

In [None]:
# Shift bins by -0.5 so that bars are centered at integer ticks
bins = np.arange(n_clusters + 1) - 0.5

plt.figure()
n, bins, patches = plt.hist(labels, bins = bins, edgecolor = 'white')
for bin in range(0, n_clusters):
    patches[bin].set_facecolor(color_vals[bin])

ticklabels = [i for i in range(n_clusters)]
plt.xticks(ticklabels)
plt.xlabel('Cluster Center')
plt.ylabel('Counts')
plt.title(f'Pixels per Cluster: {n_clusters} clusters')

### Reconstruct the image using cluster centers

In [None]:
reduced_image = np.zeros_like(image)

# Assign each pixel to the nearest cluster center
for i in range(n_clusters):
    reduced_image[labels.reshape(image.shape[0], image.shape[1]) == i] = kmeans.cluster_centers_[i]

# Plot the reduced image
plt.figure(figsize=(10, 10))
plt.imshow(reduced_image)
plt.axis('off')