In [1]:

from Pylette import extract_colors
from matplotlib import pyplot as plt
import numpy as np
from pyemd import emd
import os
from collections import Counter
from colorsys import rgb_to_hsv
import cv2


In [23]:

def edge_density(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    return np.sum(edges > 0) / edges.size

def shannon_entropy(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
    hist_norm = hist.ravel() / hist.sum()
    hist_norm = hist_norm[hist_norm > 0]
    entropy = -np.sum(hist_norm * np.log2(hist_norm))
    return entropy

def brightness(img):
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    return np.mean(hsv[:,:,2])

def extract_palette_rgb(image_path, num_colors=10, mode="MedianCut"):
    # Read image
    img = cv2.imread(image_path)

    # Image-level features
    img_brightness = brightness(img)
    
    # Extract colors
    colors = extract_colors(image=img, palette_size=num_colors, mode=mode)
    
    # Convert RGB to HSV
    rgb_array = np.array([color.rgb for color in colors], dtype=np.float64)
    
    
    # Augment with brightness and edge density
    augmented_array = np.hstack([
        rgb_array,  # shape Nx3
        np.full((len(rgb_array), 1), img_brightness),  # shape Nx1
    ])  # final shape Nx5

    # Frequencies
    freq_array = np.array([color.freq for color in colors], dtype=np.float64)
    freq_array /= freq_array.sum()

    # Padding if fewer colors found
    n = len(augmented_array)
    if n < num_colors:
        pad_len = num_colors - n
        augmented_array = np.vstack([augmented_array, np.zeros((pad_len, 5))])
        freq_array = np.hstack([freq_array, np.zeros(pad_len)])

    return augmented_array, freq_array


def visualize_palette(rgb_array, freq_array):
    # Normalize RGB to [0, 1]
    rgb_array = rgb_array / 255.0

    # Create an image strip using frequency as width proportions
    total_width = 800
    height = 100
    img = np.zeros((height, total_width, 3))

    start = 0
    for color, freq in zip(rgb_array, freq_array):
        width = int(freq * total_width)
        end = start + width
        img[:, start:end, :] = color
        start = end

    plt.figure(figsize=(8, 2))
    plt.imshow(img)
    plt.axis("off")
    plt.show()



In [24]:
print(extract_palette_rgb('train_data/dark_forest/2036_2035_6102.png'))


(array([[ 11.        ,  39.        ,  21.        , 115.25852045],
       [255.        , 183.        , 145.        , 115.25852045],
       [ 22.        ,  75.        ,  42.        , 115.25852045],
       [253.        , 190.        , 155.        , 115.25852045],
       [ 27.        ,  92.        ,  51.        , 115.25852045],
       [ 30.        , 103.        ,  55.        , 115.25852045],
       [ 33.        , 118.        ,  62.        , 115.25852045],
       [ 38.        , 107.        ,  66.        , 115.25852045],
       [ 84.        ,  94.        ,  68.        , 115.25852045],
       [152.        , 137.        , 115.        , 115.25852045]]), array([0.5       , 0.25      , 0.125     , 0.0625    , 0.03125   ,
       0.015625  , 0.0078125 , 0.00390625, 0.00195312, 0.00195312]))


In [25]:
# ----------------------------
# Step 2: Compute EMD distance
# ----------------------------
def palette_emd(palette1, palette2, frequency1, frequency2):
    """
    Compute order-independent EMD between two RGB palettes.
    Assumes both palettes have same number of colors (10 here).
    """
    n = len(palette1)
    # Distance matrix (Euclidean in RGB)
    dist_matrix = np.linalg.norm(palette1[:, None, :] - palette2[None, :, :], axis=2).astype(np.float64)
    return emd(frequency1, frequency2, dist_matrix)

# ----------------------------
# Step 3: KNN classifier
# ----------------------------
class EMDBasedKNN:
    def __init__(self, n_neighbors=3):
        self.n_neighbors = n_neighbors
        self.palettes = []
        self.labels = []

    def fit(self, palettes, labels,):
        self.palettes = palettes
        self.labels = labels

    def predict(self, new_palette):
        new_rgb, new_freq = new_palette
        # Compute distances to all training palettes
        distances = [palette_emd(new_rgb, p[0], new_freq, p[1]) for p in self.palettes]
        
        # Get indices of k nearest neighbors
        sorted_idx = np.argsort(distances)[:self.n_neighbors]
        neighbor_labels = [self.labels[i] for i in sorted_idx]
        neighbor_distances = [distances[i] for i in sorted_idx]
        
        # Weighted vote: closer neighbors count more
        counter = Counter()
        for label, dist in zip(neighbor_labels, neighbor_distances):
            counter[label] += 1 / (dist + 1e-6)  # avoid division by zero

        # Return the label with highest weighted vote
        return counter.most_common(1)[0][0]
    
    def score(self, X, y):
        """
        Compute accuracy on a dataset.
        X: list of palettes (rgb_array, freq_array)
        y: list of true labels
        """
        correct = 0
        for palette, true_label in zip(X, y):
            pred = self.predict(palette)
            if pred == true_label:
                correct += 1
        return correct / len(y)

In [26]:
def load_training_data(train_folder):
    palettes = []
    labels = []
    for biome_name in os.listdir(train_folder):
        biome_path = os.path.join(train_folder, biome_name)
        if not os.path.isdir(biome_path):
            continue
        for img_file in os.listdir(biome_path):
            if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                img_path = os.path.join(biome_path, img_file)
                palette = extract_palette_rgb(img_path)
                palettes.append(palette)
                labels.append(biome_name)
    return palettes, labels

In [27]:
from sklearn.model_selection import train_test_split

In [28]:
X, y = load_training_data('train_data')

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # you can change test_size
KNN = EMDBasedKNN(n_neighbors=4)
KNN.fit(X_train, y_train)


In [29]:
print("Train Accuracy:", KNN.score(X_train, y_train))
print("Test Accuracy:", KNN.score(X_test, y_test))

Train Accuracy: 0.14839797639123103
Test Accuracy: 0.1111111111111111
