In [2]:
import numpy as np
from matplotlib import image
from sklearn.cluster import KMeans
from util.constants import Topic, THUMBNAIL_HEIGHT, THUMBNAIL_WIDTH, THUMBNAIL_PATH
from tqdm import tqdm
from PIL import Image
import numpy as np
import json
import os
import matplotlib.pyplot as plt
# %matplotlib

def rgb_to_hex(rgb):
    return '#%02x%02x%02x' % rgb

def make_image_grid_1_row(imgs):
    w, h = imgs[0].size
    cols = len(imgs)

    grid = Image.new('RGB', size=(w * cols, h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols*w, i//cols*h))
    return np.array(grid)


def make_image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return np.array(grid)


def find_dom_colours_with_percentages(img, clusters):
    _, _, C = img.shape
        
    # resize so kmeans is faster
    img = np.resize(img,(100, 200, C))

    # remove dark colours from image
    img = img[img[:,:,0] + img[:,:,1] + img[:,:,2] > 50]
    
    flat_img = np.reshape(img,(-1,3))

    kmeans = KMeans(n_clusters=clusters, random_state=0)
    kmeans.fit(flat_img)

    dominant_colors = np.array(kmeans.cluster_centers_, dtype='uint8')
    dominant_colors = tuple(map(tuple, dominant_colors))
    dominant_colors = [rgb_to_hex(dom_color) for dom_color in dominant_colors]

    percentages = (np.unique(kmeans.labels_, return_counts=True)[1]) / flat_img.shape[0]
    percentages = [round(perc, 3) for perc in percentages]

    c_and_p_dict = {"colours" : dominant_colors, "perc" : percentages}

    _, ax1 = plt.subplots()
    _, _, m = ax1.pie(percentages, autopct='%1.1f%%', colors=dominant_colors, shadow=True, startangle=90)
    [m[i].set_color('white') for i in range(len(m))]
    ax1.axis('equal') 

    plt.show()

    # dc_and_p = zip(dominant_colors, percentages)
    # dc_and_p = sorted(dc_and_p, reverse = True)

    # dc_and_p_dict = {}
    # for k, v in dc_and_p:
    #     dc_and_p_dict[k] = v


    return c_and_p_dict


def extract_category_dom_colours(category: Topic, clusters):
    """
    Function to extract dominant colours per creator and category for the given category.

    args:
        - category: the given category
    """

    with open(os.path.join("..", "data", "info_videos", f"videos-info_{category}.json"), "r") as f:
        print("Loading creator's videos")
        videos_info = json.load(f)
    print("Finished loading creator's videos\n")

    print("Calculating dominant colours for all creators in category: " + category + "\n")
    all_category_thumbnails = []
    for creator in tqdm(list(videos_info.keys())[:10]):
        print(creator)
        all_creator_thumbnails = []
        
        try:
            creator_thumbnails = []
            for vid_dict in videos_info[creator]:
                img = Image.open(os.path.join(THUMBNAIL_PATH, vid_dict['id'] + "_high.jpg"))
                w, h = img.size
                if h == THUMBNAIL_HEIGHT and w == THUMBNAIL_WIDTH:
                    # we exclude shorts by skipping the videos that were not cropped
                    continue
                creator_thumbnails.append(img)
            if len(creator_thumbnails) < 5:
                continue

            all_creator_thumbnails.extend(creator_thumbnails)

            all_category_thumbnails.extend(all_creator_thumbnails)

        except FileNotFoundError:
            continue

        img_grid = make_image_grid_1_row(all_creator_thumbnails[:9])

        plt.imshow(img_grid)
        plt.show()

        dom_colours = find_dom_colours_with_percentages(img_grid, clusters)

        with open(os.path.join(CREATORS_PATH, creator + ".json"), 'w') as f:
            json.dump(dom_colours, f)
        
    print("Calculating dominant colours for category: " + category + "\n")

    img_grid = make_image_grid_1_row(all_category_thumbnails[:2])
    dom_colours = find_dom_colours_with_percentages(img_grid, clusters)

    plt.imshow(img_grid)
    plt.show()

    with open(os.path.join(CATEGORY_PATH, category + ".json"), 'w') as f:
        json.dump(dom_colours, f)

    print("Finished calculating dominant colours per creator and for category\n")


if __name__ == '__main__':

    CATEGORY_PATH = "../data/thumbnail-dom-colours/categories"
    CREATORS_PATH = "../data/thumbnail-dom-colours/creators"

    if not os.path.exists(CATEGORY_PATH):
        os.makedirs(CATEGORY_PATH)

    if not os.path.exists(CREATORS_PATH):
        os.makedirs(CREATORS_PATH)

    # choose how many colours to extract
    clusters = 5

    # categories = ["education", "science", "animals", "autos", "blogs", "comedy", "entertainment", "howto"]
    # for category in categories:
    # category = "science"
    extract_category_dom_colours("animals", clusters)
        # only do gaming for now

  0%|          | 0/10 [00:00<?, ?it/s]

Loading creator's videos
Finished loading creator's videos

Calculating dominant colours for all creators in category: animals

kiyo-c50d0d69-dcc3-47cb-95b3-8bba51ed5e3d
fabinho-filho-do-sertao
0f2ad495-4683-4bce-a210-7bf3f9c3f014
eu-nao-sabia
unal-vlog
6e8cd5b7-9409-478d-9366-50124f0ffc3a
ntdaptv
mako0mako0


100%|██████████| 10/10 [00:00<00:00, 117.66it/s]

this-is-bailey
rnickeymouse
Calculating dominant colours for category: animals






IndexError: list index out of range

In [None]:
CATEGORY_PATH = "../data/thumbnail-dom-colours/categories"
CREATORS_PATH = "../data/thumbnail-dom-colours/creators"

import matplotlib.pyplot as plt
import json
import os

with open(os.path.join(CATEGORY_PATH, "gaming.json"), "r") as f:
        print("Loading creator's videos")
        d_and_c = json.load(f)

colours = d_and_c["colours"]
p = d_and_c["perc"]

print(colours)
print(sum(p))

# Pie chart, where the slices will be ordered and plotted counter-clockwise:
# labels = 'Frogs', 'Hogs', 'Dogs', 'Logs'
sizes = p
# explode = (0, 0.1, 0, 0)  # only "explode" the 2nd slice (i.e. 'Hogs')

fig1, ax1 = plt.subplots()
_, _ ,m = ax1.pie(sizes, autopct='%1.1f%%', colors=colours, shadow=True, startangle=90)
[m[i].set_color('white') for i in range(len(m))]

ax1.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.

plt.show()