In [None]:
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
#                                                                                                   #
#        Finds the most important RGB values in the image and recolors using those found            #
#                                                                                                   #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from skimage import io, img_as_float
import os
from glob import glob # For finding image files
from PIL import Image # For image resizing (alternative to skimage for some)

# --- Configuration ---
DATASET_PATH = 'path/to/your/image_dataset' # <--- IMPORTANT: Change this to your dataset path!
IMAGE_SIZE = (128, 128) # Resize images to a common size (width, height)
N_CLUSTERS = 5 # Number of clusters for K-Means

# --- Helper function to load and preprocess a single image ---
def load_and_preprocess_image(image_path, target_size):
    try:
        img = io.imread(image_path)
        img = img_as_float(img) # Convert to float [0, 1] for K-Means

        # Ensure image has 3 channels (RGB) if it's grayscale or RGBA
        if img.ndim == 2: # Grayscale image
            img = np.stack([img, img, img], axis=-1)
        elif img.shape[-1] == 4: # RGBA image
            img = img[..., :3] # Discard alpha channel

        # Resize the image
        # Using PIL for robust resizing, but skimage.transform.resize also works
        img_pil = Image.fromarray((img * 255).astype(np.uint8)) # Convert float [0,1] to uint8 [0,255] for PIL
        img_pil = img_pil.resize(target_size)
        img = np.array(img_pil) / 255.0 # Convert back to float [0,1]

        return img
    except Exception as e:
        print(f"Error loading/processing image {image_path}: {e}")
        return None

# --- Main Logic for loading dataset and applying K-Means ---

# List to store all preprocessed image data
all_pixel_data = []
original_image_shapes = [] # To store original shapes for reshaping back

print(f"Loading images from: {DATASET_PATH}")

# SCENARIO 1 & 2: Images in folders (either flat or by class)
# Use glob to find all image files
image_files = []
for ext in ['jpg', 'jpeg', 'png', 'bmp', 'tiff']: # Add other extensions if needed
    image_files.extend(glob(os.path.join(DATASET_PATH, '**', f'*.{ext}'), recursive=True))

if not image_files:
    print(f"No images found in {DATASET_PATH}. Please check the path and file types.")
else:
    print(f"Found {len(image_files)} images.")
    for i, img_path in enumerate(image_files):
        print(f"Processing image {i+1}/{len(image_files)}: {os.path.basename(img_path)}")
        processed_img = load_and_preprocess_image(img_path, IMAGE_SIZE)
        if processed_img is not None:
            height, width, channels = processed_img.shape
            original_image_shapes.append((height, width, channels))
            # Reshape image to (pixels, channels) and add to list
            all_pixel_data.append(processed_img.reshape((height * width, channels)))

# Combine all pixel data into a single NumPy array
if all_pixel_data:
    X = np.vstack(all_pixel_data)
    print(f"\nTotal pixels for K-Means: {X.shape[0]}, with {X.shape[1]} features (channels).")

    # 3. Apply K-Means Clustering
    print(f"Applying K-Means with {N_CLUSTERS} clusters...")
    kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=42, n_init='auto')
    kmeans.fit(X)

    # Get the cluster labels for each pixel
    labels = kmeans.labels_
    # Get the centroid (average color) for each cluster
    centers = kmeans.cluster_centers_

    # 4. Reshape Back and Visualize one example
    print("\nVisualizing one example segmented image:")
    # We will pick the first image to demonstrate the segmentation
    # You could iterate through all images if you want to save them
    first_image_pixels = all_pixel_data[0]
    first_image_labels = labels[:first_image_pixels.shape[0]] # Get labels corresponding to the first image
    first_image_original_shape = original_image_shapes[0]

    segmented_first_image = centers[first_image_labels].reshape(first_image_original_shape)

    # Plotting the groups with labels
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    # Reload the original first image for true comparison (before resizing/channel adjustments)
    original_first_image_display = io.imread(image_files[0])
    plt.imshow(original_first_image_display)
    plt.title('Original First Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(segmented_first_image)
    plt.title(f'K-Means Segmented First Image (K={N_CLUSTERS})')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    print("\nCluster Centroid Colors (RGB values):")
    for i, center in enumerate(centers):
        print(f"Cluster {i}: R={center[0]:.2f}, G={center[1]:.2f}, B={center[2]:.2f}")

else:
    print("No valid images were loaded for K-Means clustering.")

# --- SCENARIO 3: CSV/JSON with image paths (more advanced, often for labeled datasets) ---
# If you have a CSV file, you'd typically read it with pandas:
# import pandas as pd
# df = pd.read_csv('path/to/your/image_info.csv')
#
# Then iterate through rows, load images using df['image_path']
# and apply the same preprocessing and K-Means logic.
# If your CSV has 'label' columns, you can use those for evaluation after K-Means,
# but not for the K-Means training itself as it's unsupervised.