In [1]:
import ast
import numpy as np
import pandas as pd
import plotly.graph_objects as go

from collections import Counter 
from environment.settings import config
from numpy import matlib
from sklearn.cluster import KMeans
from sqlalchemy import select, text, and_, or_
from sqlalchemy.sql import Select
from typing import Tuple
from utils import connections
from utils import database

database_dir = config['DATABASE_DIR']
dataset_dir = config['DATASET_DIR']

NUM_COLORS = 10

kmeans_model = KMeans(n_clusters=NUM_COLORS, n_init='auto')

Functions

In [2]:
def rgb2hex(rgb: np.ndarray):
    ''' Converts an N X 3 numpy array of RGB values into a list of hex strings'''
    hex_list = list(map(lambda x: '#%02x%02x%02x' % tuple(x), rgb))
    return hex_list

def extract_colors(model: KMeans, img: np.ndarray) -> Tuple[np.ndarray, Counter, np.ndarray]:
    ''' Extract the X most common colors from an image with a KMeans model '''       
    cluster_labels = model.fit_predict(img)
    return cluster_labels.astype(np.uint8), Counter(cluster_labels), model.cluster_centers_.astype(np.uint8)

We need to group the artworks by:
- Artist
- Movement
- Century
- Country

In [3]:
# Artist = pd.read_csv(dataset_dir+'Artist.csv')
# ArtistMovement = pd.read_csv(dataset_dir+'ArtistMovements.csv')
# Artwork = pd.read_csv(dataset_dir+'Artwork.csv')
# Movement = pd.read_csv(dataset_dir+'Movement.csv')
# ? Read Stratos' file
df = pd.read_csv(dataset_dir+'movement_artist_artwork.csv', index_col=0)
df.dropna(subset=['artwork_id'], inplace=True)
df.artist_id = df.artist_id.astype(int)
df.artwork_id = df.artwork_id.astype(int)

Cluster by artist

In [17]:
# dictionary to hold the cluster counts and rgb colors for each artist 
artist_colors_dict = {
    'artist_id': [],
    'cluster_counts': [],
    'rgb_colors': []
}
# Iterate over all artists
for artist in df.artist_id.unique():
    artworks = df[df.artist_id == artist].artwork_id.to_list()
    query = select(database.Colors)\
            .filter(database.Colors.id.in_(artworks))
    
    with connections.session_db() as con:
        df_colors = pd.read_sql_query(sql=query, con=con.connection())
        df_colors.rgb_colors = df_colors.rgb_colors.apply(lambda x: ast.literal_eval(x))
        df_colors.cluster_counts = df_colors.cluster_counts.apply(lambda x: ast.literal_eval(x))
    
    # ? Check that the artist has artworks
    if df_colors.empty:
        print(artist, artworks)
    
    # ? Iterate over each artist's artworks
    color_percentage_list = []
    for i, row in df_colors.iterrows():
        cluster_counts, rgb_colors = row.cluster_counts, row.rgb_colors
        
        labels = rgb2hex(rgb_colors)
        values = list(map(lambda x: x[1], sorted(cluster_counts.items())))
        # make sure that the percentages add up to 100
        percentages = list(map(lambda x: x / sum(values) * 100, values))
        # cumulative sum of the percentages
        cum_percentages = np.cumsum(percentages).round(0)
        # shift the percentages right by 1 and make the first element 0 (from 100)
        rolled_cp = np.roll(cum_percentages, 1)
        rolled_cp[0] = 0
        # substract the shifted percentages from the cumulative percentages
        # the result is the size of percentage of each color
        # we do this to ensure that percentages always add up to 100
        percentages = (cum_percentages - rolled_cp).astype(int)
        # get a 100 * 3 array of the colors
        color_percentages = [matlib.repmat(a=color, m=percentage, n=1) for color, percentage in zip(rgb_colors, percentages)] 
        color_percentage_list.append(np.vstack(color_percentages))
        
    artist_colors = np.vstack(color_percentage_list)
    _, cluster_counts, rgb_colors = extract_colors(kmeans_model, artist_colors)
    
    # ? Append the results to the dictionary
    artist_colors_dict['artist_id'].append(artist)
    artist_colors_dict['cluster_counts'].append(cluster_counts)
    artist_colors_dict['rgb_colors'].append(rgb_colors)
    break

In [31]:
df_cluster = pd.DataFrame(artist_colors_dict)
df_cluster.cluster_counts = df_cluster.cluster_counts.astype(str).str.lstrip('Counter(').str.rstrip(')')
df_cluster.rgb_colors = df_cluster.rgb_colors.apply(lambda x: str([list(a) for a in x]))[0]