In [1]:
# !pip install SQLAlchemy
import ast
import numpy as np
import pandas as pd
import plotly.graph_objects as go

from collections import Counter 
# from environment.settings import config
from numpy import matlib
from sklearn.cluster import KMeans
from sqlalchemy import select, text, and_, or_
from sqlalchemy.sql import Select
from typing import Tuple
from utils import connections
from utils import database

from tqdm.notebook import tqdm
from copy import deepcopy
import warnings

# database_dir = config['DATABASE_DIR']
# dataset_dir = config['DATASET_DIR']

dataset_dir = '../data/'

NUM_COLORS = 30

# kmeans_model = KMeans(n_clusters=NUM_COLORS, n_init='auto')

Functions

In [2]:
def rgb2hex(rgb: np.ndarray):
    ''' Converts an N X 3 numpy array of RGB values into a list of hex strings'''
    hex_list = list(map(lambda x: '#%02x%02x%02x' % tuple(x), rgb))
    return hex_list

def extract_colors(model: KMeans, img: np.ndarray) -> Tuple[np.ndarray, Counter, np.ndarray]:
    ''' Extract the X most common colors from an image with a KMeans model '''       
    cluster_labels = model.fit_predict(img)
    return cluster_labels.astype(np.uint8), Counter(cluster_labels), model.cluster_centers_.astype(np.uint8)

We need to group the artworks by:
- Artist
- Movement
- Century
- Country

In [3]:
# Artist = pd.read_csv(dataset_dir+'Artist.csv')
# ArtistMovement = pd.read_csv(dataset_dir+'ArtistMovements.csv')
# Artwork = pd.read_csv(dataset_dir+'Artwork.csv')
# Movement = pd.read_csv(dataset_dir+'Movement.csv')
# ? Read Stratos' file
movement_artist_artwork = pd.read_csv(dataset_dir+'movement_artist_artwork.csv', index_col=0)
movement_artist_artwork.dropna(subset=['artwork_id'], inplace=True)
movement_artist_artwork.artist_id = movement_artist_artwork.artist_id.astype(int)
movement_artist_artwork.artwork_id = movement_artist_artwork.artwork_id.astype(int)

artwork_century = pd.read_csv(dataset_dir+'artwork_century.csv', index_col=0)
artwork_century.dropna(subset=['century'], inplace=True)
artwork_century.century = artwork_century.century.astype(int)
artwork_century.artwork_id = artwork_century.artwork_id.astype(int)

place_artwork = pd.read_csv(dataset_dir+'place_artwork.csv', index_col=0)
place_artwork.dropna(subset=['place_id','artwork_id'], inplace=True)
place_artwork.place_id = place_artwork.place_id.astype(int)
place_artwork.artwork_id = place_artwork.artwork_id.astype(int)

In [4]:
movement_artist_artwork

Unnamed: 0,artist_id,artwork_id,artist_name,painting_name,movement_id,name
0,0,0,Vincent Van Gogh,Cafe Terrace on the Place du Forum,43.0,Post-Impressionism
1,0,1,Vincent Van Gogh,Starry Night,43.0,Post-Impressionism
2,0,2,Vincent Van Gogh,A Digger,43.0,Post-Impressionism
3,0,3,Vincent Van Gogh,A Group of Cottages,43.0,Post-Impressionism
4,0,4,Vincent Van Gogh,A Pair of Shoes,43.0,Post-Impressionism
...,...,...,...,...,...,...
11929,615,10515,Otto van Veen,Distribution of Herring and White Bread during...,,
11930,615,10516,Otto van Veen,"The Artist Painting, Surrounded by his Family",,
11931,615,10517,Otto van Veen,Hercules seated at the foot of a tree in a lan...,,
11932,615,10518,Otto van Veen,St. Matthew Bringing Back To Life The Son Of T...,,


Cluster by artist

In [5]:
warnings.filterwarnings("error")
# dictionary to hold the cluster counts and rgb colors for each artist 
superclusters_per_category={
    'artist':{},
    'movement':{},
    'century':{},
    'country':{}
}

tmp= {
    'id': [],
    'cluster_counts': [],
    'rgb_colors': [],
    'name':[]
}

dataframes={
    'artist':movement_artist_artwork.drop(['name'], axis=1)
            .rename(columns={"artist_id": "id", "artist_name":"name"}),
    'movement':movement_artist_artwork.rename(columns={"movement_id": "id"}).dropna(subset=['id']),
    'century':artwork_century.rename(columns={"century": "id"}).assign(name=artwork_century['century']),
    'country':place_artwork.rename(columns={"place_id": "id", "place_name":"name"})
}

# superclusters_per_category['century']['century']=superclusters_per_category['century'].pop('id')
# superclusters_per_category['country']['country']=superclusters_per_category['country'].pop('id')

for key, value in tqdm(superclusters_per_category.items()):
    superclusters_per_category[key]=deepcopy(tmp)
    # Iterate over all artists
    for ids in tqdm(dataframes[key]['id'].unique(),desc=key):
        artworks = dataframes[key][dataframes[key]['id'] == ids]['artwork_id'].to_list()
        query = select(database.Colors).filter(database.Colors.id.in_(artworks))

        with connections.session_db() as con:
            df_colors = pd.read_sql_query(sql=query, con=con.connection())
            df_colors.rgb_colors = df_colors.rgb_colors.apply(lambda x: ast.literal_eval(x))
            df_colors.cluster_counts = df_colors.cluster_counts.apply(lambda x: ast.literal_eval(x))

        # ? Check that the category has artworks
        if df_colors.empty:
            print(key)
            print(ids, artworks)
            continue

        # ? Iterate over each artist's artworks
        color_percentage_list = []
        for i, row in df_colors.iterrows():
            cluster_counts, rgb_colors = row.cluster_counts, row.rgb_colors

            labels = rgb2hex(rgb_colors)
            values = list(map(lambda x: x[1], sorted(cluster_counts.items())))
            # make sure that the percentages add up to 100
            percentages = list(map(lambda x: x / sum(values) * 100, values))
            # cumulative sum of the percentages
            cum_percentages = np.cumsum(percentages).round(0)
            # shift the percentages right by 1 and make the first element 0 (from 100)
            rolled_cp = np.roll(cum_percentages, 1)
            rolled_cp[0] = 0
            # substract the shifted percentages from the cumulative percentages
            # the result is the size of percentage of each color
            # we do this to ensure that percentages always add up to 100
            percentages = (cum_percentages - rolled_cp).astype(int)
            # get a 100 * 3 array of the colors
            color_percentages = [matlib.repmat(a=color, m=percentage, n=1) for color, percentage in zip(rgb_colors, percentages)] 
            color_percentage_list.append(np.vstack(color_percentages))

        category_colors = np.vstack(color_percentage_list)
        try:
            _, cluster_counts, rgb_colors = extract_colors(KMeans(n_clusters=NUM_COLORS, n_init='auto'), category_colors)
        except Warning as e:
            print('warning at id',ids)
            print(e)

        # ? Append the results to the dictionary
        superclusters_per_category[key]['id'].append(ids)
        superclusters_per_category[key]['cluster_counts'].append(cluster_counts)
        superclusters_per_category[key]['rgb_colors'].append(rgb_colors)
        superclusters_per_category[key]['name'].append(dataframes[key][dataframes[key]['id'] == ids]['name'].to_list()[0])
    
warnings.resetwarnings()

  0%|          | 0/4 [00:00<?, ?it/s]

artist:   0%|          | 0/610 [00:00<?, ?it/s]

Number of distinct clusters (27) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (26) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (28) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (27) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (25) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters 

movement:   0%|          | 0/54 [00:00<?, ?it/s]

century:   0%|          | 0/7 [00:00<?, ?it/s]

country:   0%|          | 0/145 [00:00<?, ?it/s]

Number of distinct clusters (28) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.
Number of distinct clusters (29) found smaller than n_clusters (30). Possibly due to duplicate points in X.


In [6]:
from IPython.display import display

for key, value in superclusters_per_category.items():
    df = pd.DataFrame(value)  
    df.cluster_counts = df.cluster_counts.astype(str).str.lstrip('Counter(').str.rstrip(')')
    df.rgb_colors = df.rgb_colors.apply(lambda x: str([list(a) for a in x]))
    df.to_csv('./'+key+'_supercluster.csv')