In [None]:
import io
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import plotly.graph_objects as go
import requests

from collections import Counter 
from environment.settings import config
from numpy import matlib
from PIL import Image
from plotly.subplots import make_subplots
from sklearn.cluster import KMeans 
from typing import Tuple

dataset_dir = config['DATASET_DIR']
database_dir = config['DATABASE_DIR']

# Number of colors to be extracted
NUM_COLORS = 10

Convenience functions

In [2]:
('10px ' * 4).rstrip()

'10px 10px 10px 10px'

In [None]:
def hex2rgb(hex: str) -> Tuple[int, int, int]:
    """Converts a hex color to rgb"""
    hex = hex.lstrip('#')
    return tuple(int(hex[i:i+2], 16) for i in (0, 2, 4))

In [None]:
def rgb2hex(rgb: np.ndarray):
    ''' Converts an N X 3 numpy array of RGB values into a list of hex strings'''
    hex_list = list(map(lambda x: '#%02x%02x%02x' % tuple(x), rgb))
    return hex_list

def url2img(url: str):
    ''' Gets a URL and returns a PIL image '''
    response = requests.get(url)
    img = Image.open(io.BytesIO(response.content))
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

def url2array(url: str):
    ''' Gets a URL and returns an image as a numpy array '''
    response = requests.get(url)
    img = Image.open(io.BytesIO(response.content))
    return np.array(img)

def reshape_image(img: np.ndarray) -> np.ndarray:
    ''' Reshapes an image to MxNx3 as input for the KMeans model '''
    if len(img.shape) == 3:
        img2d = img.reshape((img.shape[0]*img.shape[1], 3))
    else:
        img2d = np.stack((img, img, img), axis=2).reshape(img.shape[0]*img.shape[1], 3)
    
    return img2d

def extract_colors(model: KMeans, img: np.ndarray) -> Tuple[np.ndarray, Counter, np.ndarray]:
    ''' Extract the X most common colors from an image with a KMeans model '''       
    cluster_labels = model.fit_predict(img)
    return cluster_labels.astype(np.uint8), Counter(cluster_labels), model.cluster_centers_.astype(np.uint8)

Read the data

In [None]:
Artwork = pd.read_csv(dataset_dir+'Artwork.csv')
Artist = pd.read_csv(dataset_dir+'Artist.csv')
url = Artwork.image_url[123]
img = url2array(url)
# An example of an artist's images (Monet)
image_list = Artwork[Artwork.artist == 0].image_url.apply(url2array).reset_index(drop=True)

Clustering

In [None]:
kmeans_model = KMeans(n_clusters=NUM_COLORS, n_init='auto')
# url = Artwork.image_url[35]
img = image_list[20]

img2d = reshape_image(img)
cluster_labels, cluster_counts, rgb_colors = extract_colors(kmeans_model, img2d)
img_quant = np.reshape(rgb_colors[cluster_labels], (img.shape[0], img.shape[1], 3))

"Super" Clustering

In [None]:
color_percentage_list = []
for img in image_list:
    img2d = reshape_image(img)
    cluster_labels, cluster_counts, rgb_colors = extract_colors(kmeans_model, img2d)
    
    labels = rgb2hex(rgb_colors)
    values = values = list(map(lambda x: x[1], sorted(cluster_counts.items())))
    
    # make sure that the percentages add up to 100
    percentages = list(map(lambda x: x / sum(values) * 100, values))
    # cumulative sum of the percentages
    cum_percentages = np.cumsum(percentages).round(0)
    # shift the percentages right by 1 and make the first element 0 (from 100)
    rolled_cp = np.roll(cum_percentages, 1)
    rolled_cp[0] = 0
    # substract the shifted percentages from the cumulative percentages
    # the result is the size of percentage of each color
    # we do this to ensure that percentages always add up to 100
    percentages = (cum_percentages - rolled_cp).astype(int)
    # get a 100 * 3 
    color_percentages = [matlib.repmat(a=color, m=percentage, n=1) for color, percentage in zip(rgb_colors, percentages)] 
    color_percentage_list.append(np.vstack(color_percentages))

artist_colors = np.vstack(color_percentage_list)
cluster_labels, cluster_counts, rgb_colors = extract_colors(kmeans_model, artist_colors)

In [None]:
values

In [None]:
fig = go.Figure()

labels = rgb2hex(rgb_colors)
values = list(map(lambda x: x[1], sorted(cluster_counts.items())))

pie = go.Pie(labels=labels, values=values, hole=.2, marker=dict(colors=labels))
fig.add_trace(pie)
fig.show()
# plt.imshow(img)

In [None]:
fig = make_subplots(rows=1, cols=2, subplot_titles=['Original (Top) & Clustered Image (Bottom)', f'{NUM_COLORS} Most Common Colors'],
                    specs=[[{'type': 'xy'}, {'type': 'domain'}]], column_widths=[0.8, 0.2])

# Stack the images together to display them
stacked_images = np.vstack((img, img_quant))

# Extract the hex values from the RGB colors
labels = rgb2hex(rgb_colors)
# Get the number of occurences for each color
values = list(map(lambda x: x[1], sorted(cluster_counts.items())))
# Pie chart with the colors
pie = go.Pie(labels=labels, values=values, hole=.2, marker=dict(colors=labels))

fig.add_trace(go.Image(z=stacked_images), row=1, col=1)
fig.add_trace(pie, row=1, col=2)

fig.update_layout(width=1024, height=1280)