# Imports
This code imports several libraries in order to perform some data processing tasks. The libraries used are:
import pickle: imports the pickle module that provides a way to serialize and deserialize Python objects.

import os: imports the os module that provides a way to interact with the operating system.

import sqlite3: imports the sqlite3 module that provides a way to work with SQLite databases.

import json: imports the json module that provides a way to encode and decode JSON data.

import cv2: imports the cv2 module which is an OpenCV library for image processing and computer vision tasks.

from sklearn.cluster import MiniBatchKMeans: imports the MiniBatchKMeans class from the sklearn.cluster module which provides a way to perform KMeans clustering on a large dataset.

from tqdm import tqdm: imports the tqdm module which provides a progress bar for long-running operations.

from transformers import DetrImageProcessor, DetrForObjectDetection: imports the DetrImageProcessor and DetrForObjectDetection classes from the transformers module which provides a way to perform object detection using the DETR (DEtection TRansformer) model.

import torch: imports the torch module which is a PyTorch library for machine learning and deep learning tasks.

from PIL import Image: imports the Image class from the PIL module which provides a way to manipulate and analyze image data.

With these libraries, you should be able to perform a wide range of data processing and analysis tasks.

In [None]:
!pip install sqlite3 json cv2 scikit-learn tqdm transformers torch Pillow

In [None]:
import pickle
import os
import sqlite3
import json
import cv2
from sklearn.cluster import MiniBatchKMeans
from tqdm import tqdm
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image

# Settings base variables and paths
This code sets up the file structure and URL's for a project that uses data from an image dataset.

In [None]:
# Set the base folder path for the project
output_path = "../output"
images_path = os.path.join(output_path, "images")
metadata_path = os.path.join(output_path, "metadata")
config_path = os.path.join(output_path, "config")

list_of_paths = [output_path, images_path, metadata_path, config_path]

# Set the base URL for the dataset
metadata_extension = "sqlite"

# Create folder structure
The code creates the folder structure for the project. The folder structure is as follows:
- output
    - images
    - metadata
    - config

This method creates a folder with the given path if it doesn't already exist, It also outputs a message to inform the user if the folder was created or if it already exists.
This is useful for organizing and managing files in a project. By creating a folder to store data and resources, it keeps the working directory tidy and makes it easier to locate files. Additionally, by checking if the folder exists before creating it, it prevents the program from overwriting existing data or throwing an error.

In [None]:
def create_folder(path):
    """
    This function creates a folder at the specified path.
    If the folder already exists, it will print a message saying so.
    If there is an error creating the folder, it will print the error message.

    Parameters:
        :param path (str): The path of the folder to be created.

    Returns:
    None
    """
    try:
        # Use os.mkdir to create the folder at the specified path
        os.mkdir(path)
        print(f"Folder {path} created")
    except FileExistsError:
        # If the folder already exists, print a message saying so
        print(f"Folder {path} already exists")
    except Exception as e:
        # If there is an error creating the folder, print the error message
        print(f"Error creating folder {path}: {e}")

# Create the folder structure
This method initializes a list of folders by calling the create_folder method for each folder in the list.
The purpose of this method is to make sure that all necessary folders exist before the program continues its execution.
If a folder does not exist, the create_folder method will create it. If a folder already exists, the method will simply print a message indicating that the folder already exists. In case of any other error, the method will print the error message.

In [None]:
def init_folder(folder_names: list):
    for folder_name in folder_names:
        create_folder(folder_name)

In [None]:
init_folder(list_of_paths)

# Define methods to get all the image paths
The get_all_images method is used to retrieve all images present in the specified image path. It uses the os.walk function to traverse through all subdirectories within the image path and collects the file names that end with either '.png' or '.jpg' extensions. The full path of each image is then generated by joining the root directory and the file name. The method returns a list of all images' full paths. In case of any error, an error message is printed and an empty list is returned.

In [None]:
def get_all_images(path):
    """Get all images from the given path.

    Args:
    param: image_path (str): path to the directory containing the images.

    Returns:
    - list: a list of full path to all the images with png or jpg extensions.
    - empty list: an empty list if an error occurred while fetching images.
    """
    try:
        # use os.walk to traverse all the subdirectories and get all images
        return [os.path.join(root, name)
                for root, dirs, files in os.walk(path)
                for name in files
                if name.endswith((".png", ".jpg"))]
    except Exception as e:
        # return an empty list and log the error message if an error occurred
        print(f"An error occurred while fetching images: {e}")
        return []

# Facebook DETR model (detr-resnet-101)

The detect_with_transformers function takes an image file path as an input, then uses a pre-trained model called DEtection TRansformer (DETR) to detect objects within the image.

The function first opens the input image using the Python Imaging Library (PIL) Image.open method. Then it instantiates two components of the DETR model: a DetrImageProcessor and a DetrForObjectDetection model. The DetrImageProcessor is responsible for processing the input image into a format that can be fed into the DetrForObjectDetection model. The DetrForObjectDetection model then takes the processed image and performs object detection by predicting bounding boxes and class labels for each detected object.

Once the model has made its predictions, the function uses the processor.post_process_object_detection method to convert the bounding box and class label predictions into a format that is compatible with the Common Objects in Context (COCO) dataset. This conversion is necessary in order to use the COCO API, which provides a common framework for evaluating object detection models.

The function then filters the detected objects by only keeping those with a confidence score above a certain threshold (0.9 in this case), and extracts the corresponding class labels. Finally, the function prints out a message for each detected object, indicating its class label, confidence score, and location within the image. The function returns a list of the detected object class labels.

In [None]:
def detect_with_transformers(image):
    """
    This function detects objects in an image using the DETR (DEtection TRansformer) model by Facebook.

    Args:
    image: A string representing the path of the image to be processed.

    Returns:
    A list containing the labels of the detected objects in the image.

    Raises:
    None.
    """
    #image = Image.open(image)
    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101")
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    # convert outputs (bounding boxes and class logits) to COCO API
    # let's only keep detections with score > 0.9
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
    labels = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 2) for i in box.tolist()]
        labels.append(model.config.id2label[label.item()])
        print(
            f"Detected {model.config.id2label[label.item()]} with confidence "
            f"{round(score.item(), 3)} at location {box}"
        )
    return labels

# Save  metadata

The function save_metadata allows you to save metadata information of an image in either pickle, json, or sqlite format. The function takes four parameters: metadata, img_name, metadata_path, and save_format.

metadata is a dictionary that contains the metadata information of an image. img_name is a string that represents the file name of the image. metadata_path is a string that specifies the path to the directory where the metadata will be saved. save_format is an optional parameter that specifies the format in which the metadata will be saved. The default value is pickle.

The function saves the metadata in the specified format. If save_format is set to pickle, the metadata is saved in the pickle format. If save_format is set to json, the metadata is saved in the json format. If save_format is set to sqlite, the metadata is saved in the sqlite database.

If an error occurs while saving the metadata, the function will print an error message indicating the image name and the error that occurred.

The function does not return any value.

In [None]:
def save_metadata(metadata, img_name, metadata_path, save_format='pickle'):
    """
    This function saves the metadata information of an image in either pickle or json format.
    Parameters:
    metadata (dict): The metadata information of an image.
    img_name (str): The file name of the image.
    metadata_path (str): The path to the directory where the metadata will be saved.
    save_format (str): The format in which the metadata will be saved. The default is 'pickle'.

    Returns:
    None
    """
    try:
        if save_format == 'pickle':
            # save the metadata in pickle format
            with open(os.path.join(metadata_path, os.path.splitext(os.path.basename(img_name))[0] + '.pickle'),
                      'wb') as f:
                pickle.dump(metadata, f)
        elif save_format == 'json':
            # save the metadata in json format
            with open(os.path.join(metadata_path, os.path.splitext(os.path.basename(img_name))[0] + '.json'), 'w') as f:
                json.dump(metadata, f)
        elif save_format == 'sqlite':
            # Get only the file name of the image
            img_name = os.path.basename(img_name)
            # Open a connection to the database
            conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))
            # Create a cursor
            c = conn.cursor()
            # Create a table if it doesn't exist : filename, key, value
            c.execute('''CREATE TABLE IF NOT EXISTS metadata (filename text, key text, value text)''')
            # Insert the metadata into the table
            for key, value in metadata.items():
                # Convert key, value to string
                key = str(key)
                value = str(value)
                # Check if the key is already in the table
                c.execute("SELECT * FROM metadata WHERE filename=? AND key=?", (img_name, key))
                # If the key is already in the table, update the value
                if c.fetchone():
                    c.execute("UPDATE metadata SET value=? WHERE filename=? AND key=?", (value, img_name, key))
                    # Commit the changes
                    conn.commit()
                # If the key is not in the table, insert the key, value pair
                else:
                    c.execute("INSERT INTO metadata VALUES (?, ?, ?)", (img_name, key, value))
                    # Commit the changes
                    conn.commit()
            # Close the connection
            conn.close()
        else:
            raise ValueError("Invalid save format")
    except Exception as e:
        # print an error message if an error occurs
        print(f"An error occurred while saving metadata for {img_name}: {str(e)}")

# Read SQLite metadata

The read_sqlite method is a function used to read metadata information from a SQLite database. The method takes two parameters: metadata_path, the path to the directory where the metadata is saved, and filename, the name of the file for which the metadata is to be retrieved.

The method starts by connecting to the SQLite database located at metadata_path/metadata.db using the sqlite3.connect method. A cursor is then created to allow interaction with the database. If the metadata table doesn't exist in the database, it is created.

The metadata for the specified filename is retrieved from the database by executing a SQL query that selects all rows where the filename column is equal to the filename parameter. The retrieved metadata is stored in a dictionary, where the keys are taken from the key column and the values from the value column.

Finally, the database connection is closed and the metadata dictionary is returned as the result of the function.

In [None]:
def read_sqlite(metadata_path, filename):
    # Open a connection to the database
    conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))
    # Create a cursor
    c = conn.cursor()
    # Create a table if it doesn't exist : filename, key, value
    c.execute('''CREATE TABLE IF NOT EXISTS metadata (filename text, key text, value text)''')
    # Insert the metadata into the table
    c.execute("SELECT * FROM metadata WHERE filename=?", (filename,))
    # If the key is already in the table, update the value
    metadata = {}
    for row in c.fetchall():
        metadata[row[1]] = row[2]
    # Close the connection
    conn.close()
    return metadata

# Set tags in metadata
This function "update_tags" is used to run the YOLOv3 algorithm on a set of images, update the metadata of each image with the detected labels (tags) and save the updated metadata.

The function takes 3 parameters:

images: a list of file paths for the images that need to be processed.
metadata_path: a file path to the directory where the metadata files are stored.
save_format: the format of the metadata files. Can be either 'pickle' or 'sqlite'.
The function uses the tqdm library to display a progress bar for the image processing. For each image, the function tries to retrieve its metadata based on the save_format. If the metadata file format is 'sqlite', the function calls the read_sqlite function to retrieve the metadata. If the metadata file format is 'pickle', the function reads the metadata file directly.

If the metadata already contains a "tags" key, it means that the image has already been processed and its metadata has been updated with the labels, so the function skips that image.

The function then calls the detect function to run the YOLOv3 algorithm on the image and retrieve the labels (tags). The labels are added to the metadata under the "tags" key.

Finally, the function calls the save_metadata function to save the updated metadata. If an error occurs while processing an image (e.g. the metadata file is not found), the function prints an error message and continues processing the next image.

In [None]:
def update_tags(images, metadata_path, save_format='pickle'):
    # Run the YOLOv3 algorithm on each image
    # display progress bar in the first thread only
    for image in tqdm(images, desc="Updating tags"):
        # read pickle file from ../output/metadata/file_name.pkl
        file_name = os.path.basename(image)
        file_name, ext = file_name.split(".")
        try:
            if save_format == 'sqlite':
                metadata = read_sqlite(metadata_path, file_name + "." + ext)
            else:
                with open(os.path.join("../output/metadata", file_name + "." + metadata_extension), "rb") as f:
                    if metadata_extension == "json":
                        metadata = json.load(f)
                    else:
                        metadata = pickle.load(f)

            if "tags" in metadata:
                continue

            image = Image.open(image)
            # resize image to 416x416
            image = image.resize((416, 416))
            labels = detect_with_transformers(image)

            # Remove duplicates from labels
            labels = list(set(labels))
            # add labels to metadata
            metadata["tags"] = labels
            save_metadata(metadata, file_name + "." + ext, metadata_path, save_format)
        except FileNotFoundError:
            print("File not found: ", file_name)
            continue
        except Exception as e:
            continue

In [None]:
# Get the list of images
images = os.listdir(images_path)
images = [os.path.join(images_path, image) for image in images]

update_tags(images, metadata_path, save_format='sqlite')

    ### Now, find dominant colors in the images
The functions rgb_to_hex and find_dominant_colors are used to find the dominant colors in an image.

The function rgb_to_hex takes in an RGB array with 3 values, and returns the hexadecimal representation of the color. This can be useful for formatting colors in a standardized way, as hexadecimal codes are widely used in web development and other applications.

The function find_dominant_colors takes in an image and optional parameters k and image_processing_size. The k parameter specifies the number of dominant colors to return, with a default value of 4. The image_processing_size parameter allows you to resize the image to a smaller size, to speed up the processing, if desired.

The image is first converted from BGR to RGB, and then reshaped into a list of pixels. The KMeans algorithm is used to cluster the pixels into k clusters, and the most popular clusters are identified. The color values for each of the k clusters are converted to hexadecimal representation and returned as a list, along with the percentage of the image covered by each color.

In [None]:
def rgb_to_hex(rgb):
    return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))

In [None]:
def find_dominant_colors(image_path, k=4, downsample=2, resize=(200, 200)):
    # Load image and convert to RGB
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Downsample the image
    image = cv2.resize(image, (image.shape[1] // downsample, image.shape[0] // downsample))

    # Resize the image if requested
    if resize is not None:
        image = cv2.resize(image, resize)

    # Flatten the image
    image_flat = image.reshape((image.shape[0] * image.shape[1], 3))

    # Cluster the pixels using KMeans and find percentage of image covered by each color
    clt = MiniBatchKMeans(n_clusters=k, n_init=10, batch_size=100, random_state=42)
    labels = clt.fit_predict(image_flat)

    # Count the number of pixels assigned to each cluster
    counts = np.bincount(labels)

    # Calculate the percentage of pixels assigned to each cluster
    percentages = counts / len(labels)

    # Get the dominant colors
    dominant_colors = clt.cluster_centers_

    # Convert to hexadecimal format
    dominant_colors_hex = [rgb_to_hex(color) for color in dominant_colors]

    # Combine the dominant colors and their percentages into a array of tuples
    result = list(zip(dominant_colors_hex, percentages))

    return result


This function takes a list of dictionaries containing metadata information of images and generates a list of SQL requests to insert the metadata into a database.

The function first creates an empty list to store the SQL requests. It then loops over each metadata dictionary in the input list using the tqdm function to provide a progress bar. For each metadata dictionary, the function extracts the filename of the image and then loops over all the items in the dictionary.

For each key-value pair in the metadata dictionary, the function creates an SQL request to insert the metadata into the database. The SQL request is in the form of a string that contains the filename, key, and value of the metadata item. The function adds each SQL request to the list of SQL requests.

After processing all the metadata dictionaries, the function returns the list of SQL requests.

In [None]:
def gen_sql_requests(filenames, colors):
    """
    This function generates a list of SQL requests to insert metadata into a database.

    Returns:
    list: A list of SQL requests to insert metadata into a database.
    """
    # Create a list to store SQL requests
    sql_requests = []

    # Loop over all metadata
    for filename, color in tqdm(zip(filenames, colors), desc="Generating SQL requests"):
        # Create SQL request to insert metadata into database (filename, key, value)
        # format color to avoid errors with quote marks

        sql_request = f"INSERT INTO metadata VALUES ('{filename}', 'dominant_color', '{color}')"
        # Add SQL request to list
        sql_requests.append(sql_request)

    # Return the list of SQL requests
    return sql_requests

This method executes a SQL query on a SQLite database. The method takes a single parameter, query, which should be a valid SQL query that is compatible with the SQLite database.

The method first establishes a connection to the SQLite database using the connect() method of the sqlite3 module in Python. It then executes the SQL query using the execute() method of the connection object. After executing the query, the method commits the changes to the database using the commit() method of the connection object, and then closes the connection using the close() method of the connection object.

This method is typically used to insert metadata information into a SQLite database. It assumes that the SQLite database already exists and is located in the metadata_path directory.

In [None]:
# Execute sql query
def execute_query(query):
    conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))
    # Insert the metadata into the database
    conn.execute(query)
    # Commit the changes
    conn.commit()
    # Close the connection
    conn.close()

The following code block is used to process images and find their dominant colors. The code first retrieves all the images present in the folder specified by the images_path variable. Then, it iterates over each image, reads the metadata associated with the image and finds its dominant color if it hasn't been calculated already.

For each image, the code first reads the image using OpenCV's cv2.imread() function and stores the result in the img variable. The code then reads the metadata of the image. The type of metadata file (e.g. .json, .pkl, .sqlite) is specified by the metadata_extension variable. Based on the file extension, the code reads the metadata using either read_sqlite(), json.load(), or pickle.load() functions. If the metadata file is not found, the code continues to the next iteration of the loop, but if there is an error, it prints the error message and continues to the next iteration.

If the metadata does not contain information about the dominant color of the image, the code calculates the dominant color by calling the find_dominant_colors() function. The result of the find_dominant_colors() function is then added to the metadata under the key "dominant_color". Finally, the updated metadata is saved using the save_metadata() function, which saves the metadata to the specified location using the specified file format (metadata_extension).

In [None]:
import numpy as np

def get_all_colors(image_path):
    """
    This coroutine extracts dominant colors from all images in a directory and saves the color information in the database.

    Parameters:
    image_path (str): The path to the directory where the images are stored.

    Returns:
    None
    """
    # Get a list of all images in the directory
    img_files = get_all_images(image_path)
    colors = []

    # Create a progress bar to track the progress of processing all images
    for img in tqdm(img_files, desc="Processing images (Aprox: 25 minutes"):
        try:
            # Create a list of coroutines to extract metadata for all images
            color = find_dominant_colors(img, downsample=2, resize=(100, 100))
        except Exception as e:
            print("Error: ", e)
            continue

        if color:
            # color to string to avoid errors with quote marks
            color = str(color)
            # replace quotes by double quotes
            color = color.replace("'", '"')
            colors.append(color)

    img_files = [os.path.basename(img) for img in img_files]

    queries = gen_sql_requests(img_files, colors)

    conn = sqlite3.connect(os.path.join(metadata_path, 'metadata.db'))

    for query in tqdm(queries, "Inserting colors"):
        # Insert the metadata into the database
        conn.execute(query)
        # Commit the changes
        conn.commit()
        # Close the connection
    conn.close()

    #for query in tqdm(queries, desc="Inserting dominant colors into the database"):
    #    execute_query(query)

In [None]:
get_all_colors(images_path)