In [None]:
import os
import random
import textwrap
import webcolors
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.pyplot import imread

def get_color_name(requested_colour):
        min_colours = {}
        for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
            r_c, g_c, b_c = webcolors.hex_to_rgb(key)
            rd = (r_c - requested_colour[0]) ** 2
            gd = (g_c - requested_colour[1]) ** 2
            bd = (b_c - requested_colour[2]) ** 2
            min_colours[(rd + gd + bd)] = name
        closest_color = min_colours[min(min_colours.keys())]
        return closest_color

def show_color_selection(requested_colour):

    closest_name = get_color_name(requested_colour)
    similar_color = dict(webcolors.CSS3_NAMES_TO_HEX)[closest_name]

    fig, ax = plt.subplots(1, 2, figsize=(6, 3))

    # Check if the input is RGB array
    if (isinstance(requested_colour, list) or isinstance(requested_colour, tuple)) and len(requested_colour) == 3:
        requested_colour = [x/255 for x in requested_colour]  # Normalize to [0,1]

    rect = patches.Rectangle((0, 0), 1, 1, facecolor=requested_colour)
    ax[0].add_patch(rect)
    ax[0].axis('off')
    ax[0].set_title("input")

    rect = patches.Rectangle((0, 0), 1, 1, facecolor=similar_color)
    ax[1].add_patch(rect)
    ax[1].axis('off')
    ax[1].set_title(f"match")
    plt.suptitle(closest_name)
 

requested_colour = (244,37,20)
show_color_selection(requested_colour)

requested_colour = (104,53,140)
show_color_selection(requested_colour)

requested_colour = (133,133,0)
show_color_selection(requested_colour)

requested_colour = (200,250,0)
show_color_selection(requested_colour)

In [None]:

df = pd.read_csv("prompts.csv", header=None)
df = df.rename(columns={0:"filename",1:"prompt"})

dir = "./shapes"
filenames = os.listdir(dir)

for _ in range(5):
    filename = filenames[random.randint(0, len(filenames))]
    sub_df = df.loc[df["filename"] == filename].reset_index(drop=True)
    prompt = sub_df.at[0, "prompt"]
    image = imread(os.path.join(dir, filename))

    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(image)
    axs[0].axis('off')  # Hide axes on the image subplot

    prompt = textwrap.fill(prompt, 40)
    axs[1].text(0.5, 0.5, prompt, size=9, ha='center', va='center')
    axs[1].axis('off')  # Hide axes on the text subplot
    plt.show()