In [95]:
import numpy as np
import pandas as pd
import requests
from PIL import Image, ImageColor
from sklearn.cluster import KMeans, MiniBatchKMeans

In [79]:
path = r'./ramadan_temp_comp.png'

In [3]:
def online_image_loader(url):

    """
    Load an image from a given URL.

    Args:
        url: A string representing the URL of the image to be loaded.

    Returns:
        An Image object representing the loaded image.

    Raises:
        Exception: If the image could not be loaded from the given URL.

    Example:
        >>> image = online_image_loader('https://example.com/image.jpg')
    """
    
    response = requests.get(url, stream=True).raw
    try:
        image = Image.open(response)
    except:
        raise Exception('Image not found, try another URL')  
    return image

In [80]:
def local_image_loader(path):
    
    """
    Load an image from a given file path.

    Args:
        path: A string representing the file path of the image to be loaded.

    Returns:
        An Image object representing the loaded image.

    Raises:
        Exception: If the image could not be loaded from the given file path.

    Example:
        >>> image = local_image_loader('/path/to/image.jpg')
    """

    try:
        image = Image.open(path)
    except:
        raise Exception('Image not found, try another path')
    return image

In [81]:
def flatten_image(image):

    """
    Flatten an image into a pandas DataFrame.

    Args:
        image: An Image object representing the image to be flattened.

    Returns:
        A pandas DataFrame representing the flattened image. Each column in the
        DataFrame represents a color channel of the image.

    Example:
        >>> image = Image.open('/path/to/image.jpg')
        >>> flattened_image = flatten_image(image)
    """
    
    df = pd.DataFrame(list(image.getdata()),columns=list(image.getbands()))
    return df

In [82]:
img_local = local_image_loader(path)
# url = input('Enter image url: ')
# img_online = online_image_loader(url)

In [83]:
flat_local = flatten_image(img_local)
# flat_online = flatten_image(img_online)

In [8]:
class Clusterer():
    
    """
    A class for clustering RGB color values using different clustering algorithms.

    Attributes:
        df: A pandas DataFrame containing the RGB values of each color to be clustered.
        models: A dictionary storing the clustering models created by the class methods.

    Methods:
        KM: Create a KMeans clustering model and store it in the models dictionary.
        MBKM: Create a MiniBatchKMeans clustering model and store it in the models dictionary.

    Example:
        >>> clusterer = Clusterer(df)
        >>> km_model = clusterer.KM()
        >>> mbkm_model = clusterer.MBKM()
    """

    def __init__(self, df):

        """
        Initialize a new Clusterer object.

        Args:
            df: A pandas DataFrame containing the RGB values of each color to be clustered.
        """

        self.df = df
        self.models = {}

    def KM(self):

        """
        Create a KMeans clustering model and store it in the models dictionary.

        Returns:
            The KMeans clustering model.

        Example:
            >>> km_model = clusterer.KM()
        """

        self.models['KMeans'] = KMeans(n_clusters=8,n_init='auto').fit(self.df)
        return self.models['KMeans']
    
    def MBKM(self):

        """
        Create a MiniBatchKMeans clustering model and store it in the models dictionary.

        Returns:
            The MiniBatchKMeans clustering model.

        Example:
            >>> mbkm_model = clusterer.MBKM()
        """

        self.models['MiniBatchKMeans'] = MiniBatchKMeans(n_clusters=8,n_init='auto').fit(self.df)
        return self.models['MiniBatchKMeans']

In [103]:
def generate_rgb_palette(model,norm=False):
    
    """
    Generate a list of RGB color tuples based on the cluster centers of a given model.

    Args:
        model: A clustering model with a `cluster_centers_` attribute containing
            the RGB values of the cluster centers. This should be a NumPy array
            with shape (n_clusters, 3).
        norm (bool): Whether to normalize the RGB values to the range [0, 1] by
            dividing each value by 255. Default is False.

    Returns:
        A list of RGB color tuples. Each tuple has three integers between 0 and 255
        representing the red, green, and blue components of the color.

    Example:
        >>> from sklearn.cluster import KMeans
        >>> model = KMeans(n_clusters=5, random_state=0).fit(X)
        >>> generate_rgb_palette(model)
        [(255, 51, 0), (51, 153, 255), (255, 204, 0), (0, 102, 51), (255, 153, 204)]
    """
    
    centers = model.cluster_centers_.astype(int)
    luminance = [0.2126*r + 0.7152*g + 0.0722*b for r, g, b in centers]
    sorted_indices = np.argsort(luminance)
    centers_sorted = centers[sorted_indices]

    if norm:
        palette = centers_sorted[:8] / 255.0
    else:
        palette = centers_sorted[:8].round().astype(int)
    
    palette_tuple = [tuple(p) for p in palette]
    return palette_tuple

def rgb_to_hex(rgb_value):

    """
    Convert an RGB color tuple to its hexadecimal representation.

    Args:
        rgb_value (tuple): A tuple containing three integers between 0 and 255
            representing the red, green, and blue components of a color.

    Returns:
        A string containing the hexadecimal representation of the color. The
        string will start with a '#' character followed by six hexadecimal digits
        representing the red, green, and blue components of the color.

    Example:
        >>> rgb_to_hex((255, 0, 0))
        '#ff0000'
        >>> rgb_to_hex((0, 128, 255))
        '#0080ff'
    """

    r, g, b = rgb_value[0], rgb_value[1], rgb_value[2]
    return '#{:02x}{:02x}{:02x}'.format(r, g, b)

def generate_hex_palette(model):

    """
    Generate a list of hexadecimal color codes based on the cluster centers of a given model.

    Args:
        model: A clustering model with a `cluster_centers_` attribute containing
            the RGB values of the cluster centers. This should be a NumPy array
            with shape (n_clusters, 3).

    Returns:
        A list of hexadecimal color codes. Each code is a string representing a
        6-digit hexadecimal number in the format "#RRGGBB", where RR, GG, and BB
        are two-digit hexadecimal numbers representing the red, green, and blue
        components of the color.

    Example:
        >>> from sklearn.cluster import KMeans
        >>> model = KMeans(n_clusters=5, random_state=0).fit(X)
        >>> generate_hex_palette(model)
        ['#ff3300', '#3399ff', '#ffcc00', '#006633', '#ff99cc']
    """
    
    centers = model.cluster_centers_.astype(int)
    luminance = [0.2126*r + 0.7152*g + 0.0722*b for r, g, b in centers]
    sorted_indices = np.argsort(luminance)
    centers_sorted = centers[sorted_indices]
    palette = centers_sorted[:8].round().astype(int)
    palette = [rgb_to_hex(p) for p in palette]
    return palette

def hex_to_rgb(hex_value):

    """
    Convert a hexadecimal color code to an RGB color tuple.

    Args:
        hex_value: A string representing a 6-digit hexadecimal number in the
            format "#RRGGBB", where RR, GG, and BB are two-digit hexadecimal
            numbers representing the red, green, and blue components of the color.

    Returns:
        A tuple of three floats between 0 and 1 representing the red, green, and
        blue components of the color.

    Example:
        >>> hex_to_rgb('#ff3300')
        (1.0, 0.2, 0.0)
    """

    h = hex_value.lstrip('#')
    return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4))

In [129]:
models = [Clusterer(flat_local).KM(), Clusterer(flat_local).MBKM()]

In [145]:
def ensemble_palettes(models):

    """ 
    Generate a palette of colors by averaging the cluster centers of multiple models.

    Args:
        models: A list of clustering models with a `cluster_centers_` attribute containing
        the RGB values of the cluster centers. This should be a NumPy array

    Returns:
        A list of RGB color tuples and a list of hexadecimal color codes.
    """
    
    rgb_palette = np.mean([generate_rgb_palette(model) for model in models], axis=0).astype(int)
    hex_palette = [rgb_to_hex(p) for p in rgb_palette]
    rgb_palette = [tuple(p) for p in rgb_palette]
    return rgb_palette, hex_palette

In [146]:
rgb_palette, hex_palette = ensemble_palettes(models)

In [147]:
for color in rgb_palette:
    print(color)

(9, 11, 13)
(65, 40, 34)
(15, 81, 83)
(101, 66, 43)
(164, 115, 65)
(225, 171, 84)
(239, 207, 136)
(248, 235, 207)


In [155]:
hex_palette

['#090b0d',
 '#412822',
 '#0f5153',
 '#65422b',
 '#a47341',
 '#e1ab54',
 '#efcf88',
 '#f8ebcf']

In [148]:
for color in hex_palette:
    print(color)

#090b0d
#412822
#0f5153
#65422b
#a47341
#e1ab54
#efcf88
#f8ebcf


In [149]:
def visualize_palette_on_image(palette, image, mode='RGB'):

    """
    Visualizes the given color palette on top of the input image.

    Args:
    - palette (list of tuples or list of strings): a list of colors in RGB tuples (0-255) or hex string format (#RRGGBB)
    - image (PIL.Image): the input image to which the palette will be applied
    - mode (str, optional): the color mode of the palette, either 'RGB' or 'HEX'. Default is 'RGB'.

    Returns:
    - new_image (PIL.Image): a new image with the given color palette stacked on top of the input image

    Raises:
    - AssertionError: if the width and height of the palette image and the input image do not match
    - ValueError: if the mode is not 'RGB' or 'HEX'
    """
    
    if mode not in ['RGB', 'HEX']:
        raise ValueError("mode must be either 'RGB' or 'HEX'")
    
    new_width, _ = image.size
    size = (len(palette), 1)
    palette_img = Image.new('RGB', size, color='white')
    
    pixels = palette_img.load()
    for i in range(len(palette)):
        if mode == 'RGB':
            pixels[i, 0] = palette[i]
        elif mode == 'HEX':
            pixels[i, 0] = ImageColor.getrgb(palette[i])
    
    palette_image = np.asarray(palette_img.resize((new_width, int(new_width/8)), 
                                                resample= Image.Resampling.NEAREST))
    image = np.asarray(image)

    assert palette_image.shape[1] == image.shape[1], "Image dimension mismatch"
    assert palette_image.shape[2] == image.shape[2], "Image dimension mismatch"
    concat_image = np.concatenate((image,palette_image), axis=0)
    new_image = Image.fromarray(concat_image)
    
    return new_image

In [157]:
def visualize_palette(palette, mode='RGB'):

    """ 
    Takes in a list of RGB tuples or HEX strings representing a color palette, 
    and returns an image object that visualizes the palette.

    Args:

    - palette (list): A list of RGB tuples or HEX strings representing the colors in the palette.
    - mode (str): The mode in which to interpret the color values. Valid values are 'RGB' and 'HEX'. Default is 'RGB'.
    
    Returns:
    palette_image (PIL.Image.Image): An image object that visualizes the color palette, with each color in the palette represented by a single pixel. 
    The image has a width of 1024 pixels and a height of 128 pixels (i.e. 8 pixels tall, with each color being represented by a row of 128 pixels). 
    If mode is 'RGB', the color values are interpreted as RGB tuples. If mode is 'HEX', the color values are interpreted as HEX strings. 

    Raises:
    - ValueError: if the mode is not 'RGB' or 'HEX'
    """
    
    if mode not in ['RGB', 'HEX']:
        raise ValueError("mode must be either 'RGB' or 'HEX'")
    
    size = (len(palette), 1)
    palette_img = Image.new('RGB', size, color='white')
    
    pixels = palette_img.load()
    
    for i in range(len(palette)):
        if mode == 'RGB':
            pixels[i, 0] = palette[i]
        elif mode == 'HEX':
            pixels[i, 0] = ImageColor.getrgb(palette[i])
    
    palette_image = palette_img.resize((1024, int(1024/8)), resample= Image.Resampling.NEAREST)
    
    return palette_image

In [156]:
result = visualize_palette_on_image(hex_palette, img_local,mode='HEX')
result.show()

In [152]:
pal_image = visualize_palette(hex_palette, mode='HEX')
pal_image.show()