In [1]:
import io
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd
import plotly.graph_objects as go
import requests

from collections import Counter 
from environment.settings import config
from numpy import matlib
from numba import njit
from PIL import Image
from plotly.subplots import make_subplots
from sklearn.cluster import KMeans 
from tqdm import tqdm
from typing import Tuple

dataset_dir = config['DATASET_DIR']
database_dir = config['DATABASE_DIR']

# Number of colors to be extracted
NUM_COLORS = 10

In [26]:
Artwork = pd.read_csv(dataset_dir+'Artwork.csv')
Artist = pd.read_csv(dataset_dir+'Artist.csv')
df_temp = Artwork.reset_index().rename(columns={'index':'artwork_id'})

In [35]:
df1 = pd.read_pickle(f'{dataset_dir}/cluster_data_small.pkl')
df2 = pd.read_pickle(f'{dataset_dir}/cluster_data_small2.pkl')

df = pd.concat([df1, df2], ignore_index=True)
# df['id'] = df.merge(df_temp, left_on='id', right_on='artwork_id').id_y
# df.cluster_counts = df.cluster_counts.astype(str).str.lstrip('Counter(').str.rstrip(')')
# df.rgb_colors = df.rgb_colors.apply(lambda x: str([list(a) for a in x]))[0]
# df.to_csv(f'{dataset_dir}/cluster_data_artworks.csv', index=False)

In [None]:
def url2array(url: str):
    ''' Gets a URL and returns an image as a numpy array '''
    response = requests.get(url)
    img = Image.open(io.BytesIO(response.content))
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return np.array(img)

@njit
def reshape_image(img):
    ''' Reshapes an image to MxNx3 as input for the KMeans model '''
    if len(img.shape) == 3:
        img2d = img.reshape((img.shape[0]*img.shape[1], 3))
    else:
        img2d = np.stack((img, img, img), axis=2).reshape(img.shape[0]*img.shape[1], 3)
    
    return img2d

def extract_colors(model, img):
    ''' Extract the X most common colors from an image with a KMeans model '''       
    cluster_labels = model.fit_predict(img)
    return cluster_labels.astype(np.uint8), Counter(cluster_labels), model.cluster_centers_.astype(np.uint8)

In [None]:
Artwork = pd.read_csv(dataset_dir+'Artwork.csv')
kmeans_model = KMeans(n_clusters=NUM_COLORS, n_init='auto')

# data_dict = {'id': [], 'cluster_labels': [], 
#              'cluster_counts': [], 'rgb_colors': []}

data_dict = {'id': [], 'cluster_counts': [], 'rgb_colors': []}
urls = Artwork.image_url
for i, url in tqdm(enumerate(urls[6056:])):
    img = url2array(url)
    img2d = reshape_image(img)
    cluster_labels, cluster_counts, rgb_colors = extract_colors(kmeans_model, img2d)
    # ? Add data to the dictionary
    data_dict['id'].append(i+6056)
    # data_dict['cluster_labels'].append(cluster_labels)
    data_dict['cluster_counts'].append(cluster_counts)
    data_dict['rgb_colors'].append(rgb_colors)

In [None]:
pd.DataFrame(data_dict).to_pickle(dataset_dir+'cluster_data.pkl')

In [None]:
df = pd.DataFrame(data_dict)
df[['id', 'cluster_counts', 'rgb_colors']].to_pickle(dataset_dir+'cluster_data_small2.pkl')