In [1]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from matplotlib.colors import rgb2hex

In [2]:

def compute_centroids(X, y):
    """
    Compute centroids for each class in the dataset.
    
    Parameters:
    - X: feature matrix
    - y: target values
    
    Returns:
    - Dictionary mapping class labels to centroids
    """
    unique_classes = np.unique(y)
    centroids = {}
    
    for cls in unique_classes:
        class_samples = X[y == cls]
        centroids[cls] = np.mean(class_samples, axis=0)
    
    return centroids

def project_to_rgb(centroids):
    """
    Project centroids to RGB color space.
    
    Parameters:
    - centroids: Dictionary mapping class labels to centroids
    
    Returns:
    - Dictionary mapping class labels to RGB colors
    """
    labels = list(centroids.keys())
    centroid_matrix = np.array([centroids[label] for label in labels])
    
    # Use PCA to reduce to 3 dimensions
    if centroid_matrix.shape[1] > 3:
        pca = PCA(n_components=3)
        reduced_centroids = pca.fit_transform(centroid_matrix)
    else:
        reduced_centroids = centroid_matrix
    
    # Scale to [0, 1] range for RGB
    scaler = MinMaxScaler(feature_range=(0, 1))
    rgb_values = scaler.fit_transform(reduced_centroids)
    
    # Create dictionary mapping labels to RGB colors
    colors = {label: rgb_values[i] for i, label in enumerate(labels)}
    return colors

In [3]:
def get_class_colors(X, y):
    """
    Main function to compute and assign colors to classes based on centroid similarity.
    
    Parameters:
    - X: feature matrix
    - y: target values
    - target_names: Optional mapping from class labels to human-readable names
    - visualize: Whether to visualize the results
    
    Returns:
    - Dictionary mapping class labels to RGB colors
    """
    # Compute centroids
    centroids = compute_centroids(X, y)
    
    # Project centroids to RGB space
    colors = project_to_rgb(centroids)
    
    return colors, centroids

In [4]:
from sklearn.datasets import load_iris

# Load dataset
data = load_iris()
X = data.data
y = data.target
target_names = data.target_names

# Get colors
colors, centroids = get_class_colors(X, y)

# Print RGB values for each class
print("\nAssigned RGB Colors:")
for label, color in colors.items():
    print(f"{target_names[label]}: RGB{tuple(np.round(color, 3))} (Hex: {rgb2hex(color)})")


Assigned RGB Colors:
setosa: RGB(1.0, 0.205, 0.0) (Hex: #ff3400)
versicolor: RGB(0.329, 1.0, 0.0) (Hex: #54ff00)
virginica: RGB(0.0, 0.0, 0.0) (Hex: #000000)


In [5]:
from tensorflow.keras.datasets import mnist

# Load MNIST dataset
(X_train, y_train), (_, _) = mnist.load_data()

# Flatten images
X_train = X_train.reshape(X_train.shape[0], -1)

colors, centroids = get_class_colors(X_train, y_train)

# Print RGB values for each class
print("\nAssigned RGB Colors:")
for label, color in colors.items():
    print(f"{label}: RGB{tuple(np.round(color, 3))} (Hex: {rgb2hex(color)})")


Assigned RGB Colors:
0: RGB(1.0, 0.253, 0.208) (Hex: #ff4135)
1: RGB(0.0, 1.0, 0.307) (Hex: #00ff4e)
2: RGB(0.437, 0.718, 0.725) (Hex: #70b7b9)
3: RGB(0.451, 0.766, 0.0) (Hex: #73c300)
4: RGB(0.19, 0.0, 0.567) (Hex: #300091)
5: RGB(0.47, 0.477, 0.127) (Hex: #787a20)
6: RGB(0.458, 0.429, 1.0) (Hex: #756dff)
7: RGB(0.115, 0.017, 0.087) (Hex: #1d0416)
8: RGB(0.353, 0.608, 0.257) (Hex: #5a9b42)
9: RGB(0.146, 0.023, 0.298) (Hex: #25064c)
