# Construct K-means clusters in color space

In [5]:
import os
import cv2
import pickle
import numpy as np
from glob import glob 
from sklearn.cluster import MiniBatchKMeans, KMeans
from tqdm.notebook import tqdm
import plotly.graph_objs as go
import matplotlib.pyplot as plt

In [6]:
# Function to obtain all colors

def get_all_colors(root_dir, n_colors=10):
    """Find the colors for each image and collect all colors"""
    
    all_colors = []
    file_color_map = {}
    for folder in os.listdir(root_dir):
        count = 0
        for path in tqdm(glob(root_dir+'/'+folder+'/*.jpg')):
            img = cv2.imread(path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = img.reshape(img.shape[0]*img.shape[1], -1)
            clt = MiniBatchKMeans(n_clusters=n_colors)
            clt.fit(img)
            all_colors.extend(clt.cluster_centers_.round())
            file_color_map.update({path: clt.cluster_centers_.round()})
            count += 1
            if count == 50:
                break
            
    return np.array(all_colors), file_color_map


# Call the function
all_colors, file_color_map = get_all_colors('../../data', n_colors=5)

with open('../../saved_data/13 Jun/all_colors.pkl', 'wb') as f:
    pickle.dump(all_colors, f)

with open('../../saved_data/13 Jun/file_color_map.pkl', 'wb') as f:
    pickle.dump(file_color_map, f)

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

HBox(children=(IntProgress(value=0, max=54), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=442), HTML(value='')))

HBox(children=(IntProgress(value=0, max=268), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=47), HTML(value='')))




HBox(children=(IntProgress(value=0, max=2), HTML(value='')))




HBox(children=(IntProgress(value=0, max=371), HTML(value='')))

HBox(children=(IntProgress(value=0, max=452), HTML(value='')))

In [8]:
# Load all colors 

with open('../../saved_data/13 Jun/all_colors.pkl', 'rb') as f:
    all_colors = pickle.load(f)
    
with open('../../saved_data/13 Jun/file_color_map.pkl', 'rb') as f:
    file_color_map = pickle.load(f)

In [9]:
# Function to convert rgb colors to hex values
def rgb2hex(c):
    return '#{:02x}{:02x}{:02x}'.format(int(c[0]), int(c[1]), int(c[2]))

# Plot colors in 3d space
hex_colors = [rgb2hex(c) for c in all_colors]
x, y, z = all_colors[:, 0], all_colors[:, 1], all_colors[:, 2]

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(color=hex_colors)))
fig.update_layout(autosize=False, width=600, height=600, title='All colors in 3d')
fig.show()

In [10]:
# K means clustering on colors with 10 clusters
# Assign a label to each color

clt = KMeans(n_clusters=5)
clt.fit(all_colors)
centers, labels = clt.cluster_centers_, clt.labels_

fig = go.Figure()
fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(color=labels)))
fig.update_layout(autosize=False, width=600, height=600, title='All colors in 3d')
fig.show()

In [11]:
# Searching mechanism

def color_uid(rgb):
    return ''.join([str(int(i)) for i in rgb])

# Get list of all uids
uids = [color_uid(i) for i in all_colors]
color_label_map = dict(zip(uids, labels))

In [12]:
# Replace every color in file_color_map with labels

for path, colors in file_color_map.items():
    labels = [color_label_map[color_uid(i)] for i in colors]
    file_color_map.update({path: labels})

In [13]:
# Save file_color_map and trained clustering object

with open('../../saved_data/13 Jun/file_color_map.pkl', 'wb') as f:
    pickle.dump(file_color_map, f)
    
with open('../../saved_data/13 Jun/kmeans_clt.pkl', 'wb') as f:
    pickle.dump(clt, f)