In [1]:
import os
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

In [6]:
os.listdir("../GenerationData")

['Backgrounds', 'Clustered', 'Objects', 'prompts.txt']

In [7]:
# Paths
objects_folder = "../GenerationData/Objects"

# Constants
IMAGE_SIZE = 512 * 512
THRESHOLD = 0.05  # 5%

# Loop through object images
for filename in os.listdir(objects_folder):
    if not filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
        continue

    path = os.path.join(objects_folder, filename)

    try:
        img = Image.open(path).convert("RGBA")
        alpha = img.split()[-1]  # get alpha channel

        # Count non-transparent pixels
        non_transparent = (alpha.point(lambda p: p > 0 and 1).convert("L").point(lambda p: p)).getbbox()
        if non_transparent:
            non_transparent_pixels = sum(p > 0 for p in alpha.getdata())
        else:
            non_transparent_pixels = 0

        ratio = non_transparent_pixels / IMAGE_SIZE

        if ratio > THRESHOLD:
            print(f"Removing {filename} ({ratio*100:.2f}% of area)")
            os.remove(path)

    except Exception as e:
        print(f"Error processing {filename}: {e}")

Removing 100_1.png (7.61% of area)
Removing 100_5.png (6.15% of area)
Removing 100_8.png (16.44% of area)
Removing 102_51.png (5.81% of area)
Removing 102_52.png (6.16% of area)
Removing 102_53.png (5.70% of area)
Removing 102_56.png (6.12% of area)
Removing 102_61.png (6.84% of area)
Removing 102_82.png (13.93% of area)
Removing 102_87.png (7.34% of area)
Removing 102_89.png (5.54% of area)
Removing 102_91.png (6.87% of area)
Removing 105_8.png (10.51% of area)
Removing 109_29.png (5.11% of area)
Removing 109_35.png (5.26% of area)
Removing 109_39.png (5.84% of area)
Removing 109_41.png (9.90% of area)
Removing 109_58.png (5.81% of area)
Removing 109_70.png (5.67% of area)
Removing 109_86.png (5.96% of area)
Removing 111_0.png (11.78% of area)
Removing 111_1.png (9.83% of area)
Removing 111_12.png (16.18% of area)
Removing 111_2.png (8.45% of area)
Removing 111_36.png (6.39% of area)
Removing 111_43.png (5.83% of area)
Removing 111_7.png (5.68% of area)
Removing 116_11.png (5.16% of a

Used for grouping objects <hr>

In [8]:
import os
import numpy as np
from PIL import Image
from sklearn.cluster import KMeans
import shutil  # for copying files safely

objects_folder = "../GenerationData/Objects"
output_folder = "../GenerationData/Clustered"

os.makedirs(output_folder, exist_ok=True)

num_clusters = 3
image_size = (128, 128)

def extract_color_features(img_path):
    try:
        img = Image.open(img_path).convert("RGB").resize(image_size)
        arr = np.array(img)
        hist = np.histogramdd(
            arr.reshape(-1, 3),
            bins=(8, 8, 8),
            range=((0, 256), (0, 256), (0, 256))
        )[0]
        hist = hist.flatten()
        hist = hist / np.sum(hist)
        return hist
    except Exception as e:
        print(f"Error reading {img_path}: {e}")
        return None

features = []
filenames = []

for filename in os.listdir(objects_folder):
    if not filename.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
        continue

    path = os.path.join(objects_folder, filename)
    feat = extract_color_features(path)
    if feat is not None:
        features.append(feat)
        filenames.append(filename)

if len(features) == 0:
    raise ValueError("No valid image features extracted. Check your image folder or formats.")

features = np.vstack(features)  # ensure 2D array for KMeans

num_clusters = min(num_clusters, len(features))
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
labels = kmeans.fit_predict(features)

# Create output cluster folders
for i in range(num_clusters):
    cluster_dir = os.path.join(output_folder, f"cluster_{i}")
    os.makedirs(cluster_dir, exist_ok=True)

# Copy files instead of moving
for fname, label in zip(filenames, labels):
    src = os.path.join(objects_folder, fname)
    dst = os.path.join(output_folder, f"cluster_{label}", fname)
    shutil.copy2(src, dst)  # copy file, keep original

print("Clustering complete!")
print("Images grouped into:", [f"cluster_{i}" for i in range(num_clusters)])

Clustering complete!
Images grouped into: ['cluster_0', 'cluster_1', 'cluster_2']
